Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split behaviour #86

Merged
merged 36 commits into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
049c937
FIX: utils/yaml_data.py:generate_data_configs -> generates only a con…
Feb 12, 2025
7e9ad63
FIX: utils/yaml_data.py:generate_data_configs -> changed transform to…
Feb 12, 2025
3be595c
FIX: utils/yaml_data.py:YamlSubConfigDict -> transforms is now a list…
Feb 12, 2025
acda764
FIX: utils/yaml_data.py:generate_data_configs -> updated docstring
Feb 12, 2025
5ab7ca0
FIX: utils/yaml_data.py -> Removed 'Yaml'
Feb 12, 2025
b213b2f
FIX: utils/yaml_data.py -> Put 'Yaml' back to modify that in another …
Feb 12, 2025
beac235
FIX: tests/cli/__snapshots__/test_split_yaml.ambr -> updated the snap…
Feb 12, 2025
cf0f234
NEW: cli/split_split.py -> New file to split the config only on the s…
Feb 14, 2025
cfdbee6
NEW: cli/split_transforms.py -> New file to split the config for each…
Feb 14, 2025
7221371
REMOVE: cli/split_yaml -> deleted the old splitter file because two n…
Feb 14, 2025
b8251d3
UPDATE: utils/yaml_data.py -> new functions for the transform splitter
Feb 14, 2025
a34754e
FIX: data/data_handlers.py -> YamlSubConfigDict takes a transform or …
Feb 14, 2025
6437bd2
NEW: src/stimulus/cli/split_split.py -> A file that splits all the sp…
Feb 17, 2025
ac1d70c
NEW: src/stimulus/cli/split_transforms.py -> Cli function to split un…
Feb 17, 2025
359b86e
NEW: tests/test_data/titanic/titanic_unique_split.yaml -> A config fi…
Feb 17, 2025
4aca687
DELETE: tests/cli/__snapshots__/test_split_yaml.ambr -> Removed this …
Feb 17, 2025
887d9f1
{src/stimulus,tests}/cli/check_model.py -> Input is now a yaml with o…
Feb 17, 2025
ff0881c
FIX: src/stimulus/cli/split_csv.py -> Updated to take a YamlSplitConf…
Feb 17, 2025
96b3b5b
FIX: tests/cli/test_shuffle_csv.py -> changed the test file to be the…
Feb 17, 2025
f4edfab
FIX: {src/stimulus, tests}/data/data_handlers.py -> Changed the input…
Feb 18, 2025
1dd858c
FIX: check_model.py -> takes a yaml and then passes on only YamlSplit…
Feb 18, 2025
f4ad7ab
FIX: tests/data/test_data_handlers.py -> changed the call from 'confi…
Feb 18, 2025
131388d
FIX: tests/data/transform/test_data_transformers.py -> changed 'gener…
Feb 18, 2025
2c110ad
FIX: tests/utils/test_data_yaml.py -> added tests for the new generat…
Feb 18, 2025
1b4339b
FIX: tests/data/test_experiment.py -> Fixed tests to work with the ne…
Feb 18, 2025
19d8d16
FIX: tests/data/test_handlertorch.py -> changed test to use the YamlS…
Feb 18, 2025
5ef82ff
ADD: src/stimulus/utils/yaml_data.py -> Added a class SplitConfigDict
Feb 18, 2025
cda0a6e
FIX: tests/cli/test_shuffle_csv.py -> Uses a YamlSplitTransformDict now
Feb 18, 2025
ac672f0
{src/stimulus,tests}/typing/{__init__.py, test_typing.py} -> Added th…
Feb 18, 2025
4fa3abd
FIX: src/stimulus/cli/transform_csv.py -> When the file is called it …
Feb 18, 2025
d03ac07
FIX: tests/learner/test_raytune_learner -> Uses a YamlSplitTransformD…
Feb 18, 2025
3c31d67
FIX: {src/stimulus,tests}/cli/tuning.py -> the main function now take…
Feb 18, 2025
cb7d55f
DELETED: tests/cli/test_split_yaml.py -> Deleted it has it has been r…
Feb 18, 2025
10092ee
FIX: src/stimulus/loaders.py -> initialize_splitter_from_config uses …
Feb 18, 2025
eaaf40c
Merge branch 'dev' into split_behaviour
Feb 18, 2025
0012c72
FIX: src/stimulus/data/{data_handlers.py,handlertorch.py} -> Change t…
Feb 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 24 additions & 7 deletions src/stimulus/cli/check_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,22 @@ def get_args() -> argparse.Namespace:
Parsed command line arguments.
"""
parser = argparse.ArgumentParser(description="Launch check_model.")
parser.add_argument("-d", "--data", type=str, required=True, metavar="FILE", help="Path to input csv file.")
parser.add_argument("-m", "--model", type=str, required=True, metavar="FILE", help="Path to model file.")
parser.add_argument(
"-d",
"--data",
type=str,
required=True,
metavar="FILE",
help="Path to input csv file.",
)
parser.add_argument(
"-m",
"--model",
type=str,
required=True,
metavar="FILE",
help="Path to model file.",
)
parser.add_argument(
"-e",
"--data_config",
Expand Down Expand Up @@ -107,22 +121,25 @@ def main(
"""
with open(data_config_path) as file:
data_config = yaml.safe_load(file)
data_config = yaml_data.YamlSubConfigDict(**data_config)
data_config = yaml_data.YamlSplitTransformDict(**data_config)

with open(model_config_path) as file:
model_config = yaml.safe_load(file)
model_config = yaml_model_schema.Model(**model_config)

encoder_loader = loaders.EncoderLoader()
encoder_loader.initialize_column_encoders_from_config(column_config=data_config.columns)
encoder_loader.initialize_column_encoders_from_config(
column_config=data_config.columns
)

logger.info("Dataset loaded successfully.")

model_class = launch_utils.import_class_from_file(model_path)

logger.info("Model class loaded successfully.")

ray_config_loader = yaml_model_schema.YamlRayConfigLoader(model=model_config)
ray_config_loader = yaml_model_schema.YamlRayConfigLoader(
model=model_config)
ray_config_dict = ray_config_loader.get_config().model_dump()
ray_config_model = ray_config_loader.get_config()

Expand All @@ -140,7 +157,7 @@ def main(
logger.info("Model instance loaded successfully.")

torch_dataset = handlertorch.TorchDataset(
config_path=data_config_path,
data_config=data_config,
csv_path=data_path,
encoder_loader=encoder_loader,
)
Expand Down Expand Up @@ -171,7 +188,7 @@ def main(

tuner = raytune_learner.TuneWrapper(
model_config=ray_config_model,
data_config_path=data_config_path,
data_config=data_config,
model_class=model_class,
data_path=data_path,
encoder_loader=encoder_loader,
Expand Down
8 changes: 5 additions & 3 deletions src/stimulus/cli/split_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from stimulus.data.data_handlers import DatasetProcessor, SplitManager
from stimulus.data.loaders import SplitLoader
from stimulus.utils.yaml_data import YamlSubConfigDict
from stimulus.utils.yaml_data import YamlSplitConfigDict


def get_args() -> argparse.Namespace:
Expand Down Expand Up @@ -49,7 +49,9 @@ def get_args() -> argparse.Namespace:
return parser.parse_args()


def main(data_csv: str, config_yaml: str, out_path: str, *, force: bool = False) -> None:
def main(
data_csv: str, config_yaml: str, out_path: str, *, force: bool = False
) -> None:
"""Connect CSV and YAML configuration and handle sanity checks.

Args:
Expand All @@ -64,7 +66,7 @@ def main(data_csv: str, config_yaml: str, out_path: str, *, force: bool = False)
# create a split manager from the config
split_config = processor.dataset_manager.config.split
with open(config_yaml) as f:
yaml_config = YamlSubConfigDict(**yaml.safe_load(f))
yaml_config = YamlSplitConfigDict(**yaml.safe_load(f))
split_loader = SplitLoader(seed=yaml_config.global_params.seed)
split_loader.initialize_splitter_from_config(split_config)
split_manager = SplitManager(split_loader)
Expand Down
22 changes: 12 additions & 10 deletions src/stimulus/cli/split_yaml.py → src/stimulus/cli/split_split.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#!/usr/bin/env python3
"""CLI module for splitting YAML configuration files.
"""CLI module for splitting YAML configuration files into unique files for each split.

This module provides functionality to split a single YAML configuration file into multiple
YAML files, each containing a specific combination of data transformations and splits.
YAML files, each containing a unique split.
The resulting YAML files can be used as input configurations for the stimulus package.
"""

Expand All @@ -13,9 +13,10 @@

from stimulus.utils.yaml_data import (
YamlConfigDict,
YamlSplitConfigDict,
check_yaml_schema,
dump_yaml_list_into_files,
generate_data_configs,
generate_split_configs,
)


Expand All @@ -39,39 +40,40 @@ def get_args() -> argparse.Namespace:
const="./",
default="./",
metavar="DIR",
# TODO: Change the output name
help="The output dir where all the YAMLs are written to. Output YAML will be called split-#[number].yaml transform-#[number].yaml. Default -> ./",
)

return parser.parse_args()


def main(config_yaml: str, out_dir_path: str) -> None:
"""Reads a YAML config file and generates all possible data configurations.
"""Reads a YAML config file and generates a file per unique split.

This script reads a YAML with a defined structure and creates all the YAML files ready to be passed to
the stimulus package.

The structure of the YAML is described here -> TODO paste here link to documentation.
This YAML and it's structure summarize how to generate all the transform - split and respective parameter combinations.
Each resulting YAML will hold only one combination of the above three things.
This YAML and its structure summarize how to generate unique splits and all the transformations associated to this split.

This script will always generate at least one YAML file that represent the combination that does not touch the data (no transform)
and uses the default split behavior.
"""
# read the yaml experiment config and load it to dictionary
# read the yaml experiment config and load its to dictionary
yaml_config: dict[str, Any] = {}
with open(config_yaml) as conf_file:
yaml_config = yaml.safe_load(conf_file)

yaml_config_dict: YamlConfigDict = YamlConfigDict(**yaml_config)
# check if the yaml schema is correct
# FIXME: isn't it redundant to check and already class with pydantic ?
check_yaml_schema(yaml_config_dict)

# generate all the YAML configs
data_configs = generate_data_configs(yaml_config_dict)
# generate the yaml files per split
split_configs: list[YamlSplitConfigDict] = generate_split_configs(yaml_config_dict)

# dump all the YAML configs into files
dump_yaml_list_into_files(data_configs, out_dir_path, "test")
dump_yaml_list_into_files(split_configs, out_dir_path, "test_split")


def run() -> None:
Expand Down
76 changes: 76 additions & 0 deletions src/stimulus/cli/split_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#!/usr/bin/env python3
"""CLI module for splitting YAML configuration files into unique files for each transform.

This module provides functionality to split a single YAML configuration file into multiple
YAML files, each containing a unique transform associated to a unique split.
The resulting YAML files can be used as input configurations for the stimulus package.
"""

import argparse
from typing import Any

import yaml

from stimulus.utils.yaml_data import (
YamlSplitConfigDict,
YamlSplitTransformDict,
dump_yaml_list_into_files,
generate_split_transform_configs,
)


def get_args() -> argparse.Namespace:
"""Get the arguments when using the command line."""
parser = argparse.ArgumentParser(description="")
parser.add_argument(
"-j",
"--yaml",
type=str,
required=True,
metavar="FILE",
help="The YAML config file that hold all the transform per split parameter info",
)
parser.add_argument(
"-d",
"--out-dir",
type=str,
required=False,
nargs="?",
const="./",
default="./",
metavar="DIR",
help="The output dir where all the YAMLs are written to. Output YAML will be called split_transform-#[number].yaml. Default -> ./",
)

return parser.parse_args()


def main(config_yaml: str, out_dir_path: str) -> None:
"""Reads a YAML config and generates files for all split - transform possible combinations.

This script reads a YAML with a defined structure and creates all the YAML files ready to be passed to the stimulus package.

The structure of the YAML is described here -> TODO: paste here the link to documentation
This YAML and its structure summarize how to generate all the transform for the split and respective parameter combinations.

This script will always generate at least one YAML file that represent the combination that does not touch the data (no transform).
"""
# read the yaml experiment config and load its dictionnary
yaml_config: dict[str, Any] = {}
with open(config_yaml) as conf_file:
yaml_config = yaml.safe_load(conf_file)

yaml_config_dict: YamlSplitConfigDict = YamlSplitConfigDict(**yaml_config)

# Generate the yaml files for each transform
split_transform_configs: list[YamlSplitTransformDict] = (
generate_split_transform_configs(yaml_config_dict)
)

# Dump all the YAML configs into files
dump_yaml_list_into_files(split_transform_configs, out_dir_path, "test_transforms")


if __name__ == "__main__":
args = get_args()
main(args.yaml, args.out_dir)
9 changes: 6 additions & 3 deletions src/stimulus/cli/transform_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@

from stimulus.data.data_handlers import DatasetProcessor, TransformManager
from stimulus.data.loaders import TransformLoader
from stimulus.utils.yaml_data import YamlSubConfigDict
from stimulus.utils.yaml_data import YamlSplitConfigDict


def get_args() -> argparse.Namespace:
"""Get the arguments when using from the commandline."""
parser = argparse.ArgumentParser(description="CLI for transforming CSV data files using YAML configuration.")
parser = argparse.ArgumentParser(
description="CLI for transforming CSV data files using YAML configuration."
)
parser.add_argument(
"-c",
"--csv",
Expand Down Expand Up @@ -53,8 +55,9 @@ def main(data_csv: str, config_yaml: str, out_path: str) -> None:
# initialize the transform manager
transform_config = processor.dataset_manager.config.transforms
with open(config_yaml) as f:
yaml_config = YamlSubConfigDict(**yaml.safe_load(f))
yaml_config = YamlSplitConfigDict(**yaml.safe_load(f))
transform_loader = TransformLoader(seed=yaml_config.global_params.seed)
print(transform_config)
transform_loader.initialize_column_data_transformers_from_config(transform_config)
transform_manager = TransformManager(transform_loader)

Expand Down
45 changes: 31 additions & 14 deletions src/stimulus/cli/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,22 @@ def get_args() -> argparse.Namespace:
Parsed command line arguments.
"""
parser = argparse.ArgumentParser(description="Launch check_model.")
parser.add_argument("-d", "--data", type=str, required=True, metavar="FILE", help="Path to input csv file.")
parser.add_argument("-m", "--model", type=str, required=True, metavar="FILE", help="Path to model file.")
parser.add_argument(
"-d",
"--data",
type=str,
required=True,
metavar="FILE",
help="Path to input csv file.",
)
parser.add_argument(
"-m",
"--model",
type=str,
required=True,
metavar="FILE",
help="Path to model file.",
)
parser.add_argument(
"-e",
"--data_config",
Expand Down Expand Up @@ -137,7 +151,7 @@ def get_args() -> argparse.Namespace:
def main(
model_path: str,
data_path: str,
data_config_path: str,
data_config: yaml_data.YamlSplitTransformDict,
model_config_path: str,
initial_weights: str | None = None, # noqa: ARG001
ray_results_dirpath: str | None = None,
Expand All @@ -153,7 +167,7 @@ def main(
Args:
data_path: Path to input data file.
model_path: Path to model file.
data_config_path: Path to data config file.
data_config: A YamlSplitTransformObject
model_config_path: Path to model config file.
initial_weights: Optional path to initial weights.
ray_results_dirpath: Directory for ray results.
Expand All @@ -163,26 +177,25 @@ def main(
best_metrics_path: Path to write the best metrics to.
best_config_path: Path to write the best config to.
"""
# Convert data config to proper type
with open(data_config_path) as file:
data_config_dict: dict[str, Any] = yaml.safe_load(file)
data_config: yaml_data.YamlSubConfigDict = yaml_data.YamlSubConfigDict(**data_config_dict)

with open(model_config_path) as file:
model_config_dict: dict[str, Any] = yaml.safe_load(file)
model_config: yaml_model_schema.Model = yaml_model_schema.Model(**model_config_dict)
model_config: yaml_model_schema.Model = yaml_model_schema.Model(
**model_config_dict)

encoder_loader = loaders.EncoderLoader()
encoder_loader.initialize_column_encoders_from_config(column_config=data_config.columns)
encoder_loader.initialize_column_encoders_from_config(
column_config=data_config.columns
)

model_class = launch_utils.import_class_from_file(model_path)

ray_config_loader = yaml_model_schema.YamlRayConfigLoader(model=model_config)
ray_config_loader = yaml_model_schema.YamlRayConfigLoader(
model=model_config)
ray_config_model = ray_config_loader.get_config()

tuner = raytune_learner.TuneWrapper(
model_config=ray_config_model,
data_config_path=data_config_path,
data_config=data_config,
model_class=model_class,
data_path=data_path,
encoder_loader=encoder_loader,
Expand Down Expand Up @@ -228,10 +241,14 @@ def run() -> None:
"""Run the model checking script."""
ray.init(address="auto", ignore_reinit_error=True)
args = get_args()
# Try to convert the configuration file to a YamlSplitTransformDict
config_dict: yaml_data.YamlSplitTransformDict
with open(args.data_config) as f:
config_dict = yaml_data.YamlSplitTransformDict(**yaml.safe_load(f))
main(
data_path=args.data,
model_path=args.model,
data_config_path=args.data_config,
data_config=config_dict,
model_config_path=args.model_config,
initial_weights=args.initial_weights,
ray_results_dirpath=args.ray_results_dirpath,
Expand Down
Loading
Loading