Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor split_split and split_transform to use Click API #124

Merged
merged 6 commits into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions src/stimulus/cli/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""Main entry point for stimulus-py cli."""

import click
from importlib_metadata import version


@click.group(context_settings={"help_option_names": ["-h", "--help"]})
@click.version_option(version("stimulus-py"), "-v", "--version")
def cli() -> None:
"""Stimulus is an open-science framework for data processing and model training."""


@cli.command()
@click.option(
"-y",
"--yaml",
type=click.Path(exists=True),
required=True,
help="The YAML config file that hold all transform - split - parameter info",
)
@click.option(
"-d",
"--out-dir",
type=click.Path(),
required=False,
default="./",
help="The output dir where all the YAMLs are written to. Output YAML will be called split-#[number].yaml transform-#[number].yaml. Default -> ./",
)
def split_split(
yaml: str,
out_dir: str,
) -> None:
"""Split a YAML configuration file into multiple YAML files, each containing a unique split."""
from stimulus.cli.split_split import split_split as split_split_func

split_split_func(config_yaml=yaml, out_dir_path=out_dir)


@cli.command()
@click.option(
"-j",
"--yaml",
type=click.Path(exists=True),
required=True,
help="The YAML config file that hold all the transform per split parameter info",
)
@click.option(
"-d",
"--out-dir",
type=click.Path(),
required=False,
default="./",
help="The output dir where all the YAMLs are written to. Output YAML will be called split_transform-#[number].yaml. Default -> ./",
)
def split_transforms(
yaml: str,
out_dir: str,
) -> None:
"""Split a YAML configuration file into multiple YAML files, each containing a unique transform."""
from stimulus.cli.split_transforms import split_transforms as split_transforms_func

split_transforms_func(config_yaml=yaml, out_dir_path=out_dir)
68 changes: 15 additions & 53 deletions src/stimulus/cli/split_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,48 +6,17 @@
The resulting YAML files can be used as input configurations for the stimulus package.
"""

import argparse
import logging
from typing import Any

import yaml

from stimulus.utils.yaml_data import (
YamlConfigDict,
YamlSplitConfigDict,
check_yaml_schema,
dump_yaml_list_into_files,
generate_split_configs,
)


def get_args() -> argparse.Namespace:
"""Get the arguments when using from the command line."""
parser = argparse.ArgumentParser(description="")
parser.add_argument(
"-y",
"--yaml",
type=str,
required=True,
metavar="FILE",
help="The YAML config file that hold all transform - split - parameter info",
)
parser.add_argument(
"-d",
"--out_dir",
type=str,
required=False,
nargs="?",
const="./",
default="./",
metavar="DIR",
# TODO: Change the output name
help="The output dir where all the YAMLs are written to. Output YAML will be called split-#[number].yaml transform-#[number].yaml. Default -> ./",
)

return parser.parse_args()


def main(config_yaml: str, out_dir_path: str) -> None:
from stimulus.data.interface import data_config_parser

logger = logging.getLogger(__name__)


def split_split(config_yaml: str, out_dir_path: str) -> None:
"""Reads a YAML config file and generates a file per unique split.

This script reads a YAML with a defined structure and creates all the YAML files ready to be passed to
Expand All @@ -64,23 +33,16 @@ def main(config_yaml: str, out_dir_path: str) -> None:
with open(config_yaml) as conf_file:
yaml_config = yaml.safe_load(conf_file)

yaml_config_dict: YamlConfigDict = YamlConfigDict(**yaml_config)
# check if the yaml schema is correct
# FIXME: isn't it redundant to check and already class with pydantic ?
check_yaml_schema(yaml_config_dict)

# generate the yaml files per split
split_configs: list[YamlSplitConfigDict] = generate_split_configs(yaml_config_dict)
yaml_config_dict = data_config_parser.ConfigDict(**yaml_config)

# dump all the YAML configs into files
dump_yaml_list_into_files(split_configs, out_dir_path, "test_split")
logger.info("YAML config loaded successfully.")

# generate the yaml files per split
split_configs = data_config_parser.generate_split_configs(yaml_config_dict)

def run() -> None:
"""Run the split_yaml CLI."""
args = get_args()
main(args.yaml, args.out_dir)
logger.info("Splits generated successfully.")

# dump all the YAML configs into files
data_config_parser.dump_yaml_list_into_files(split_configs, out_dir_path, "test_split")

if __name__ == "__main__":
run()
logger.info("YAML files saved successfully.")
47 changes: 5 additions & 42 deletions src/stimulus/cli/split_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,46 +6,14 @@
The resulting YAML files can be used as input configurations for the stimulus package.
"""

import argparse
from typing import Any

import yaml

from stimulus.utils.yaml_data import (
YamlSplitConfigDict,
YamlSplitTransformDict,
dump_yaml_list_into_files,
generate_split_transform_configs,
)
from stimulus.data.interface import data_config_parser


def get_args() -> argparse.Namespace:
"""Get the arguments when using the command line."""
parser = argparse.ArgumentParser(description="")
parser.add_argument(
"-j",
"--yaml",
type=str,
required=True,
metavar="FILE",
help="The YAML config file that hold all the transform per split parameter info",
)
parser.add_argument(
"-d",
"--out-dir",
type=str,
required=False,
nargs="?",
const="./",
default="./",
metavar="DIR",
help="The output dir where all the YAMLs are written to. Output YAML will be called split_transform-#[number].yaml. Default -> ./",
)

return parser.parse_args()


def main(config_yaml: str, out_dir_path: str) -> None:
def split_transforms(config_yaml: str, out_dir_path: str) -> None:
"""Reads a YAML config and generates files for all split - transform possible combinations.

This script reads a YAML with a defined structure and creates all the YAML files ready to be passed to the stimulus package.
Expand All @@ -60,15 +28,10 @@ def main(config_yaml: str, out_dir_path: str) -> None:
with open(config_yaml) as conf_file:
yaml_config = yaml.safe_load(conf_file)

yaml_config_dict: YamlSplitConfigDict = YamlSplitConfigDict(**yaml_config)
yaml_config_dict = data_config_parser.SplitConfigDict(**yaml_config)

# Generate the yaml files for each transform
split_transform_configs: list[YamlSplitTransformDict] = generate_split_transform_configs(yaml_config_dict)
split_transform_configs = data_config_parser.generate_split_transform_configs(yaml_config_dict)

# Dump all the YAML configs into files
dump_yaml_list_into_files(split_transform_configs, out_dir_path, "test_transforms")


if __name__ == "__main__":
args = get_args()
main(args.yaml, args.out_dir)
data_config_parser.dump_yaml_list_into_files(split_transform_configs, out_dir_path, "test_transforms")
4 changes: 2 additions & 2 deletions tests/cli/__snapshots__/test_split_splits.ambr
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# serializer version: 1
# name: test_split_split[correct_yaml_path-None]
# name: test_split_split_main[correct_yaml_path-None]
list([
'42139ca7745259e09d1e56e24570d2c7',
'8bca0bebb576d5ce5bb9de5641f627e4',
])
# ---
5 changes: 3 additions & 2 deletions tests/cli/__snapshots__/test_split_transforms.ambr
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# serializer version: 1
# name: test_split_transforms[correct_yaml_path-None]
list([
'26d921a1580fb16ef597e91f4defec5607c11a7e822590f39e9ca456cc29819b',
'db0c550bfe6027c91981423afdc6115c9dcfba61409f961b98831db9e664fd57',
'3a4b035c71d6ef1a46cea25c318634f2331df8993876b71849de80f8c8d0db70',
'7b03cb09d862e4fb884dc0cf4da93ac7efe8068b72a4b086f532a93befb925f1',
'd9013844e136dfeb3969eefb6fb9787e6bf6886c796c3bca5ed1e0dd3400dd29',
])
# ---
58 changes: 47 additions & 11 deletions tests/cli/test_split_splits.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,32 @@

import hashlib
import os
import tempfile
from pathlib import Path
from typing import Any, Callable
from typing import Any, Optional

import pytest
from click.testing import CliRunner

from src.stimulus.cli import split_split
from stimulus.cli import split_split
from stimulus.cli.main import cli


# Fixtures
@pytest.fixture
def correct_yaml_path() -> str:
"""Fixture that returns the path to a correct YAML file."""
return "tests/test_data/titanic/titanic.yaml"
return str(
Path(__file__).parent.parent / "test_data" / "titanic" / "titanic.yaml",
)


@pytest.fixture
def wrong_yaml_path() -> str:
"""Fixture that returns the path to a wrong YAML file."""
return "tests/test_data/yaml_files/wrong_field_type.yaml"
return str(
Path(__file__).parent.parent / "test_data" / "yaml_files" / "wrong_field_type.yaml",
)


# Test cases
Expand All @@ -32,25 +39,54 @@ def wrong_yaml_path() -> str:

# Tests
@pytest.mark.parametrize(("yaml_type", "error"), test_cases)
def test_split_split(
request: pytest.FixtureRequest,
snapshot: Callable[[], Any],
def test_split_split_main(
yaml_type: str,
error: Exception | None,
tmp_path: Path, # Pytest tmp file system
error: Optional[Exception],
request: Any,
snapshot: Any,
tmp_path: Path,
) -> None:
"""Tests the CLI command with correct and wrong YAML files."""
yaml_path = request.getfixturevalue(yaml_type)
tmpdir = str(tmp_path)
if error:
with pytest.raises(error): # type: ignore[call-overload]
split_split.main(yaml_path, tmpdir)
split_split.split_split(yaml_path, tmpdir)
else:
split_split.main(yaml_path, tmpdir) # main() returns None, no need to assert
split_split.split_split(yaml_path, tmpdir) # split_split() returns None, no need to assert
files = os.listdir(tmpdir)
test_out = [f for f in files if f.startswith("test_")]
assert len(test_out) > 0, "No output files were generated"
hashes = []
for f in test_out:
with open(os.path.join(tmpdir, f)) as file:
hashes.append(hashlib.md5(file.read().encode()).hexdigest()) # noqa: S324
assert sorted(hashes) == snapshot # sorted ensures that the order of the hashes does not matter


def test_cli_invocation(
correct_yaml_path: str,
) -> None:
"""Test the CLI invocation of split-split command.

Args:
config_yaml: Path to the YAML config file.
out_dir: Path to the output directory.
"""
runner = CliRunner()
with runner.isolated_filesystem():
output_path = tempfile.gettempdir()
result = runner.invoke(
cli,
[
"split-split",
"-y",
correct_yaml_path,
"-d",
output_path,
],
)
files = os.listdir(output_path)
test_out = [f for f in files if f.startswith("test_")]
assert result.exit_code == 0
assert len(test_out) > 0, "No output files were generated"
Loading
Loading