Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update CLI functions to use new classes #55

Merged
merged 12 commits into from
Jan 23, 2025
41 changes: 11 additions & 30 deletions src/stimulus/cli/shuffle_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,8 @@
"""CLI module for shuffling CSV data files."""

import argparse
import json
import os

from stimulus.data.data_handlers import CsvProcessing
from stimulus.utils.launch_utils import get_experiment
from stimulus.data.csv import DatasetProcessor


def get_args() -> argparse.Namespace:
Expand All @@ -25,12 +22,12 @@ def get_args() -> argparse.Namespace:
help="The file path for the csv containing all data",
)
parser.add_argument(
"-j",
"--json",
"-y",
"--yaml",
type=str,
required=True,
metavar="FILE",
help="The json config file that hold all parameter info",
help="The YAML config file that hold all parameter info",
)
parser.add_argument(
"-o",
Expand All @@ -44,46 +41,30 @@ def get_args() -> argparse.Namespace:
return parser.parse_args()


def main(data_csv: str, config_json: str, out_path: str) -> None:
def main(data_csv: str, config_yaml: str, out_path: str) -> None:
"""Shuffle the data and split it according to the default split method.

Args:
data_csv: Path to input CSV file.
config_json: Path to config JSON file.
config_yaml: Path to config YAML file.
out_path: Path to output shuffled CSV.

TODO major changes when this is going to select a given shuffle method and integration with split.
"""
# open and read Json, just to extract the experiment name, so all other fields are scratched
config = None
with open(config_json) as in_json:
tmp = json.load(in_json)
config = tmp
# add fake transform informations
config["transform"] = "shuffle (special case)"

# write the config modified, this will be associated to the shuffled data. TODO better solution to renaming like this
modified_json = os.path.splitext(os.path.basename(data_csv))[0].split("-split")[0] + "-shuffled-experiment.json"
with open(modified_json, "w") as out_json:
json.dump(config, out_json)

# initialize the experiment class
exp_obj = get_experiment(config["experiment"])

# initialize the csv processing class, it open and reads the csv in automatic
csv_obj = CsvProcessing(exp_obj, data_csv)
# create a DatasetProcessor object from the config and the csv
processor = DatasetProcessor(config_path=config_yaml, csv_path=data_csv)

# shuffle the data with a default seed. TODO get the seed for the config if and when that is going to be set there.
csv_obj.shuffle_labels(seed=42)
processor.shuffle_labels(seed=42)

# save the modified csv
csv_obj.save(out_path)
processor.save(out_path)


def run() -> None:
"""Run the CSV shuffling script."""
args = get_args()
main(args.csv, args.json, args.output)
main(args.csv, args.yaml, args.output)


if __name__ == "__main__":
Expand Down
81 changes: 37 additions & 44 deletions src/stimulus/cli/split_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
"""CLI module for splitting CSV data files."""

import argparse
import json
import logging
from typing import Optional

from stimulus.data.data_handlers import CsvProcessing
from stimulus.utils.launch_utils import get_experiment
from stimulus.data.csv import DatasetProcessor, SplitManager
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

csv has been renamed to data_handers to avoid shadowing csv python librairy

from stimulus.data.experiments import SplitLoader


def get_args() -> argparse.Namespace:
Expand All @@ -21,12 +20,12 @@ def get_args() -> argparse.Namespace:
help="The file path for the csv containing all data",
)
parser.add_argument(
"-j",
"--json",
"-y",
"--yaml",
type=str,
required=True,
metavar="FILE",
help="The json config file that hold all parameter info",
help="The YAML config file that hold all parameter info",
)
parser.add_argument(
"-o",
Expand All @@ -36,61 +35,55 @@ def get_args() -> argparse.Namespace:
metavar="FILE",
help="The output file path to write the noised csv",
)
parser.add_argument(
"-f",
"--force",
type=bool,
required=False,
default=False,
help="Overwrite the split column if it already exists in the csv",
)
parser.add_argument(
"-s",
"--seed",
type=int,
required=False,
default=None,
help="Seed for the random number generator",
)

return parser.parse_args()


def main(data_csv: str, config_json: str, out_path: str) -> None:
"""Connect CSV and JSON configuration and handle sanity checks.
def main(data_csv: str, config_yaml: str, out_path: str, *, force: bool = False, seed: Optional[int] = None) -> None:
"""Connect CSV and YAML configuration and handle sanity checks.

Args:
data_csv: Path to input CSV file.
config_json: Path to config JSON file.
config_json: Path to config YAML file.
out_path: Path to output split CSV.

TODO what happens when the user write his own experiment class? how should he do it ? how does it integrates here?
force: Overwrite the split column if it already exists in the CSV.
"""
# open and read Json
config = {}
with open(config_json) as in_json:
config = json.load(in_json)

# initialize the experiment class
exp_obj = get_experiment(config["experiment"])

# initialize the csv processing class, it open and reads the csv in automatic
csv_obj = CsvProcessing(exp_obj, data_csv)

# CASE 1: SPLIT in csv, not in json --> keep the split from the csv
if "split" in csv_obj.check_and_get_categories() and config["split"] is None:
pass

# CASE 2: SPLIT in csv and in json --> use the split from the json
# TODO change this behaviour to do both, maybe
elif "split" in csv_obj.check_and_get_categories() and config["split"]:
logging.info("SPLIT present in both csv and json --> use the split from the json")
csv_obj.add_split(config["split"], force=True)
# create a DatasetProcessor object from the config and the csv
processor = DatasetProcessor(config_path=config_yaml, csv_path=data_csv)

# CASE 3: SPLIT nor in csv and or json --> use the default RandomSplitter
elif "split" not in csv_obj.check_and_get_categories() and config["split"] is None:
# In case no split is provided, we use the default RandomSplitter
logging.warning("SPLIT nor in csv and or json --> use the default RandomSplitter")
# if the user config is None then set to default splitter -> RandomSplitter.
config_default = {"name": "RandomSplitter", "params": {}}
csv_obj.add_split(config_default)
# create a split manager from the config
split_config = processor.dataset_manager.config.split
split_loader = SplitLoader(seed=seed)
split_loader.initialize_splitter_from_config(split_config)
split_manager = SplitManager(split_loader)

# CASE 4: SPLIT in json, not in csv --> use the split from the json
else:
csv_obj.add_split(config["split"], force=True)
# apply the split method to the data
processor.add_split(split_manager=split_manager, force=force)

# save the modified csv
csv_obj.save(out_path)
processor.save(out_path)


def run() -> None:
"""Run the CSV splitting script."""
args = get_args()
main(args.csv, args.json, args.output)
main(args.csv, args.json, args.output, force=args.force, seed=args.seed)


if __name__ == "__main__":
Expand Down
4 changes: 4 additions & 0 deletions tests/cli/__snapshots__/test_shuffle_csv.ambr
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# serializer version: 1
# name: test_shuffle_csv[correct_csv_path-correct_yaml_path-None]
'874d4fc87d68eb8972fea5667ecf6712'
# ---
10 changes: 10 additions & 0 deletions tests/cli/__snapshots__/test_split_csv.ambr
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# serializer version: 1
# name: test_split_csv[csv_path_no_split-yaml_path-False-None]
'1181dc120be24ceb54a6026d18d5e87c'
# ---
# name: test_split_csv[csv_path_no_split-yaml_path-True-None]
'1181dc120be24ceb54a6026d18d5e87c'
# ---
# name: test_split_csv[csv_path_with_split-yaml_path-True-None]
'f5f903e92470dec26f00b25f92a017cd'
# ---
51 changes: 51 additions & 0 deletions tests/cli/test_shuffle_csv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Tests for the shuffle_csv CLI command."""

import hashlib
import pathlib
import tempfile

import pytest

from src.stimulus.cli.shuffle_csv import main


# Fixtures
@pytest.fixture
def correct_yaml_path() -> str:
"""Fixture that returns the path to a correct YAML file."""
return "tests/test_data/titanic/titanic_sub_config.yaml"


@pytest.fixture
def correct_csv_path() -> str:
"""Fixture that returns the path to a correct CSV file."""
return "tests/test_data/titanic/titanic_stimulus.csv"


# Test cases
test_cases = [
("correct_csv_path", "correct_yaml_path", None),
]


# Tests
@pytest.mark.parametrize(("csv_type", "yaml_type", "error"), test_cases)
def test_shuffle_csv(
request: pytest.FixtureRequest,
snapshot: pytest.fixture,
csv_type: str,
yaml_type: str,
error: Exception | None,
) -> None:
"""Tests the CLI command with correct and wrong YAML files."""
csv_path = request.getfixturevalue(csv_type)
yaml_path = request.getfixturevalue(yaml_type)
tmpdir = pathlib.Path(tempfile.gettempdir())
if error:
with pytest.raises(error):
main(csv_path, yaml_path, tmpdir / "test.csv")
else:
assert main(csv_path, yaml_path, tmpdir / "test.csv") is None
with open(tmpdir / "test.csv") as file:
hash = hashlib.md5(file.read().encode()).hexdigest() # noqa: S324
assert hash == snapshot
63 changes: 63 additions & 0 deletions tests/cli/test_split_csv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""Tests for the split_csv CLI command."""

import hashlib
import pathlib
import tempfile

import pytest

from src.stimulus.cli.split_csv import main


# Fixtures
@pytest.fixture
def csv_path_no_split() -> str:
"""Fixture that returns the path to a CSV file without split column."""
return "tests/test_data/dna_experiment/test.csv"


@pytest.fixture
def csv_path_with_split() -> str:
"""Fixture that returns the path to a CSV file with split column."""
return "tests/test_data/dna_experiment/test_with_split.csv"


@pytest.fixture
def yaml_path() -> str:
"""Fixture that returns the path to a YAML config file."""
return "tests/test_data/dna_experiment/dna_experiment_config_template_0.yaml"


# Test cases
test_cases = [
("csv_path_no_split", "yaml_path", False, None),
("csv_path_with_split", "yaml_path", False, ValueError),
("csv_path_no_split", "yaml_path", True, None),
("csv_path_with_split", "yaml_path", True, None),
]


# Tests
@pytest.mark.skip(reason="There is an issue with non-deterministic output")
@pytest.mark.parametrize(("csv_type", "yaml_type", "force", "error"), test_cases)
def test_split_csv(
request: pytest.FixtureRequest,
snapshot: pytest.fixture,
csv_type: str,
yaml_type: str,
force: bool,
error: Exception | None,
) -> None:
"""Tests the CLI command with correct and wrong YAML files."""
csv_path = request.getfixturevalue(csv_type)
yaml_path = request.getfixturevalue(yaml_type)
tmpdir = pathlib.Path(tempfile.gettempdir())
if error:
with pytest.raises(error):
main(csv_path, yaml_path, tmpdir / "test.csv", force=force, seed=42)
else:
filename = f"{csv_type}_{force}.csv"
assert main(csv_path, yaml_path, tmpdir / filename, force=force, seed=42) is None
with open(tmpdir / filename) as file:
hash = hashlib.md5(file.read().encode()).hexdigest() # noqa: S324
assert hash == snapshot
2 changes: 1 addition & 1 deletion tests/test_data/dna_experiment/test.csv
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
hello:input:dna,hola:label:float,pet:meta:str
hello,hola,pet
ACTGACTGATCGATGC,12,cat
ACTGACTGATCGATGC,12,dog
2 changes: 1 addition & 1 deletion tests/test_data/dna_experiment/test_with_split.csv
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
hello:input:dna,hola:label:float,split:split:int,pet:meta:str
hello,hola,split,pet
CACTGACTGATCGAG,0,0,dog
CCACTGACTGATCAT,0,0,dog
CCACTGACTGATGAT,0,0,dog
Expand Down