From 01d4532b4c4e2a9faee8c08004c2bba9f6b56480 Mon Sep 17 00:00:00 2001 From: mgrapotte Date: Fri, 21 Feb 2025 12:44:05 +0100 Subject: [PATCH 1/4] refactor(main_cli): added main cli. --- pyproject.toml | 11 +- src/stimulus/cli/check_model.py | 194 +++++++++++++------------------- src/stimulus/cli/main.py | 35 ++++++ 3 files changed, 115 insertions(+), 125 deletions(-) create mode 100644 src/stimulus/cli/main.py diff --git a/pyproject.toml b/pyproject.toml index 920af149..f990e839 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,17 +43,12 @@ dependencies = [ "scipy==1.14.1", "syrupy>=4.8.0", "torch>=2.2.2", - "torch==2.2.2; sys_platform == 'darwin' and platform_machine == 'x86_64'" + "torch==2.2.2; sys_platform == 'darwin' and platform_machine == 'x86_64'", + "click>=8.1.0" ] [project.scripts] -stimulus-shuffle-csv = "stimulus.cli.shuffle_csv:run" -stimulus-transform-csv = "stimulus.cli.transform_csv:run" -stimulus-split-csv = "stimulus.cli.split_csv:run" -stimulus-check-model = "stimulus.cli.check_model:run" -stimulus-tuning = "stimulus.cli.tuning:run" -stimulus-split-yaml-split = "stimulus.cli.split_split:run" -stimulus-split-yaml-transform = "stimulus.cli.split_transforms:run" +stimulus = "stimulus.cli.main:cli" [project.urls] Homepage = "https://mathysgrapotte.github.io/stimulus-py" diff --git a/src/stimulus/cli/check_model.py b/src/stimulus/cli/check_model.py index 2fec3967..7d01a2d6 100755 --- a/src/stimulus/cli/check_model.py +++ b/src/stimulus/cli/check_model.py @@ -1,9 +1,9 @@ #!/usr/bin/env python3 """CLI module for checking model configuration and running initial tests.""" -import argparse import logging +import click import ray import yaml from torch.utils.data import DataLoader @@ -15,130 +15,100 @@ logger = logging.getLogger(__name__) -def get_args() -> argparse.Namespace: - """Get the arguments when using from the commandline. - - Returns: - Parsed command line arguments. - """ - parser = argparse.ArgumentParser(description="Launch check_model.") - parser.add_argument( - "-d", - "--data", - type=str, - required=True, - metavar="FILE", - help="Path to input csv file.", - ) - parser.add_argument( - "-m", - "--model", - type=str, - required=True, - metavar="FILE", - help="Path to model file.", - ) - parser.add_argument( - "-e", - "--data_config", - type=str, - required=True, - metavar="FILE", - help="Path to data config file.", - ) - parser.add_argument( - "-c", - "--model_config", - type=str, - required=True, - metavar="FILE", - help="Path to yaml config training file.", - ) - parser.add_argument( - "-w", - "--initial_weights", - type=str, - required=False, - nargs="?", - const=None, - default=None, - metavar="FILE", - help="The path to the initial weights (optional).", - ) - - parser.add_argument( - "-n", - "--num_samples", - type=int, - required=False, - nargs="?", - const=3, - default=3, - metavar="NUM_SAMPLES", - help="Number of samples for tuning. Overwrites tune.tune_params.num_samples in config.", - ) - parser.add_argument( - "--ray_results_dirpath", - type=str, - required=False, - nargs="?", - const=None, - default=None, - metavar="DIR_PATH", - help="Location where ray_results output dir should be written. If None, uses ~/ray_results.", - ) - parser.add_argument( - "--debug_mode", - action="store_true", - help="Activate debug mode for tuning. Default false, no debug.", - ) - - return parser.parse_args() - - -def main( - model_path: str, - data_path: str, - data_config_path: str, - model_config_path: str, +@click.command() +@click.option( + "-d", + "--data", + type=click.Path(exists=True), + required=True, + help="Path to input csv file.", +) +@click.option( + "-m", + "--model", + type=click.Path(exists=True), + required=True, + help="Path to model file.", +) +@click.option( + "-e", + "--data-config", + type=click.Path(exists=True), + required=True, + help="Path to data config file.", +) +@click.option( + "-c", + "--model-config", + type=click.Path(exists=True), + required=True, + help="Path to yaml config training file.", +) +@click.option( + "-w", + "--initial-weights", + type=click.Path(exists=True), + help="The path to the initial weights (optional).", +) +@click.option( + "-n", + "--num-samples", + type=int, + default=3, + help="Number of samples for tuning. Overwrites tune.tune_params.num_samples in config.", +) +@click.option( + "--ray-results-dirpath", + type=click.Path(), + help="Location where ray_results output dir should be written. If None, uses ~/ray_results.", +) +@click.option( + "--debug-mode", + is_flag=True, + help="Activate debug mode for tuning. Default false, no debug.", +) +def check_model( + data: str, + model: str, + data_config: str, + model_config: str, initial_weights: str | None = None, # noqa: ARG001 num_samples: int = 3, ray_results_dirpath: str | None = None, - *, debug_mode: bool = False, ) -> None: """Run the main model checking pipeline. Args: - data_path: Path to input data file. - model_path: Path to model file. - data_config_path: Path to data config file. - model_config_path: Path to model config file. + data: Path to input data file. + model: Path to model file. + data_config: Path to data config file. + model_config: Path to model config file. initial_weights: Optional path to initial weights. num_samples: Number of samples for tuning. ray_results_dirpath: Directory for ray results. debug_mode: Whether to run in debug mode. """ - with open(data_config_path) as file: - data_config = yaml.safe_load(file) - data_config = yaml_data.YamlSplitTransformDict(**data_config) + with open(data_config) as file: + data_config_dict = yaml.safe_load(file) + data_config_obj = yaml_data.YamlSplitTransformDict(**data_config_dict) - with open(model_config_path) as file: - model_config = yaml.safe_load(file) - model_config = yaml_model_schema.Model(**model_config) + with open(model_config) as file: + model_config_dict = yaml.safe_load(file) + model_config_obj = yaml_model_schema.Model(**model_config_dict) encoder_loader = loaders.EncoderLoader() encoder_loader.initialize_column_encoders_from_config( - column_config=data_config.columns, + column_config=data_config_obj.columns, ) logger.info("Dataset loaded successfully.") - model_class = model_file_interface.import_class_from_file(model_path) + model_class = model_file_interface.import_class_from_file(model) logger.info("Model class loaded successfully.") - ray_config_loader = yaml_model_schema.YamlRayConfigLoader(model=model_config) + ray_config_loader = yaml_model_schema.YamlRayConfigLoader(model=model_config_obj) ray_config_dict = ray_config_loader.get_config().model_dump() ray_config_model = ray_config_loader.get_config() @@ -156,8 +126,8 @@ def main( logger.info("Model instance loaded successfully.") torch_dataset = handlertorch.TorchDataset( - data_config=data_config, - csv_path=data_path, + data_config=data_config_obj, + csv_path=data, encoder_loader=encoder_loader, ) @@ -183,13 +153,13 @@ def main( logger.info("Model checking single pass completed successfully.") # override num_samples - model_config.tune.tune_params.num_samples = num_samples + model_config_obj.tune.tune_params.num_samples = num_samples tuner = raytune_learner.TuneWrapper( model_config=ray_config_model, - data_config=data_config, + data_config=data_config_obj, model_class=model_class, - data_path=data_path, + data_path=data, encoder_loader=encoder_loader, seed=42, ray_results_dir=ray_results_dirpath, @@ -207,17 +177,7 @@ def main( def run() -> None: """Run the model checking script.""" ray.init(address="auto", ignore_reinit_error=True) - args = get_args() - main( - data_path=args.data, - model_path=args.model, - data_config_path=args.data_config, - model_config_path=args.model_config, - initial_weights=args.initial_weights, - num_samples=args.num_samples, - ray_results_dirpath=args.ray_results_dirpath, - debug_mode=args.debug_mode, - ) + check_model() if __name__ == "__main__": diff --git a/src/stimulus/cli/main.py b/src/stimulus/cli/main.py new file mode 100644 index 00000000..92eaedbb --- /dev/null +++ b/src/stimulus/cli/main.py @@ -0,0 +1,35 @@ +"""Main entry point for stimulus-py cli.""" + +import click +from importlib_metadata import version + +@click.group(context_settings=dict(help_option_names=["-h", "--help"])) +@click.version_option(version("stimulus-py"), "-v", "--version") +def cli(): + """Stimulus is an open-science framework for data processing and model training.""" + pass + + +@cli.command() +def check_model(): + """Check model configuration and run initial tests. + + check-model will connect to an existing ray cluster. Make sure you start a ray cluster before by running: + ray start --head + + \b + Required Options: + -d, --data PATH Path to input csv file + -m, --model PATH Path to model file + -e, --data-config PATH Path to data config file + -c, --model-config PATH Path to yaml config training file + + \b + Optional Options: + -w, --initial-weights PATH Path to initial weights + -n, --num-samples INTEGER Number of samples for tuning [default: 3] + --ray-results-dirpath PATH Location for ray_results output dir + --debug-mode Activate debug mode for tuning + """ + from stimulus.cli.check_model import main + main() From 09044c7b7b671d05e9df15d90679ce421b5b335f Mon Sep 17 00:00:00 2001 From: mgrapotte Date: Fri, 21 Feb 2025 14:57:36 +0100 Subject: [PATCH 2/4] refactor(check_model): added refactoring for check-model cli. --- src/stimulus/cli/check_model.py | 153 +++++++++++--------------------- src/stimulus/cli/main.py | 101 ++++++++++++++++----- tests/cli/test_check_model.py | 54 ++++++++++- 3 files changed, 182 insertions(+), 126 deletions(-) diff --git a/src/stimulus/cli/check_model.py b/src/stimulus/cli/check_model.py index 7d01a2d6..82b1f06c 100755 --- a/src/stimulus/cli/check_model.py +++ b/src/stimulus/cli/check_model.py @@ -3,112 +3,82 @@ import logging -import click -import ray import yaml from torch.utils.data import DataLoader -from stimulus.data import handlertorch, loaders +from stimulus.data import data_handlers +from stimulus.data.interface import data_config_parser from stimulus.learner import raytune_learner -from stimulus.utils import model_file_interface, yaml_data, yaml_model_schema +from stimulus.utils import model_file_interface, yaml_model_schema logger = logging.getLogger(__name__) -@click.command() -@click.option( - "-d", - "--data", - type=click.Path(exists=True), - required=True, - help="Path to input csv file.", -) -@click.option( - "-m", - "--model", - type=click.Path(exists=True), - required=True, - help="Path to model file.", -) -@click.option( - "-e", - "--data-config", - type=click.Path(exists=True), - required=True, - help="Path to data config file.", -) -@click.option( - "-c", - "--model-config", - type=click.Path(exists=True), - required=True, - help="Path to yaml config training file.", -) -@click.option( - "-w", - "--initial-weights", - type=click.Path(exists=True), - help="The path to the initial weights (optional).", -) -@click.option( - "-n", - "--num-samples", - type=int, - default=3, - help="Number of samples for tuning. Overwrites tune.tune_params.num_samples in config.", -) -@click.option( - "--ray-results-dirpath", - type=click.Path(), - help="Location where ray_results output dir should be written. If None, uses ~/ray_results.", -) -@click.option( - "--debug-mode", - is_flag=True, - help="Activate debug mode for tuning. Default false, no debug.", -) +def load_data_config_from_path(data_path: str, data_config_path: str, split: int) -> data_handlers.TorchDataset: + """Load the data config from a path. + + Args: + data_config_path: Path to the data config file. + + Returns: + A tuple of the parsed configuration. + """ + with open(data_config_path) as file: + data_config_dict = yaml.safe_load(file) + data_config_obj = data_config_parser.SplitTransformDict(**data_config_dict) + + encoders, input_columns, label_columns, meta_columns = data_config_parser.parse_split_transform_config( + data_config_obj, + ) + + return data_handlers.TorchDataset( + loader=data_handlers.DatasetLoader( + encoders=encoders, + input_columns=input_columns, + label_columns=label_columns, + meta_columns=meta_columns, + csv_path=data_path, + split=split, + ), + ) + + def check_model( - data: str, - model: str, - data_config: str, - model_config: str, + data_path: str, + model_path: str, + data_config_path: str, + model_config_path: str, initial_weights: str | None = None, # noqa: ARG001 num_samples: int = 3, ray_results_dirpath: str | None = None, + *, debug_mode: bool = False, ) -> None: """Run the main model checking pipeline. Args: - data: Path to input data file. - model: Path to model file. - data_config: Path to data config file. - model_config: Path to model config file. + data_path: Path to input data file. + model_path: Path to model file. + data_config_path: Path to data config file. + model_config_path: Path to model config file. initial_weights: Optional path to initial weights. num_samples: Number of samples for tuning. ray_results_dirpath: Directory for ray results. debug_mode: Whether to run in debug mode. """ - with open(data_config) as file: - data_config_dict = yaml.safe_load(file) - data_config_obj = yaml_data.YamlSplitTransformDict(**data_config_dict) - - with open(model_config) as file: - model_config_dict = yaml.safe_load(file) - model_config_obj = yaml_model_schema.Model(**model_config_dict) - - encoder_loader = loaders.EncoderLoader() - encoder_loader.initialize_column_encoders_from_config( - column_config=data_config_obj.columns, - ) - + train_dataset = load_data_config_from_path(data_path, data_config_path, split=0) + validation_dataset = load_data_config_from_path(data_path, data_config_path, split=1) logger.info("Dataset loaded successfully.") - model_class = model_file_interface.import_class_from_file(model) + model_class = model_file_interface.import_class_from_file(model_path) logger.info("Model class loaded successfully.") - ray_config_loader = yaml_model_schema.YamlRayConfigLoader(model=model_config_obj) + with open(model_config_path) as file: + model_config_content = yaml.safe_load(file) + model_config = yaml_model_schema.Model(**model_config_content) + + ray_config_loader = yaml_model_schema.RayConfigLoader(model=model_config) ray_config_dict = ray_config_loader.get_config().model_dump() ray_config_model = ray_config_loader.get_config() @@ -125,13 +95,7 @@ def check_model( logger.info("Model instance loaded successfully.") - torch_dataset = handlertorch.TorchDataset( - data_config=data_config_obj, - csv_path=data, - encoder_loader=encoder_loader, - ) - - torch_dataloader = DataLoader(torch_dataset, batch_size=10, shuffle=True) + torch_dataloader = DataLoader(train_dataset, batch_size=10, shuffle=True) logger.info("Torch dataloader loaded successfully.") @@ -153,14 +117,13 @@ def check_model( logger.info("Model checking single pass completed successfully.") # override num_samples - model_config_obj.tune.tune_params.num_samples = num_samples + model_config.tune.tune_params.num_samples = num_samples tuner = raytune_learner.TuneWrapper( model_config=ray_config_model, - data_config=data_config_obj, model_class=model_class, - data_path=data, - encoder_loader=encoder_loader, + train_dataset=train_dataset, + validation_dataset=validation_dataset, seed=42, ray_results_dir=ray_results_dirpath, debug=debug_mode, @@ -172,13 +135,3 @@ def check_model( logger.info("Tuning completed successfully.") logger.info("Checks complete") - - -def run() -> None: - """Run the model checking script.""" - ray.init(address="auto", ignore_reinit_error=True) - check_model() - - -if __name__ == "__main__": - run() diff --git a/src/stimulus/cli/main.py b/src/stimulus/cli/main.py index 92eaedbb..3e4d512a 100644 --- a/src/stimulus/cli/main.py +++ b/src/stimulus/cli/main.py @@ -3,33 +3,86 @@ import click from importlib_metadata import version -@click.group(context_settings=dict(help_option_names=["-h", "--help"])) + +@click.group(context_settings={"help_option_names": ["-h", "--help"]}) @click.version_option(version("stimulus-py"), "-v", "--version") -def cli(): +def cli() -> None: """Stimulus is an open-science framework for data processing and model training.""" - pass @cli.command() -def check_model(): - """Check model configuration and run initial tests. - - check-model will connect to an existing ray cluster. Make sure you start a ray cluster before by running: - ray start --head - - \b - Required Options: - -d, --data PATH Path to input csv file - -m, --model PATH Path to model file - -e, --data-config PATH Path to data config file - -c, --model-config PATH Path to yaml config training file +@click.option( + "-d", + "--data", + type=click.Path(exists=True), + required=True, + help="Path to input csv file", +) +@click.option( + "-m", + "--model", + type=click.Path(exists=True), + required=True, + help="Path to model file", +) +@click.option( + "-e", + "--data-config", + type=click.Path(exists=True), + required=True, + help="Path to data config file", +) +@click.option( + "-c", + "--model-config", + type=click.Path(exists=True), + required=True, + help="Path to yaml config training file", +) +@click.option( + "-w", + "--initial-weights", + type=click.Path(exists=True), + help="Path to initial weights", +) +@click.option( + "-n", + "--num-samples", + type=int, + default=3, + help="Number of samples for tuning [default: 3]", +) +@click.option( + "--ray-results-dirpath", + type=click.Path(), + help="Location for ray_results output dir", +) +@click.option( + "--debug-mode", + is_flag=True, + help="Activate debug mode for tuning", +) +def check_model( + data: str, + model: str, + data_config: str, + model_config: str, + initial_weights: str | None, + num_samples: int, + ray_results_dirpath: str | None, + *, + debug_mode: bool, +) -> None: + """Check model configuration and run initial tests.""" + from stimulus.cli.check_model import check_model as check_model_func - \b - Optional Options: - -w, --initial-weights PATH Path to initial weights - -n, --num-samples INTEGER Number of samples for tuning [default: 3] - --ray-results-dirpath PATH Location for ray_results output dir - --debug-mode Activate debug mode for tuning - """ - from stimulus.cli.check_model import main - main() + check_model_func( + data_path=data, + model_path=model, + data_config_path=data_config, + model_config_path=model_config, + initial_weights=initial_weights, + num_samples=num_samples, + ray_results_dirpath=ray_results_dirpath, + debug_mode=debug_mode, + ) diff --git a/tests/cli/test_check_model.py b/tests/cli/test_check_model.py index d7e05fbb..8d495ec4 100644 --- a/tests/cli/test_check_model.py +++ b/tests/cli/test_check_model.py @@ -6,8 +6,10 @@ import pytest import ray +from click.testing import CliRunner from stimulus.cli import check_model +from stimulus.cli.main import cli @pytest.fixture @@ -22,7 +24,7 @@ def data_path() -> str: def data_config() -> str: """Get path to test data config YAML.""" return str( - Path(__file__).parent.parent / "test_data" / "titanic" / "titanic_sub_config.yaml", + Path(__file__).parent.parent / "test_data" / "titanic" / "titanic_unique_transform.yaml", ) @@ -65,7 +67,7 @@ def test_check_model_main( try: # Run main function - should complete without errors - check_model.main( + check_model.check_model( model_path=model_path, data_path=data_path, data_config_path=data_config, @@ -85,3 +87,51 @@ def test_check_model_main( import shutil shutil.rmtree(ray_results_dir) + + +def test_cli_invocation( + data_path: str, + data_config: str, + model_path: str, + model_config: str, +) -> None: + """Test the CLI invocation of check-model command. + + Args: + data_path: Path to test CSV data. + data_config: Path to data config YAML. + model_path: Path to model implementation. + model_config: Path to model config YAML. + """ + ray.init(ignore_reinit_error=True) + runner = CliRunner() + try: + result = runner.invoke( + cli, + [ + "check-model", + "-d", + data_path, + "-m", + model_path, + "-e", + data_config, + "-c", + model_config, + "-n", + "1", + ], + ) + assert result.exit_code == 0 + + finally: + # Ensure Ray is shut down properly + if ray.is_initialized(): + ray.shutdown() + + # Clean up any ray files/directories that may have been created + ray_results_dir = os.path.expanduser("~/ray_results") + if os.path.exists(ray_results_dir): + import shutil + + shutil.rmtree(ray_results_dir) From e6a0e23800776898dbb7ac669faa0c049b645ad3 Mon Sep 17 00:00:00 2001 From: mgrapotte Date: Fri, 21 Feb 2025 15:48:28 +0100 Subject: [PATCH 3/4] refactor(shuffle_csv): updated shuffle_csv to the new paradigm. --- src/stimulus/cli/main.py | 37 +++++++++++ src/stimulus/cli/shuffle_csv.py | 86 ++++++++++--------------- tests/cli/test_shuffle_csv.py | 111 ++++++++++++++++++++++---------- 3 files changed, 147 insertions(+), 87 deletions(-) diff --git a/src/stimulus/cli/main.py b/src/stimulus/cli/main.py index 3e4d512a..3dc8e107 100644 --- a/src/stimulus/cli/main.py +++ b/src/stimulus/cli/main.py @@ -86,3 +86,40 @@ def check_model( ray_results_dirpath=ray_results_dirpath, debug_mode=debug_mode, ) + + +@cli.command() +@click.option( + "-c", + "--csv", + type=click.Path(exists=True), + required=True, + help="The file path for the csv containing the data in csv format", +) +@click.option( + "-y", + "--yaml", + type=click.Path(exists=True), + required=True, + help="The YAML data config", +) +@click.option( + "-o", + "--output", + type=click.Path(), + required=True, + help="The output file path to write the shuffled csv", +) +def shuffle_csv( + csv: str, + yaml: str, + output: str, +) -> None: + """Shuffle rows in a CSV data file.""" + from stimulus.cli.shuffle_csv import shuffle_csv as shuffle_csv_func + + shuffle_csv_func( + data_csv=csv, + config_yaml=yaml, + out_path=output, + ) diff --git a/src/stimulus/cli/shuffle_csv.py b/src/stimulus/cli/shuffle_csv.py index a1f61765..dc5a0fd1 100755 --- a/src/stimulus/cli/shuffle_csv.py +++ b/src/stimulus/cli/shuffle_csv.py @@ -1,79 +1,59 @@ #!/usr/bin/env python3 """CLI module for shuffling CSV data files.""" -import argparse +import logging import yaml -from stimulus.data.data_handlers import DatasetProcessor -from stimulus.utils.yaml_data import YamlSplitTransformDict +from stimulus.data import data_handlers +from stimulus.data.interface import data_config_parser +logger = logging.getLogger(__name__) -def get_args() -> argparse.Namespace: - """Get the arguments when using from the commandline. + +def load_data_config_from_path(data_path: str, data_config_path: str) -> data_handlers.DatasetProcessor: + """Load the data config from a path. + + Args: + data_config_path: Path to the data config file. Returns: - Parsed command line arguments. + A tuple of the parsed configuration. """ - parser = argparse.ArgumentParser(description="Shuffle rows in a CSV data file.") - parser.add_argument( - "-c", - "--csv", - type=str, - required=True, - metavar="FILE", - help="The file path for the csv containing all data", - ) - parser.add_argument( - "-y", - "--yaml", - type=str, - required=True, - metavar="FILE", - help="The YAML config file that hold all parameter info", - ) - parser.add_argument( - "-o", - "--output", - type=str, - required=True, - metavar="FILE", - help="The output file path to write the noised csv", - ) + with open(data_config_path) as file: + data_config_dict = yaml.safe_load(file) + data_config_obj = data_config_parser.SplitConfigDict(**data_config_dict) - return parser.parse_args() + splitters = data_config_parser.create_splitter(data_config_obj.split) + transforms = data_config_parser.create_transforms(data_config_obj.transforms) + split_columns = data_config_obj.split.split_input_columns + label_columns = [column.column_name for column in data_config_obj.columns if column.column_type == "label"] + return data_handlers.DatasetProcessor( + csv_path=data_path, + transforms=transforms, + split_columns=split_columns, + splitter=splitters, + ), label_columns -def main(data_csv: str, config_yaml: str, out_path: str) -> None: + +def shuffle_csv(data_csv: str, config_yaml: str, out_path: str) -> None: """Shuffle the data and split it according to the default split method. Args: data_csv: Path to input CSV file. config_yaml: Path to config YAML file. out_path: Path to output shuffled CSV. - - TODO major changes when this is going to select a given shuffle method and integration with split. """ - # read the yaml file - with open(config_yaml) as f: - data_config: YamlSplitTransformDict = YamlSplitTransformDict( - **yaml.safe_load(f), - ) # create a DatasetProcessor object from the config and the csv - processor = DatasetProcessor(data_config=data_config, csv_path=data_csv) + processor, label_columns = load_data_config_from_path(data_csv, config_yaml) + logger.info("Dataset processor initialized successfully.") - # shuffle the data with a default seed. TODO get the seed for the config if and when that is going to be set there. - processor.shuffle_labels(seed=42) + # shuffle the data with a default seed + # TODO: get the seed from the config if and when that is going to be set there + processor.shuffle_labels(label_columns, seed=42) + logger.info("Data shuffled successfully.") # save the modified csv processor.save(out_path) - - -def run() -> None: - """Run the CSV shuffling script.""" - args = get_args() - main(args.csv, args.yaml, args.output) - - -if __name__ == "__main__": - run() + logger.info("Shuffled data saved successfully.") diff --git a/tests/cli/test_shuffle_csv.py b/tests/cli/test_shuffle_csv.py index 8fef803c..a6773331 100644 --- a/tests/cli/test_shuffle_csv.py +++ b/tests/cli/test_shuffle_csv.py @@ -1,52 +1,95 @@ """Tests for the shuffle_csv CLI command.""" import hashlib -import pathlib +import os import tempfile -from typing import Any, Callable +from pathlib import Path import pytest +from click.testing import CliRunner -from src.stimulus.cli.shuffle_csv import main +from stimulus.cli.main import cli +from stimulus.cli.shuffle_csv import shuffle_csv -# Fixtures @pytest.fixture -def correct_yaml_path() -> str: - """Fixture that returns the path to a correct YAML file.""" - return "tests/test_data/titanic/titanic_sub_config.yaml" +def csv_path() -> str: + """Get path to test CSV file.""" + return str( + Path(__file__).parent.parent / "test_data" / "titanic" / "titanic_stimulus.csv", + ) @pytest.fixture -def correct_csv_path() -> str: - """Fixture that returns the path to a correct CSV file.""" - return "tests/test_data/titanic/titanic_stimulus.csv" +def yaml_path() -> str: + """Get path to test config YAML file.""" + return str( + Path(__file__).parent.parent / "test_data" / "titanic" / "titanic_unique_split.yaml", + ) -# Test cases -test_cases = [ - ("correct_csv_path", "correct_yaml_path", None), -] +def test_shuffle_csv_main( + csv_path: str, + yaml_path: str, +) -> None: + """Test that shuffle_csv.main runs without errors. + Args: + csv_path: Path to test CSV data. + yaml_path: Path to test config YAML. + """ + # Verify required files exist + assert os.path.exists(csv_path), f"CSV file not found at {csv_path}" + assert os.path.exists(yaml_path), f"YAML config not found at {yaml_path}" -# Tests -@pytest.mark.parametrize(("csv_type", "yaml_type", "error"), test_cases) -def test_shuffle_csv( - request: pytest.FixtureRequest, - snapshot: Callable[[], Any], - csv_type: str, - yaml_type: str, - error: Exception | None, -) -> None: - """Tests the CLI command with correct and wrong YAML files.""" - csv_path = request.getfixturevalue(csv_type) - yaml_path = request.getfixturevalue(yaml_type) - tmpdir = pathlib.Path(tempfile.gettempdir()) - if error: - with pytest.raises(error): # type: ignore[call-overload] - main(csv_path, yaml_path, str(tmpdir / "test.csv")) - else: - main(csv_path, yaml_path, str(tmpdir / "test.csv")) - with open(tmpdir / "test.csv") as file: + # Create temporary output file + with tempfile.NamedTemporaryFile(suffix=".csv", delete=False) as tmp_file: + output_path = tmp_file.name + + try: + # Run main function - should complete without errors + shuffle_csv( + data_csv=csv_path, + config_yaml=yaml_path, + out_path=output_path, + ) + + # Verify output file exists and has content + assert os.path.exists(output_path), "Output file was not created" + with open(output_path) as file: hash = hashlib.md5(file.read().encode()).hexdigest() # noqa: S324 - assert hash == snapshot + assert hash # Verify we got a hash (file not empty) + + finally: + # Clean up temporary file + if os.path.exists(output_path): + os.unlink(output_path) + + +def test_cli_invocation( + csv_path: str, + yaml_path: str, +) -> None: + """Test the CLI invocation of shuffle-csv command. + + Args: + csv_path: Path to test CSV data. + yaml_path: Path to test config YAML. + """ + runner = CliRunner() + with runner.isolated_filesystem(): + output_path = "output.csv" + result = runner.invoke( + cli, + [ + "shuffle-csv", + "-c", + csv_path, + "-y", + yaml_path, + "-o", + output_path, + ], + ) + assert result.exit_code == 0 + assert os.path.exists(output_path), "Output file was not created" From 92558c07acef4fa5a6f5622ec75e08393882226f Mon Sep 17 00:00:00 2001 From: mathysgrapotte Date: Mon, 24 Feb 2025 18:41:19 +0100 Subject: [PATCH 4/4] refactor(cli): added refactoring for split_csv cli. --- src/stimulus/cli/main.py | 47 +++++++++++ src/stimulus/cli/split_csv.py | 103 ++++++++--------------- tests/cli/test_split_csv.py | 154 ++++++++++++++++++++++++---------- 3 files changed, 192 insertions(+), 112 deletions(-) diff --git a/src/stimulus/cli/main.py b/src/stimulus/cli/main.py index 3dc8e107..9e8fd5c4 100644 --- a/src/stimulus/cli/main.py +++ b/src/stimulus/cli/main.py @@ -123,3 +123,50 @@ def shuffle_csv( config_yaml=yaml, out_path=output, ) + + +@cli.command() +@click.option( + "-c", + "--csv", + type=click.Path(exists=True), + required=True, + help="The file path for the csv containing all data", +) +@click.option( + "-y", + "--yaml", + type=click.Path(exists=True), + required=True, + help="The YAML config file that holds all parameter info", +) +@click.option( + "-o", + "--output", + type=click.Path(), + required=True, + help="The output file path to write the split csv", +) +@click.option( + "-f", + "--force", + is_flag=True, + default=False, + help="Overwrite the split column if it already exists in the csv", +) +def split_csv( + csv: str, + yaml: str, + output: str, + *, + force: bool, +) -> None: + """Split rows in a CSV data file.""" + from stimulus.cli.split_csv import split_csv as split_csv_func + + split_csv_func( + data_csv=csv, + config_yaml=yaml, + out_path=output, + force=force, + ) diff --git a/src/stimulus/cli/split_csv.py b/src/stimulus/cli/split_csv.py index b1daddff..c95668b4 100755 --- a/src/stimulus/cli/split_csv.py +++ b/src/stimulus/cli/split_csv.py @@ -1,62 +1,44 @@ #!/usr/bin/env python3 """CLI module for splitting CSV data files.""" -import argparse +import logging import yaml -from stimulus.data.data_handlers import DatasetProcessor, SplitManager -from stimulus.data.loaders import SplitLoader -from stimulus.utils.yaml_data import YamlSplitConfigDict +from stimulus.data import data_handlers +from stimulus.data.interface import data_config_parser +logger = logging.getLogger(__name__) -def get_args() -> argparse.Namespace: - """Get the arguments when using from the commandline.""" - parser = argparse.ArgumentParser(description="Split a CSV data file.") - parser.add_argument( - "-c", - "--csv", - type=str, - required=True, - metavar="FILE", - help="The file path for the csv containing all data", - ) - parser.add_argument( - "-y", - "--yaml", - type=str, - required=True, - metavar="FILE", - help="The YAML config file that hold all parameter info", - ) - parser.add_argument( - "-o", - "--output", - type=str, - required=True, - metavar="FILE", - help="The output file path to write the noised csv", - ) - parser.add_argument( - "-f", - "--force", - type=bool, - required=False, - default=False, - help="Overwrite the split column if it already exists in the csv", - ) - return parser.parse_args() +def load_data_config_from_path(data_path: str, data_config_path: str) -> data_handlers.DatasetProcessor: + """Load the data config from a path. + + Args: + data_path: Path to the data file. + data_config_path: Path to the data config file. + + Returns: + A DatasetProcessor instance configured with the data. + """ + with open(data_config_path) as file: + data_config_dict = yaml.safe_load(file) + data_config_obj = data_config_parser.SplitConfigDict(**data_config_dict) + + splitters = data_config_parser.create_splitter(data_config_obj.split) + transforms = data_config_parser.create_transforms(data_config_obj.transforms) + split_columns = data_config_obj.split.split_input_columns + + return data_handlers.DatasetProcessor( + csv_path=data_path, + transforms=transforms, + split_columns=split_columns, + splitter=splitters, + ) -def main( - data_csv: str, - config_yaml: str, - out_path: str, - *, - force: bool = False, -) -> None: - """Connect CSV and YAML configuration and handle sanity checks. +def split_csv(data_csv: str, config_yaml: str, out_path: str, *, force: bool = False) -> None: + """Split the data according to the configuration. Args: data_csv: Path to input CSV file. @@ -65,28 +47,13 @@ def main( force: Overwrite the split column if it already exists in the CSV. """ # create a DatasetProcessor object from the config and the csv - processor = DatasetProcessor(config_path=config_yaml, csv_path=data_csv) - - # create a split manager from the config - split_config = processor.dataset_manager.config.split - with open(config_yaml) as f: - yaml_config = YamlSplitConfigDict(**yaml.safe_load(f)) - split_loader = SplitLoader(seed=yaml_config.global_params.seed) - split_loader.initialize_splitter_from_config(split_config) - split_manager = SplitManager(split_loader) + processor = load_data_config_from_path(data_csv, config_yaml) + logger.info("Dataset processor initialized successfully.") # apply the split method to the data - processor.add_split(split_manager=split_manager, force=force) + processor.add_split(force=force) + logger.info("Split applied successfully.") # save the modified csv processor.save(out_path) - - -def run() -> None: - """Run the CSV splitting script.""" - args = get_args() - main(args.csv, args.yaml, args.output, force=args.force) - - -if __name__ == "__main__": - run() + logger.info("Split data saved successfully.") diff --git a/tests/cli/test_split_csv.py b/tests/cli/test_split_csv.py index 8e50bf9d..59236205 100644 --- a/tests/cli/test_split_csv.py +++ b/tests/cli/test_split_csv.py @@ -1,64 +1,130 @@ """Tests for the split_csv CLI command.""" import hashlib -import pathlib +import os import tempfile -from typing import Any, Callable +from pathlib import Path import pytest +from click.testing import CliRunner -from src.stimulus.cli.split_csv import main +from stimulus.cli.main import cli +from stimulus.cli.split_csv import split_csv -# Fixtures @pytest.fixture -def csv_path_no_split() -> str: - """Fixture that returns the path to a CSV file without split column.""" - return "tests/test_data/dna_experiment/test.csv" +def csv_path() -> str: + """Get path to test CSV file.""" + return str( + Path(__file__).parent.parent / "test_data" / "titanic" / "titanic_stimulus.csv", + ) @pytest.fixture def csv_path_with_split() -> str: - """Fixture that returns the path to a CSV file with split column.""" - return "tests/test_data/dna_experiment/test_with_split.csv" + """Get path to test CSV file with split column.""" + return str( + Path(__file__).parent.parent / "test_data" / "titanic" / "titanic_stimulus_split.csv", + ) @pytest.fixture def yaml_path() -> str: - """Fixture that returns the path to a YAML config file.""" - return "tests/test_data/dna_experiment/dna_experiment_config_template_0.yaml" - - -# Test cases -test_cases = [ - ("csv_path_no_split", "yaml_path", False, None), - ("csv_path_with_split", "yaml_path", False, ValueError), - ("csv_path_no_split", "yaml_path", True, None), - ("csv_path_with_split", "yaml_path", True, None), -] - - -# Tests -@pytest.mark.skip(reason="There is an issue with non-deterministic output") -@pytest.mark.parametrize(("csv_type", "yaml_type", "force", "error"), test_cases) -def test_split_csv( - request: pytest.FixtureRequest, - snapshot: Callable[[], Any], - csv_type: str, - yaml_type: str, - force: bool, - error: Exception | None, + """Get path to test config YAML file.""" + return str( + Path(__file__).parent.parent / "test_data" / "titanic" / "titanic_unique_split.yaml", + ) + + +def test_split_csv_main( + csv_path: str, + yaml_path: str, ) -> None: - """Tests the CLI command with correct and wrong YAML files.""" - csv_path = request.getfixturevalue(csv_type) - yaml_path = request.getfixturevalue(yaml_type) - tmpdir = pathlib.Path(tempfile.gettempdir()) - if error: - with pytest.raises(error): # type: ignore[call-overload] - main(csv_path, yaml_path, str(tmpdir / "test.csv"), force=force) - else: - filename = f"{csv_type}_{force}.csv" - main(csv_path, yaml_path, str(tmpdir / filename), force=force) - with open(tmpdir / filename) as file: + """Test that split_csv runs without errors. + + Args: + csv_path: Path to test CSV data. + yaml_path: Path to test config YAML. + """ + # Verify required files exist + assert os.path.exists(csv_path), f"CSV file not found at {csv_path}" + assert os.path.exists(yaml_path), f"YAML config not found at {yaml_path}" + + # Create temporary output file + with tempfile.NamedTemporaryFile(suffix=".csv", delete=False) as tmp_file: + output_path = tmp_file.name + + try: + # Run main function - should complete without errors + split_csv( + data_csv=csv_path, + config_yaml=yaml_path, + out_path=output_path, + ) + + # Verify output file exists and has content + assert os.path.exists(output_path), "Output file was not created" + with open(output_path) as file: hash = hashlib.md5(file.read().encode()).hexdigest() # noqa: S324 - assert hash == snapshot + assert hash # Verify we got a hash (file not empty) + + finally: + # Clean up temporary file + if os.path.exists(output_path): + os.unlink(output_path) + + +def test_split_csv_with_force( + csv_path_with_split: str, + yaml_path: str, +) -> None: + """Test split_csv with force flag on file that already has split column. + + Args: + csv_path_with_split: Path to test CSV data with split column. + yaml_path: Path to test config YAML. + """ + with tempfile.NamedTemporaryFile(suffix=".csv", delete=False) as tmp_file: + output_path = tmp_file.name + + try: + split_csv( + data_csv=csv_path_with_split, + config_yaml=yaml_path, + out_path=output_path, + force=True, + ) + assert os.path.exists(output_path) + + finally: + if os.path.exists(output_path): + os.unlink(output_path) + + +def test_cli_invocation( + csv_path: str, + yaml_path: str, +) -> None: + """Test the CLI invocation of split-csv command. + + Args: + csv_path: Path to test CSV data. + yaml_path: Path to test config YAML. + """ + runner = CliRunner() + with runner.isolated_filesystem(): + output_path = "output.csv" + result = runner.invoke( + cli, + [ + "split-csv", + "-c", + csv_path, + "-y", + yaml_path, + "-o", + output_path, + ], + ) + assert result.exit_code == 0 + assert os.path.exists(output_path), "Output file was not created"