Skip to content

Commit

Permalink
refactor(cli): added refactoring for split_csv cli.
Browse files Browse the repository at this point in the history
  • Loading branch information
mathysgrapotte committed Feb 24, 2025
1 parent e6a0e23 commit 92558c0
Show file tree
Hide file tree
Showing 3 changed files with 192 additions and 112 deletions.
47 changes: 47 additions & 0 deletions src/stimulus/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,50 @@ def shuffle_csv(
config_yaml=yaml,
out_path=output,
)


@cli.command()
@click.option(
"-c",
"--csv",
type=click.Path(exists=True),
required=True,
help="The file path for the csv containing all data",
)
@click.option(
"-y",
"--yaml",
type=click.Path(exists=True),
required=True,
help="The YAML config file that holds all parameter info",
)
@click.option(
"-o",
"--output",
type=click.Path(),
required=True,
help="The output file path to write the split csv",
)
@click.option(
"-f",
"--force",
is_flag=True,
default=False,
help="Overwrite the split column if it already exists in the csv",
)
def split_csv(
csv: str,
yaml: str,
output: str,
*,
force: bool,
) -> None:
"""Split rows in a CSV data file."""
from stimulus.cli.split_csv import split_csv as split_csv_func

split_csv_func(
data_csv=csv,
config_yaml=yaml,
out_path=output,
force=force,
)
103 changes: 35 additions & 68 deletions src/stimulus/cli/split_csv.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,44 @@
#!/usr/bin/env python3
"""CLI module for splitting CSV data files."""

import argparse
import logging

import yaml

from stimulus.data.data_handlers import DatasetProcessor, SplitManager
from stimulus.data.loaders import SplitLoader
from stimulus.utils.yaml_data import YamlSplitConfigDict
from stimulus.data import data_handlers
from stimulus.data.interface import data_config_parser

logger = logging.getLogger(__name__)

def get_args() -> argparse.Namespace:
"""Get the arguments when using from the commandline."""
parser = argparse.ArgumentParser(description="Split a CSV data file.")
parser.add_argument(
"-c",
"--csv",
type=str,
required=True,
metavar="FILE",
help="The file path for the csv containing all data",
)
parser.add_argument(
"-y",
"--yaml",
type=str,
required=True,
metavar="FILE",
help="The YAML config file that hold all parameter info",
)
parser.add_argument(
"-o",
"--output",
type=str,
required=True,
metavar="FILE",
help="The output file path to write the noised csv",
)
parser.add_argument(
"-f",
"--force",
type=bool,
required=False,
default=False,
help="Overwrite the split column if it already exists in the csv",
)

return parser.parse_args()
def load_data_config_from_path(data_path: str, data_config_path: str) -> data_handlers.DatasetProcessor:
"""Load the data config from a path.
Args:
data_path: Path to the data file.
data_config_path: Path to the data config file.
Returns:
A DatasetProcessor instance configured with the data.
"""
with open(data_config_path) as file:
data_config_dict = yaml.safe_load(file)
data_config_obj = data_config_parser.SplitConfigDict(**data_config_dict)

splitters = data_config_parser.create_splitter(data_config_obj.split)
transforms = data_config_parser.create_transforms(data_config_obj.transforms)
split_columns = data_config_obj.split.split_input_columns

return data_handlers.DatasetProcessor(
csv_path=data_path,
transforms=transforms,
split_columns=split_columns,
splitter=splitters,
)


def main(
data_csv: str,
config_yaml: str,
out_path: str,
*,
force: bool = False,
) -> None:
"""Connect CSV and YAML configuration and handle sanity checks.
def split_csv(data_csv: str, config_yaml: str, out_path: str, *, force: bool = False) -> None:
"""Split the data according to the configuration.
Args:
data_csv: Path to input CSV file.
Expand All @@ -65,28 +47,13 @@ def main(
force: Overwrite the split column if it already exists in the CSV.
"""
# create a DatasetProcessor object from the config and the csv
processor = DatasetProcessor(config_path=config_yaml, csv_path=data_csv)

# create a split manager from the config
split_config = processor.dataset_manager.config.split
with open(config_yaml) as f:
yaml_config = YamlSplitConfigDict(**yaml.safe_load(f))
split_loader = SplitLoader(seed=yaml_config.global_params.seed)
split_loader.initialize_splitter_from_config(split_config)
split_manager = SplitManager(split_loader)
processor = load_data_config_from_path(data_csv, config_yaml)
logger.info("Dataset processor initialized successfully.")

# apply the split method to the data
processor.add_split(split_manager=split_manager, force=force)
processor.add_split(force=force)
logger.info("Split applied successfully.")

# save the modified csv
processor.save(out_path)


def run() -> None:
"""Run the CSV splitting script."""
args = get_args()
main(args.csv, args.yaml, args.output, force=args.force)


if __name__ == "__main__":
run()
logger.info("Split data saved successfully.")
154 changes: 110 additions & 44 deletions tests/cli/test_split_csv.py
Original file line number Diff line number Diff line change
@@ -1,64 +1,130 @@
"""Tests for the split_csv CLI command."""

import hashlib
import pathlib
import os
import tempfile
from typing import Any, Callable
from pathlib import Path

import pytest
from click.testing import CliRunner

from src.stimulus.cli.split_csv import main
from stimulus.cli.main import cli
from stimulus.cli.split_csv import split_csv


# Fixtures
@pytest.fixture
def csv_path_no_split() -> str:
"""Fixture that returns the path to a CSV file without split column."""
return "tests/test_data/dna_experiment/test.csv"
def csv_path() -> str:
"""Get path to test CSV file."""
return str(
Path(__file__).parent.parent / "test_data" / "titanic" / "titanic_stimulus.csv",
)


@pytest.fixture
def csv_path_with_split() -> str:
"""Fixture that returns the path to a CSV file with split column."""
return "tests/test_data/dna_experiment/test_with_split.csv"
"""Get path to test CSV file with split column."""
return str(
Path(__file__).parent.parent / "test_data" / "titanic" / "titanic_stimulus_split.csv",
)


@pytest.fixture
def yaml_path() -> str:
"""Fixture that returns the path to a YAML config file."""
return "tests/test_data/dna_experiment/dna_experiment_config_template_0.yaml"


# Test cases
test_cases = [
("csv_path_no_split", "yaml_path", False, None),
("csv_path_with_split", "yaml_path", False, ValueError),
("csv_path_no_split", "yaml_path", True, None),
("csv_path_with_split", "yaml_path", True, None),
]


# Tests
@pytest.mark.skip(reason="There is an issue with non-deterministic output")
@pytest.mark.parametrize(("csv_type", "yaml_type", "force", "error"), test_cases)
def test_split_csv(
request: pytest.FixtureRequest,
snapshot: Callable[[], Any],
csv_type: str,
yaml_type: str,
force: bool,
error: Exception | None,
"""Get path to test config YAML file."""
return str(
Path(__file__).parent.parent / "test_data" / "titanic" / "titanic_unique_split.yaml",
)


def test_split_csv_main(
csv_path: str,
yaml_path: str,
) -> None:
"""Tests the CLI command with correct and wrong YAML files."""
csv_path = request.getfixturevalue(csv_type)
yaml_path = request.getfixturevalue(yaml_type)
tmpdir = pathlib.Path(tempfile.gettempdir())
if error:
with pytest.raises(error): # type: ignore[call-overload]
main(csv_path, yaml_path, str(tmpdir / "test.csv"), force=force)
else:
filename = f"{csv_type}_{force}.csv"
main(csv_path, yaml_path, str(tmpdir / filename), force=force)
with open(tmpdir / filename) as file:
"""Test that split_csv runs without errors.
Args:
csv_path: Path to test CSV data.
yaml_path: Path to test config YAML.
"""
# Verify required files exist
assert os.path.exists(csv_path), f"CSV file not found at {csv_path}"
assert os.path.exists(yaml_path), f"YAML config not found at {yaml_path}"

# Create temporary output file
with tempfile.NamedTemporaryFile(suffix=".csv", delete=False) as tmp_file:
output_path = tmp_file.name

try:
# Run main function - should complete without errors
split_csv(
data_csv=csv_path,
config_yaml=yaml_path,
out_path=output_path,
)

# Verify output file exists and has content
assert os.path.exists(output_path), "Output file was not created"
with open(output_path) as file:
hash = hashlib.md5(file.read().encode()).hexdigest() # noqa: S324
assert hash == snapshot
assert hash # Verify we got a hash (file not empty)

finally:
# Clean up temporary file
if os.path.exists(output_path):
os.unlink(output_path)


def test_split_csv_with_force(
csv_path_with_split: str,
yaml_path: str,
) -> None:
"""Test split_csv with force flag on file that already has split column.
Args:
csv_path_with_split: Path to test CSV data with split column.
yaml_path: Path to test config YAML.
"""
with tempfile.NamedTemporaryFile(suffix=".csv", delete=False) as tmp_file:
output_path = tmp_file.name

try:
split_csv(
data_csv=csv_path_with_split,
config_yaml=yaml_path,
out_path=output_path,
force=True,
)
assert os.path.exists(output_path)

finally:
if os.path.exists(output_path):
os.unlink(output_path)


def test_cli_invocation(
csv_path: str,
yaml_path: str,
) -> None:
"""Test the CLI invocation of split-csv command.
Args:
csv_path: Path to test CSV data.
yaml_path: Path to test config YAML.
"""
runner = CliRunner()
with runner.isolated_filesystem():
output_path = "output.csv"
result = runner.invoke(
cli,
[
"split-csv",
"-c",
csv_path,
"-y",
yaml_path,
"-o",
output_path,
],
)
assert result.exit_code == 0
assert os.path.exists(output_path), "Output file was not created"

0 comments on commit 92558c0

Please sign in to comment.