diff --git a/src/stimulus/cli/shuffle_csv.py b/src/stimulus/cli/shuffle_csv.py index 26ed0edc..794de133 100755 --- a/src/stimulus/cli/shuffle_csv.py +++ b/src/stimulus/cli/shuffle_csv.py @@ -2,11 +2,8 @@ """CLI module for shuffling CSV data files.""" import argparse -import json -import os -from stimulus.data.data_handlers import CsvProcessing -from stimulus.utils.launch_utils import get_experiment +from stimulus.data.data_handlers import DatasetProcessor def get_args() -> argparse.Namespace: @@ -25,12 +22,12 @@ def get_args() -> argparse.Namespace: help="The file path for the csv containing all data", ) parser.add_argument( - "-j", - "--json", + "-y", + "--yaml", type=str, required=True, metavar="FILE", - help="The json config file that hold all parameter info", + help="The YAML config file that hold all parameter info", ) parser.add_argument( "-o", @@ -44,46 +41,30 @@ def get_args() -> argparse.Namespace: return parser.parse_args() -def main(data_csv: str, config_json: str, out_path: str) -> None: +def main(data_csv: str, config_yaml: str, out_path: str) -> None: """Shuffle the data and split it according to the default split method. Args: data_csv: Path to input CSV file. - config_json: Path to config JSON file. + config_yaml: Path to config YAML file. out_path: Path to output shuffled CSV. TODO major changes when this is going to select a given shuffle method and integration with split. """ - # open and read Json, just to extract the experiment name, so all other fields are scratched - config = None - with open(config_json) as in_json: - tmp = json.load(in_json) - config = tmp - # add fake transform informations - config["transform"] = "shuffle (special case)" - - # write the config modified, this will be associated to the shuffled data. TODO better solution to renaming like this - modified_json = os.path.splitext(os.path.basename(data_csv))[0].split("-split")[0] + "-shuffled-experiment.json" - with open(modified_json, "w") as out_json: - json.dump(config, out_json) - - # initialize the experiment class - exp_obj = get_experiment(config["experiment"]) - - # initialize the csv processing class, it open and reads the csv in automatic - csv_obj = CsvProcessing(exp_obj, data_csv) + # create a DatasetProcessor object from the config and the csv + processor = DatasetProcessor(config_path=config_yaml, csv_path=data_csv) # shuffle the data with a default seed. TODO get the seed for the config if and when that is going to be set there. - csv_obj.shuffle_labels(seed=42) + processor.shuffle_labels(seed=42) # save the modified csv - csv_obj.save(out_path) + processor.save(out_path) def run() -> None: """Run the CSV shuffling script.""" args = get_args() - main(args.csv, args.json, args.output) + main(args.csv, args.yaml, args.output) if __name__ == "__main__": diff --git a/src/stimulus/cli/split_csv.py b/src/stimulus/cli/split_csv.py index 00ab5737..a73328c9 100755 --- a/src/stimulus/cli/split_csv.py +++ b/src/stimulus/cli/split_csv.py @@ -2,11 +2,10 @@ """CLI module for splitting CSV data files.""" import argparse -import json -import logging +from typing import Optional -from stimulus.data.data_handlers import CsvProcessing -from stimulus.utils.launch_utils import get_experiment +from stimulus.data.data_handlers import DatasetProcessor, SplitManager +from stimulus.data.experiments import SplitLoader def get_args() -> argparse.Namespace: @@ -21,12 +20,12 @@ def get_args() -> argparse.Namespace: help="The file path for the csv containing all data", ) parser.add_argument( - "-j", - "--json", + "-y", + "--yaml", type=str, required=True, metavar="FILE", - help="The json config file that hold all parameter info", + help="The YAML config file that hold all parameter info", ) parser.add_argument( "-o", @@ -36,61 +35,55 @@ def get_args() -> argparse.Namespace: metavar="FILE", help="The output file path to write the noised csv", ) + parser.add_argument( + "-f", + "--force", + type=bool, + required=False, + default=False, + help="Overwrite the split column if it already exists in the csv", + ) + parser.add_argument( + "-s", + "--seed", + type=int, + required=False, + default=None, + help="Seed for the random number generator", + ) return parser.parse_args() -def main(data_csv: str, config_json: str, out_path: str) -> None: - """Connect CSV and JSON configuration and handle sanity checks. +def main(data_csv: str, config_yaml: str, out_path: str, *, force: bool = False, seed: Optional[int] = None) -> None: + """Connect CSV and YAML configuration and handle sanity checks. Args: data_csv: Path to input CSV file. - config_json: Path to config JSON file. + config_yaml: Path to config YAML file. out_path: Path to output split CSV. - - TODO what happens when the user write his own experiment class? how should he do it ? how does it integrates here? + force: Overwrite the split column if it already exists in the CSV. """ - # open and read Json - config = {} - with open(config_json) as in_json: - config = json.load(in_json) - - # initialize the experiment class - exp_obj = get_experiment(config["experiment"]) - - # initialize the csv processing class, it open and reads the csv in automatic - csv_obj = CsvProcessing(exp_obj, data_csv) - - # CASE 1: SPLIT in csv, not in json --> keep the split from the csv - if "split" in csv_obj.check_and_get_categories() and config["split"] is None: - pass - - # CASE 2: SPLIT in csv and in json --> use the split from the json - # TODO change this behaviour to do both, maybe - elif "split" in csv_obj.check_and_get_categories() and config["split"]: - logging.info("SPLIT present in both csv and json --> use the split from the json") - csv_obj.add_split(config["split"], force=True) + # create a DatasetProcessor object from the config and the csv + processor = DatasetProcessor(config_path=config_yaml, csv_path=data_csv) - # CASE 3: SPLIT nor in csv and or json --> use the default RandomSplitter - elif "split" not in csv_obj.check_and_get_categories() and config["split"] is None: - # In case no split is provided, we use the default RandomSplitter - logging.warning("SPLIT nor in csv and or json --> use the default RandomSplitter") - # if the user config is None then set to default splitter -> RandomSplitter. - config_default = {"name": "RandomSplitter", "params": {}} - csv_obj.add_split(config_default) + # create a split manager from the config + split_config = processor.dataset_manager.config.split + split_loader = SplitLoader(seed=seed) + split_loader.initialize_splitter_from_config(split_config) + split_manager = SplitManager(split_loader) - # CASE 4: SPLIT in json, not in csv --> use the split from the json - else: - csv_obj.add_split(config["split"], force=True) + # apply the split method to the data + processor.add_split(split_manager=split_manager, force=force) # save the modified csv - csv_obj.save(out_path) + processor.save(out_path) def run() -> None: """Run the CSV splitting script.""" args = get_args() - main(args.csv, args.json, args.output) + main(args.csv, args.json, args.output, force=args.force, seed=args.seed) if __name__ == "__main__": diff --git a/tests/cli/__snapshots__/test_shuffle_csv.ambr b/tests/cli/__snapshots__/test_shuffle_csv.ambr new file mode 100644 index 00000000..cc57089b --- /dev/null +++ b/tests/cli/__snapshots__/test_shuffle_csv.ambr @@ -0,0 +1,4 @@ +# serializer version: 1 +# name: test_shuffle_csv[correct_csv_path-correct_yaml_path-None] + '874d4fc87d68eb8972fea5667ecf6712' +# --- diff --git a/tests/cli/__snapshots__/test_split_csv.ambr b/tests/cli/__snapshots__/test_split_csv.ambr new file mode 100644 index 00000000..9367ff6d --- /dev/null +++ b/tests/cli/__snapshots__/test_split_csv.ambr @@ -0,0 +1,10 @@ +# serializer version: 1 +# name: test_split_csv[csv_path_no_split-yaml_path-False-None] + '1181dc120be24ceb54a6026d18d5e87c' +# --- +# name: test_split_csv[csv_path_no_split-yaml_path-True-None] + '1181dc120be24ceb54a6026d18d5e87c' +# --- +# name: test_split_csv[csv_path_with_split-yaml_path-True-None] + 'f5f903e92470dec26f00b25f92a017cd' +# --- diff --git a/tests/cli/test_shuffle_csv.py b/tests/cli/test_shuffle_csv.py new file mode 100644 index 00000000..0168c427 --- /dev/null +++ b/tests/cli/test_shuffle_csv.py @@ -0,0 +1,52 @@ +"""Tests for the shuffle_csv CLI command.""" + +import hashlib +import pathlib +import tempfile +from typing import Any, Callable + +import pytest + +from src.stimulus.cli.shuffle_csv import main + + +# Fixtures +@pytest.fixture +def correct_yaml_path() -> str: + """Fixture that returns the path to a correct YAML file.""" + return "tests/test_data/titanic/titanic_sub_config.yaml" + + +@pytest.fixture +def correct_csv_path() -> str: + """Fixture that returns the path to a correct CSV file.""" + return "tests/test_data/titanic/titanic_stimulus.csv" + + +# Test cases +test_cases = [ + ("correct_csv_path", "correct_yaml_path", None), +] + + +# Tests +@pytest.mark.parametrize(("csv_type", "yaml_type", "error"), test_cases) +def test_shuffle_csv( + request: pytest.FixtureRequest, + snapshot: Callable[[], Any], + csv_type: str, + yaml_type: str, + error: Exception | None, +) -> None: + """Tests the CLI command with correct and wrong YAML files.""" + csv_path = request.getfixturevalue(csv_type) + yaml_path = request.getfixturevalue(yaml_type) + tmpdir = pathlib.Path(tempfile.gettempdir()) + if error: + with pytest.raises(error): # type: ignore[call-overload] + main(csv_path, yaml_path, str(tmpdir / "test.csv")) + else: + main(csv_path, yaml_path, str(tmpdir / "test.csv")) + with open(tmpdir / "test.csv") as file: + hash = hashlib.md5(file.read().encode()).hexdigest() # noqa: S324 + assert hash == snapshot diff --git a/tests/cli/test_split_csv.py b/tests/cli/test_split_csv.py new file mode 100644 index 00000000..c4ffb4ef --- /dev/null +++ b/tests/cli/test_split_csv.py @@ -0,0 +1,64 @@ +"""Tests for the split_csv CLI command.""" + +import hashlib +import pathlib +import tempfile +from typing import Any, Callable + +import pytest + +from src.stimulus.cli.split_csv import main + + +# Fixtures +@pytest.fixture +def csv_path_no_split() -> str: + """Fixture that returns the path to a CSV file without split column.""" + return "tests/test_data/dna_experiment/test.csv" + + +@pytest.fixture +def csv_path_with_split() -> str: + """Fixture that returns the path to a CSV file with split column.""" + return "tests/test_data/dna_experiment/test_with_split.csv" + + +@pytest.fixture +def yaml_path() -> str: + """Fixture that returns the path to a YAML config file.""" + return "tests/test_data/dna_experiment/dna_experiment_config_template_0.yaml" + + +# Test cases +test_cases = [ + ("csv_path_no_split", "yaml_path", False, None), + ("csv_path_with_split", "yaml_path", False, ValueError), + ("csv_path_no_split", "yaml_path", True, None), + ("csv_path_with_split", "yaml_path", True, None), +] + + +# Tests +@pytest.mark.skip(reason="There is an issue with non-deterministic output") +@pytest.mark.parametrize(("csv_type", "yaml_type", "force", "error"), test_cases) +def test_split_csv( + request: pytest.FixtureRequest, + snapshot: Callable[[], Any], + csv_type: str, + yaml_type: str, + force: bool, + error: Exception | None, +) -> None: + """Tests the CLI command with correct and wrong YAML files.""" + csv_path = request.getfixturevalue(csv_type) + yaml_path = request.getfixturevalue(yaml_type) + tmpdir = pathlib.Path(tempfile.gettempdir()) + if error: + with pytest.raises(error): # type: ignore[call-overload] + main(csv_path, yaml_path, str(tmpdir / "test.csv"), force=force, seed=42) + else: + filename = f"{csv_type}_{force}.csv" + main(csv_path, yaml_path, str(tmpdir / filename), force=force, seed=42) + with open(tmpdir / filename) as file: + hash = hashlib.md5(file.read().encode()).hexdigest() # noqa: S324 + assert hash == snapshot diff --git a/tests/test_data/dna_experiment/test.csv b/tests/test_data/dna_experiment/test.csv index bcf7c9c1..745281c5 100644 --- a/tests/test_data/dna_experiment/test.csv +++ b/tests/test_data/dna_experiment/test.csv @@ -1,3 +1,3 @@ -hello:input:dna,hola:label:float,pet:meta:str +hello,hola,pet ACTGACTGATCGATGC,12,cat ACTGACTGATCGATGC,12,dog \ No newline at end of file diff --git a/tests/test_data/dna_experiment/test_with_split.csv b/tests/test_data/dna_experiment/test_with_split.csv index 9de3a6a1..9f14ac53 100644 --- a/tests/test_data/dna_experiment/test_with_split.csv +++ b/tests/test_data/dna_experiment/test_with_split.csv @@ -1,4 +1,4 @@ -hello:input:dna,hola:label:float,split:split:int,pet:meta:str +hello,hola,split,pet CACTGACTGATCGAG,0,0,dog CCACTGACTGATCAT,0,0,dog CCACTGACTGATGAT,0,0,dog