-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor(shuffle_csv): updated shuffle_csv to the new paradigm.
- Loading branch information
1 parent
09044c7
commit e6a0e23
Showing
3 changed files
with
147 additions
and
87 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,79 +1,59 @@ | ||
#!/usr/bin/env python3 | ||
"""CLI module for shuffling CSV data files.""" | ||
|
||
import argparse | ||
import logging | ||
|
||
import yaml | ||
|
||
from stimulus.data.data_handlers import DatasetProcessor | ||
from stimulus.utils.yaml_data import YamlSplitTransformDict | ||
from stimulus.data import data_handlers | ||
from stimulus.data.interface import data_config_parser | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
def get_args() -> argparse.Namespace: | ||
"""Get the arguments when using from the commandline. | ||
|
||
def load_data_config_from_path(data_path: str, data_config_path: str) -> data_handlers.DatasetProcessor: | ||
"""Load the data config from a path. | ||
Args: | ||
data_config_path: Path to the data config file. | ||
Returns: | ||
Parsed command line arguments. | ||
A tuple of the parsed configuration. | ||
""" | ||
parser = argparse.ArgumentParser(description="Shuffle rows in a CSV data file.") | ||
parser.add_argument( | ||
"-c", | ||
"--csv", | ||
type=str, | ||
required=True, | ||
metavar="FILE", | ||
help="The file path for the csv containing all data", | ||
) | ||
parser.add_argument( | ||
"-y", | ||
"--yaml", | ||
type=str, | ||
required=True, | ||
metavar="FILE", | ||
help="The YAML config file that hold all parameter info", | ||
) | ||
parser.add_argument( | ||
"-o", | ||
"--output", | ||
type=str, | ||
required=True, | ||
metavar="FILE", | ||
help="The output file path to write the noised csv", | ||
) | ||
with open(data_config_path) as file: | ||
data_config_dict = yaml.safe_load(file) | ||
data_config_obj = data_config_parser.SplitConfigDict(**data_config_dict) | ||
|
||
return parser.parse_args() | ||
splitters = data_config_parser.create_splitter(data_config_obj.split) | ||
transforms = data_config_parser.create_transforms(data_config_obj.transforms) | ||
split_columns = data_config_obj.split.split_input_columns | ||
label_columns = [column.column_name for column in data_config_obj.columns if column.column_type == "label"] | ||
|
||
return data_handlers.DatasetProcessor( | ||
csv_path=data_path, | ||
transforms=transforms, | ||
split_columns=split_columns, | ||
splitter=splitters, | ||
), label_columns | ||
|
||
def main(data_csv: str, config_yaml: str, out_path: str) -> None: | ||
|
||
def shuffle_csv(data_csv: str, config_yaml: str, out_path: str) -> None: | ||
"""Shuffle the data and split it according to the default split method. | ||
Args: | ||
data_csv: Path to input CSV file. | ||
config_yaml: Path to config YAML file. | ||
out_path: Path to output shuffled CSV. | ||
TODO major changes when this is going to select a given shuffle method and integration with split. | ||
""" | ||
# read the yaml file | ||
with open(config_yaml) as f: | ||
data_config: YamlSplitTransformDict = YamlSplitTransformDict( | ||
**yaml.safe_load(f), | ||
) | ||
# create a DatasetProcessor object from the config and the csv | ||
processor = DatasetProcessor(data_config=data_config, csv_path=data_csv) | ||
processor, label_columns = load_data_config_from_path(data_csv, config_yaml) | ||
logger.info("Dataset processor initialized successfully.") | ||
|
||
# shuffle the data with a default seed. TODO get the seed for the config if and when that is going to be set there. | ||
processor.shuffle_labels(seed=42) | ||
# shuffle the data with a default seed | ||
# TODO: get the seed from the config if and when that is going to be set there | ||
processor.shuffle_labels(label_columns, seed=42) | ||
logger.info("Data shuffled successfully.") | ||
|
||
# save the modified csv | ||
processor.save(out_path) | ||
|
||
|
||
def run() -> None: | ||
"""Run the CSV shuffling script.""" | ||
args = get_args() | ||
main(args.csv, args.yaml, args.output) | ||
|
||
|
||
if __name__ == "__main__": | ||
run() | ||
logger.info("Shuffled data saved successfully.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,52 +1,95 @@ | ||
"""Tests for the shuffle_csv CLI command.""" | ||
|
||
import hashlib | ||
import pathlib | ||
import os | ||
import tempfile | ||
from typing import Any, Callable | ||
from pathlib import Path | ||
|
||
import pytest | ||
from click.testing import CliRunner | ||
|
||
from src.stimulus.cli.shuffle_csv import main | ||
from stimulus.cli.main import cli | ||
from stimulus.cli.shuffle_csv import shuffle_csv | ||
|
||
|
||
# Fixtures | ||
@pytest.fixture | ||
def correct_yaml_path() -> str: | ||
"""Fixture that returns the path to a correct YAML file.""" | ||
return "tests/test_data/titanic/titanic_sub_config.yaml" | ||
def csv_path() -> str: | ||
"""Get path to test CSV file.""" | ||
return str( | ||
Path(__file__).parent.parent / "test_data" / "titanic" / "titanic_stimulus.csv", | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def correct_csv_path() -> str: | ||
"""Fixture that returns the path to a correct CSV file.""" | ||
return "tests/test_data/titanic/titanic_stimulus.csv" | ||
def yaml_path() -> str: | ||
"""Get path to test config YAML file.""" | ||
return str( | ||
Path(__file__).parent.parent / "test_data" / "titanic" / "titanic_unique_split.yaml", | ||
) | ||
|
||
|
||
# Test cases | ||
test_cases = [ | ||
("correct_csv_path", "correct_yaml_path", None), | ||
] | ||
def test_shuffle_csv_main( | ||
csv_path: str, | ||
yaml_path: str, | ||
) -> None: | ||
"""Test that shuffle_csv.main runs without errors. | ||
Args: | ||
csv_path: Path to test CSV data. | ||
yaml_path: Path to test config YAML. | ||
""" | ||
# Verify required files exist | ||
assert os.path.exists(csv_path), f"CSV file not found at {csv_path}" | ||
assert os.path.exists(yaml_path), f"YAML config not found at {yaml_path}" | ||
|
||
# Tests | ||
@pytest.mark.parametrize(("csv_type", "yaml_type", "error"), test_cases) | ||
def test_shuffle_csv( | ||
request: pytest.FixtureRequest, | ||
snapshot: Callable[[], Any], | ||
csv_type: str, | ||
yaml_type: str, | ||
error: Exception | None, | ||
) -> None: | ||
"""Tests the CLI command with correct and wrong YAML files.""" | ||
csv_path = request.getfixturevalue(csv_type) | ||
yaml_path = request.getfixturevalue(yaml_type) | ||
tmpdir = pathlib.Path(tempfile.gettempdir()) | ||
if error: | ||
with pytest.raises(error): # type: ignore[call-overload] | ||
main(csv_path, yaml_path, str(tmpdir / "test.csv")) | ||
else: | ||
main(csv_path, yaml_path, str(tmpdir / "test.csv")) | ||
with open(tmpdir / "test.csv") as file: | ||
# Create temporary output file | ||
with tempfile.NamedTemporaryFile(suffix=".csv", delete=False) as tmp_file: | ||
output_path = tmp_file.name | ||
|
||
try: | ||
# Run main function - should complete without errors | ||
shuffle_csv( | ||
data_csv=csv_path, | ||
config_yaml=yaml_path, | ||
out_path=output_path, | ||
) | ||
|
||
# Verify output file exists and has content | ||
assert os.path.exists(output_path), "Output file was not created" | ||
with open(output_path) as file: | ||
hash = hashlib.md5(file.read().encode()).hexdigest() # noqa: S324 | ||
assert hash == snapshot | ||
assert hash # Verify we got a hash (file not empty) | ||
|
||
finally: | ||
# Clean up temporary file | ||
if os.path.exists(output_path): | ||
os.unlink(output_path) | ||
|
||
|
||
def test_cli_invocation( | ||
csv_path: str, | ||
yaml_path: str, | ||
) -> None: | ||
"""Test the CLI invocation of shuffle-csv command. | ||
Args: | ||
csv_path: Path to test CSV data. | ||
yaml_path: Path to test config YAML. | ||
""" | ||
runner = CliRunner() | ||
with runner.isolated_filesystem(): | ||
output_path = "output.csv" | ||
result = runner.invoke( | ||
cli, | ||
[ | ||
"shuffle-csv", | ||
"-c", | ||
csv_path, | ||
"-y", | ||
yaml_path, | ||
"-o", | ||
output_path, | ||
], | ||
) | ||
assert result.exit_code == 0 | ||
assert os.path.exists(output_path), "Output file was not created" |