Skip to content

Commit

Permalink
refactor(shuffle_csv): updated shuffle_csv to the new paradigm.
Browse files Browse the repository at this point in the history
  • Loading branch information
mathysgrapotte committed Feb 21, 2025
1 parent 09044c7 commit e6a0e23
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 87 deletions.
37 changes: 37 additions & 0 deletions src/stimulus/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,40 @@ def check_model(
ray_results_dirpath=ray_results_dirpath,
debug_mode=debug_mode,
)


@cli.command()
@click.option(
"-c",
"--csv",
type=click.Path(exists=True),
required=True,
help="The file path for the csv containing the data in csv format",
)
@click.option(
"-y",
"--yaml",
type=click.Path(exists=True),
required=True,
help="The YAML data config",
)
@click.option(
"-o",
"--output",
type=click.Path(),
required=True,
help="The output file path to write the shuffled csv",
)
def shuffle_csv(
csv: str,
yaml: str,
output: str,
) -> None:
"""Shuffle rows in a CSV data file."""
from stimulus.cli.shuffle_csv import shuffle_csv as shuffle_csv_func

shuffle_csv_func(
data_csv=csv,
config_yaml=yaml,
out_path=output,
)
86 changes: 33 additions & 53 deletions src/stimulus/cli/shuffle_csv.py
Original file line number Diff line number Diff line change
@@ -1,79 +1,59 @@
#!/usr/bin/env python3
"""CLI module for shuffling CSV data files."""

import argparse
import logging

import yaml

from stimulus.data.data_handlers import DatasetProcessor
from stimulus.utils.yaml_data import YamlSplitTransformDict
from stimulus.data import data_handlers
from stimulus.data.interface import data_config_parser

logger = logging.getLogger(__name__)

def get_args() -> argparse.Namespace:
"""Get the arguments when using from the commandline.

def load_data_config_from_path(data_path: str, data_config_path: str) -> data_handlers.DatasetProcessor:
"""Load the data config from a path.
Args:
data_config_path: Path to the data config file.
Returns:
Parsed command line arguments.
A tuple of the parsed configuration.
"""
parser = argparse.ArgumentParser(description="Shuffle rows in a CSV data file.")
parser.add_argument(
"-c",
"--csv",
type=str,
required=True,
metavar="FILE",
help="The file path for the csv containing all data",
)
parser.add_argument(
"-y",
"--yaml",
type=str,
required=True,
metavar="FILE",
help="The YAML config file that hold all parameter info",
)
parser.add_argument(
"-o",
"--output",
type=str,
required=True,
metavar="FILE",
help="The output file path to write the noised csv",
)
with open(data_config_path) as file:
data_config_dict = yaml.safe_load(file)
data_config_obj = data_config_parser.SplitConfigDict(**data_config_dict)

return parser.parse_args()
splitters = data_config_parser.create_splitter(data_config_obj.split)
transforms = data_config_parser.create_transforms(data_config_obj.transforms)
split_columns = data_config_obj.split.split_input_columns
label_columns = [column.column_name for column in data_config_obj.columns if column.column_type == "label"]

return data_handlers.DatasetProcessor(
csv_path=data_path,
transforms=transforms,
split_columns=split_columns,
splitter=splitters,
), label_columns

def main(data_csv: str, config_yaml: str, out_path: str) -> None:

def shuffle_csv(data_csv: str, config_yaml: str, out_path: str) -> None:
"""Shuffle the data and split it according to the default split method.
Args:
data_csv: Path to input CSV file.
config_yaml: Path to config YAML file.
out_path: Path to output shuffled CSV.
TODO major changes when this is going to select a given shuffle method and integration with split.
"""
# read the yaml file
with open(config_yaml) as f:
data_config: YamlSplitTransformDict = YamlSplitTransformDict(
**yaml.safe_load(f),
)
# create a DatasetProcessor object from the config and the csv
processor = DatasetProcessor(data_config=data_config, csv_path=data_csv)
processor, label_columns = load_data_config_from_path(data_csv, config_yaml)
logger.info("Dataset processor initialized successfully.")

# shuffle the data with a default seed. TODO get the seed for the config if and when that is going to be set there.
processor.shuffle_labels(seed=42)
# shuffle the data with a default seed
# TODO: get the seed from the config if and when that is going to be set there
processor.shuffle_labels(label_columns, seed=42)
logger.info("Data shuffled successfully.")

# save the modified csv
processor.save(out_path)


def run() -> None:
"""Run the CSV shuffling script."""
args = get_args()
main(args.csv, args.yaml, args.output)


if __name__ == "__main__":
run()
logger.info("Shuffled data saved successfully.")
111 changes: 77 additions & 34 deletions tests/cli/test_shuffle_csv.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,95 @@
"""Tests for the shuffle_csv CLI command."""

import hashlib
import pathlib
import os
import tempfile
from typing import Any, Callable
from pathlib import Path

import pytest
from click.testing import CliRunner

from src.stimulus.cli.shuffle_csv import main
from stimulus.cli.main import cli
from stimulus.cli.shuffle_csv import shuffle_csv


# Fixtures
@pytest.fixture
def correct_yaml_path() -> str:
"""Fixture that returns the path to a correct YAML file."""
return "tests/test_data/titanic/titanic_sub_config.yaml"
def csv_path() -> str:
"""Get path to test CSV file."""
return str(
Path(__file__).parent.parent / "test_data" / "titanic" / "titanic_stimulus.csv",
)


@pytest.fixture
def correct_csv_path() -> str:
"""Fixture that returns the path to a correct CSV file."""
return "tests/test_data/titanic/titanic_stimulus.csv"
def yaml_path() -> str:
"""Get path to test config YAML file."""
return str(
Path(__file__).parent.parent / "test_data" / "titanic" / "titanic_unique_split.yaml",
)


# Test cases
test_cases = [
("correct_csv_path", "correct_yaml_path", None),
]
def test_shuffle_csv_main(
csv_path: str,
yaml_path: str,
) -> None:
"""Test that shuffle_csv.main runs without errors.
Args:
csv_path: Path to test CSV data.
yaml_path: Path to test config YAML.
"""
# Verify required files exist
assert os.path.exists(csv_path), f"CSV file not found at {csv_path}"
assert os.path.exists(yaml_path), f"YAML config not found at {yaml_path}"

# Tests
@pytest.mark.parametrize(("csv_type", "yaml_type", "error"), test_cases)
def test_shuffle_csv(
request: pytest.FixtureRequest,
snapshot: Callable[[], Any],
csv_type: str,
yaml_type: str,
error: Exception | None,
) -> None:
"""Tests the CLI command with correct and wrong YAML files."""
csv_path = request.getfixturevalue(csv_type)
yaml_path = request.getfixturevalue(yaml_type)
tmpdir = pathlib.Path(tempfile.gettempdir())
if error:
with pytest.raises(error): # type: ignore[call-overload]
main(csv_path, yaml_path, str(tmpdir / "test.csv"))
else:
main(csv_path, yaml_path, str(tmpdir / "test.csv"))
with open(tmpdir / "test.csv") as file:
# Create temporary output file
with tempfile.NamedTemporaryFile(suffix=".csv", delete=False) as tmp_file:
output_path = tmp_file.name

try:
# Run main function - should complete without errors
shuffle_csv(
data_csv=csv_path,
config_yaml=yaml_path,
out_path=output_path,
)

# Verify output file exists and has content
assert os.path.exists(output_path), "Output file was not created"
with open(output_path) as file:
hash = hashlib.md5(file.read().encode()).hexdigest() # noqa: S324
assert hash == snapshot
assert hash # Verify we got a hash (file not empty)

finally:
# Clean up temporary file
if os.path.exists(output_path):
os.unlink(output_path)


def test_cli_invocation(
csv_path: str,
yaml_path: str,
) -> None:
"""Test the CLI invocation of shuffle-csv command.
Args:
csv_path: Path to test CSV data.
yaml_path: Path to test config YAML.
"""
runner = CliRunner()
with runner.isolated_filesystem():
output_path = "output.csv"
result = runner.invoke(
cli,
[
"shuffle-csv",
"-c",
csv_path,
"-y",
yaml_path,
"-o",
output_path,
],
)
assert result.exit_code == 0
assert os.path.exists(output_path), "Output file was not created"

0 comments on commit e6a0e23

Please sign in to comment.