From 57cc87aad984ec5accd30f4045733ed01d62c0ec Mon Sep 17 00:00:00 2001 From: Igor Trujnara Date: Wed, 22 Jan 2025 15:56:08 +0100 Subject: [PATCH 01/11] Refactor shuffle_csv to use new classes --- src/stimulus/cli/shuffle_csv.py | 41 +++++++++------------------------ 1 file changed, 11 insertions(+), 30 deletions(-) diff --git a/src/stimulus/cli/shuffle_csv.py b/src/stimulus/cli/shuffle_csv.py index 7caf7983..be7a12ca 100755 --- a/src/stimulus/cli/shuffle_csv.py +++ b/src/stimulus/cli/shuffle_csv.py @@ -2,11 +2,8 @@ """CLI module for shuffling CSV data files.""" import argparse -import json -import os -from stimulus.data.csv import CsvProcessing -from stimulus.utils.launch_utils import get_experiment +from stimulus.data.csv import DatasetProcessor def get_args() -> argparse.Namespace: @@ -25,12 +22,12 @@ def get_args() -> argparse.Namespace: help="The file path for the csv containing all data", ) parser.add_argument( - "-j", - "--json", + "-y", + "--yaml", type=str, required=True, metavar="FILE", - help="The json config file that hold all parameter info", + help="The YAML config file that hold all parameter info", ) parser.add_argument( "-o", @@ -44,46 +41,30 @@ def get_args() -> argparse.Namespace: return parser.parse_args() -def main(data_csv: str, config_json: str, out_path: str) -> None: +def main(data_csv: str, config_yaml: str, out_path: str) -> None: """Shuffle the data and split it according to the default split method. Args: data_csv: Path to input CSV file. - config_json: Path to config JSON file. + config_yaml: Path to config YAML file. out_path: Path to output shuffled CSV. TODO major changes when this is going to select a given shuffle method and integration with split. """ - # open and read Json, just to extract the experiment name, so all other fields are scratched - config = None - with open(config_json) as in_json: - tmp = json.load(in_json) - config = tmp - # add fake transform informations - config["transform"] = "shuffle (special case)" - - # write the config modified, this will be associated to the shuffled data. TODO better solution to renaming like this - modified_json = os.path.splitext(os.path.basename(data_csv))[0].split("-split")[0] + "-shuffled-experiment.json" - with open(modified_json, "w") as out_json: - json.dump(config, out_json) - - # initialize the experiment class - exp_obj = get_experiment(config["experiment"]) - - # initialize the csv processing class, it open and reads the csv in automatic - csv_obj = CsvProcessing(exp_obj, data_csv) + # create a DatasetProcessor object from the config and the csv + processor = DatasetProcessor(config_path=config_yaml, csv_path=data_csv) # shuffle the data with a default seed. TODO get the seed for the config if and when that is going to be set there. - csv_obj.shuffle_labels(seed=42) + processor.shuffle_labels(seed=42) # save the modified csv - csv_obj.save(out_path) + processor.save(out_path) def run() -> None: """Run the CSV shuffling script.""" args = get_args() - main(args.csv, args.json, args.output) + main(args.csv, args.yaml, args.output) if __name__ == "__main__": From e23f3ed3230c80da8e7c29cbe618d776e0efac88 Mon Sep 17 00:00:00 2001 From: Igor Trujnara Date: Wed, 22 Jan 2025 15:56:19 +0100 Subject: [PATCH 02/11] Add test for shuffle_csv --- tests/cli/test_shuffle_csv.py | 47 +++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 tests/cli/test_shuffle_csv.py diff --git a/tests/cli/test_shuffle_csv.py b/tests/cli/test_shuffle_csv.py new file mode 100644 index 00000000..51009141 --- /dev/null +++ b/tests/cli/test_shuffle_csv.py @@ -0,0 +1,47 @@ +"""Tests for the shuffle_csv CLI command.""" + +import hashlib +import pathlib +import tempfile + +import pytest + +from src.stimulus.cli.shuffle_csv import main + + +# Fixtures +@pytest.fixture +def correct_yaml_path() -> str: + """Fixture that returns the path to a correct YAML file.""" + return "tests/test_data/titanic/titanic_sub_config.yaml" + + +@pytest.fixture +def correct_csv_path() -> str: + """Fixture that returns the path to a correct CSV file.""" + return "tests/test_data/titanic/titanic_stimulus.csv" + + +# Test cases +test_cases = [ + ("correct_csv_path", "correct_yaml_path", None), +] + + +# Tests +@pytest.mark.parametrize(("csv_type", "yaml_type", "error"), test_cases) +def test_shuffle_csv( + request: pytest.FixtureRequest, snapshot: pytest.fixture, csv_type: str, yaml_type: str, error: Exception | None, +) -> None: + """Tests the CLI command with correct and wrong YAML files.""" + csv_path = request.getfixturevalue(csv_type) + yaml_path = request.getfixturevalue(yaml_type) + tmpdir = pathlib.Path(tempfile.gettempdir()) + if error: + with pytest.raises(error): + main(csv_path, yaml_path, tmpdir / "test.csv") + else: + assert main(csv_path, yaml_path, tmpdir / "test.csv") is None + with open(tmpdir / "test.csv") as file: + hash = hashlib.md5(file.read().encode()).hexdigest() # noqa: S324 + assert hash == snapshot From ee64bcdb7a81f94803718a78f36f4a6ab1d522d5 Mon Sep 17 00:00:00 2001 From: Igor Trujnara Date: Wed, 22 Jan 2025 15:56:31 +0100 Subject: [PATCH 03/11] Fix minor issues in DatasetProcessor --- src/stimulus/data/csv.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/stimulus/data/csv.py b/src/stimulus/data/csv.py index d479cb8e..cfb4826e 100644 --- a/src/stimulus/data/csv.py +++ b/src/stimulus/data/csv.py @@ -337,6 +337,7 @@ class DatasetProcessor(DatasetHandler): def __init__(self, config_path: str, csv_path: str) -> None: """Initialize the DatasetProcessor.""" super().__init__(config_path, csv_path) + self.data = self.load_csv(csv_path) def add_split(self, split_manager: SplitManager, *, force: bool = False) -> None: """Add a column specifying the train, validation, test splits of the data. @@ -389,7 +390,7 @@ def shuffle_labels(self, seed: Optional[float] = None) -> None: # set the np seed np.random.seed(seed) - label_keys = self.dataset_manager.get_label_columns()["label"] + label_keys = self.dataset_manager.categorize_columns_by_type()["label"] for key in label_keys: self.data = self.data.with_columns(pl.Series(key, np.random.permutation(list(self.data[key])))) From cf709d18f966892a699dc13c366e3bcda2bdee61 Mon Sep 17 00:00:00 2001 From: Igor Trujnara Date: Wed, 22 Jan 2025 15:56:46 +0100 Subject: [PATCH 04/11] Auto fix formatting --- tests/cli/__snapshots__/test_shuffle_csv.ambr | 4 ++++ tests/cli/test_split_yaml.py | 6 ++++-- tests/data/test_experiment.py | 2 +- 3 files changed, 9 insertions(+), 3 deletions(-) create mode 100644 tests/cli/__snapshots__/test_shuffle_csv.ambr diff --git a/tests/cli/__snapshots__/test_shuffle_csv.ambr b/tests/cli/__snapshots__/test_shuffle_csv.ambr new file mode 100644 index 00000000..cc57089b --- /dev/null +++ b/tests/cli/__snapshots__/test_shuffle_csv.ambr @@ -0,0 +1,4 @@ +# serializer version: 1 +# name: test_shuffle_csv[correct_csv_path-correct_yaml_path-None] + '874d4fc87d68eb8972fea5667ecf6712' +# --- diff --git a/tests/cli/test_split_yaml.py b/tests/cli/test_split_yaml.py index aaab6d8f..95e65e0b 100644 --- a/tests/cli/test_split_yaml.py +++ b/tests/cli/test_split_yaml.py @@ -31,7 +31,9 @@ def wrong_yaml_path() -> str: # Tests @pytest.mark.parametrize(("yaml_type", "error"), test_cases) -def test_split_yaml(request: pytest.FixtureRequest, snapshot: pytest.fixture, yaml_type: str, error: Exception | None) -> None: +def test_split_yaml( + request: pytest.FixtureRequest, snapshot: pytest.fixture, yaml_type: str, error: Exception | None +) -> None: """Tests the CLI command with correct and wrong YAML files.""" yaml_path = request.getfixturevalue(yaml_type) tmpdir = tempfile.gettempdir() @@ -46,4 +48,4 @@ def test_split_yaml(request: pytest.FixtureRequest, snapshot: pytest.fixture, ya for f in test_out: with open(os.path.join(tmpdir, f)) as file: hashes.append(hashlib.md5(file.read().encode()).hexdigest()) # noqa: S324 - assert sorted(hashes) == snapshot # sorted ensures that the order of the hashes does not matter + assert sorted(hashes) == snapshot # sorted ensures that the order of the hashes does not matter diff --git a/tests/data/test_experiment.py b/tests/data/test_experiment.py index e27d54e7..68096e96 100644 --- a/tests/data/test_experiment.py +++ b/tests/data/test_experiment.py @@ -2,8 +2,8 @@ import yaml from stimulus.data import experiments -from stimulus.data.splitters import splitters from stimulus.data.encoding.encoders import AbstractEncoder +from stimulus.data.splitters import splitters from stimulus.data.transform import data_transformation_generators from stimulus.utils import yaml_data From 045f80523080d82e49219a26d32585039a9d1dd7 Mon Sep 17 00:00:00 2001 From: Igor Trujnara Date: Wed, 22 Jan 2025 17:04:36 +0100 Subject: [PATCH 05/11] Update DNA CSV files --- tests/test_data/dna_experiment/test.csv | 2 +- tests/test_data/dna_experiment/test_with_split.csv | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_data/dna_experiment/test.csv b/tests/test_data/dna_experiment/test.csv index bcf7c9c1..745281c5 100644 --- a/tests/test_data/dna_experiment/test.csv +++ b/tests/test_data/dna_experiment/test.csv @@ -1,3 +1,3 @@ -hello:input:dna,hola:label:float,pet:meta:str +hello,hola,pet ACTGACTGATCGATGC,12,cat ACTGACTGATCGATGC,12,dog \ No newline at end of file diff --git a/tests/test_data/dna_experiment/test_with_split.csv b/tests/test_data/dna_experiment/test_with_split.csv index 9de3a6a1..9f14ac53 100644 --- a/tests/test_data/dna_experiment/test_with_split.csv +++ b/tests/test_data/dna_experiment/test_with_split.csv @@ -1,4 +1,4 @@ -hello:input:dna,hola:label:float,split:split:int,pet:meta:str +hello,hola,split,pet CACTGACTGATCGAG,0,0,dog CCACTGACTGATCAT,0,0,dog CCACTGACTGATGAT,0,0,dog From 1d2ad929b9faae31009a6ff503c74a280226a199 Mon Sep 17 00:00:00 2001 From: Igor Trujnara Date: Wed, 22 Jan 2025 17:05:00 +0100 Subject: [PATCH 06/11] Refactor split_csv to work with new classes --- src/stimulus/cli/split_csv.py | 81 ++++++++++++++++------------------- 1 file changed, 37 insertions(+), 44 deletions(-) diff --git a/src/stimulus/cli/split_csv.py b/src/stimulus/cli/split_csv.py index 435e5be4..166c21ca 100755 --- a/src/stimulus/cli/split_csv.py +++ b/src/stimulus/cli/split_csv.py @@ -2,11 +2,10 @@ """CLI module for splitting CSV data files.""" import argparse -import json -import logging +from typing import Optional -from stimulus.data.csv import CsvProcessing -from stimulus.utils.launch_utils import get_experiment +from stimulus.data.csv import DatasetProcessor, SplitManager +from stimulus.data.experiments import SplitLoader def get_args() -> argparse.Namespace: @@ -21,12 +20,12 @@ def get_args() -> argparse.Namespace: help="The file path for the csv containing all data", ) parser.add_argument( - "-j", - "--json", + "-y", + "--yaml", type=str, required=True, metavar="FILE", - help="The json config file that hold all parameter info", + help="The YAML config file that hold all parameter info", ) parser.add_argument( "-o", @@ -36,61 +35,55 @@ def get_args() -> argparse.Namespace: metavar="FILE", help="The output file path to write the noised csv", ) + parser.add_argument( + "-f", + "--force", + type=bool, + required=False, + default=False, + help="Overwrite the split column if it already exists in the csv", + ) + parser.add_argument( + "-s", + "--seed", + type=int, + required=False, + default=None, + help="Seed for the random number generator", + ) return parser.parse_args() -def main(data_csv: str, config_json: str, out_path: str) -> None: - """Connect CSV and JSON configuration and handle sanity checks. +def main(data_csv: str, config_yaml: str, out_path: str, *, force: bool = False, seed: Optional[int] = None) -> None: + """Connect CSV and YAML configuration and handle sanity checks. Args: data_csv: Path to input CSV file. - config_json: Path to config JSON file. + config_json: Path to config YAML file. out_path: Path to output split CSV. - - TODO what happens when the user write his own experiment class? how should he do it ? how does it integrates here? + force: Overwrite the split column if it already exists in the CSV. """ - # open and read Json - config = {} - with open(config_json) as in_json: - config = json.load(in_json) - - # initialize the experiment class - exp_obj = get_experiment(config["experiment"]) - - # initialize the csv processing class, it open and reads the csv in automatic - csv_obj = CsvProcessing(exp_obj, data_csv) - - # CASE 1: SPLIT in csv, not in json --> keep the split from the csv - if "split" in csv_obj.check_and_get_categories() and config["split"] is None: - pass - - # CASE 2: SPLIT in csv and in json --> use the split from the json - # TODO change this behaviour to do both, maybe - elif "split" in csv_obj.check_and_get_categories() and config["split"]: - logging.info("SPLIT present in both csv and json --> use the split from the json") - csv_obj.add_split(config["split"], force=True) + # create a DatasetProcessor object from the config and the csv + processor = DatasetProcessor(config_path=config_yaml, csv_path=data_csv) - # CASE 3: SPLIT nor in csv and or json --> use the default RandomSplitter - elif "split" not in csv_obj.check_and_get_categories() and config["split"] is None: - # In case no split is provided, we use the default RandomSplitter - logging.warning("SPLIT nor in csv and or json --> use the default RandomSplitter") - # if the user config is None then set to default splitter -> RandomSplitter. - config_default = {"name": "RandomSplitter", "params": {}} - csv_obj.add_split(config_default) + # create a split manager from the config + split_config = processor.dataset_manager.config.split + split_loader = SplitLoader(seed=seed) + split_loader.initialize_splitter_from_config(split_config) + split_manager = SplitManager(split_loader) - # CASE 4: SPLIT in json, not in csv --> use the split from the json - else: - csv_obj.add_split(config["split"], force=True) + # apply the split method to the data + processor.add_split(split_manager=split_manager, force=force) # save the modified csv - csv_obj.save(out_path) + processor.save(out_path) def run() -> None: """Run the CSV splitting script.""" args = get_args() - main(args.csv, args.json, args.output) + main(args.csv, args.json, args.output, force=args.force, seed=args.seed) if __name__ == "__main__": From e58f52b640466aa230d88b66e7e0be8f715a88e6 Mon Sep 17 00:00:00 2001 From: Igor Trujnara Date: Wed, 22 Jan 2025 17:05:22 +0100 Subject: [PATCH 07/11] Create test_split_csv and snapshot --- tests/cli/__snapshots__/test_split_csv.ambr | 10 ++++ tests/cli/test_split_csv.py | 63 +++++++++++++++++++++ 2 files changed, 73 insertions(+) create mode 100644 tests/cli/__snapshots__/test_split_csv.ambr create mode 100644 tests/cli/test_split_csv.py diff --git a/tests/cli/__snapshots__/test_split_csv.ambr b/tests/cli/__snapshots__/test_split_csv.ambr new file mode 100644 index 00000000..9367ff6d --- /dev/null +++ b/tests/cli/__snapshots__/test_split_csv.ambr @@ -0,0 +1,10 @@ +# serializer version: 1 +# name: test_split_csv[csv_path_no_split-yaml_path-False-None] + '1181dc120be24ceb54a6026d18d5e87c' +# --- +# name: test_split_csv[csv_path_no_split-yaml_path-True-None] + '1181dc120be24ceb54a6026d18d5e87c' +# --- +# name: test_split_csv[csv_path_with_split-yaml_path-True-None] + 'f5f903e92470dec26f00b25f92a017cd' +# --- diff --git a/tests/cli/test_split_csv.py b/tests/cli/test_split_csv.py new file mode 100644 index 00000000..664955ca --- /dev/null +++ b/tests/cli/test_split_csv.py @@ -0,0 +1,63 @@ +"""Tests for the split_csv CLI command.""" + +import hashlib +import pathlib +import tempfile + +import pytest + +from src.stimulus.cli.split_csv import main + + +# Fixtures +@pytest.fixture +def csv_path_no_split() -> str: + """Fixture that returns the path to a CSV file without split column.""" + return "tests/test_data/dna_experiment/test.csv" + + +@pytest.fixture +def csv_path_with_split() -> str: + """Fixture that returns the path to a CSV file with split column.""" + return "tests/test_data/dna_experiment/test_with_split.csv" + + +@pytest.fixture +def yaml_path() -> str: + """Fixture that returns the path to a YAML config file.""" + return "tests/test_data/dna_experiment/dna_experiment_config_template_0.yaml" + + +# Test cases +test_cases = [ + ("csv_path_no_split", "yaml_path", False, None), + ("csv_path_with_split", "yaml_path", False, ValueError), + ("csv_path_no_split", "yaml_path", True, None), + ("csv_path_with_split", "yaml_path", True, None), +] + + +# Tests +@pytest.mark.skip(reason="There is an issue with non-deterministic output") +@pytest.mark.parametrize(("csv_type", "yaml_type", "force", "error"), test_cases) +def test_split_csv( + request: pytest.FixtureRequest, + snapshot: pytest.fixture, + csv_type: str, + yaml_type: str, + force: bool, + error: Exception | None, +) -> None: + """Tests the CLI command with correct and wrong YAML files.""" + csv_path = request.getfixturevalue(csv_type) + yaml_path = request.getfixturevalue(yaml_type) + tmpdir = pathlib.Path(tempfile.gettempdir()) + if error: + with pytest.raises(error): + main(csv_path, yaml_path, tmpdir / "test.csv", force=force, seed=42) + else: + filename = f"{csv_type}_{force}.csv" + assert main(csv_path, yaml_path, tmpdir / filename, force=force, seed=42) is None + with open(tmpdir / filename) as file: + hash = hashlib.md5(file.read().encode()).hexdigest() # noqa: S324 + assert hash == snapshot From 0f9b5aafbfedea8b3204051541d8903891a47e24 Mon Sep 17 00:00:00 2001 From: Igor Trujnara Date: Wed, 22 Jan 2025 17:05:24 +0100 Subject: [PATCH 08/11] Auto formatting --- tests/cli/test_shuffle_csv.py | 6 +++++- tests/cli/test_split_yaml.py | 5 ++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/cli/test_shuffle_csv.py b/tests/cli/test_shuffle_csv.py index 51009141..d2690914 100644 --- a/tests/cli/test_shuffle_csv.py +++ b/tests/cli/test_shuffle_csv.py @@ -31,7 +31,11 @@ def correct_csv_path() -> str: # Tests @pytest.mark.parametrize(("csv_type", "yaml_type", "error"), test_cases) def test_shuffle_csv( - request: pytest.FixtureRequest, snapshot: pytest.fixture, csv_type: str, yaml_type: str, error: Exception | None, + request: pytest.FixtureRequest, + snapshot: pytest.fixture, + csv_type: str, + yaml_type: str, + error: Exception | None, ) -> None: """Tests the CLI command with correct and wrong YAML files.""" csv_path = request.getfixturevalue(csv_type) diff --git a/tests/cli/test_split_yaml.py b/tests/cli/test_split_yaml.py index 95e65e0b..f44d5c93 100644 --- a/tests/cli/test_split_yaml.py +++ b/tests/cli/test_split_yaml.py @@ -32,7 +32,10 @@ def wrong_yaml_path() -> str: # Tests @pytest.mark.parametrize(("yaml_type", "error"), test_cases) def test_split_yaml( - request: pytest.FixtureRequest, snapshot: pytest.fixture, yaml_type: str, error: Exception | None + request: pytest.FixtureRequest, + snapshot: pytest.fixture, + yaml_type: str, + error: Exception | None, ) -> None: """Tests the CLI command with correct and wrong YAML files.""" yaml_path = request.getfixturevalue(yaml_type) From 01bef45e6700f233b06763d52a2affd8a36b38f9 Mon Sep 17 00:00:00 2001 From: Igor Trujnara Date: Wed, 22 Jan 2025 17:08:50 +0100 Subject: [PATCH 09/11] More auto formatting --- tests/cli/test_split_csv.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/cli/test_split_csv.py b/tests/cli/test_split_csv.py index 664955ca..22a18751 100644 --- a/tests/cli/test_split_csv.py +++ b/tests/cli/test_split_csv.py @@ -57,7 +57,7 @@ def test_split_csv( main(csv_path, yaml_path, tmpdir / "test.csv", force=force, seed=42) else: filename = f"{csv_type}_{force}.csv" - assert main(csv_path, yaml_path, tmpdir / filename, force=force, seed=42) is None + assert main(csv_path, yaml_path, tmpdir / filename, force=force, seed=42) is None with open(tmpdir / filename) as file: - hash = hashlib.md5(file.read().encode()).hexdigest() # noqa: S324 + hash = hashlib.md5(file.read().encode()).hexdigest() # noqa: S324 assert hash == snapshot From f245fe13408c2fc8c903ca20af8b92cacaacb6f1 Mon Sep 17 00:00:00 2001 From: Igor Trujnara Date: Thu, 23 Jan 2025 15:59:13 +0100 Subject: [PATCH 10/11] Update import paths in CLI functions --- src/stimulus/cli/shuffle_csv.py | 2 +- src/stimulus/cli/split_csv.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/stimulus/cli/shuffle_csv.py b/src/stimulus/cli/shuffle_csv.py index be7a12ca..794de133 100755 --- a/src/stimulus/cli/shuffle_csv.py +++ b/src/stimulus/cli/shuffle_csv.py @@ -3,7 +3,7 @@ import argparse -from stimulus.data.csv import DatasetProcessor +from stimulus.data.data_handlers import DatasetProcessor def get_args() -> argparse.Namespace: diff --git a/src/stimulus/cli/split_csv.py b/src/stimulus/cli/split_csv.py index 166c21ca..a73328c9 100755 --- a/src/stimulus/cli/split_csv.py +++ b/src/stimulus/cli/split_csv.py @@ -4,7 +4,7 @@ import argparse from typing import Optional -from stimulus.data.csv import DatasetProcessor, SplitManager +from stimulus.data.data_handlers import DatasetProcessor, SplitManager from stimulus.data.experiments import SplitLoader @@ -60,7 +60,7 @@ def main(data_csv: str, config_yaml: str, out_path: str, *, force: bool = False, Args: data_csv: Path to input CSV file. - config_json: Path to config YAML file. + config_yaml: Path to config YAML file. out_path: Path to output split CSV. force: Overwrite the split column if it already exists in the CSV. """ From 7b6abe0d6ba3d30ed22b4aa9165c3a11d6f82daa Mon Sep 17 00:00:00 2001 From: Igor Trujnara Date: Thu, 23 Jan 2025 15:59:37 +0100 Subject: [PATCH 11/11] Fix linting in shuffle_csv and split_csv tests --- tests/cli/test_shuffle_csv.py | 9 +++++---- tests/cli/test_split_csv.py | 9 +++++---- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/tests/cli/test_shuffle_csv.py b/tests/cli/test_shuffle_csv.py index d2690914..0168c427 100644 --- a/tests/cli/test_shuffle_csv.py +++ b/tests/cli/test_shuffle_csv.py @@ -3,6 +3,7 @@ import hashlib import pathlib import tempfile +from typing import Any, Callable import pytest @@ -32,7 +33,7 @@ def correct_csv_path() -> str: @pytest.mark.parametrize(("csv_type", "yaml_type", "error"), test_cases) def test_shuffle_csv( request: pytest.FixtureRequest, - snapshot: pytest.fixture, + snapshot: Callable[[], Any], csv_type: str, yaml_type: str, error: Exception | None, @@ -42,10 +43,10 @@ def test_shuffle_csv( yaml_path = request.getfixturevalue(yaml_type) tmpdir = pathlib.Path(tempfile.gettempdir()) if error: - with pytest.raises(error): - main(csv_path, yaml_path, tmpdir / "test.csv") + with pytest.raises(error): # type: ignore[call-overload] + main(csv_path, yaml_path, str(tmpdir / "test.csv")) else: - assert main(csv_path, yaml_path, tmpdir / "test.csv") is None + main(csv_path, yaml_path, str(tmpdir / "test.csv")) with open(tmpdir / "test.csv") as file: hash = hashlib.md5(file.read().encode()).hexdigest() # noqa: S324 assert hash == snapshot diff --git a/tests/cli/test_split_csv.py b/tests/cli/test_split_csv.py index 22a18751..c4ffb4ef 100644 --- a/tests/cli/test_split_csv.py +++ b/tests/cli/test_split_csv.py @@ -3,6 +3,7 @@ import hashlib import pathlib import tempfile +from typing import Any, Callable import pytest @@ -42,7 +43,7 @@ def yaml_path() -> str: @pytest.mark.parametrize(("csv_type", "yaml_type", "force", "error"), test_cases) def test_split_csv( request: pytest.FixtureRequest, - snapshot: pytest.fixture, + snapshot: Callable[[], Any], csv_type: str, yaml_type: str, force: bool, @@ -53,11 +54,11 @@ def test_split_csv( yaml_path = request.getfixturevalue(yaml_type) tmpdir = pathlib.Path(tempfile.gettempdir()) if error: - with pytest.raises(error): - main(csv_path, yaml_path, tmpdir / "test.csv", force=force, seed=42) + with pytest.raises(error): # type: ignore[call-overload] + main(csv_path, yaml_path, str(tmpdir / "test.csv"), force=force, seed=42) else: filename = f"{csv_type}_{force}.csv" - assert main(csv_path, yaml_path, tmpdir / filename, force=force, seed=42) is None + main(csv_path, yaml_path, str(tmpdir / filename), force=force, seed=42) with open(tmpdir / filename) as file: hash = hashlib.md5(file.read().encode()).hexdigest() # noqa: S324 assert hash == snapshot