Skip to content

Commit

Permalink
refactor(main_cli): added main cli.
Browse files Browse the repository at this point in the history
  • Loading branch information
mathysgrapotte committed Feb 21, 2025
1 parent a3abc79 commit 01d4532
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 125 deletions.
11 changes: 3 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,12 @@ dependencies = [
"scipy==1.14.1",
"syrupy>=4.8.0",
"torch>=2.2.2",
"torch==2.2.2; sys_platform == 'darwin' and platform_machine == 'x86_64'"
"torch==2.2.2; sys_platform == 'darwin' and platform_machine == 'x86_64'",
"click>=8.1.0"
]

[project.scripts]
stimulus-shuffle-csv = "stimulus.cli.shuffle_csv:run"
stimulus-transform-csv = "stimulus.cli.transform_csv:run"
stimulus-split-csv = "stimulus.cli.split_csv:run"
stimulus-check-model = "stimulus.cli.check_model:run"
stimulus-tuning = "stimulus.cli.tuning:run"
stimulus-split-yaml-split = "stimulus.cli.split_split:run"
stimulus-split-yaml-transform = "stimulus.cli.split_transforms:run"
stimulus = "stimulus.cli.main:cli"

[project.urls]
Homepage = "https://mathysgrapotte.github.io/stimulus-py"
Expand Down
194 changes: 77 additions & 117 deletions src/stimulus/cli/check_model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#!/usr/bin/env python3
"""CLI module for checking model configuration and running initial tests."""

import argparse
import logging

import click
import ray
import yaml
from torch.utils.data import DataLoader
Expand All @@ -15,130 +15,100 @@
logger = logging.getLogger(__name__)


def get_args() -> argparse.Namespace:
"""Get the arguments when using from the commandline.
Returns:
Parsed command line arguments.
"""
parser = argparse.ArgumentParser(description="Launch check_model.")
parser.add_argument(
"-d",
"--data",
type=str,
required=True,
metavar="FILE",
help="Path to input csv file.",
)
parser.add_argument(
"-m",
"--model",
type=str,
required=True,
metavar="FILE",
help="Path to model file.",
)
parser.add_argument(
"-e",
"--data_config",
type=str,
required=True,
metavar="FILE",
help="Path to data config file.",
)
parser.add_argument(
"-c",
"--model_config",
type=str,
required=True,
metavar="FILE",
help="Path to yaml config training file.",
)
parser.add_argument(
"-w",
"--initial_weights",
type=str,
required=False,
nargs="?",
const=None,
default=None,
metavar="FILE",
help="The path to the initial weights (optional).",
)

parser.add_argument(
"-n",
"--num_samples",
type=int,
required=False,
nargs="?",
const=3,
default=3,
metavar="NUM_SAMPLES",
help="Number of samples for tuning. Overwrites tune.tune_params.num_samples in config.",
)
parser.add_argument(
"--ray_results_dirpath",
type=str,
required=False,
nargs="?",
const=None,
default=None,
metavar="DIR_PATH",
help="Location where ray_results output dir should be written. If None, uses ~/ray_results.",
)
parser.add_argument(
"--debug_mode",
action="store_true",
help="Activate debug mode for tuning. Default false, no debug.",
)

return parser.parse_args()


def main(
model_path: str,
data_path: str,
data_config_path: str,
model_config_path: str,
@click.command()
@click.option(
"-d",
"--data",
type=click.Path(exists=True),
required=True,
help="Path to input csv file.",
)
@click.option(
"-m",
"--model",
type=click.Path(exists=True),
required=True,
help="Path to model file.",
)
@click.option(
"-e",
"--data-config",
type=click.Path(exists=True),
required=True,
help="Path to data config file.",
)
@click.option(
"-c",
"--model-config",
type=click.Path(exists=True),
required=True,
help="Path to yaml config training file.",
)
@click.option(
"-w",
"--initial-weights",
type=click.Path(exists=True),
help="The path to the initial weights (optional).",
)
@click.option(
"-n",
"--num-samples",
type=int,
default=3,
help="Number of samples for tuning. Overwrites tune.tune_params.num_samples in config.",
)
@click.option(
"--ray-results-dirpath",
type=click.Path(),
help="Location where ray_results output dir should be written. If None, uses ~/ray_results.",
)
@click.option(
"--debug-mode",
is_flag=True,
help="Activate debug mode for tuning. Default false, no debug.",
)
def check_model(
data: str,
model: str,
data_config: str,
model_config: str,
initial_weights: str | None = None, # noqa: ARG001
num_samples: int = 3,
ray_results_dirpath: str | None = None,
*,
debug_mode: bool = False,
) -> None:
"""Run the main model checking pipeline.
Args:
data_path: Path to input data file.
model_path: Path to model file.
data_config_path: Path to data config file.
model_config_path: Path to model config file.
data: Path to input data file.
model: Path to model file.
data_config: Path to data config file.
model_config: Path to model config file.
initial_weights: Optional path to initial weights.
num_samples: Number of samples for tuning.
ray_results_dirpath: Directory for ray results.
debug_mode: Whether to run in debug mode.
"""
with open(data_config_path) as file:
data_config = yaml.safe_load(file)
data_config = yaml_data.YamlSplitTransformDict(**data_config)
with open(data_config) as file:
data_config_dict = yaml.safe_load(file)
data_config_obj = yaml_data.YamlSplitTransformDict(**data_config_dict)

with open(model_config_path) as file:
model_config = yaml.safe_load(file)
model_config = yaml_model_schema.Model(**model_config)
with open(model_config) as file:
model_config_dict = yaml.safe_load(file)
model_config_obj = yaml_model_schema.Model(**model_config_dict)

encoder_loader = loaders.EncoderLoader()
encoder_loader.initialize_column_encoders_from_config(
column_config=data_config.columns,
column_config=data_config_obj.columns,
)

logger.info("Dataset loaded successfully.")

model_class = model_file_interface.import_class_from_file(model_path)
model_class = model_file_interface.import_class_from_file(model)

logger.info("Model class loaded successfully.")

ray_config_loader = yaml_model_schema.YamlRayConfigLoader(model=model_config)
ray_config_loader = yaml_model_schema.YamlRayConfigLoader(model=model_config_obj)
ray_config_dict = ray_config_loader.get_config().model_dump()
ray_config_model = ray_config_loader.get_config()

Expand All @@ -156,8 +126,8 @@ def main(
logger.info("Model instance loaded successfully.")

torch_dataset = handlertorch.TorchDataset(
data_config=data_config,
csv_path=data_path,
data_config=data_config_obj,
csv_path=data,
encoder_loader=encoder_loader,
)

Expand All @@ -183,13 +153,13 @@ def main(
logger.info("Model checking single pass completed successfully.")

# override num_samples
model_config.tune.tune_params.num_samples = num_samples
model_config_obj.tune.tune_params.num_samples = num_samples

tuner = raytune_learner.TuneWrapper(
model_config=ray_config_model,
data_config=data_config,
data_config=data_config_obj,
model_class=model_class,
data_path=data_path,
data_path=data,
encoder_loader=encoder_loader,
seed=42,
ray_results_dir=ray_results_dirpath,
Expand All @@ -207,17 +177,7 @@ def main(
def run() -> None:
"""Run the model checking script."""
ray.init(address="auto", ignore_reinit_error=True)
args = get_args()
main(
data_path=args.data,
model_path=args.model,
data_config_path=args.data_config,
model_config_path=args.model_config,
initial_weights=args.initial_weights,
num_samples=args.num_samples,
ray_results_dirpath=args.ray_results_dirpath,
debug_mode=args.debug_mode,
)
check_model()


if __name__ == "__main__":
Expand Down
35 changes: 35 additions & 0 deletions src/stimulus/cli/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""Main entry point for stimulus-py cli."""

import click
from importlib_metadata import version

@click.group(context_settings=dict(help_option_names=["-h", "--help"]))
@click.version_option(version("stimulus-py"), "-v", "--version")
def cli():
"""Stimulus is an open-science framework for data processing and model training."""
pass


@cli.command()
def check_model():
"""Check model configuration and run initial tests.
check-model will connect to an existing ray cluster. Make sure you start a ray cluster before by running:
ray start --head
\b
Required Options:
-d, --data PATH Path to input csv file
-m, --model PATH Path to model file
-e, --data-config PATH Path to data config file
-c, --model-config PATH Path to yaml config training file
\b
Optional Options:
-w, --initial-weights PATH Path to initial weights
-n, --num-samples INTEGER Number of samples for tuning [default: 3]
--ray-results-dirpath PATH Location for ray_results output dir
--debug-mode Activate debug mode for tuning
"""
from stimulus.cli.check_model import main
main()

0 comments on commit 01d4532

Please sign in to comment.