diff --git a/pyproject.toml b/pyproject.toml index 920af149..f990e839 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,17 +43,12 @@ dependencies = [ "scipy==1.14.1", "syrupy>=4.8.0", "torch>=2.2.2", - "torch==2.2.2; sys_platform == 'darwin' and platform_machine == 'x86_64'" + "torch==2.2.2; sys_platform == 'darwin' and platform_machine == 'x86_64'", + "click>=8.1.0" ] [project.scripts] -stimulus-shuffle-csv = "stimulus.cli.shuffle_csv:run" -stimulus-transform-csv = "stimulus.cli.transform_csv:run" -stimulus-split-csv = "stimulus.cli.split_csv:run" -stimulus-check-model = "stimulus.cli.check_model:run" -stimulus-tuning = "stimulus.cli.tuning:run" -stimulus-split-yaml-split = "stimulus.cli.split_split:run" -stimulus-split-yaml-transform = "stimulus.cli.split_transforms:run" +stimulus = "stimulus.cli.main:cli" [project.urls] Homepage = "https://mathysgrapotte.github.io/stimulus-py" diff --git a/src/stimulus/cli/check_model.py b/src/stimulus/cli/check_model.py index 2fec3967..7d01a2d6 100755 --- a/src/stimulus/cli/check_model.py +++ b/src/stimulus/cli/check_model.py @@ -1,9 +1,9 @@ #!/usr/bin/env python3 """CLI module for checking model configuration and running initial tests.""" -import argparse import logging +import click import ray import yaml from torch.utils.data import DataLoader @@ -15,130 +15,100 @@ logger = logging.getLogger(__name__) -def get_args() -> argparse.Namespace: - """Get the arguments when using from the commandline. - - Returns: - Parsed command line arguments. - """ - parser = argparse.ArgumentParser(description="Launch check_model.") - parser.add_argument( - "-d", - "--data", - type=str, - required=True, - metavar="FILE", - help="Path to input csv file.", - ) - parser.add_argument( - "-m", - "--model", - type=str, - required=True, - metavar="FILE", - help="Path to model file.", - ) - parser.add_argument( - "-e", - "--data_config", - type=str, - required=True, - metavar="FILE", - help="Path to data config file.", - ) - parser.add_argument( - "-c", - "--model_config", - type=str, - required=True, - metavar="FILE", - help="Path to yaml config training file.", - ) - parser.add_argument( - "-w", - "--initial_weights", - type=str, - required=False, - nargs="?", - const=None, - default=None, - metavar="FILE", - help="The path to the initial weights (optional).", - ) - - parser.add_argument( - "-n", - "--num_samples", - type=int, - required=False, - nargs="?", - const=3, - default=3, - metavar="NUM_SAMPLES", - help="Number of samples for tuning. Overwrites tune.tune_params.num_samples in config.", - ) - parser.add_argument( - "--ray_results_dirpath", - type=str, - required=False, - nargs="?", - const=None, - default=None, - metavar="DIR_PATH", - help="Location where ray_results output dir should be written. If None, uses ~/ray_results.", - ) - parser.add_argument( - "--debug_mode", - action="store_true", - help="Activate debug mode for tuning. Default false, no debug.", - ) - - return parser.parse_args() - - -def main( - model_path: str, - data_path: str, - data_config_path: str, - model_config_path: str, +@click.command() +@click.option( + "-d", + "--data", + type=click.Path(exists=True), + required=True, + help="Path to input csv file.", +) +@click.option( + "-m", + "--model", + type=click.Path(exists=True), + required=True, + help="Path to model file.", +) +@click.option( + "-e", + "--data-config", + type=click.Path(exists=True), + required=True, + help="Path to data config file.", +) +@click.option( + "-c", + "--model-config", + type=click.Path(exists=True), + required=True, + help="Path to yaml config training file.", +) +@click.option( + "-w", + "--initial-weights", + type=click.Path(exists=True), + help="The path to the initial weights (optional).", +) +@click.option( + "-n", + "--num-samples", + type=int, + default=3, + help="Number of samples for tuning. Overwrites tune.tune_params.num_samples in config.", +) +@click.option( + "--ray-results-dirpath", + type=click.Path(), + help="Location where ray_results output dir should be written. If None, uses ~/ray_results.", +) +@click.option( + "--debug-mode", + is_flag=True, + help="Activate debug mode for tuning. Default false, no debug.", +) +def check_model( + data: str, + model: str, + data_config: str, + model_config: str, initial_weights: str | None = None, # noqa: ARG001 num_samples: int = 3, ray_results_dirpath: str | None = None, - *, debug_mode: bool = False, ) -> None: """Run the main model checking pipeline. Args: - data_path: Path to input data file. - model_path: Path to model file. - data_config_path: Path to data config file. - model_config_path: Path to model config file. + data: Path to input data file. + model: Path to model file. + data_config: Path to data config file. + model_config: Path to model config file. initial_weights: Optional path to initial weights. num_samples: Number of samples for tuning. ray_results_dirpath: Directory for ray results. debug_mode: Whether to run in debug mode. """ - with open(data_config_path) as file: - data_config = yaml.safe_load(file) - data_config = yaml_data.YamlSplitTransformDict(**data_config) + with open(data_config) as file: + data_config_dict = yaml.safe_load(file) + data_config_obj = yaml_data.YamlSplitTransformDict(**data_config_dict) - with open(model_config_path) as file: - model_config = yaml.safe_load(file) - model_config = yaml_model_schema.Model(**model_config) + with open(model_config) as file: + model_config_dict = yaml.safe_load(file) + model_config_obj = yaml_model_schema.Model(**model_config_dict) encoder_loader = loaders.EncoderLoader() encoder_loader.initialize_column_encoders_from_config( - column_config=data_config.columns, + column_config=data_config_obj.columns, ) logger.info("Dataset loaded successfully.") - model_class = model_file_interface.import_class_from_file(model_path) + model_class = model_file_interface.import_class_from_file(model) logger.info("Model class loaded successfully.") - ray_config_loader = yaml_model_schema.YamlRayConfigLoader(model=model_config) + ray_config_loader = yaml_model_schema.YamlRayConfigLoader(model=model_config_obj) ray_config_dict = ray_config_loader.get_config().model_dump() ray_config_model = ray_config_loader.get_config() @@ -156,8 +126,8 @@ def main( logger.info("Model instance loaded successfully.") torch_dataset = handlertorch.TorchDataset( - data_config=data_config, - csv_path=data_path, + data_config=data_config_obj, + csv_path=data, encoder_loader=encoder_loader, ) @@ -183,13 +153,13 @@ def main( logger.info("Model checking single pass completed successfully.") # override num_samples - model_config.tune.tune_params.num_samples = num_samples + model_config_obj.tune.tune_params.num_samples = num_samples tuner = raytune_learner.TuneWrapper( model_config=ray_config_model, - data_config=data_config, + data_config=data_config_obj, model_class=model_class, - data_path=data_path, + data_path=data, encoder_loader=encoder_loader, seed=42, ray_results_dir=ray_results_dirpath, @@ -207,17 +177,7 @@ def main( def run() -> None: """Run the model checking script.""" ray.init(address="auto", ignore_reinit_error=True) - args = get_args() - main( - data_path=args.data, - model_path=args.model, - data_config_path=args.data_config, - model_config_path=args.model_config, - initial_weights=args.initial_weights, - num_samples=args.num_samples, - ray_results_dirpath=args.ray_results_dirpath, - debug_mode=args.debug_mode, - ) + check_model() if __name__ == "__main__": diff --git a/src/stimulus/cli/main.py b/src/stimulus/cli/main.py new file mode 100644 index 00000000..92eaedbb --- /dev/null +++ b/src/stimulus/cli/main.py @@ -0,0 +1,35 @@ +"""Main entry point for stimulus-py cli.""" + +import click +from importlib_metadata import version + +@click.group(context_settings=dict(help_option_names=["-h", "--help"])) +@click.version_option(version("stimulus-py"), "-v", "--version") +def cli(): + """Stimulus is an open-science framework for data processing and model training.""" + pass + + +@cli.command() +def check_model(): + """Check model configuration and run initial tests. + + check-model will connect to an existing ray cluster. Make sure you start a ray cluster before by running: + ray start --head + + \b + Required Options: + -d, --data PATH Path to input csv file + -m, --model PATH Path to model file + -e, --data-config PATH Path to data config file + -c, --model-config PATH Path to yaml config training file + + \b + Optional Options: + -w, --initial-weights PATH Path to initial weights + -n, --num-samples INTEGER Number of samples for tuning [default: 3] + --ray-results-dirpath PATH Location for ray_results output dir + --debug-mode Activate debug mode for tuning + """ + from stimulus.cli.check_model import main + main()