From 049c937bc37a87b1a92a13c692518c6308a26267 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 12 Feb 2025 16:42:03 +0100 Subject: [PATCH 01/35] FIX: utils/yaml_data.py:generate_data_configs -> generates only a config per split and not all the config possible between splits and transforms --- src/stimulus/utils/yaml_data.py | 79 +++++++++++++++++++++++---------- 1 file changed, 55 insertions(+), 24 deletions(-) diff --git a/src/stimulus/utils/yaml_data.py b/src/stimulus/utils/yaml_data.py index cfb550fc..0d49c0b2 100644 --- a/src/stimulus/utils/yaml_data.py +++ b/src/stimulus/utils/yaml_data.py @@ -16,7 +16,9 @@ class YamlColumnsEncoder(BaseModel): """Model for column encoder configuration.""" name: str - params: Optional[dict[str, Union[str, list[Any]]]] # Allow both string and list values + params: Optional[ + dict[str, Union[str, list[Any]]] + ] # Allow both string and list values class YamlColumns(BaseModel): @@ -32,7 +34,9 @@ class YamlTransformColumnsTransformation(BaseModel): """Model for column transformation configuration.""" name: str - params: Optional[dict[str, Union[list[Any], float]]] # Allow both list and float values + params: Optional[ + dict[str, Union[list[Any], float]] + ] # Allow both list and float values class YamlTransformColumns(BaseModel): @@ -50,7 +54,9 @@ class YamlTransform(BaseModel): @field_validator("columns") @classmethod - def validate_param_lists_across_columns(cls, columns: list[YamlTransformColumns]) -> list[YamlTransformColumns]: + def validate_param_lists_across_columns( + cls, columns: list[YamlTransformColumns] + ) -> list[YamlTransformColumns]: """Validate that parameter lists across columns have consistent lengths. Args: @@ -120,7 +126,9 @@ class YamlSchema(BaseModel): yaml_conf: YamlConfigDict -def extract_transform_parameters_at_index(transform: YamlTransform, index: int = 0) -> YamlTransform: +def extract_transform_parameters_at_index( + transform: YamlTransform, index: int = 0 +) -> YamlTransform: """Get a transform with parameters at the specified index. Args: @@ -149,7 +157,9 @@ def extract_transform_parameters_at_index(transform: YamlTransform, index: int = return new_transform -def expand_transform_parameter_combinations(transform: YamlTransform) -> list[YamlTransform]: +def expand_transform_parameter_combinations( + transform: YamlTransform, +) -> list[YamlTransform]: """Get all possible transforms by extracting parameters at each valid index. For a transform with parameter lists, creates multiple new transforms, each containing @@ -167,9 +177,15 @@ def expand_transform_parameter_combinations(transform: YamlTransform) -> list[Ya for column in transform.columns: for transformation in column.transformations: if transformation.params: - list_lengths = [len(v) for v in transformation.params.values() if isinstance(v, list) and len(v) > 1] + list_lengths = [ + len(v) + for v in transformation.params.values() + if isinstance(v, list) and len(v) > 1 + ] if list_lengths: - max_length = list_lengths[0] # All lists have same length due to validator + max_length = list_lengths[ + 0 + ] # All lists have same length due to validator break # Generate a transform for each index @@ -180,7 +196,9 @@ def expand_transform_parameter_combinations(transform: YamlTransform) -> list[Ya return transforms -def expand_transform_list_combinations(transform_list: list[YamlTransform]) -> list[YamlTransform]: +def expand_transform_list_combinations( + transform_list: list[YamlTransform], +) -> list[YamlTransform]: """Expands a list of transforms into all possible parameter combinations. Takes a list of transforms where each transform may contain parameter lists, @@ -229,19 +247,17 @@ def generate_data_configs(yaml_config: YamlConfigDict) -> list[YamlSubConfigDict if isinstance(yaml_config, dict) and not isinstance(yaml_config, YamlConfigDict): raise TypeError("Input must be a YamlConfigDict object") - sub_transforms = expand_transform_list_combinations(yaml_config.transforms) sub_splits = yaml_config.split sub_configs = [] for split in sub_splits: - for transform in sub_transforms: - sub_configs.append( - YamlSubConfigDict( - global_params=yaml_config.global_params, - columns=yaml_config.columns, - transforms=transform, - split=split, - ), - ) + sub_configs.append( + YamlSubConfigDict( + global_params=yaml_config.global_params, + columns=yaml_config.columns, + transforms=yaml_config.transform, + split=split, + ), + ) return sub_configs @@ -264,9 +280,13 @@ def custom_representer(dumper: yaml.Dumper, data: Any) -> yaml.Node: if len(data) == 0: return dumper.represent_scalar("tag:yaml.org,2002:null", "") if isinstance(data[0], dict): - return dumper.represent_sequence("tag:yaml.org,2002:seq", data, flow_style=False) + return dumper.represent_sequence( + "tag:yaml.org,2002:seq", data, flow_style=False + ) if isinstance(data[0], list): - return dumper.represent_sequence("tag:yaml.org,2002:seq", data, flow_style=True) + return dumper.represent_sequence( + "tag:yaml.org,2002:seq", data, flow_style=True + ) return dumper.represent_sequence("tag:yaml.org,2002:seq", data, flow_style=True) class CustomDumper(yaml.Dumper): @@ -282,7 +302,9 @@ def write_line_break(self, _data: Any = None) -> None: if len(self.indents) <= 1: # At root level super().write_line_break(_data) - def increase_indent(self, *, flow: bool = False, indentless: bool = False) -> None: # type: ignore[override] + def increase_indent( + self, *, flow: bool = False, indentless: bool = False + ) -> None: # type: ignore[override] """Ensure consistent indentation by preventing indentless sequences.""" return super().increase_indent( flow=flow, @@ -305,21 +327,30 @@ def fix_params(input_dict: dict[str, Any]) -> dict[str, Any]: processed_dict[key] = [] for encoder in value: processed_encoder = dict(encoder) - if "params" not in processed_encoder or not processed_encoder["params"]: + if ( + "params" not in processed_encoder + or not processed_encoder["params"] + ): processed_encoder["params"] = {} processed_dict[key].append(processed_encoder) elif key == "transformations" and isinstance(value, list): processed_dict[key] = [] for transformation in value: processed_transformation = dict(transformation) - if "params" not in processed_transformation or not processed_transformation["params"]: + if ( + "params" not in processed_transformation + or not processed_transformation["params"] + ): processed_transformation["params"] = {} processed_dict[key].append(processed_transformation) elif isinstance(value, dict): processed_dict[key] = fix_params(value) elif isinstance(value, list): processed_dict[key] = [ - fix_params(list_item) if isinstance(list_item, dict) else list_item for list_item in value + fix_params(list_item) + if isinstance(list_item, dict) + else list_item + for list_item in value ] else: processed_dict[key] = value From 7e9ad638e0bc18227be911fdbf9d734acf2c8ecd Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 12 Feb 2025 16:42:57 +0100 Subject: [PATCH 02/35] FIX: utils/yaml_data.py:generate_data_configs -> changed transform to the right variable transforms --- src/stimulus/utils/yaml_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/stimulus/utils/yaml_data.py b/src/stimulus/utils/yaml_data.py index 0d49c0b2..5f64981e 100644 --- a/src/stimulus/utils/yaml_data.py +++ b/src/stimulus/utils/yaml_data.py @@ -254,7 +254,7 @@ def generate_data_configs(yaml_config: YamlConfigDict) -> list[YamlSubConfigDict YamlSubConfigDict( global_params=yaml_config.global_params, columns=yaml_config.columns, - transforms=yaml_config.transform, + transforms=yaml_config.transforms, split=split, ), ) From 3be595cd7966921c77c9a27c27feb5dbd0a412aa Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 12 Feb 2025 16:44:57 +0100 Subject: [PATCH 03/35] FIX: utils/yaml_data.py:YamlSubConfigDict -> transforms is now a list of YamlTransform --- src/stimulus/utils/yaml_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/stimulus/utils/yaml_data.py b/src/stimulus/utils/yaml_data.py index 5f64981e..7e9eb72a 100644 --- a/src/stimulus/utils/yaml_data.py +++ b/src/stimulus/utils/yaml_data.py @@ -116,7 +116,7 @@ class YamlSubConfigDict(BaseModel): global_params: YamlGlobalParams columns: list[YamlColumns] - transforms: YamlTransform + transforms: list[YamlTransform] split: YamlSplit From acda7643656f38a64ab1111d80e54f83930ad6f4 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 12 Feb 2025 16:47:48 +0100 Subject: [PATCH 04/35] FIX: utils/yaml_data.py:generate_data_configs -> updated docstring --- src/stimulus/utils/yaml_data.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/stimulus/utils/yaml_data.py b/src/stimulus/utils/yaml_data.py index 7e9eb72a..ea8532b0 100644 --- a/src/stimulus/utils/yaml_data.py +++ b/src/stimulus/utils/yaml_data.py @@ -230,9 +230,16 @@ def generate_data_configs(yaml_config: YamlConfigDict) -> list[YamlSubConfigDict separate data configurations. For example, if the config has: - - A transform with parameters [0.1, 0.2] + - Two transforms with parameters [0.1, 0.2], [0.3, 0.4] - Two splits [0.7/0.3] and [0.8/0.2] - This will generate 4 configs, 2 for each split. + This will generate 2 configs, 2 for each split. + config_1: + transform: [[0.1, 0.2], [0.3, 0.4]] + split: [0.7, 0.3] + + config_2: + transform: [[0.1, 0.2], [0.3, 0.4]] + split: [0.8, 0.2] Args: yaml_config: The source YAML configuration containing transforms with From 5ab7ca04ab4ec15b0c73abc4a5ea55cdb76a9f84 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 12 Feb 2025 16:49:42 +0100 Subject: [PATCH 05/35] FIX: utils/yaml_data.py -> Removed 'Yaml' --- src/stimulus/utils/yaml_data.py | 84 ++++++++++++++++----------------- 1 file changed, 42 insertions(+), 42 deletions(-) diff --git a/src/stimulus/utils/yaml_data.py b/src/stimulus/utils/yaml_data.py index ea8532b0..38b81c1a 100644 --- a/src/stimulus/utils/yaml_data.py +++ b/src/stimulus/utils/yaml_data.py @@ -6,13 +6,13 @@ from pydantic import BaseModel, ValidationError, field_validator -class YamlGlobalParams(BaseModel): +class GlobalParams(BaseModel): """Model for global parameters in YAML configuration.""" seed: int -class YamlColumnsEncoder(BaseModel): +class ColumnsEncoder(BaseModel): """Model for column encoder configuration.""" name: str @@ -21,16 +21,16 @@ class YamlColumnsEncoder(BaseModel): ] # Allow both string and list values -class YamlColumns(BaseModel): +class Columns(BaseModel): """Model for column configuration.""" column_name: str column_type: str data_type: str - encoder: list[YamlColumnsEncoder] + encoder: list[ColumnsEncoder] -class YamlTransformColumnsTransformation(BaseModel): +class TransformColumnsTransformation(BaseModel): """Model for column transformation configuration.""" name: str @@ -39,24 +39,24 @@ class YamlTransformColumnsTransformation(BaseModel): ] # Allow both list and float values -class YamlTransformColumns(BaseModel): +class TransformColumns(BaseModel): """Model for transform columns configuration.""" column_name: str - transformations: list[YamlTransformColumnsTransformation] + transformations: list[TransformColumnsTransformation] -class YamlTransform(BaseModel): +class Transform(BaseModel): """Model for transform configuration.""" transformation_name: str - columns: list[YamlTransformColumns] + columns: list[TransformColumns] @field_validator("columns") @classmethod def validate_param_lists_across_columns( - cls, columns: list[YamlTransformColumns] - ) -> list[YamlTransformColumns]: + cls, columns: list[TransformColumns] + ) -> list[TransformColumns]: """Validate that parameter lists across columns have consistent lengths. Args: @@ -94,7 +94,7 @@ def validate_param_lists_across_columns( return columns -class YamlSplit(BaseModel): +class Split(BaseModel): """Model for split configuration.""" split_method: str @@ -102,33 +102,33 @@ class YamlSplit(BaseModel): split_input_columns: list[str] -class YamlConfigDict(BaseModel): +class ConfigDict(BaseModel): """Model for main YAML configuration.""" - global_params: YamlGlobalParams - columns: list[YamlColumns] - transforms: list[YamlTransform] - split: list[YamlSplit] + global_params: GlobalParams + columns: list[Columns] + transforms: list[Transform] + split: list[Split] -class YamlSubConfigDict(BaseModel): +class SubConfigDict(BaseModel): """Model for sub-configuration generated from main config.""" - global_params: YamlGlobalParams - columns: list[YamlColumns] - transforms: list[YamlTransform] - split: YamlSplit + global_params: GlobalParams + columns: list[Columns] + transforms: list[Transform] + split: Split -class YamlSchema(BaseModel): +class Schema(BaseModel): """Model for validating YAML schema.""" - yaml_conf: YamlConfigDict + yaml_conf: ConfigDict def extract_transform_parameters_at_index( - transform: YamlTransform, index: int = 0 -) -> YamlTransform: + transform: Transform, index: int = 0 +) -> Transform: """Get a transform with parameters at the specified index. Args: @@ -139,7 +139,7 @@ def extract_transform_parameters_at_index( A new transform with single parameter values at the specified index """ # Create a copy of the transform - new_transform = YamlTransform(**transform.model_dump()) + new_transform = Transform(**transform.model_dump()) # Process each column and transformation for column in new_transform.columns: @@ -158,8 +158,8 @@ def extract_transform_parameters_at_index( def expand_transform_parameter_combinations( - transform: YamlTransform, -) -> list[YamlTransform]: + transform: Transform, +) -> list[Transform]: """Get all possible transforms by extracting parameters at each valid index. For a transform with parameter lists, creates multiple new transforms, each containing @@ -197,8 +197,8 @@ def expand_transform_parameter_combinations( def expand_transform_list_combinations( - transform_list: list[YamlTransform], -) -> list[YamlTransform]: + transform_list: list[Transform], +) -> list[Transform]: """Expands a list of transforms into all possible parameter combinations. Takes a list of transforms where each transform may contain parameter lists, @@ -207,11 +207,11 @@ def expand_transform_list_combinations( create two transforms: one with 0.1/1 and another with 0.2/2. Args: - transform_list: A list of YamlTransform objects containing parameter lists + transform_list: A list of Transform objects containing parameter lists that need to be expanded into individual transforms. Returns: - list[YamlTransform]: A flattened list of transforms where each transform + list[Transform]: A flattened list of transforms where each transform has single parameter values instead of parameter lists. The length of the returned list will be the sum of the number of parameter combinations for each input transform. @@ -222,7 +222,7 @@ def expand_transform_list_combinations( return sub_transforms -def generate_data_configs(yaml_config: YamlConfigDict) -> list[YamlSubConfigDict]: +def generate_data_configs(yaml_config: ConfigDict) -> list[SubConfigDict]: """Generates all possible data configurations from a YAML config. Takes a YAML configuration that may contain parameter lists and splits, @@ -246,19 +246,19 @@ def generate_data_configs(yaml_config: YamlConfigDict) -> list[YamlSubConfigDict parameter lists and multiple splits. Returns: - list[YamlSubConfigDict]: A list of data configurations, where each + list[SubConfigDict]: A list of data configurations, where each config has single parameter values and one split configuration. The length will be the product of the number of parameter combinations and the number of splits. """ - if isinstance(yaml_config, dict) and not isinstance(yaml_config, YamlConfigDict): - raise TypeError("Input must be a YamlConfigDict object") + if isinstance(yaml_config, dict) and not isinstance(yaml_config, ConfigDict): + raise TypeError("Input must be a ConfigDict object") sub_splits = yaml_config.split sub_configs = [] for split in sub_splits: sub_configs.append( - YamlSubConfigDict( + SubConfigDict( global_params=yaml_config.global_params, columns=yaml_config.columns, transforms=yaml_config.transforms, @@ -269,7 +269,7 @@ def generate_data_configs(yaml_config: YamlConfigDict) -> list[YamlSubConfigDict def dump_yaml_list_into_files( - yaml_list: list[YamlSubConfigDict], + yaml_list: list[SubConfigDict], directory_path: str, base_name: str, ) -> None: @@ -378,14 +378,14 @@ def fix_params(input_dict: dict[str, Any]) -> dict[str, Any]: ) -def check_yaml_schema(config_yaml: YamlConfigDict) -> str: +def check_yaml_schema(config_yaml: ConfigDict) -> str: """Validate YAML configuration fields have correct types. If the children field is specific to a parent, the children fields class is hosted in the parent fields class. If any field in not the right type, the function prints an error message explaining the problem and exits the python code. Args: - config_yaml: The YamlConfigDict containing the fields of the yaml configuration file + config_yaml: The ConfigDict containing the fields of the yaml configuration file Returns: str: Empty string if validation succeeds @@ -394,7 +394,7 @@ def check_yaml_schema(config_yaml: YamlConfigDict) -> str: ValueError: If validation fails """ try: - YamlSchema(yaml_conf=config_yaml) + Schema(yaml_conf=config_yaml) except ValidationError as e: # Use logging instead of print for error handling raise ValueError("Wrong type on a field, see the pydantic report above") from e From b213b2f24555b97f38b2cb6b3670dd784a31795e Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 12 Feb 2025 17:13:22 +0100 Subject: [PATCH 06/35] FIX: utils/yaml_data.py -> Put 'Yaml' back to modify that in another branch --- src/stimulus/utils/yaml_data.py | 84 ++++++++++++++++----------------- 1 file changed, 42 insertions(+), 42 deletions(-) diff --git a/src/stimulus/utils/yaml_data.py b/src/stimulus/utils/yaml_data.py index 38b81c1a..ea8532b0 100644 --- a/src/stimulus/utils/yaml_data.py +++ b/src/stimulus/utils/yaml_data.py @@ -6,13 +6,13 @@ from pydantic import BaseModel, ValidationError, field_validator -class GlobalParams(BaseModel): +class YamlGlobalParams(BaseModel): """Model for global parameters in YAML configuration.""" seed: int -class ColumnsEncoder(BaseModel): +class YamlColumnsEncoder(BaseModel): """Model for column encoder configuration.""" name: str @@ -21,16 +21,16 @@ class ColumnsEncoder(BaseModel): ] # Allow both string and list values -class Columns(BaseModel): +class YamlColumns(BaseModel): """Model for column configuration.""" column_name: str column_type: str data_type: str - encoder: list[ColumnsEncoder] + encoder: list[YamlColumnsEncoder] -class TransformColumnsTransformation(BaseModel): +class YamlTransformColumnsTransformation(BaseModel): """Model for column transformation configuration.""" name: str @@ -39,24 +39,24 @@ class TransformColumnsTransformation(BaseModel): ] # Allow both list and float values -class TransformColumns(BaseModel): +class YamlTransformColumns(BaseModel): """Model for transform columns configuration.""" column_name: str - transformations: list[TransformColumnsTransformation] + transformations: list[YamlTransformColumnsTransformation] -class Transform(BaseModel): +class YamlTransform(BaseModel): """Model for transform configuration.""" transformation_name: str - columns: list[TransformColumns] + columns: list[YamlTransformColumns] @field_validator("columns") @classmethod def validate_param_lists_across_columns( - cls, columns: list[TransformColumns] - ) -> list[TransformColumns]: + cls, columns: list[YamlTransformColumns] + ) -> list[YamlTransformColumns]: """Validate that parameter lists across columns have consistent lengths. Args: @@ -94,7 +94,7 @@ def validate_param_lists_across_columns( return columns -class Split(BaseModel): +class YamlSplit(BaseModel): """Model for split configuration.""" split_method: str @@ -102,33 +102,33 @@ class Split(BaseModel): split_input_columns: list[str] -class ConfigDict(BaseModel): +class YamlConfigDict(BaseModel): """Model for main YAML configuration.""" - global_params: GlobalParams - columns: list[Columns] - transforms: list[Transform] - split: list[Split] + global_params: YamlGlobalParams + columns: list[YamlColumns] + transforms: list[YamlTransform] + split: list[YamlSplit] -class SubConfigDict(BaseModel): +class YamlSubConfigDict(BaseModel): """Model for sub-configuration generated from main config.""" - global_params: GlobalParams - columns: list[Columns] - transforms: list[Transform] - split: Split + global_params: YamlGlobalParams + columns: list[YamlColumns] + transforms: list[YamlTransform] + split: YamlSplit -class Schema(BaseModel): +class YamlSchema(BaseModel): """Model for validating YAML schema.""" - yaml_conf: ConfigDict + yaml_conf: YamlConfigDict def extract_transform_parameters_at_index( - transform: Transform, index: int = 0 -) -> Transform: + transform: YamlTransform, index: int = 0 +) -> YamlTransform: """Get a transform with parameters at the specified index. Args: @@ -139,7 +139,7 @@ def extract_transform_parameters_at_index( A new transform with single parameter values at the specified index """ # Create a copy of the transform - new_transform = Transform(**transform.model_dump()) + new_transform = YamlTransform(**transform.model_dump()) # Process each column and transformation for column in new_transform.columns: @@ -158,8 +158,8 @@ def extract_transform_parameters_at_index( def expand_transform_parameter_combinations( - transform: Transform, -) -> list[Transform]: + transform: YamlTransform, +) -> list[YamlTransform]: """Get all possible transforms by extracting parameters at each valid index. For a transform with parameter lists, creates multiple new transforms, each containing @@ -197,8 +197,8 @@ def expand_transform_parameter_combinations( def expand_transform_list_combinations( - transform_list: list[Transform], -) -> list[Transform]: + transform_list: list[YamlTransform], +) -> list[YamlTransform]: """Expands a list of transforms into all possible parameter combinations. Takes a list of transforms where each transform may contain parameter lists, @@ -207,11 +207,11 @@ def expand_transform_list_combinations( create two transforms: one with 0.1/1 and another with 0.2/2. Args: - transform_list: A list of Transform objects containing parameter lists + transform_list: A list of YamlTransform objects containing parameter lists that need to be expanded into individual transforms. Returns: - list[Transform]: A flattened list of transforms where each transform + list[YamlTransform]: A flattened list of transforms where each transform has single parameter values instead of parameter lists. The length of the returned list will be the sum of the number of parameter combinations for each input transform. @@ -222,7 +222,7 @@ def expand_transform_list_combinations( return sub_transforms -def generate_data_configs(yaml_config: ConfigDict) -> list[SubConfigDict]: +def generate_data_configs(yaml_config: YamlConfigDict) -> list[YamlSubConfigDict]: """Generates all possible data configurations from a YAML config. Takes a YAML configuration that may contain parameter lists and splits, @@ -246,19 +246,19 @@ def generate_data_configs(yaml_config: ConfigDict) -> list[SubConfigDict]: parameter lists and multiple splits. Returns: - list[SubConfigDict]: A list of data configurations, where each + list[YamlSubConfigDict]: A list of data configurations, where each config has single parameter values and one split configuration. The length will be the product of the number of parameter combinations and the number of splits. """ - if isinstance(yaml_config, dict) and not isinstance(yaml_config, ConfigDict): - raise TypeError("Input must be a ConfigDict object") + if isinstance(yaml_config, dict) and not isinstance(yaml_config, YamlConfigDict): + raise TypeError("Input must be a YamlConfigDict object") sub_splits = yaml_config.split sub_configs = [] for split in sub_splits: sub_configs.append( - SubConfigDict( + YamlSubConfigDict( global_params=yaml_config.global_params, columns=yaml_config.columns, transforms=yaml_config.transforms, @@ -269,7 +269,7 @@ def generate_data_configs(yaml_config: ConfigDict) -> list[SubConfigDict]: def dump_yaml_list_into_files( - yaml_list: list[SubConfigDict], + yaml_list: list[YamlSubConfigDict], directory_path: str, base_name: str, ) -> None: @@ -378,14 +378,14 @@ def fix_params(input_dict: dict[str, Any]) -> dict[str, Any]: ) -def check_yaml_schema(config_yaml: ConfigDict) -> str: +def check_yaml_schema(config_yaml: YamlConfigDict) -> str: """Validate YAML configuration fields have correct types. If the children field is specific to a parent, the children fields class is hosted in the parent fields class. If any field in not the right type, the function prints an error message explaining the problem and exits the python code. Args: - config_yaml: The ConfigDict containing the fields of the yaml configuration file + config_yaml: The YamlConfigDict containing the fields of the yaml configuration file Returns: str: Empty string if validation succeeds @@ -394,7 +394,7 @@ def check_yaml_schema(config_yaml: ConfigDict) -> str: ValueError: If validation fails """ try: - Schema(yaml_conf=config_yaml) + YamlSchema(yaml_conf=config_yaml) except ValidationError as e: # Use logging instead of print for error handling raise ValueError("Wrong type on a field, see the pydantic report above") from e From beac23503fb89dcfa014c0d87197a57a9dd9b894 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 12 Feb 2025 17:17:09 +0100 Subject: [PATCH 07/35] FIX: tests/cli/__snapshots__/test_split_yaml.ambr -> updated the snapshot for the unique generated file --- tests/cli/__snapshots__/test_split_yaml.ambr | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/cli/__snapshots__/test_split_yaml.ambr b/tests/cli/__snapshots__/test_split_yaml.ambr index e4e7731c..99779398 100644 --- a/tests/cli/__snapshots__/test_split_yaml.ambr +++ b/tests/cli/__snapshots__/test_split_yaml.ambr @@ -1,8 +1,6 @@ # serializer version: 1 # name: test_split_yaml[correct_yaml_path-None] list([ - '0e43b7cdcd8d458cc4e6ff80e06ba7ea', - '43a7f9fbac5c32f51fa51680c7679a57', - 'edf8dd2d39b74619d17b298e3b010c77', + '42139ca7745259e09d1e56e24570d2c7', ]) # --- From cf0f2343d3595700317e64c6b0c3fcf2f1cd6057 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Fri, 14 Feb 2025 17:36:34 +0100 Subject: [PATCH 08/35] NEW: cli/split_split.py -> New file to split the config only on the splits, gives one file per split --- src/stimulus/cli/split_split.py | 81 +++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100755 src/stimulus/cli/split_split.py diff --git a/src/stimulus/cli/split_split.py b/src/stimulus/cli/split_split.py new file mode 100755 index 00000000..5d80d1de --- /dev/null +++ b/src/stimulus/cli/split_split.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python3 +"""CLI module for splitting YAML configuration files into unique files for each split. + +This module provides functionality to split a single YAML configuration file into multiple +YAML files, each containing a unique split. +The resulting YAML files can be used as input configurations for the stimulus package. +""" + +import argparse +from typing import Any + +import yaml + +from stimulus.utils.yaml_data import ( + YamlConfigDict, + YamlSubConfigDict, + check_yaml_schema, + dump_yaml_list_into_files, + generate_split_configs, +) + + +def get_args() -> argparse.Namespace: + """Get the arguments when using from the command line.""" + parser = argparse.ArgumentParser(description="") + parser.add_argument( + "-j", + "--yaml", + type=str, + required=True, + metavar="FILE", + help="The YAML config file that hold all transform - split - parameter info", + ) + parser.add_argument( + "-d", + "--out_dir", + type=str, + required=False, + nargs="?", + const="./", + default="./", + metavar="DIR", + # TODO: Change the output name + help="The output dir where all the YAMLs are written to. Output YAML will be called split-#[number].yaml transform-#[number].yaml. Default -> ./", + ) + + return parser.parse_args() + + +def main(config_yaml: str, out_dir_path: str) -> None: + """Reads a YAML config file and generates a file per unique split. + + This script reads a YAML with a defined structure and creates all the YAML files ready to be passed to + the stimulus package. + + The structure of the YAML is described here -> TODO paste here link to documentation. + This YAML and its structure summarize how to generate unique splits and all the transformations associated to this split. + + This script will always generate at least one YAML file that represent the combination that does not touch the data (no transform) + and uses the default split behavior. + """ + # read the yaml experiment config and load its to dictionary + yaml_config: dict[str, Any] = {} + with open(config_yaml) as conf_file: + yaml_config = yaml.safe_load(conf_file) + + yaml_config_dict: YamlConfigDict = YamlConfigDict(**yaml_config) + # check if the yaml schema is correct + # FIXME: isn't it redundant to check and already class with pydantic ? + check_yaml_schema(yaml_config_dict) + + # generate the yaml files per split + split_configs: list[YamlSubConfigDict] = generate_split_configs(yaml_config_dict) + + # dump all the YAML configs into files + dump_yaml_list_into_files(split_configs, out_dir_path, "test") + + +if __name__ == "__main__": + args = get_args() + main(args.yaml, args.out_dir) From cfdbee6558785250b0878ade436042cb101a7b65 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Fri, 14 Feb 2025 17:38:42 +0100 Subject: [PATCH 09/35] NEW: cli/split_transforms.py -> New file to split the config for each transform, gives one file per transform --- src/stimulus/cli/split_transforms.py | 77 ++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) create mode 100644 src/stimulus/cli/split_transforms.py diff --git a/src/stimulus/cli/split_transforms.py b/src/stimulus/cli/split_transforms.py new file mode 100644 index 00000000..70e85f1e --- /dev/null +++ b/src/stimulus/cli/split_transforms.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python3 +"""CLI module for splitting YAML configuration files into unique files for each transform. + +This module provides functionality to split a single YAML configuration file into multiple +YAML files, each containing a unique transform associated to a unique split. +The resulting YAML files can be used as input configurations for the stimulus package. +""" + +import argparse +from typing import Any + +import yaml + +from stimulus.utils.yaml_data import ( + YamlSubConfigDict, + YamlSubConfigTransformDict, + check_yaml_schema, + dump_yaml_list_into_files, + generate_split_transform_configs, +) + + +def get_args() -> argparse.Namespace: + """Get the arguments when using the command line.""" + parser = argparse.ArgumentParser(description="") + parser.add_argument( + "-j", + "--yaml", + type=str, + required=True, + metavar="FILE", + help="The YAML config file that hold all the transform per split parameter info", + ) + parser.add_argument( + "-d", + "--out-dir", + type=str, + required=False, + nargs="?", + const="./", + default="./", + metavar="DIR", + help="The output dir where all the YAMLs are written to. Output YAML will be called split_transform-#[number].yaml. Default -> ./", + ) + + return parser.parse_args() + + +def main(config_yaml: str, out_dir_path: str) -> None: + """Reads a YAML config and generates files for all split - transform possible combinations. + + This script reads a YAML with a defined structure and creates all the YAML files ready to be passed to the stimulus package. + + The structure of the YAML is described here -> TODO: paste here the link to documentation + This YAML and its structure summarize how to generate all the transform for the split and respective parameter combinations. + + This script will always generate at least one YAML file that represent the combination that does not touch the data (no transform). + """ + # read the yaml experiment config and load its dictionnary + yaml_config: dict[str, Any] = {} + with open(config_yaml) as conf_file: + yaml_config = yaml.safe_load(conf_file) + + yaml_config_dict: YamlSubConfigDict = YamlSubConfigDict(**yaml_config) + + # Generate the yaml files for each transform + split_transform_configs: list[YamlSubConfigTransformDict] = ( + generate_split_transform_configs(yaml_config_dict) + ) + + # Dump all the YAML configs into files + dump_yaml_list_into_files(split_transform_configs, out_dir_path, "test") + + +if __name__ == "__main__": + args = get_args() + main(args.yaml, args.out_dir) From 722137117123e92c9d684694c2771d79a23eb0d6 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Fri, 14 Feb 2025 17:39:37 +0100 Subject: [PATCH 10/35] REMOVE: cli/split_yaml -> deleted the old splitter file because two new ones exist now --- src/stimulus/cli/split_yaml.py | 79 ---------------------------------- 1 file changed, 79 deletions(-) delete mode 100755 src/stimulus/cli/split_yaml.py diff --git a/src/stimulus/cli/split_yaml.py b/src/stimulus/cli/split_yaml.py deleted file mode 100755 index 1a492bad..00000000 --- a/src/stimulus/cli/split_yaml.py +++ /dev/null @@ -1,79 +0,0 @@ -#!/usr/bin/env python3 -"""CLI module for splitting YAML configuration files. - -This module provides functionality to split a single YAML configuration file into multiple -YAML files, each containing a specific combination of data transformations and splits. -The resulting YAML files can be used as input configurations for the stimulus package. -""" - -import argparse -from typing import Any - -import yaml - -from stimulus.utils.yaml_data import ( - YamlConfigDict, - check_yaml_schema, - dump_yaml_list_into_files, - generate_data_configs, -) - - -def get_args() -> argparse.Namespace: - """Get the arguments when using from the command line.""" - parser = argparse.ArgumentParser(description="") - parser.add_argument( - "-j", - "--yaml", - type=str, - required=True, - metavar="FILE", - help="The YAML config file that hold all transform - split - parameter info", - ) - parser.add_argument( - "-d", - "--out_dir", - type=str, - required=False, - nargs="?", - const="./", - default="./", - metavar="DIR", - help="The output dir where all the YAMLs are written to. Output YAML will be called split-#[number].yaml transform-#[number].yaml. Default -> ./", - ) - - return parser.parse_args() - - -def main(config_yaml: str, out_dir_path: str) -> None: - """Reads a YAML config file and generates all possible data configurations. - - This script reads a YAML with a defined structure and creates all the YAML files ready to be passed to - the stimulus package. - - The structure of the YAML is described here -> TODO paste here link to documentation. - This YAML and it's structure summarize how to generate all the transform - split and respective parameter combinations. - Each resulting YAML will hold only one combination of the above three things. - - This script will always generate at least one YAML file that represent the combination that does not touch the data (no transform) - and uses the default split behavior. - """ - # read the yaml experiment config and load it to dictionary - yaml_config: dict[str, Any] = {} - with open(config_yaml) as conf_file: - yaml_config = yaml.safe_load(conf_file) - - yaml_config_dict: YamlConfigDict = YamlConfigDict(**yaml_config) - # check if the yaml schema is correct - check_yaml_schema(yaml_config_dict) - - # generate all the YAML configs - data_configs = generate_data_configs(yaml_config_dict) - - # dump all the YAML configs into files - dump_yaml_list_into_files(data_configs, out_dir_path, "test") - - -if __name__ == "__main__": - args = get_args() - main(args.yaml, args.out_dir) From b8251d388c12e8331d04d4d609da7030d03a4ecb Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Fri, 14 Feb 2025 17:40:33 +0100 Subject: [PATCH 11/35] UPDATE: utils/yaml_data.py -> new functions for the transform splitter --- src/stimulus/utils/yaml_data.py | 75 ++++++++++++++++++++++++++++++--- 1 file changed, 69 insertions(+), 6 deletions(-) diff --git a/src/stimulus/utils/yaml_data.py b/src/stimulus/utils/yaml_data.py index ea8532b0..6330dc87 100644 --- a/src/stimulus/utils/yaml_data.py +++ b/src/stimulus/utils/yaml_data.py @@ -111,12 +111,22 @@ class YamlConfigDict(BaseModel): split: list[YamlSplit] +# TODO: Rename this class to SplitConfigDict class YamlSubConfigDict(BaseModel): """Model for sub-configuration generated from main config.""" global_params: YamlGlobalParams columns: list[YamlColumns] - transforms: list[YamlTransform] + transforms: Union[list[YamlTransform]] + split: YamlSplit + + +class YamlSubConfigTransformDict(BaseModel): + """Model for sub-configuration generated from main config.""" + + global_params: YamlGlobalParams + columns: list[YamlColumns] + transform: Union[YamlTransform] split: YamlSplit @@ -126,6 +136,12 @@ class YamlSchema(BaseModel): yaml_conf: YamlConfigDict +class YamlSplitSchema(BaseModel): + """Model for validating a Split YAML schema.""" + + yaml_conf: YamlSubConfigDict + + def extract_transform_parameters_at_index( transform: YamlTransform, index: int = 0 ) -> YamlTransform: @@ -222,12 +238,11 @@ def expand_transform_list_combinations( return sub_transforms -def generate_data_configs(yaml_config: YamlConfigDict) -> list[YamlSubConfigDict]: - """Generates all possible data configurations from a YAML config. +def generate_split_configs(yaml_config: YamlConfigDict) -> list[YamlSubConfigDict]: + """Generates all possible split configuration from a YAML config. Takes a YAML configuration that may contain parameter lists and splits, - and generates all possible combinations of parameters and splits into - separate data configurations. + and generates all unique splits into separate data configurations. For example, if the config has: - Two transforms with parameters [0.1, 0.2], [0.3, 0.4] @@ -247,7 +262,7 @@ def generate_data_configs(yaml_config: YamlConfigDict) -> list[YamlSubConfigDict Returns: list[YamlSubConfigDict]: A list of data configurations, where each - config has single parameter values and one split configuration. The + config has a list of parameters and one split configuration. The length will be the product of the number of parameter combinations and the number of splits. """ @@ -268,6 +283,54 @@ def generate_data_configs(yaml_config: YamlConfigDict) -> list[YamlSubConfigDict return sub_configs +def generate_split_transform_configs( + yaml_config: YamlSubConfigDict, +) -> list[YamlSubConfigTransformDict]: + """Generates all the transform configuration for a given split + + Takes a YAML configuration that may contain a transform or a list of transform, + and generates all unique transform for a split into separate data configurations. + + For example, if the config has: + - Two transforms with parameters [0.1, 0.2], [0.3, 0.4] + - A split [0.7, 0.3] + This will generate 2 configs, 2 for each split. + transform_config_1: + transform: [0.1, 0.2] + split: [0.7, 0.3] + + transform_config_2: + transform: [0.3, 0.4] + split: [0.7, 0.3] + + Args: + yaml_config: The source YAML configuration containing each + a split with transforms with parameters lists + + Returns: + list[YamlSubConfigTransformDict]: A list of data configurations, where each + config has a list of parameters and one split configuration. The + length will be the product of the number of parameter combinations + and the number of splits. + """ + if isinstance(yaml_config, dict) and not isinstance(yaml_config, YamlSubConfigDict): + raise TypeError("Input must be a list of YamlSubConfigDict") + + split_configs = yaml_config.split + split_transform_config: list[YamlSubConfigTransformDict] = [] + for split_config in split_configs: + for transform in split_configs.get("transforms"): + split_transform_config.append( + YamlSubConfigTransformDict( + global_params=split_config.get("global_params"), + columns=split_config.get("columns"), + transform=transform, + split=split_config.get("split"), + ) + ) + return split_transform_config + + def dump_yaml_list_into_files( yaml_list: list[YamlSubConfigDict], directory_path: str, From a34754e46e50a46baee95829d88180433aa27342 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Fri, 14 Feb 2025 17:58:02 +0100 Subject: [PATCH 12/35] FIX: data/data_handlers.py -> YamlSubConfigDict takes a transform or a list of transforms because the titanic test uses a unique value and the dna uses a list --- src/stimulus/utils/yaml_data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/stimulus/utils/yaml_data.py b/src/stimulus/utils/yaml_data.py index 6330dc87..64610265 100644 --- a/src/stimulus/utils/yaml_data.py +++ b/src/stimulus/utils/yaml_data.py @@ -117,7 +117,7 @@ class YamlSubConfigDict(BaseModel): global_params: YamlGlobalParams columns: list[YamlColumns] - transforms: Union[list[YamlTransform]] + transforms: Union[YamlTransform, list[YamlTransform]] split: YamlSplit @@ -126,7 +126,7 @@ class YamlSubConfigTransformDict(BaseModel): global_params: YamlGlobalParams columns: list[YamlColumns] - transform: Union[YamlTransform] + transform: YamlTransform split: YamlSplit From 6437bd2cb14cf38d11a314d0368de5c3107a37f3 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Mon, 17 Feb 2025 11:26:11 +0100 Subject: [PATCH 13/35] NEW: src/stimulus/cli/split_split.py -> A file that splits all the splits in a yaml config in x files holding unique splits --- src/stimulus/cli/split_split.py | 6 +- .../cli/__snapshots__/test_split_splits.ambr | 6 ++ tests/cli/test_split_splits.py | 58 +++++++++++++++++++ 3 files changed, 67 insertions(+), 3 deletions(-) create mode 100644 tests/cli/__snapshots__/test_split_splits.ambr create mode 100644 tests/cli/test_split_splits.py diff --git a/src/stimulus/cli/split_split.py b/src/stimulus/cli/split_split.py index 5d80d1de..e5ba5d8d 100755 --- a/src/stimulus/cli/split_split.py +++ b/src/stimulus/cli/split_split.py @@ -13,7 +13,7 @@ from stimulus.utils.yaml_data import ( YamlConfigDict, - YamlSubConfigDict, + YamlSplitConfigDict, check_yaml_schema, dump_yaml_list_into_files, generate_split_configs, @@ -70,10 +70,10 @@ def main(config_yaml: str, out_dir_path: str) -> None: check_yaml_schema(yaml_config_dict) # generate the yaml files per split - split_configs: list[YamlSubConfigDict] = generate_split_configs(yaml_config_dict) + split_configs: list[YamlSplitConfigDict] = generate_split_configs(yaml_config_dict) # dump all the YAML configs into files - dump_yaml_list_into_files(split_configs, out_dir_path, "test") + dump_yaml_list_into_files(split_configs, out_dir_path, "test_split") if __name__ == "__main__": diff --git a/tests/cli/__snapshots__/test_split_splits.ambr b/tests/cli/__snapshots__/test_split_splits.ambr new file mode 100644 index 00000000..7ab227de --- /dev/null +++ b/tests/cli/__snapshots__/test_split_splits.ambr @@ -0,0 +1,6 @@ +# serializer version: 1 +# name: test_split_split[correct_yaml_path-None] + list([ + '42139ca7745259e09d1e56e24570d2c7', + ]) +# --- diff --git a/tests/cli/test_split_splits.py b/tests/cli/test_split_splits.py new file mode 100644 index 00000000..e0cc445c --- /dev/null +++ b/tests/cli/test_split_splits.py @@ -0,0 +1,58 @@ +"""Tests for the split_split CLI command.""" + +import hashlib +import os +import tempfile +from typing import Any, Callable + +import pytest + +from src.stimulus.cli import split_split + + +# Fixtures +@pytest.fixture +def correct_yaml_path() -> str: + """Fixture that returns the path to a correct YAML file.""" + return "tests/test_data/titanic/titanic.yaml" + + +@pytest.fixture +def wrong_yaml_path() -> str: + """Fixture that returns the path to a wrong YAML file.""" + return "tests/test_data/yaml_files/wrong_field_type.yaml" + + +# Test cases +test_cases = [ + ("correct_yaml_path", None), + ("wrong_yaml_path", ValueError), +] + + +# Tests +@pytest.mark.parametrize(("yaml_type", "error"), test_cases) +def test_split_split( + request: pytest.FixtureRequest, + snapshot: Callable[[], Any], + yaml_type: str, + error: Exception | None, + tmp_path, # Pytest tmp file system +) -> None: + """Tests the CLI command with correct and wrong YAML files.""" + yaml_path = request.getfixturevalue(yaml_type) + tmpdir = tmp_path + if error: + with pytest.raises(error): # type: ignore[call-overload] + split_split.main(yaml_path, tmpdir) + else: + split_split.main(yaml_path, tmpdir) # main() returns None, no need to assert + files = os.listdir(tmpdir) + test_out = [f for f in files if f.startswith("test_")] + hashes = [] + for f in test_out: + with open(os.path.join(tmpdir, f)) as file: + hashes.append(hashlib.md5(file.read().encode()).hexdigest()) # noqa: S324 + assert ( + sorted(hashes) == snapshot + ) # sorted ensures that the order of the hashes does not matter From ac1d70c27bbcebb7f21b4452b7ee3d8af3107b58 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Mon, 17 Feb 2025 11:27:54 +0100 Subject: [PATCH 14/35] NEW: src/stimulus/cli/split_transforms.py -> Cli function to split unique yaml config with splits and multiple transforms in configs with one unique split and one unique transform --- src/stimulus/cli/split_transforms.py | 11 ++-- .../__snapshots__/test_split_transforms.ambr | 7 +++ tests/cli/test_split_transforms.py | 54 +++++++++++++++++++ 3 files changed, 66 insertions(+), 6 deletions(-) create mode 100644 tests/cli/__snapshots__/test_split_transforms.ambr create mode 100644 tests/cli/test_split_transforms.py diff --git a/src/stimulus/cli/split_transforms.py b/src/stimulus/cli/split_transforms.py index 70e85f1e..f3e57717 100644 --- a/src/stimulus/cli/split_transforms.py +++ b/src/stimulus/cli/split_transforms.py @@ -12,9 +12,8 @@ import yaml from stimulus.utils.yaml_data import ( - YamlSubConfigDict, - YamlSubConfigTransformDict, - check_yaml_schema, + YamlSplitConfigDict, + YamlSplitTransformDict, dump_yaml_list_into_files, generate_split_transform_configs, ) @@ -61,15 +60,15 @@ def main(config_yaml: str, out_dir_path: str) -> None: with open(config_yaml) as conf_file: yaml_config = yaml.safe_load(conf_file) - yaml_config_dict: YamlSubConfigDict = YamlSubConfigDict(**yaml_config) + yaml_config_dict: YamlSplitConfigDict = YamlSplitConfigDict(**yaml_config) # Generate the yaml files for each transform - split_transform_configs: list[YamlSubConfigTransformDict] = ( + split_transform_configs: list[YamlSplitTransformDict] = ( generate_split_transform_configs(yaml_config_dict) ) # Dump all the YAML configs into files - dump_yaml_list_into_files(split_transform_configs, out_dir_path, "test") + dump_yaml_list_into_files(split_transform_configs, out_dir_path, "test_transforms") if __name__ == "__main__": diff --git a/tests/cli/__snapshots__/test_split_transforms.ambr b/tests/cli/__snapshots__/test_split_transforms.ambr new file mode 100644 index 00000000..2743eb3f --- /dev/null +++ b/tests/cli/__snapshots__/test_split_transforms.ambr @@ -0,0 +1,7 @@ +# serializer version: 1 +# name: test_split_transforms[correct_yaml_path-None] + list([ + '0e43b7cdcd8d458cc4e6ff80e06ba7ea', + 'e213d72c1df7eda1e0fdee6ccb4bec7f', + ]) +# --- diff --git a/tests/cli/test_split_transforms.py b/tests/cli/test_split_transforms.py new file mode 100644 index 00000000..770ec042 --- /dev/null +++ b/tests/cli/test_split_transforms.py @@ -0,0 +1,54 @@ +"""Test for the split_transforms CLI command""" + +import hashlib +import os +from typing import Any, Callable + +import pytest + +from src.stimulus.cli import split_transforms + + +# Fixtures +@pytest.fixture +def correct_yaml_path() -> str: + """Fixture that returns the path to a correct YAML file with one split only""" + return "tests/test_data/titanic/titanic_unique_split.yaml" + + +@pytest.fixture +def wrong_yaml_path() -> str: + """Fixture that returns the path to a wrong YAML file""" + return "tests/test_data/yaml_files/wrong_field_type.yaml" + + +# Test cases +test_cases = [("correct_yaml_path", None), ("wrong_yaml_path", ValueError)] + + +# Tests +@pytest.mark.parametrize(("yaml_type", "error"), test_cases) +def test_split_transforms( + request: pytest.FixtureRequest, + snapshot: Callable[[], Any], + yaml_type: str, + error: Exception | None, + tmp_path, # Pytest tmp file system +) -> None: + """Tests the CLI command with correct and wrong YAML files.""" + yaml_path: str = request.getfixturevalue(yaml_type) + tmpdir = tmp_path + if error: + with pytest.raises(error): + split_transforms.main(yaml_path, tmpdir) + else: + split_transforms.main(yaml_path, tmpdir) + files = os.listdir(tmpdir) + test_out = [f for f in files if f.startswith("test_")] + hashes = [] + for f in test_out: + with open(os.path.join(tmpdir, f)) as file: + hashes.append(hashlib.md5(file.read().encode()).hexdigest()) + assert ( + sorted(hashes) == snapshot + ) # Sorted ensures that the order of the hashes does not matter From 359b86eb7d77e44c98a9a121b0ab29be36dba1c8 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Mon, 17 Feb 2025 11:29:05 +0100 Subject: [PATCH 15/35] NEW: tests/test_data/titanic/titanic_unique_split.yaml -> A config file with a unique split and multiple Transforms --- .../titanic/titanic_unique_split.yaml | 91 +++++++++++++++++++ 1 file changed, 91 insertions(+) create mode 100644 tests/test_data/titanic/titanic_unique_split.yaml diff --git a/tests/test_data/titanic/titanic_unique_split.yaml b/tests/test_data/titanic/titanic_unique_split.yaml new file mode 100644 index 00000000..3ef7d17d --- /dev/null +++ b/tests/test_data/titanic/titanic_unique_split.yaml @@ -0,0 +1,91 @@ + +global_params: + seed: 42 + +columns: + - column_name: passenger_id + column_type: meta + data_type: int + encoder: + - name: NumericEncoder + params: {} + - column_name: survived + column_type: label + data_type: int + encoder: + - name: NumericEncoder + params: {} + - column_name: pclass + column_type: input + data_type: int + encoder: + - name: NumericEncoder + params: {} + - column_name: sex + column_type: input + data_type: str + encoder: + - name: StrClassificationEncoder + params: {} + - column_name: age + column_type: input + data_type: float + encoder: + - name: NumericEncoder + params: {} + - column_name: sibsp + column_type: input + data_type: int + encoder: + - name: NumericEncoder + params: {} + - column_name: parch + column_type: input + data_type: int + encoder: + - name: NumericEncoder + params: {} + - column_name: fare + column_type: input + data_type: float + encoder: + - name: NumericEncoder + params: {} + - column_name: embarked + column_type: input + data_type: str + encoder: + - name: StrClassificationEncoder + params: {} + +transforms: + - transformation_name: noise + columns: + - column_name: age + transformations: + - name: GaussianNoise + params: + std: 0.1 + - column_name: fare + transformations: + - name: GaussianNoise + params: + std: 0.1 + - transformation_name: noise2 + columns: + - column_name: age + transformations: + - name: GaussianNoise + params: + std: 0.1 + - column_name: fare + transformations: + - name: GaussianNoise + params: + std: 0.1 + +split: + split_method: RandomSplit + params: + split: [0.7, 0.15, 0.15] + split_input_columns: [age] From 4aca687bc7e448f559317d801e1ae589bdb74846 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Mon, 17 Feb 2025 11:31:11 +0100 Subject: [PATCH 16/35] DELETE: tests/cli/__snapshots__/test_split_yaml.ambr -> Removed this file because now the the config are first separated per unique split and then per unique transforms --- tests/cli/__snapshots__/test_split_yaml.ambr | 6 ------ 1 file changed, 6 deletions(-) delete mode 100644 tests/cli/__snapshots__/test_split_yaml.ambr diff --git a/tests/cli/__snapshots__/test_split_yaml.ambr b/tests/cli/__snapshots__/test_split_yaml.ambr deleted file mode 100644 index 99779398..00000000 --- a/tests/cli/__snapshots__/test_split_yaml.ambr +++ /dev/null @@ -1,6 +0,0 @@ -# serializer version: 1 -# name: test_split_yaml[correct_yaml_path-None] - list([ - '42139ca7745259e09d1e56e24570d2c7', - ]) -# --- From 887d9f1b7cb3d16116403304ed367bcc0c16c3a5 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Mon, 17 Feb 2025 11:33:49 +0100 Subject: [PATCH 17/35] {src/stimulus,tests}/cli/check_model.py -> Input is now a yaml with one split and one transform and changed the test file is just formatting --- src/stimulus/cli/check_model.py | 25 +++++++++++++++++++++---- tests/cli/test_check_model.py | 18 +++++++++++++++--- 2 files changed, 36 insertions(+), 7 deletions(-) diff --git a/src/stimulus/cli/check_model.py b/src/stimulus/cli/check_model.py index a127f186..1fe13fa3 100755 --- a/src/stimulus/cli/check_model.py +++ b/src/stimulus/cli/check_model.py @@ -21,8 +21,22 @@ def get_args() -> argparse.Namespace: Parsed command line arguments. """ parser = argparse.ArgumentParser(description="Launch check_model.") - parser.add_argument("-d", "--data", type=str, required=True, metavar="FILE", help="Path to input csv file.") - parser.add_argument("-m", "--model", type=str, required=True, metavar="FILE", help="Path to model file.") + parser.add_argument( + "-d", + "--data", + type=str, + required=True, + metavar="FILE", + help="Path to input csv file.", + ) + parser.add_argument( + "-m", + "--model", + type=str, + required=True, + metavar="FILE", + help="Path to model file.", + ) parser.add_argument( "-e", "--data_config", @@ -106,14 +120,17 @@ def main( """ with open(data_config_path) as file: data_config = yaml.safe_load(file) - data_config = yaml_data.YamlSubConfigDict(**data_config) + # FIXME: LEQUEL DES DEUX ? + data_config = yaml_data.YamlSplitTransformDict(**data_config) with open(model_config_path) as file: model_config = yaml.safe_load(file) model_config = yaml_model_schema.Model(**model_config) encoder_loader = loaders.EncoderLoader() - encoder_loader.initialize_column_encoders_from_config(column_config=data_config.columns) + encoder_loader.initialize_column_encoders_from_config( + column_config=data_config.columns + ) logger.info("Dataset loaded successfully.") diff --git a/tests/cli/test_check_model.py b/tests/cli/test_check_model.py index 7115e809..e00a1a5d 100644 --- a/tests/cli/test_check_model.py +++ b/tests/cli/test_check_model.py @@ -13,13 +13,23 @@ @pytest.fixture def data_path() -> str: """Get path to test data CSV file.""" - return str(Path(__file__).parent.parent / "test_data" / "titanic" / "titanic_stimulus_split.csv") + return str( + Path(__file__).parent.parent + / "test_data" + / "titanic" + / "titanic_stimulus_split.csv" + ) @pytest.fixture def data_config() -> str: """Get path to test data config YAML.""" - return str(Path(__file__).parent.parent / "test_data" / "titanic" / "titanic_sub_config.yaml") + return str( + Path(__file__).parent.parent + / "test_data" + / "titanic" + / "titanic_sub_config.yaml" + ) @pytest.fixture @@ -34,7 +44,9 @@ def model_config() -> str: return str(Path(__file__).parent.parent / "test_model" / "titanic_model_cpu.yaml") -def test_check_model_main(data_path: str, data_config: str, model_path: str, model_config: str) -> None: +def test_check_model_main( + data_path: str, data_config: str, model_path: str, model_config: str +) -> None: """Test that check_model.main runs without errors. Args: From ff0881cca3f82be96420088b35cc528a4c36631a Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Mon, 17 Feb 2025 11:36:04 +0100 Subject: [PATCH 18/35] FIX: src/stimulus/cli/split_csv.py -> Updated to take a YamlSplitConfigDict --- src/stimulus/cli/split_csv.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/stimulus/cli/split_csv.py b/src/stimulus/cli/split_csv.py index 04de0e56..44afc03d 100755 --- a/src/stimulus/cli/split_csv.py +++ b/src/stimulus/cli/split_csv.py @@ -7,7 +7,7 @@ from stimulus.data.data_handlers import DatasetProcessor, SplitManager from stimulus.data.loaders import SplitLoader -from stimulus.utils.yaml_data import YamlSubConfigDict +from stimulus.utils.yaml_data import YamlSplitConfigDict def get_args() -> argparse.Namespace: @@ -49,7 +49,9 @@ def get_args() -> argparse.Namespace: return parser.parse_args() -def main(data_csv: str, config_yaml: str, out_path: str, *, force: bool = False) -> None: +def main( + data_csv: str, config_yaml: str, out_path: str, *, force: bool = False +) -> None: """Connect CSV and YAML configuration and handle sanity checks. Args: @@ -64,7 +66,7 @@ def main(data_csv: str, config_yaml: str, out_path: str, *, force: bool = False) # create a split manager from the config split_config = processor.dataset_manager.config.split with open(config_yaml) as f: - yaml_config = YamlSubConfigDict(**yaml.safe_load(f)) + yaml_config = YamlSplitConfigDict(**yaml.safe_load(f)) split_loader = SplitLoader(seed=yaml_config.global_params.seed) split_loader.initialize_splitter_from_config(split_config) split_manager = SplitManager(split_loader) From 96b3b5bd97e2a6cd98dca7dc085ebd1978c05864 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Mon, 17 Feb 2025 14:22:08 +0100 Subject: [PATCH 19/35] FIX: tests/cli/test_shuffle_csv.py -> changed the test file to be the one with only a unique split and multiple transforms --- tests/cli/test_shuffle_csv.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cli/test_shuffle_csv.py b/tests/cli/test_shuffle_csv.py index 8fef803c..dff78887 100644 --- a/tests/cli/test_shuffle_csv.py +++ b/tests/cli/test_shuffle_csv.py @@ -14,7 +14,7 @@ @pytest.fixture def correct_yaml_path() -> str: """Fixture that returns the path to a correct YAML file.""" - return "tests/test_data/titanic/titanic_sub_config.yaml" + return "tests/test_data/titanic/titanic_unique_split.yaml" @pytest.fixture From f4edfabbf57fa8c7666d5f26da13908c75647ec1 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Tue, 18 Feb 2025 10:57:48 +0100 Subject: [PATCH 20/35] FIX: {src/stimulus, tests}/data/data_handlers.py -> Changed the input to not be a file but a YamlSplitTransformDict object --- src/stimulus/data/data_handlers.py | 109 +++++++++++++++++++++-------- tests/data/test_data_handlers.py | 90 ++++++++++++++++++------ 2 files changed, 147 insertions(+), 52 deletions(-) diff --git a/src/stimulus/data/data_handlers.py b/src/stimulus/data/data_handlers.py index e7d6fc9f..7e1edd8f 100644 --- a/src/stimulus/data/data_handlers.py +++ b/src/stimulus/data/data_handlers.py @@ -50,10 +50,11 @@ class DatasetManager: def __init__( self, - config_path: str, + config_dict: yaml_data.YamlSplitConfigDict, ) -> None: """Initialize the DatasetManager.""" - self.config = self._load_config(config_path) + # self.config = self._load_config(config_path) + self.config: yaml_data.YamlSplitTransformDict = config_dict self.column_categories = self.categorize_columns_by_type() def categorize_columns_by_type(self) -> dict: @@ -93,7 +94,8 @@ def categorize_columns_by_type(self) -> dict: return {"input": input_columns, "label": label_columns, "meta": meta_columns} - def _load_config(self, config_path: str) -> yaml_data.YamlConfigDict: + # TODO: Remove or change this function as the config is now preloaded + def _load_config(self, config_path: str) -> yaml_data.YamlSplitConfigDict: """Loads and parses a YAML configuration file. Args: @@ -108,8 +110,11 @@ def _load_config(self, config_path: str) -> yaml_data.YamlConfigDict: >>> print(config["columns"][0]["column_name"]) 'hello' """ + with open(config_path) as file: - return yaml_data.YamlSubConfigDict(**yaml.safe_load(file)) + # FIXME: cette fonction est appellé pour test_shuffle_csv et test_tune + return yaml_data.YamlSplitConfigDict(**yaml.safe_load(file)) + return yaml_data.YamlSplitTransformDict(**yaml.safe_load(file)) def get_split_columns(self) -> list[str]: """Get the columns that are used for splitting.""" @@ -185,7 +190,8 @@ def encode_column(self, column_name: str, column_data: list) -> torch.Tensor: >>> print(encoded.shape) torch.Size([2, 4, 4]) # 2 sequences, length 4, one-hot encoded """ - encode_all_function = self.encoder_loader.get_function_encode_all(column_name) + encode_all_function = self.encoder_loader.get_function_encode_all( + column_name) return encode_all_function(column_data) def encode_columns(self, column_data: dict) -> dict: @@ -207,11 +213,16 @@ def encode_columns(self, column_data: dict) -> dict: >>> print(encoded["dna_seq"].shape) torch.Size([2, 4, 4]) # 2 sequences, length 4, one-hot encoded """ - return {col: self.encode_column(col, values) for col, values in column_data.items()} + return { + col: self.encode_column(col, values) for col, values in column_data.items() + } def encode_dataframe(self, dataframe: pl.DataFrame) -> dict[str, torch.Tensor]: """Encode the dataframe using the encoders.""" - return {col: self.encode_column(col, dataframe[col].to_list()) for col in dataframe.columns} + return { + col: self.encode_column(col, dataframe[col].to_list()) + for col in dataframe.columns + } class TransformManager: @@ -224,7 +235,9 @@ def __init__( """Initialize the TransformManager.""" self.transform_loader = transform_loader - def transform_column(self, column_name: str, transform_name: str, column_data: list) -> tuple[list, bool]: + def transform_column( + self, column_name: str, transform_name: str, column_data: list + ) -> tuple[list, bool]: """Transform a column of data using the specified transformation. Args: @@ -236,7 +249,9 @@ def transform_column(self, column_name: str, transform_name: str, column_data: l list: The transformed data. bool: Whether the transformation added new rows to the data. """ - transformer = self.transform_loader.__getattribute__(column_name)[transform_name] + transformer = self.transform_loader.__getattribute__(column_name)[ + transform_name + ] return transformer.transform_all(column_data), transformer.add_row @@ -250,7 +265,9 @@ def __init__( """Initialize the SplitManager.""" self.split_loader = split_loader - def get_split_indices(self, data: dict) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + def get_split_indices( + self, data: dict + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """Get the indices for train, validation, and test splits.""" return self.split_loader.get_function_split()(data) @@ -353,7 +370,8 @@ def add_split(self, split_manager: SplitManager, *, force: bool = False) -> None split_input_data = self.select_columns(split_columns) # get the split indices - train, validation, test = split_manager.get_split_indices(split_input_data) + train, validation, test = split_manager.get_split_indices( + split_input_data) # add the split column to the data split_column = np.full(len(self.data), -1).astype(int) @@ -367,17 +385,25 @@ def add_split(self, split_manager: SplitManager, *, force: bool = False) -> None def apply_transformation_group(self, transform_manager: TransformManager) -> None: """Apply the transformation group to the data.""" - for column_name, transform_name, _params in self.dataset_manager.get_transform_logic()["transformations"]: + for ( + column_name, + transform_name, + _params, + ) in self.dataset_manager.get_transform_logic()["transformations"]: transformed_data, add_row = transform_manager.transform_column( column_name, transform_name, self.data[column_name], ) if add_row: - new_rows = self.data.with_columns(pl.Series(column_name, transformed_data)) + new_rows = self.data.with_columns( + pl.Series(column_name, transformed_data) + ) self.data = pl.vstack(self.data, new_rows) else: - self.data = self.data.with_columns(pl.Series(column_name, transformed_data)) + self.data = self.data.with_columns( + pl.Series(column_name, transformed_data) + ) def shuffle_labels(self, seed: Optional[float] = None) -> None: """Shuffles the labels in the data.""" @@ -386,7 +412,9 @@ def shuffle_labels(self, seed: Optional[float] = None) -> None: label_keys = self.dataset_manager.column_categories["label"] for key in label_keys: - self.data = self.data.with_columns(pl.Series(key, np.random.permutation(list(self.data[key])))) + self.data = self.data.with_columns( + pl.Series(key, np.random.permutation(list(self.data[key]))) + ) class DatasetLoader(DatasetHandler): @@ -402,7 +430,11 @@ def __init__( """Initialize the DatasetLoader.""" super().__init__(config_path, csv_path) self.encoder_manager = EncodeManager(encoder_loader) - self.data = self.load_csv_per_split(csv_path, split) if split is not None else self.load_csv(csv_path) + self.data = ( + self.load_csv_per_split(csv_path, split) + if split is not None + else self.load_csv(csv_path) + ) def get_all_items(self) -> tuple[dict, dict, dict]: """Get the full dataset as three separate dictionaries for inputs, labels and metadata. @@ -428,8 +460,10 @@ def get_all_items(self) -> tuple[dict, dict, dict]: self.dataset_manager.column_categories["label"], self.dataset_manager.column_categories["meta"], ) - input_data = self.encoder_manager.encode_dataframe(self.data[input_columns]) - label_data = self.encoder_manager.encode_dataframe(self.data[label_columns]) + input_data = self.encoder_manager.encode_dataframe( + self.data[input_columns]) + label_data = self.encoder_manager.encode_dataframe( + self.data[label_columns]) meta_data = {key: self.data[key].to_list() for key in meta_columns} return input_data, label_data, meta_data @@ -447,16 +481,21 @@ def load_csv_per_split(self, csv_path: str, split: int) -> pl.DataFrame: we are gonna load only the relevant data for it. """ if "split" not in self.columns: - raise ValueError("The category split is not present in the csv file") + raise ValueError( + "The category split is not present in the csv file") if split not in [0, 1, 2]: - raise ValueError(f"The split value should be 0, 1 or 2. The specified split value is {split}") + raise ValueError( + f"The split value should be 0, 1 or 2. The specified split value is {split}" + ) return pl.scan_csv(csv_path).filter(pl.col("split") == split).collect() def __len__(self) -> int: """Return the length of the first list in input, assumes that all are the same length.""" return len(self.data) - def __getitem__(self, idx: Any) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor], dict[str, list]]: + def __getitem__( + self, idx: Any + ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor], dict[str, list]]: """Get the data at a given index, and encodes the input and label, leaving meta as it is. Args: @@ -474,17 +513,24 @@ def __getitem__(self, idx: Any) -> tuple[dict[str, torch.Tensor], dict[str, torc data_at_index = self.data.slice(start, stop - start) # Process DataFrame - input_data = self.encoder_manager.encode_dataframe(data_at_index[input_columns]) - label_data = self.encoder_manager.encode_dataframe(data_at_index[label_columns]) - meta_data = {key: data_at_index[key].to_list() for key in meta_columns} + input_data = self.encoder_manager.encode_dataframe( + data_at_index[input_columns] + ) + label_data = self.encoder_manager.encode_dataframe( + data_at_index[label_columns] + ) + meta_data = {key: data_at_index[key].to_list() + for key in meta_columns} elif isinstance(idx, int): # For single row, convert to dict with column names as keys row_dict = dict(zip(self.data.columns, self.data.row(idx))) # Create single-row DataFrames for encoding - input_df = pl.DataFrame({col: [row_dict[col]] for col in input_columns}) - label_df = pl.DataFrame({col: [row_dict[col]] for col in label_columns}) + input_df = pl.DataFrame( + {col: [row_dict[col]] for col in input_columns}) + label_df = pl.DataFrame( + {col: [row_dict[col]] for col in label_columns}) input_data = self.encoder_manager.encode_dataframe(input_df) label_data = self.encoder_manager.encode_dataframe(label_df) @@ -494,8 +540,13 @@ def __getitem__(self, idx: Any) -> tuple[dict[str, torch.Tensor], dict[str, torc data_at_index = self.data.select(idx) # Process DataFrame - input_data = self.encoder_manager.encode_dataframe(data_at_index[input_columns]) - label_data = self.encoder_manager.encode_dataframe(data_at_index[label_columns]) - meta_data = {key: data_at_index[key].to_list() for key in meta_columns} + input_data = self.encoder_manager.encode_dataframe( + data_at_index[input_columns] + ) + label_data = self.encoder_manager.encode_dataframe( + data_at_index[label_columns] + ) + meta_data = {key: data_at_index[key].to_list() + for key in meta_columns} return input_data, label_data, meta_data diff --git a/tests/data/test_data_handlers.py b/tests/data/test_data_handlers.py index 725f275b..144f4416 100644 --- a/tests/data/test_data_handlers.py +++ b/tests/data/test_data_handlers.py @@ -14,15 +14,18 @@ ) from stimulus.utils.yaml_data import ( YamlConfigDict, + YamlSplitConfigDict, + YamlSplitTransformDict, YamlTransform, YamlTransformColumns, YamlTransformColumnsTransformation, - generate_data_configs, + generate_split_configs, + generate_split_transform_configs, ) # Fixtures -## Data fixtures +# Data fixtures @pytest.fixture def titanic_csv_path() -> str: """Get path to test Titanic CSV file. @@ -67,20 +70,29 @@ def generate_sub_configs(base_config: YamlConfigDict) -> list[YamlConfigDict]: Returns: list[YamlConfigDict]: List of generated configurations """ - return generate_data_configs(base_config) + split_configs: list[YamlSplitConfigDict] = generate_split_configs( + base_config) + split_transform_list: list[YamlSplitTransformDict] = [] + for split in split_configs: + split_transform_list.extend(generate_split_transform_configs(split)) + return split_transform_list @pytest.fixture -def dump_single_split_config_to_disk() -> str: +def dump_single_split_config_to_disk() -> YamlSplitTransformDict: """Get path for dumping single split config. Returns: str: Path to dump config file """ - return "tests/test_data/titanic/titanic_sub_config.yaml" + config_dict: YamlSplitTransformDict + path: str = "tests/test_data/titanic/titanic_sub_config.yaml" + with open(path) as f: + config_dict = YamlSplitTransformDict(**yaml.safe_load(f)) + return config_dict -## Loader fixtures +# Loader fixtures @pytest.fixture def encoder_loader(generate_sub_configs: list[YamlConfigDict]) -> loaders.EncoderLoader: """Create encoder loader with initialized encoders. @@ -92,12 +104,15 @@ def encoder_loader(generate_sub_configs: list[YamlConfigDict]) -> loaders.Encode experiments.EncoderLoader: Initialized encoder loader """ loader = loaders.EncoderLoader() - loader.initialize_column_encoders_from_config(generate_sub_configs[0].columns) + loader.initialize_column_encoders_from_config( + generate_sub_configs[0].columns) return loader @pytest.fixture -def transform_loader(generate_sub_configs: list[YamlConfigDict]) -> loaders.TransformLoader: +def transform_loader( + generate_sub_configs: list[YamlConfigDict], +) -> loaders.TransformLoader: """Create transform loader with initialized transformers. Args: @@ -107,7 +122,9 @@ def transform_loader(generate_sub_configs: list[YamlConfigDict]) -> loaders.Tran experiments.TransformLoader: Initialized transform loader """ loader = loaders.TransformLoader() - loader.initialize_column_data_transformers_from_config(generate_sub_configs[0].transforms) + loader.initialize_column_data_transformers_from_config( + generate_sub_configs[0].transforms + ) return loader @@ -127,14 +144,18 @@ def split_loader(generate_sub_configs: list[YamlConfigDict]) -> loaders.SplitLoa # Test DatasetManager -def test_dataset_manager_init(dump_single_split_config_to_disk: str) -> None: +def test_dataset_manager_init( + dump_single_split_config_to_disk: YamlSplitTransformDict, +) -> None: """Test initialization of DatasetManager.""" manager = DatasetManager(dump_single_split_config_to_disk) assert hasattr(manager, "config") assert hasattr(manager, "column_categories") -def test_dataset_manager_organize_columns(dump_single_split_config_to_disk: str) -> None: +def test_dataset_manager_organize_columns( + dump_single_split_config_to_disk: str, +) -> None: """Test column organization by type.""" manager = DatasetManager(dump_single_split_config_to_disk) categories = manager.categorize_columns_by_type() @@ -146,7 +167,9 @@ def test_dataset_manager_organize_columns(dump_single_split_config_to_disk: str) assert "passenger_id" in categories["meta"] -def test_dataset_manager_organize_transforms(dump_single_split_config_to_disk: str) -> None: +def test_dataset_manager_organize_transforms( + dump_single_split_config_to_disk: str, +) -> None: """Test transform organization.""" manager = DatasetManager(dump_single_split_config_to_disk) categories = manager.categorize_columns_by_type() @@ -155,7 +178,9 @@ def test_dataset_manager_organize_transforms(dump_single_split_config_to_disk: s assert all(key in categories for key in ["input", "label", "meta"]) -def test_dataset_manager_get_transform_logic(dump_single_split_config_to_disk: str) -> None: +def test_dataset_manager_get_transform_logic( + dump_single_split_config_to_disk: str, +) -> None: """Test getting transform logic from config.""" manager = DatasetManager(dump_single_split_config_to_disk) transform_logic = manager.get_transform_logic() @@ -221,10 +246,12 @@ def test_transform_manager_transform_column() -> None: ), ], ) - transform_loader.initialize_column_data_transformers_from_config(dummy_config) + transform_loader.initialize_column_data_transformers_from_config( + dummy_config) manager = TransformManager(transform_loader) data = [1, 2, 3] - transformed, added_row = manager.transform_column("test_col", "GaussianNoise", data) + transformed, added_row = manager.transform_column( + "test_col", "GaussianNoise", data) assert len(transformed) == len(data) assert added_row is False @@ -303,15 +330,32 @@ def test_dataset_processor_apply_transformation_group( ) processor_control.data = processor_control.load_csv(titanic_csv_path) - processor.apply_transformation_group(transform_manager=TransformManager(transform_loader)) + processor.apply_transformation_group( + transform_manager=TransformManager(transform_loader) + ) - assert processor.data["age"].to_list() != processor_control.data["age"].to_list() - assert processor.data["fare"].to_list() != processor_control.data["fare"].to_list() - assert processor.data["parch"].to_list() == processor_control.data["parch"].to_list() - assert processor.data["sibsp"].to_list() == processor_control.data["sibsp"].to_list() - assert processor.data["pclass"].to_list() == processor_control.data["pclass"].to_list() - assert processor.data["embarked"].to_list() == processor_control.data["embarked"].to_list() - assert processor.data["sex"].to_list() == processor_control.data["sex"].to_list() + assert processor.data["age"].to_list( + ) != processor_control.data["age"].to_list() + assert processor.data["fare"].to_list( + ) != processor_control.data["fare"].to_list() + assert ( + processor.data["parch"].to_list( + ) == processor_control.data["parch"].to_list() + ) + assert ( + processor.data["sibsp"].to_list( + ) == processor_control.data["sibsp"].to_list() + ) + assert ( + processor.data["pclass"].to_list( + ) == processor_control.data["pclass"].to_list() + ) + assert ( + processor.data["embarked"].to_list() + == processor_control.data["embarked"].to_list() + ) + assert processor.data["sex"].to_list( + ) == processor_control.data["sex"].to_list() # Test DatasetLoader From 1dd858c35c675d4995979728c080715020997a8a Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Tue, 18 Feb 2025 13:52:34 +0100 Subject: [PATCH 21/35] FIX: check_model.py -> takes a yaml and then passes on only YamlSplitTransformDict and not the yaml path anymore --- src/stimulus/cli/check_model.py | 8 +-- src/stimulus/data/data_handlers.py | 8 +-- src/stimulus/data/handlertorch.py | 5 +- src/stimulus/learner/raytune_learner.py | 86 ++++++++++++++++++------- tests/cli/test_check_model.py | 8 ++- 5 files changed, 80 insertions(+), 35 deletions(-) diff --git a/src/stimulus/cli/check_model.py b/src/stimulus/cli/check_model.py index 1fe13fa3..8a39c2a2 100755 --- a/src/stimulus/cli/check_model.py +++ b/src/stimulus/cli/check_model.py @@ -120,7 +120,6 @@ def main( """ with open(data_config_path) as file: data_config = yaml.safe_load(file) - # FIXME: LEQUEL DES DEUX ? data_config = yaml_data.YamlSplitTransformDict(**data_config) with open(model_config_path) as file: @@ -138,7 +137,8 @@ def main( logger.info("Model class loaded successfully.") - ray_config_loader = yaml_model_schema.YamlRayConfigLoader(model=model_config) + ray_config_loader = yaml_model_schema.YamlRayConfigLoader( + model=model_config) ray_config_dict = ray_config_loader.get_config().model_dump() ray_config_model = ray_config_loader.get_config() @@ -156,7 +156,7 @@ def main( logger.info("Model instance loaded successfully.") torch_dataset = handlertorch.TorchDataset( - config_path=data_config_path, + data_config=data_config, csv_path=data_path, encoder_loader=encoder_loader, ) @@ -187,7 +187,7 @@ def main( tuner = raytune_learner.TuneWrapper( model_config=ray_config_model, - data_config_path=data_config_path, + data_config=data_config, model_class=model_class, data_path=data_path, encoder_loader=encoder_loader, diff --git a/src/stimulus/data/data_handlers.py b/src/stimulus/data/data_handlers.py index 7e1edd8f..559e7d53 100644 --- a/src/stimulus/data/data_handlers.py +++ b/src/stimulus/data/data_handlers.py @@ -287,7 +287,7 @@ class DatasetHandler: def __init__( self, - config_path: str, + data_config: yaml_data.YamlSplitTransformDict, csv_path: str, ) -> None: """Initialize the DatasetHandler with required config. @@ -296,7 +296,7 @@ def __init__( config_path (str): Path to the dataset configuration file. csv_path (str): Path to the CSV data file. """ - self.dataset_manager = DatasetManager(config_path) + self.dataset_manager = DatasetManager(data_config) self.columns = self.read_csv_header(csv_path) self.data = self.load_csv(csv_path) @@ -422,13 +422,13 @@ class DatasetLoader(DatasetHandler): def __init__( self, - config_path: str, + data_config: yaml_data.YamlSplitTransformDict, csv_path: str, encoder_loader: loaders.EncoderLoader, split: Union[int, None] = None, ) -> None: """Initialize the DatasetLoader.""" - super().__init__(config_path, csv_path) + super().__init__(data_config, csv_path) self.encoder_manager = EncodeManager(encoder_loader) self.data = ( self.load_csv_per_split(csv_path, split) diff --git a/src/stimulus/data/handlertorch.py b/src/stimulus/data/handlertorch.py index e47f38c9..6d8e0641 100644 --- a/src/stimulus/data/handlertorch.py +++ b/src/stimulus/data/handlertorch.py @@ -5,6 +5,7 @@ from torch.utils.data import Dataset from stimulus.data import data_handlers, loaders +from stimulus.utils.yaml_data import YamlSplitTransformDict class TorchDataset(Dataset): @@ -12,7 +13,7 @@ class TorchDataset(Dataset): def __init__( self, - config_path: str, + data_config: YamlSplitTransformDict, csv_path: str, encoder_loader: loaders.EncoderLoader, split: Optional[int] = None, @@ -26,7 +27,7 @@ def __init__( split: Optional tuple containing split information """ self.loader = data_handlers.DatasetLoader( - config_path=config_path, + data_config=data_config, csv_path=csv_path, encoder_loader=encoder_loader, split=split, diff --git a/src/stimulus/learner/raytune_learner.py b/src/stimulus/learner/raytune_learner.py index 4d86ca47..95749b57 100644 --- a/src/stimulus/learner/raytune_learner.py +++ b/src/stimulus/learner/raytune_learner.py @@ -21,6 +21,7 @@ from stimulus.learner.predict import PredictWrapper from stimulus.utils.generic_utils import set_general_seeds from stimulus.utils.yaml_model_schema import RayTuneModel +from stimulus.utils.yaml_data import YamlSplitTransformDict class CheckpointDict(TypedDict): @@ -35,7 +36,7 @@ class TuneWrapper: def __init__( self, model_config: RayTuneModel, - data_config_path: str, + data_config: YamlSplitTransformDict, model_class: nn.Module, data_path: str, encoder_loader: EncoderLoader, @@ -75,7 +76,10 @@ def __init__( self.run_config = train.RunConfig( name=tune_run_name if tune_run_name is not None - else "TuneModel_" + datetime.datetime.now(tz=datetime.timezone.utc).strftime("%Y-%m-%d_%H-%M-%S"), + else "TuneModel_" + + datetime.datetime.now(tz=datetime.timezone.utc).strftime( + "%Y-%m-%d_%H-%M-%S" + ), storage_path=ray_results_dir, checkpoint_config=train.CheckpointConfig(checkpoint_at_end=True), stop=model_config.tune.run_params.stop, @@ -83,7 +87,8 @@ def __init__( # add the data path to the config if not os.path.exists(data_path): - raise ValueError("Data path does not exist. Given path:" + data_path) + raise ValueError( + "Data path does not exist. Given path:" + data_path) self.config["data_path"] = os.path.abspath(data_path) # Set up tune_run path @@ -93,7 +98,10 @@ def __init__( ray_results_dir, tune_run_name if tune_run_name is not None - else "TuneModel_" + datetime.datetime.now(tz=datetime.timezone.utc).strftime("%Y-%m-%d_%H-%M-%S"), + else "TuneModel_" + + datetime.datetime.now(tz=datetime.timezone.utc).strftime( + "%Y-%m-%d_%H-%M-%S" + ), ) self.config["_debug"] = debug self.config["model"] = model_class @@ -104,7 +112,7 @@ def __init__( self.cpu_per_trial = model_config.tune.cpu_per_trial self.tuner = self.tuner_initialization( - data_config_path=data_config_path, + data_config=data_config, data_path=data_path, encoder_loader=encoder_loader, autoscaler=autoscaler, @@ -112,7 +120,7 @@ def __init__( def tuner_initialization( self, - data_config_path: str, + data_config: YamlSplitTransformDict, data_path: str, encoder_loader: EncoderLoader, *, @@ -130,25 +138,29 @@ def tuner_initialization( "GPU per trial is more than what is available in the cluster, set autoscaler to True to allow for autoscaler to be used.", ) except KeyError as err: - logging.warning(f"KeyError: {err}, no GPU resources available in the cluster: {cluster_res}") + logging.warning( + f"KeyError: {err}, no GPU resources available in the cluster: {cluster_res}" + ) if self.cpu_per_trial > cluster_res["CPU"] and not autoscaler: raise ValueError( "CPU per trial is more than what is available in the cluster, set autoscaler to True to allow for autoscaler to be used.", ) - logging.info(f"PER_TRIAL resources -> GPU: {self.gpu_per_trial} CPU: {self.cpu_per_trial}") + logging.info( + f"PER_TRIAL resources -> GPU: {self.gpu_per_trial} CPU: {self.cpu_per_trial}" + ) # Pre-load and encode datasets once, then put them in Ray's object store training = TorchDataset( - config_path=data_config_path, + data_config=data_config, csv_path=data_path, encoder_loader=encoder_loader, split=0, ) validation = TorchDataset( - config_path=data_config_path, + data_config=data_config, csv_path=data_path, encoder_loader=encoder_loader, split=1, @@ -182,7 +194,12 @@ def tuner_initialization( resources={"cpu": self.cpu_per_trial, "gpu": self.gpu_per_trial}, ) - return tune.Tuner(trainable, tune_config=self.tune_config, param_space=self.config, run_config=self.run_config) + return tune.Tuner( + trainable, + tune_config=self.tune_config, + param_space=self.config, + run_config=self.run_config, + ) def tune(self) -> ray.tune.ResultGrid: """Run the tuning process.""" @@ -221,7 +238,10 @@ def setup(self, config: dict[Any, Any]) -> None: self.step_size = config["tune"]["step_size"] # Get datasets from Ray's object store - training, validation = ray.get(self.config["_training_ref"]), ray.get(self.config["_validation_ref"]) + training, validation = ( + ray.get(self.config["_training_ref"]), + ray.get(self.config["_validation_ref"]), + ) # use dataloader on training/validation data self.batch_size = config["data_params"]["batch_size"] @@ -269,7 +289,8 @@ def step(self) -> dict: for _step_size in range(self.step_size): for x, y, _meta in self.training: # the loss dict could be unpacked with ** and the function declaration handle it differently like **kwargs. to be decided, personally find this more clean and understable. - self.model.batch(x=x, y=y, optimizer=self.optimizer, **self.loss_dict) + self.model.batch( + x=x, y=y, optimizer=self.optimizer, **self.loss_dict) return self.objective() def objective(self) -> dict[str, float]: @@ -284,29 +305,48 @@ def objective(self) -> dict[str, float]: "recall", "spearmanr", ] # TODO maybe we report only a subset of metrics, given certain criteria (eg. if classification or regression) - predict_val = PredictWrapper(self.model, self.validation, loss_dict=self.loss_dict) - predict_train = PredictWrapper(self.model, self.training, loss_dict=self.loss_dict) + predict_val = PredictWrapper( + self.model, self.validation, loss_dict=self.loss_dict + ) + predict_train = PredictWrapper( + self.model, self.training, loss_dict=self.loss_dict + ) return { - **{"val_" + metric: value for metric, value in predict_val.compute_metrics(metrics).items()}, - **{"train_" + metric: value for metric, value in predict_train.compute_metrics(metrics).items()}, + **{ + "val_" + metric: value + for metric, value in predict_val.compute_metrics(metrics).items() + }, + **{ + "train_" + metric: value + for metric, value in predict_train.compute_metrics(metrics).items() + }, } - def export_model(self, export_dir: str | None = None) -> None: # type: ignore[override] + # type: ignore[override] + def export_model(self, export_dir: str | None = None) -> None: """Export model to safetensors format.""" if export_dir is None: return - safe_save_model(self.model, os.path.join(export_dir, "model.safetensors")) + safe_save_model(self.model, os.path.join( + export_dir, "model.safetensors")) def load_checkpoint(self, checkpoint: dict[Any, Any] | None) -> None: """Load model and optimizer state from checkpoint.""" if checkpoint is None: return checkpoint_dir = checkpoint["checkpoint_dir"] - self.model = safe_load_model(self.model, os.path.join(checkpoint_dir, "model.safetensors")) - self.optimizer.load_state_dict(torch.load(os.path.join(checkpoint_dir, "optimizer.pt"))) + self.model = safe_load_model( + self.model, os.path.join(checkpoint_dir, "model.safetensors") + ) + self.optimizer.load_state_dict( + torch.load(os.path.join(checkpoint_dir, "optimizer.pt")) + ) def save_checkpoint(self, checkpoint_dir: str) -> dict[Any, Any]: """Save model and optimizer state to checkpoint.""" - safe_save_model(self.model, os.path.join(checkpoint_dir, "model.safetensors")) - torch.save(self.optimizer.state_dict(), os.path.join(checkpoint_dir, "optimizer.pt")) + safe_save_model(self.model, os.path.join( + checkpoint_dir, "model.safetensors")) + torch.save( + self.optimizer.state_dict(), os.path.join(checkpoint_dir, "optimizer.pt") + ) return {"checkpoint_dir": checkpoint_dir} diff --git a/tests/cli/test_check_model.py b/tests/cli/test_check_model.py index e00a1a5d..273ab63b 100644 --- a/tests/cli/test_check_model.py +++ b/tests/cli/test_check_model.py @@ -8,6 +8,7 @@ import ray from stimulus.cli import check_model +from src.stimulus.utils.yaml_data import YamlSplitTransformDict @pytest.fixture @@ -62,11 +63,14 @@ def test_check_model_main( ray.init(ignore_reinit_error=True) # Verify all required files exist assert os.path.exists(data_path), f"Data file not found at {data_path}" - assert os.path.exists(data_config), f"Data config not found at {data_config}" + assert os.path.exists( + data_config), f"Data config not found at {data_config}" assert os.path.exists(model_path), f"Model file not found at {model_path}" - assert os.path.exists(model_config), f"Model config not found at {model_config}" + assert os.path.exists( + model_config), f"Model config not found at {model_config}" try: + config_dict: yam # Run main function - should complete without errors check_model.main( model_path=model_path, From f4ad7ab957164b4fc260f7c6120e3968901ee32f Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Tue, 18 Feb 2025 13:56:51 +0100 Subject: [PATCH 22/35] FIX: tests/data/test_data_handlers.py -> changed the call from 'config_path' to 'data_config' for the DatasetLoader --- tests/data/test_data_handlers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/data/test_data_handlers.py b/tests/data/test_data_handlers.py index 144f4416..fd765192 100644 --- a/tests/data/test_data_handlers.py +++ b/tests/data/test_data_handlers.py @@ -366,7 +366,7 @@ def test_dataset_loader_init( ) -> None: """Test initialization of DatasetLoader.""" loader = DatasetLoader( - config_path=dump_single_split_config_to_disk, + data_config=dump_single_split_config_to_disk, csv_path=titanic_csv_path, encoder_loader=encoder_loader, ) @@ -384,7 +384,7 @@ def test_dataset_loader_get_dataset( ) -> None: """Test getting dataset from loader.""" loader = DatasetLoader( - config_path=dump_single_split_config_to_disk, + data_config=dump_single_split_config_to_disk, csv_path=titanic_csv_path, encoder_loader=encoder_loader, ) From 131388d92b57f3d953d9e088633eb595834a8566 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Tue, 18 Feb 2025 14:02:41 +0100 Subject: [PATCH 23/35] FIX: tests/data/transform/test_data_transformers.py -> changed 'generate_data_configs' to 'generate_split_transform_configs' as the awaited output is the sub config file --- .../data/transform/test_data_transformers.py | 67 ++++++++++++++----- 1 file changed, 50 insertions(+), 17 deletions(-) diff --git a/tests/data/transform/test_data_transformers.py b/tests/data/transform/test_data_transformers.py index 91215a89..116b37ea 100644 --- a/tests/data/transform/test_data_transformers.py +++ b/tests/data/transform/test_data_transformers.py @@ -13,7 +13,10 @@ ReverseComplement, UniformTextMasker, ) -from stimulus.utils.yaml_data import dump_yaml_list_into_files, generate_data_configs +from stimulus.utils.yaml_data import ( + dump_yaml_list_into_files, + generate_split_transform_configs, +) class DataTransformerTest: @@ -74,7 +77,11 @@ def gaussian_noise() -> DataTransformerTest: single_input = 5.0 expected_single_output = 5.4967141530112327 multiple_inputs = [1.0, 2.0, 3.0] - expected_multiple_outputs = [1.4967141530112327, 1.8617356988288154, 3.6476885381006925] + expected_multiple_outputs = [ + 1.4967141530112327, + 1.8617356988288154, + 3.6476885381006925, + ] return DataTransformerTest( transformer=transformer, params=params, @@ -134,7 +141,9 @@ def test_transform_single(self, request: Any, test_data_name: str) -> None: test_data: The test data to use. """ test_data = request.getfixturevalue(test_data_name) - transformed_data = test_data.transformer.transform(test_data.single_input, **test_data.params) + transformed_data = test_data.transformer.transform( + test_data.single_input, **test_data.params + ) assert isinstance(transformed_data, str) assert transformed_data == test_data.expected_single_output @@ -142,7 +151,10 @@ def test_transform_single(self, request: Any, test_data_name: str) -> None: def test_transform_multiple(self, request: Any, test_data_name: str) -> None: """Test masking multiple strings.""" test_data = request.getfixturevalue(test_data_name) - transformed_data = [test_data.transformer.transform(x, **test_data.params) for x in test_data.multiple_inputs] + transformed_data = [ + test_data.transformer.transform(x, **test_data.params) + for x in test_data.multiple_inputs + ] assert isinstance(transformed_data, list) for item in transformed_data: assert isinstance(item, str) @@ -156,20 +168,30 @@ class TestGaussianNoise: def test_transform_single(self, request: Any, test_data_name: str) -> None: """Test transforming a single float.""" test_data = request.getfixturevalue(test_data_name) - transformed_data = test_data.transformer.transform(test_data.single_input, **test_data.params) + transformed_data = test_data.transformer.transform( + test_data.single_input, **test_data.params + ) assert isinstance(transformed_data, float) - assert round(transformed_data, 7) == round(test_data.expected_single_output, 7) + assert round(transformed_data, 7) == round( + test_data.expected_single_output, 7) @pytest.mark.parametrize("test_data_name", ["gaussian_noise"]) - def test_transform_multiple(self, request: Any, test_data_name: DataTransformerTest) -> None: + def test_transform_multiple( + self, request: Any, test_data_name: DataTransformerTest + ) -> None: """Test transforming multiple floats.""" test_data = request.getfixturevalue(test_data_name) - transformed_data = test_data.transformer.transform_all(test_data.multiple_inputs, **test_data.params) + transformed_data = test_data.transformer.transform_all( + test_data.multiple_inputs, **test_data.params + ) assert isinstance(transformed_data, list) for item in transformed_data: assert isinstance(item, float) - assert len(transformed_data) == len(test_data.expected_multiple_outputs) - for item, expected in zip(transformed_data, test_data.expected_multiple_outputs): + assert len(transformed_data) == len( + test_data.expected_multiple_outputs) + for item, expected in zip( + transformed_data, test_data.expected_multiple_outputs + ): assert round(item, 7) == round(expected, 7) @@ -180,7 +202,8 @@ class TestGaussianChunk: def test_transform_single(self, request: Any, test_data_name: str) -> None: """Test transforming a single string.""" test_data = request.getfixturevalue(test_data_name) - transformed_data = test_data.transformer.transform(test_data.single_input) + transformed_data = test_data.transformer.transform( + test_data.single_input) assert isinstance(transformed_data, str) assert len(transformed_data) == 2 @@ -188,7 +211,9 @@ def test_transform_single(self, request: Any, test_data_name: str) -> None: def test_transform_multiple(self, request: Any, test_data_name: str) -> None: """Test transforming multiple strings.""" test_data = request.getfixturevalue(test_data_name) - transformed_data = [test_data.transformer.transform(x) for x in test_data.multiple_inputs] + transformed_data = [ + test_data.transformer.transform(x) for x in test_data.multiple_inputs + ] assert isinstance(transformed_data, list) for item in transformed_data: assert isinstance(item, str) @@ -200,7 +225,9 @@ def test_chunk_size_excessive(self, request: Any, test_data_name: str) -> None: """Test that the transform fails if chunk size is greater than the length of the input string.""" test_data = request.getfixturevalue(test_data_name) transformer = GaussianChunk(chunk_size=100) - with pytest.raises(ValueError, match="The input data is shorter than the chunk size"): + with pytest.raises( + ValueError, match="The input data is shorter than the chunk size" + ): transformer.transform(test_data.single_input) @@ -211,7 +238,9 @@ class TestReverseComplement: def test_transform_single(self, request: Any, test_data_name: str) -> None: """Test transforming a single string.""" test_data = request.getfixturevalue(test_data_name) - transformed_data = test_data.transformer.transform(test_data.single_input, **test_data.params) + transformed_data = test_data.transformer.transform( + test_data.single_input, **test_data.params + ) assert isinstance(transformed_data, str) assert transformed_data == test_data.expected_single_output @@ -219,7 +248,9 @@ def test_transform_single(self, request: Any, test_data_name: str) -> None: def test_transform_multiple(self, request: Any, test_data_name: str) -> None: """Test transforming multiple strings.""" test_data = request.getfixturevalue(test_data_name) - transformed_data = test_data.transformer.transform_all(test_data.multiple_inputs, **test_data.params) + transformed_data = test_data.transformer.transform_all( + test_data.multiple_inputs, **test_data.params + ) assert isinstance(transformed_data, list) for item in transformed_data: assert isinstance(item, str) @@ -232,7 +263,9 @@ def titanic_config_path(base_config: dict) -> str: config_path = "tests/test_data/titanic/titanic_sub_config_0.yaml" if not os.path.exists(config_path): - configs = generate_data_configs(base_config) - dump_yaml_list_into_files([configs[0]], "tests/test_data/titanic/", "titanic_sub_config") + configs = generate_split_transform_configs(base_config) + dump_yaml_list_into_files( + [configs[0]], "tests/test_data/titanic/", "titanic_sub_config" + ) return os.path.abspath(config_path) From 2c110adee5c466d374c260dc4d6478b5d18eb55a Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Tue, 18 Feb 2025 14:29:07 +0100 Subject: [PATCH 24/35] FIX: tests/utils/test_data_yaml.py -> added tests for the new generate_split and generate_split_transform function + fixed old tests --- tests/utils/test_data_yaml.py | 55 +++++++++++++++++++++++++++-------- 1 file changed, 43 insertions(+), 12 deletions(-) diff --git a/tests/utils/test_data_yaml.py b/tests/utils/test_data_yaml.py index e7398d44..a96b5458 100644 --- a/tests/utils/test_data_yaml.py +++ b/tests/utils/test_data_yaml.py @@ -6,8 +6,10 @@ from src.stimulus.utils import yaml_data from src.stimulus.utils.yaml_data import ( YamlConfigDict, - YamlSubConfigDict, - generate_data_configs, + YamlSplitConfigDict, + YamlSplitTransformDict, + generate_split_configs, + generate_split_transform_configs, ) @@ -25,10 +27,20 @@ def load_titanic_yaml_from_file() -> YamlConfigDict: return YamlConfigDict(**yaml_dict) +@pytest.fixture +def load_split_config_yaml_from_file() -> YamlSplitConfigDict: + """Fixture that loads a test unique split YAML configuration file.""" + with open("tests/test_data/titanic/titanic_unique_split.yaml") as f: + yaml_dict = yaml.safe_load(f) + return YamlSplitConfigDict(**yaml_dict) + + @pytest.fixture def load_yaml_from_file() -> YamlConfigDict: """Fixture that loads a test YAML configuration file.""" - with open("tests/test_data/dna_experiment/dna_experiment_config_template.yaml") as f: + with open( + "tests/test_data/dna_experiment/dna_experiment_config_template.yaml" + ) as f: yaml_dict = yaml.safe_load(f) return YamlConfigDict(**yaml_dict) @@ -40,12 +52,20 @@ def load_wrong_type_yaml() -> dict: return yaml.safe_load(f) +def test_split_config_validation(load_titanic_yaml_from_file: YamlConfigDict) -> None: + """Test split configuration validation.""" + split_config = generate_split_configs(load_titanic_yaml_from_file)[0] + YamlSplitConfigDict.model_validate(split_config) + + def test_sub_config_validation( - load_titanic_yaml_from_file: YamlConfigDict, + load_split_config_yaml_from_file: YamlConfigDict, ) -> None: """Test sub-config validation.""" - sub_config = generate_data_configs(load_titanic_yaml_from_file)[0] - YamlSubConfigDict.model_validate(sub_config) + split_config = generate_split_transform_configs( + load_split_config_yaml_from_file)[0] + print(f"{split_config=}") + YamlSplitTransformDict.model_validate(split_config) def test_expand_transform_parameter_combinations( @@ -56,15 +76,20 @@ def test_expand_transform_parameter_combinations( transform = load_yaml_from_file.transforms[0] results = yaml_data.expand_transform_parameter_combinations(transform) assert len(results) == 1 # Only one transform returned - assert isinstance(results[0], yaml_data.YamlTransform) # Should return YamlTransform objects + assert isinstance( + results[0], yaml_data.YamlTransform + ) # Should return YamlTransform objects def test_expand_transform_list_combinations( load_yaml_from_file: YamlConfigDict, ) -> None: """Tests expanding a list of transforms into all parameter combinations.""" - results = yaml_data.expand_transform_list_combinations(load_yaml_from_file.transforms) - assert len(results) == 8 # 4 combinations from first transform x 2 from second + results = yaml_data.expand_transform_list_combinations( + load_yaml_from_file.transforms + ) + # 4 combinations from first transform x 2 from second + assert len(results) == 8 # Each result should be a YamlTransform for result in results: assert isinstance(result, yaml_data.YamlTransform) @@ -76,14 +101,18 @@ def test_generate_data_configs( load_yaml_from_file: YamlConfigDict, ) -> None: """Tests generating all possible data configurations.""" - configs = yaml_data.generate_data_configs(load_yaml_from_file) + split_configs = yaml_data.generate_split_configs(load_yaml_from_file) + configs: list[YamlSplitTransformDict] = [] + for s_conf in split_configs: + configs.extend(generate_split_transform_configs(s_conf)) + assert len(configs) == 16 # 8 transform combinations x 2 splits # Check each config individually to help debug for i, config in enumerate(configs): assert isinstance( config, - yaml_data.YamlSubConfigDict, + yaml_data.YamlSplitTransformDict, ), f"Config {i} is type {type(config)}, expected YamlSubConfigDict" @@ -98,7 +127,9 @@ def test_check_yaml_schema( """Tests the Pydantic schema validation.""" data = request.getfixturevalue(test_input[0]) if test_input[1]: - with pytest.raises(ValueError, match="Wrong type on a field, see the pydantic report above"): + with pytest.raises( + ValueError, match="Wrong type on a field, see the pydantic report above" + ): yaml_data.check_yaml_schema(data) else: yaml_data.check_yaml_schema(data) From 1b4339b7faa92b161e8d2db9283e846d06f71fc8 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Tue, 18 Feb 2025 14:35:54 +0100 Subject: [PATCH 25/35] FIX: tests/data/test_experiment.py -> Fixed tests to work with the new generate_split_config and generate_split_transform_config functions --- tests/data/test_experiment.py | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/tests/data/test_experiment.py b/tests/data/test_experiment.py index 15954217..aa261dc9 100644 --- a/tests/data/test_experiment.py +++ b/tests/data/test_experiment.py @@ -21,7 +21,9 @@ def dna_experiment_config_path() -> str: @pytest.fixture -def dna_experiment_sub_yaml(dna_experiment_config_path: str) -> yaml_data.YamlConfigDict: +def dna_experiment_sub_yaml( + dna_experiment_config_path: str, +) -> yaml_data.YamlConfigDict: """Get a sub-configuration from the DNA experiment config. Args: @@ -34,8 +36,11 @@ def dna_experiment_sub_yaml(dna_experiment_config_path: str) -> yaml_data.YamlCo yaml_dict = yaml.safe_load(f) yaml_config = yaml_data.YamlConfigDict(**yaml_dict) - yaml_configs = yaml_data.generate_data_configs(yaml_config) - return yaml_configs[0] + yaml_split_configs = yaml_data.generate_split_configs(yaml_config) + yaml_split_transform_configs = yaml_data.generate_split_transform_configs( + yaml_split_configs[0] + ) + return yaml_split_transform_configs[0] @pytest.fixture @@ -80,7 +85,9 @@ def test_get_encoder(text_onehot_encoder_params: tuple[str, dict[str, str]]) -> assert isinstance(encoder, AbstractEncoder) -def test_set_encoder_as_attribute(text_onehot_encoder_params: tuple[str, dict[str, str]]) -> None: +def test_set_encoder_as_attribute( + text_onehot_encoder_params: tuple[str, dict[str, str]], +) -> None: """Test the set_encoder_as_attribute method of the AbstractExperiment class. Args: @@ -95,7 +102,9 @@ def test_set_encoder_as_attribute(text_onehot_encoder_params: tuple[str, dict[st assert experiment.get_function_encode_all("ciao") == encoder.encode_all -def test_build_experiment_class_encoder_dict(dna_experiment_sub_yaml: yaml_data.YamlConfigDict) -> None: +def test_build_experiment_class_encoder_dict( + dna_experiment_sub_yaml: yaml_data.YamlConfigDict, +) -> None: """Test the build_experiment_class_encoder_dict method. Args: @@ -115,7 +124,8 @@ def test_get_data_transformer() -> None: """Test the get_data_transformer method of the TransformLoader class.""" experiment = loaders.TransformLoader() transformer = experiment.get_data_transformer("ReverseComplement") - assert isinstance(transformer, data_transformation_generators.ReverseComplement) + assert isinstance( + transformer, data_transformation_generators.ReverseComplement) def test_set_data_transformer_as_attribute() -> None: @@ -141,7 +151,10 @@ def test_initialize_column_data_transformers_from_config( assert hasattr(experiment, "col1") column_transformers = experiment.col1 - assert any(isinstance(t, data_transformation_generators.ReverseComplement) for t in column_transformers.values()) + assert any( + isinstance(t, data_transformation_generators.ReverseComplement) + for t in column_transformers.values() + ) def test_initialize_splitter_from_config( From 19d8d16cec5deb2e6d5703a9a15da1d30bb0baf2 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Tue, 18 Feb 2025 14:51:14 +0100 Subject: [PATCH 26/35] FIX: tests/data/test_handlertorch.py -> changed test to use the YamlSplitTransformDict directly instead of yaml files as argument for the class: --- tests/data/test_handlertorch.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/tests/data/test_handlertorch.py b/tests/data/test_handlertorch.py index d121ed39..dff24909 100644 --- a/tests/data/test_handlertorch.py +++ b/tests/data/test_handlertorch.py @@ -40,11 +40,13 @@ def titanic_yaml_config(titanic_config_path: str) -> dict: dict: Loaded YAML configuration """ with open(titanic_config_path) as file: - return yaml_data.YamlSubConfigDict(**yaml.safe_load(file)) + return yaml_data.YamlSplitTransformDict(**yaml.safe_load(file)) @pytest.fixture -def titanic_encoder_loader(titanic_yaml_config: yaml_data.YamlSubConfigDict) -> loaders.EncoderLoader: +def titanic_encoder_loader( + titanic_yaml_config: yaml_data.YamlSplitTransformDict, +) -> loaders.EncoderLoader: """Get Titanic encoder loader.""" loader = loaders.EncoderLoader() loader.initialize_column_encoders_from_config(titanic_yaml_config.columns) @@ -57,8 +59,11 @@ def test_init_handlertorch( titanic_encoder_loader: loaders.EncoderLoader, ) -> None: """Test TorchDataset initialization.""" + data_config: yaml_data.YamlSplitTransformDict + with open(titanic_config_path) as f: + data_config = yaml_data.YamlSplitTransformDict(**yaml.safe_load(f)) handlertorch.TorchDataset( - config_path=titanic_config_path, + data_config=data_config, csv_path=titanic_csv_path, encoder_loader=titanic_encoder_loader, ) @@ -76,8 +81,11 @@ def test_len_handlertorch( titanic_csv_path: Path to CSV file titanic_encoder_loader: Encoder loader instance """ + data_config: yaml_data.YamlSplitTransformDict + with open(titanic_config_path) as f: + data_config = yaml_data.YamlSplitTransformDict(**yaml.safe_load(f)) dataset = handlertorch.TorchDataset( - config_path=titanic_config_path, + data_config=data_config, csv_path=titanic_csv_path, encoder_loader=titanic_encoder_loader, ) @@ -96,8 +104,11 @@ def test_getitem_handlertorch_slice( titanic_csv_path: Path to CSV file titanic_encoder_loader: Encoder loader instance """ + data_config: yaml_data.YamlSplitTransformDict + with open(titanic_config_path) as f: + data_config = yaml_data.YamlSplitTransformDict(**yaml.safe_load(f)) dataset = handlertorch.TorchDataset( - config_path=titanic_config_path, + data_config=data_config, csv_path=titanic_csv_path, encoder_loader=titanic_encoder_loader, ) @@ -117,8 +128,11 @@ def test_getitem_handlertorch_int( titanic_csv_path: Path to CSV file titanic_encoder_loader: Encoder loader instance """ + data_config: yaml_data.YamlSplitTransformDict + with open(titanic_config_path) as f: + data_config = yaml_data.YamlSplitTransformDict(**yaml.safe_load(f)) dataset = handlertorch.TorchDataset( - config_path=titanic_config_path, + data_config=data_config, csv_path=titanic_csv_path, encoder_loader=titanic_encoder_loader, ) From 5ef82ff0349a9ddae49f38bb52f1068d2e81eabc Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Tue, 18 Feb 2025 15:20:04 +0100 Subject: [PATCH 27/35] ADD: src/stimulus/utils/yaml_data.py -> Added a class SplitConfigDict --- src/stimulus/utils/yaml_data.py | 45 +++++++++++++++++---------------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/src/stimulus/utils/yaml_data.py b/src/stimulus/utils/yaml_data.py index 64610265..9ed7c516 100644 --- a/src/stimulus/utils/yaml_data.py +++ b/src/stimulus/utils/yaml_data.py @@ -112,21 +112,21 @@ class YamlConfigDict(BaseModel): # TODO: Rename this class to SplitConfigDict -class YamlSubConfigDict(BaseModel): +class YamlSplitConfigDict(BaseModel): """Model for sub-configuration generated from main config.""" global_params: YamlGlobalParams columns: list[YamlColumns] - transforms: Union[YamlTransform, list[YamlTransform]] + transforms: list[YamlTransform] split: YamlSplit -class YamlSubConfigTransformDict(BaseModel): +class YamlSplitTransformDict(BaseModel): """Model for sub-configuration generated from main config.""" global_params: YamlGlobalParams columns: list[YamlColumns] - transform: YamlTransform + transforms: YamlTransform split: YamlSplit @@ -139,7 +139,7 @@ class YamlSchema(BaseModel): class YamlSplitSchema(BaseModel): """Model for validating a Split YAML schema.""" - yaml_conf: YamlSubConfigDict + yaml_conf: YamlSplitConfigDict def extract_transform_parameters_at_index( @@ -238,7 +238,7 @@ def expand_transform_list_combinations( return sub_transforms -def generate_split_configs(yaml_config: YamlConfigDict) -> list[YamlSubConfigDict]: +def generate_split_configs(yaml_config: YamlConfigDict) -> list[YamlSplitConfigDict]: """Generates all possible split configuration from a YAML config. Takes a YAML configuration that may contain parameter lists and splits, @@ -273,7 +273,7 @@ def generate_split_configs(yaml_config: YamlConfigDict) -> list[YamlSubConfigDic sub_configs = [] for split in sub_splits: sub_configs.append( - YamlSubConfigDict( + YamlSplitConfigDict( global_params=yaml_config.global_params, columns=yaml_config.columns, transforms=yaml_config.transforms, @@ -284,8 +284,8 @@ def generate_split_configs(yaml_config: YamlConfigDict) -> list[YamlSubConfigDic def generate_split_transform_configs( - yaml_config: YamlSubConfigDict, -) -> list[YamlSubConfigTransformDict]: + yaml_config: YamlSplitConfigDict, +) -> list[YamlSplitTransformDict]: """Generates all the transform configuration for a given split Takes a YAML configuration that may contain a transform or a list of transform, @@ -313,26 +313,27 @@ def generate_split_transform_configs( length will be the product of the number of parameter combinations and the number of splits. """ - if isinstance(yaml_config, dict) and not isinstance(yaml_config, YamlSubConfigDict): + if isinstance(yaml_config, dict) and not isinstance( + yaml_config, YamlSplitConfigDict + ): raise TypeError("Input must be a list of YamlSubConfigDict") - split_configs = yaml_config.split - split_transform_config: list[YamlSubConfigTransformDict] = [] - for split_config in split_configs: - for transform in split_configs.get("transforms"): - split_transform_config.append( - YamlSubConfigTransformDict( - global_params=split_config.get("global_params"), - columns=split_config.get("columns"), - transform=transform, - split=split_config.get("split"), - ) + sub_transforms = expand_transform_list_combinations(yaml_config.transforms) + split_transform_config: list[YamlSplitTransformDict] = [] + for transform in sub_transforms: + split_transform_config.append( + YamlSplitTransformDict( + global_params=yaml_config.global_params, + columns=yaml_config.columns, + transforms=transform, + split=yaml_config.split, ) + ) return split_transform_config def dump_yaml_list_into_files( - yaml_list: list[YamlSubConfigDict], + yaml_list: list[YamlSplitConfigDict], directory_path: str, base_name: str, ) -> None: From cda0a6e329100a80ea1c198526115cbb9439e5cc Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Tue, 18 Feb 2025 15:23:40 +0100 Subject: [PATCH 28/35] FIX: tests/cli/test_shuffle_csv.py -> Uses a YamlSplitTransformDict now --- tests/cli/test_shuffle_csv.py | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/tests/cli/test_shuffle_csv.py b/tests/cli/test_shuffle_csv.py index dff78887..2c95d7f0 100644 --- a/tests/cli/test_shuffle_csv.py +++ b/tests/cli/test_shuffle_csv.py @@ -1,5 +1,6 @@ """Tests for the shuffle_csv CLI command.""" +import yaml import hashlib import pathlib import tempfile @@ -8,13 +9,14 @@ import pytest from src.stimulus.cli.shuffle_csv import main +from src.stimulus.utils.yaml_data import YamlSplitTransformDict # Fixtures @pytest.fixture def correct_yaml_path() -> str: """Fixture that returns the path to a correct YAML file.""" - return "tests/test_data/titanic/titanic_unique_split.yaml" + return "tests/test_data/titanic/titanic_sub_config.yaml" @pytest.fixture @@ -42,11 +44,18 @@ def test_shuffle_csv( csv_path = request.getfixturevalue(csv_type) yaml_path = request.getfixturevalue(yaml_type) tmpdir = pathlib.Path(tempfile.gettempdir()) - if error: - with pytest.raises(error): # type: ignore[call-overload] - main(csv_path, yaml_path, str(tmpdir / "test.csv")) - else: - main(csv_path, yaml_path, str(tmpdir / "test.csv")) - with open(tmpdir / "test.csv") as file: - hash = hashlib.md5(file.read().encode()).hexdigest() # noqa: S324 - assert hash == snapshot + with open(yaml_path) as f: + if error: + with pytest.raises(error): # type: ignore[call-overload] + config_dict: YamlSplitTransformDict = YamlSplitTransformDict( + **yaml.safe_load(f) + ) + main(csv_path, config_dict, str(tmpdir / "test.csv")) + else: + config_dict: YamlSplitTransformDict = YamlSplitTransformDict( + **yaml.safe_load(f) + ) + main(csv_path, config_dict, str(tmpdir / "test.csv")) + with open(tmpdir / "test.csv") as file: + hash = hashlib.md5(file.read().encode()).hexdigest() # noqa: S324 + assert hash == snapshot From ac672f05d87c4ef0735a998f8be36410da70486d Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Tue, 18 Feb 2025 15:28:18 +0100 Subject: [PATCH 29/35] {src/stimulus,tests}/typing/{__init__.py, test_typing.py} -> Added the two new class YamlSplitConfigDict and YamlSplitTransformDict --- src/stimulus/typing/__init__.py | 20 +++++++++++++++----- tests/typing/test_typing.py | 3 ++- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/src/stimulus/typing/__init__.py b/src/stimulus/typing/__init__.py index cc59e9e2..ed970e73 100644 --- a/src/stimulus/typing/__init__.py +++ b/src/stimulus/typing/__init__.py @@ -23,10 +23,17 @@ from stimulus.data.handlertorch import TorchDataset from stimulus.data.loaders import EncoderLoader, SplitLoader, TransformLoader from stimulus.data.splitters.splitters import AbstractSplitter as Splitter -from stimulus.data.transform.data_transformation_generators import AbstractDataTransformer as Transform +from stimulus.data.transform.data_transformation_generators import ( + AbstractDataTransformer as Transform, +) from stimulus.learner.predict import PredictWrapper from stimulus.learner.raytune_learner import CheckpointDict, TuneModel, TuneWrapper -from stimulus.learner.raytune_parser import RayTuneMetrics, RayTuneOptimizer, RayTuneResult, TuneParser +from stimulus.learner.raytune_parser import ( + RayTuneMetrics, + RayTuneOptimizer, + RayTuneResult, + TuneParser, +) from stimulus.utils.performance import Performance from stimulus.utils.yaml_data import ( YamlColumns, @@ -35,7 +42,8 @@ YamlGlobalParams, YamlSchema, YamlSplit, - YamlSubConfigDict, + YamlSplitConfigDict, + YamlSplitTransformDict, YamlTransform, YamlTransformColumns, YamlTransformColumnsTransformation, @@ -56,7 +64,9 @@ # data/data_handlers.py -DataManager: TypeAlias = DatasetManager | EncodeManager | SplitManager | TransformManager +DataManager: TypeAlias = ( + DatasetManager | EncodeManager | SplitManager | TransformManager +) # data/experiments.py @@ -75,7 +85,7 @@ | YamlGlobalParams | YamlSchema | YamlSplit - | YamlSubConfigDict + | YamlSplitConfigDict | YamlTransform | YamlTransformColumns | YamlTransformColumnsTransformation diff --git a/tests/typing/test_typing.py b/tests/typing/test_typing.py index 22bc8577..d2be37ff 100644 --- a/tests/typing/test_typing.py +++ b/tests/typing/test_typing.py @@ -49,7 +49,8 @@ def test_yaml_data_types() -> None: YamlGlobalParams, YamlSchema, YamlSplit, - YamlSubConfigDict, + YamlSplitConfigDict, + YamlSplitTransformDict, YamlTransform, YamlTransformColumns, YamlTransformColumnsTransformation, From 4fa3abdb2789d7f5d118cceb89a29781fbb8fae9 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Tue, 18 Feb 2025 15:32:35 +0100 Subject: [PATCH 30/35] FIX: src/stimulus/cli/transform_csv.py -> When the file is called it creates a YamlSplitConfigDict from the path given as argument --- src/stimulus/cli/transform_csv.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/stimulus/cli/transform_csv.py b/src/stimulus/cli/transform_csv.py index 2e2ff5fd..55b25882 100755 --- a/src/stimulus/cli/transform_csv.py +++ b/src/stimulus/cli/transform_csv.py @@ -7,12 +7,14 @@ from stimulus.data.data_handlers import DatasetProcessor, TransformManager from stimulus.data.loaders import TransformLoader -from stimulus.utils.yaml_data import YamlSubConfigDict +from stimulus.utils.yaml_data import YamlSplitConfigDict def get_args() -> argparse.Namespace: """Get the arguments when using from the commandline.""" - parser = argparse.ArgumentParser(description="CLI for transforming CSV data files using YAML configuration.") + parser = argparse.ArgumentParser( + description="CLI for transforming CSV data files using YAML configuration." + ) parser.add_argument( "-c", "--csv", @@ -53,8 +55,9 @@ def main(data_csv: str, config_yaml: str, out_path: str) -> None: # initialize the transform manager transform_config = processor.dataset_manager.config.transforms with open(config_yaml) as f: - yaml_config = YamlSubConfigDict(**yaml.safe_load(f)) + yaml_config = YamlSplitConfigDict(**yaml.safe_load(f)) transform_loader = TransformLoader(seed=yaml_config.global_params.seed) + print(transform_config) transform_loader.initialize_column_data_transformers_from_config(transform_config) transform_manager = TransformManager(transform_loader) From d03ac07cedb44da448e1add5aee6b25e686e3103 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Tue, 18 Feb 2025 15:36:04 +0100 Subject: [PATCH 31/35] FIX: tests/learner/test_raytune_learner -> Uses a YamlSplitTransformDict now --- tests/learner/test_raytune_learner.py | 38 ++++++++++++++++++++------- 1 file changed, 28 insertions(+), 10 deletions(-) diff --git a/tests/learner/test_raytune_learner.py b/tests/learner/test_raytune_learner.py index 54ef6e27..8ade0911 100644 --- a/tests/learner/test_raytune_learner.py +++ b/tests/learner/test_raytune_learner.py @@ -10,7 +10,7 @@ from stimulus.data.handlertorch import TorchDataset from stimulus.data.loaders import EncoderLoader from stimulus.learner.raytune_learner import TuneWrapper -from stimulus.utils.yaml_data import YamlSubConfigDict +from stimulus.utils.yaml_data import YamlSplitConfigDict, YamlSplitTransformDict from stimulus.utils.yaml_model_schema import Model, RayTuneModel, YamlRayConfigLoader from tests.test_model import titanic_model @@ -29,7 +29,9 @@ def encoder_loader() -> EncoderLoader: with open("tests/test_data/titanic/titanic_sub_config.yaml") as file: data_config = yaml.safe_load(file) encoder_loader = EncoderLoader() - encoder_loader.initialize_column_encoders_from_config(YamlSubConfigDict(**data_config).columns) + encoder_loader.initialize_column_encoders_from_config( + YamlSplitTransformDict(**data_config).columns + ) return encoder_loader @@ -44,7 +46,9 @@ def titanic_dataset(encoder_loader: EncoderLoader) -> TorchDataset: ) -def test_tunewrapper_init(ray_config_loader: RayTuneModel, encoder_loader: EncoderLoader) -> None: +def test_tunewrapper_init( + ray_config_loader: RayTuneModel, encoder_loader: EncoderLoader +) -> None: """Test the initialization of the TuneWrapper class.""" # Filter ResourceWarning during Ray shutdown warnings.filterwarnings("ignore", category=ResourceWarning) @@ -53,14 +57,19 @@ def test_tunewrapper_init(ray_config_loader: RayTuneModel, encoder_loader: Encod ray.init(ignore_reinit_error=True) try: + data_config: YamlSplitTransformDict + with open("tests/test_data/titanic/titanic_sub_config.yaml") as f: + data_config = YamlSplitTransformDict(**yaml.safe_load(f)) + tune_wrapper = TuneWrapper( model_config=ray_config_loader, model_class=titanic_model.ModelTitanic, data_path="tests/test_data/titanic/titanic_stimulus_split.csv", - data_config_path="tests/test_data/titanic/titanic_sub_config.yaml", + data_config=data_config, encoder_loader=encoder_loader, seed=42, - ray_results_dir=os.path.abspath("tests/test_data/titanic/ray_results"), + ray_results_dir=os.path.abspath( + "tests/test_data/titanic/ray_results"), tune_run_name="test_run", debug=False, autoscaler=False, @@ -74,10 +83,13 @@ def test_tunewrapper_init(ray_config_loader: RayTuneModel, encoder_loader: Encod if os.path.exists("tests/test_data/titanic/ray_results"): import shutil - shutil.rmtree("tests/test_data/titanic/ray_results", ignore_errors=True) + shutil.rmtree("tests/test_data/titanic/ray_results", + ignore_errors=True) -def test_tune_wrapper_tune(ray_config_loader: RayTuneModel, encoder_loader: EncoderLoader) -> None: +def test_tune_wrapper_tune( + ray_config_loader: RayTuneModel, encoder_loader: EncoderLoader +) -> None: """Test the tune method of TuneWrapper class.""" # Filter ResourceWarning during Ray shutdown warnings.filterwarnings("ignore", category=ResourceWarning) @@ -86,14 +98,19 @@ def test_tune_wrapper_tune(ray_config_loader: RayTuneModel, encoder_loader: Enco ray.init(ignore_reinit_error=True) try: + data_config: YamlSplitTransformDict + with open("tests/test_data/titanic/titanic_sub_config.yaml") as f: + data_config = YamlSplitTransformDict(**yaml.safe_load(f)) + tune_wrapper = TuneWrapper( model_config=ray_config_loader, model_class=titanic_model.ModelTitanic, data_path="tests/test_data/titanic/titanic_stimulus_split.csv", - data_config_path="tests/test_data/titanic/titanic_sub_config.yaml", + data_config=data_config, encoder_loader=encoder_loader, seed=42, - ray_results_dir=os.path.abspath("tests/test_data/titanic/ray_results"), + ray_results_dir=os.path.abspath( + "tests/test_data/titanic/ray_results"), tune_run_name="test_run", debug=False, autoscaler=False, @@ -108,4 +125,5 @@ def test_tune_wrapper_tune(ray_config_loader: RayTuneModel, encoder_loader: Enco if os.path.exists("tests/test_data/titanic/ray_results"): import shutil - shutil.rmtree("tests/test_data/titanic/ray_results", ignore_errors=True) + shutil.rmtree("tests/test_data/titanic/ray_results", + ignore_errors=True) From 3c31d67b26cda521e22e50eaae0c2bfa20e2a89a Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Tue, 18 Feb 2025 15:39:44 +0100 Subject: [PATCH 32/35] FIX: {src/stimulus,tests}/cli/tuning.py -> the main function now takes directly YamlSplitTransformDict, the tests have been fixed to work with this --- src/stimulus/cli/tuning.py | 45 ++++++++++++++++++++++++++------------ tests/cli/test_tuning.py | 22 ++++++++++++++----- 2 files changed, 48 insertions(+), 19 deletions(-) diff --git a/src/stimulus/cli/tuning.py b/src/stimulus/cli/tuning.py index a0eb12ab..0ba80209 100755 --- a/src/stimulus/cli/tuning.py +++ b/src/stimulus/cli/tuning.py @@ -28,8 +28,22 @@ def get_args() -> argparse.Namespace: Parsed command line arguments. """ parser = argparse.ArgumentParser(description="Launch check_model.") - parser.add_argument("-d", "--data", type=str, required=True, metavar="FILE", help="Path to input csv file.") - parser.add_argument("-m", "--model", type=str, required=True, metavar="FILE", help="Path to model file.") + parser.add_argument( + "-d", + "--data", + type=str, + required=True, + metavar="FILE", + help="Path to input csv file.", + ) + parser.add_argument( + "-m", + "--model", + type=str, + required=True, + metavar="FILE", + help="Path to model file.", + ) parser.add_argument( "-e", "--data_config", @@ -136,7 +150,7 @@ def get_args() -> argparse.Namespace: def main( model_path: str, data_path: str, - data_config_path: str, + data_config: yaml_data.YamlSplitTransformDict, model_config_path: str, initial_weights: str | None = None, # noqa: ARG001 ray_results_dirpath: str | None = None, @@ -152,7 +166,7 @@ def main( Args: data_path: Path to input data file. model_path: Path to model file. - data_config_path: Path to data config file. + data_config: A YamlSplitTransformObject model_config_path: Path to model config file. initial_weights: Optional path to initial weights. ray_results_dirpath: Directory for ray results. @@ -162,26 +176,25 @@ def main( best_metrics_path: Path to write the best metrics to. best_config_path: Path to write the best config to. """ - # Convert data config to proper type - with open(data_config_path) as file: - data_config_dict: dict[str, Any] = yaml.safe_load(file) - data_config: yaml_data.YamlSubConfigDict = yaml_data.YamlSubConfigDict(**data_config_dict) - with open(model_config_path) as file: model_config_dict: dict[str, Any] = yaml.safe_load(file) - model_config: yaml_model_schema.Model = yaml_model_schema.Model(**model_config_dict) + model_config: yaml_model_schema.Model = yaml_model_schema.Model( + **model_config_dict) encoder_loader = loaders.EncoderLoader() - encoder_loader.initialize_column_encoders_from_config(column_config=data_config.columns) + encoder_loader.initialize_column_encoders_from_config( + column_config=data_config.columns + ) model_class = launch_utils.import_class_from_file(model_path) - ray_config_loader = yaml_model_schema.YamlRayConfigLoader(model=model_config) + ray_config_loader = yaml_model_schema.YamlRayConfigLoader( + model=model_config) ray_config_model = ray_config_loader.get_config() tuner = raytune_learner.TuneWrapper( model_config=ray_config_model, - data_config_path=data_config_path, + data_config=data_config, model_class=model_class, data_path=data_path, encoder_loader=encoder_loader, @@ -226,10 +239,14 @@ def main( def run() -> None: """Run the model checking script.""" args = get_args() + # Try to convert the configuration file to a YamlSplitTransformDict + config_dict: yaml_data.YamlSplitTransformDict + with open(args.data_config) as f: + config_dict = yaml_data.YamlSplitTransformDict(**yaml.safe_load(f)) main( data_path=args.data, model_path=args.model, - data_config_path=args.data_config, + data_config=config_dict, model_config_path=args.model_config, initial_weights=args.initial_weights, ray_results_dirpath=args.ray_results_dirpath, diff --git a/tests/cli/test_tuning.py b/tests/cli/test_tuning.py index e05d6213..1716de65 100644 --- a/tests/cli/test_tuning.py +++ b/tests/cli/test_tuning.py @@ -12,14 +12,18 @@ import ray import yaml -from stimulus.cli import tuning +from src.stimulus.cli import tuning +from src.stimulus.utils.yaml_data import YamlSplitTransformDict @pytest.fixture def data_path() -> str: """Get path to test data CSV file.""" return str( - Path(__file__).parent.parent / "test_data" / "titanic" / "titanic_stimulus_split.csv", + Path(__file__).parent.parent + / "test_data" + / "titanic" + / "titanic_stimulus_split.csv", ) @@ -27,7 +31,10 @@ def data_path() -> str: def data_config() -> str: """Get path to test data config YAML.""" return str( - Path(__file__).parent.parent / "test_data" / "titanic" / "titanic_sub_config.yaml", + Path(__file__).parent.parent + / "test_data" + / "titanic" + / "titanic_sub_config.yaml", ) @@ -115,6 +122,10 @@ def test_tuning_main( assert os.path.exists(model_config), f"Model config not found at {model_config}" try: + config_dict: YamlSplitTransformDict + with open(data_config) as f: + config_dict = YamlSplitTransformDict(**yaml.safe_load(f)) + results_dir = Path("tests/test_data/titanic/test_results/").resolve() results_dir.mkdir(parents=True, exist_ok=True) @@ -122,10 +133,11 @@ def test_tuning_main( tuning.main( model_path=model_path, data_path=data_path, - data_config_path=data_config, + data_config=config_dict, model_config_path=model_config, initial_weights=None, - ray_results_dirpath=str(results_dir), # Directory path without URI scheme + # Directory path without URI scheme + ray_results_dirpath=str(results_dir), output_path=str(results_dir / "best_model.safetensors"), best_optimizer_path=str(results_dir / "best_optimizer.pt"), best_metrics_path=str(results_dir / "best_metrics.csv"), From cb7d55fdc2bc6892c98472e339c87636843147f9 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Tue, 18 Feb 2025 15:43:19 +0100 Subject: [PATCH 33/35] DELETED: tests/cli/test_split_yaml.py -> Deleted it has it has been replaced with two new files for splitting in YamlSplitConfigDict and YamlSplitTransformDict --- tests/cli/test_split_yaml.py | 55 ------------------------------------ 1 file changed, 55 deletions(-) delete mode 100644 tests/cli/test_split_yaml.py diff --git a/tests/cli/test_split_yaml.py b/tests/cli/test_split_yaml.py deleted file mode 100644 index ad56f2b7..00000000 --- a/tests/cli/test_split_yaml.py +++ /dev/null @@ -1,55 +0,0 @@ -"""Tests for the split_yaml CLI command.""" - -import hashlib -import os -import tempfile -from typing import Any, Callable - -import pytest - -from src.stimulus.cli.split_yaml import main - - -# Fixtures -@pytest.fixture -def correct_yaml_path() -> str: - """Fixture that returns the path to a correct YAML file.""" - return "tests/test_data/titanic/titanic.yaml" - - -@pytest.fixture -def wrong_yaml_path() -> str: - """Fixture that returns the path to a wrong YAML file.""" - return "tests/test_data/yaml_files/wrong_field_type.yaml" - - -# Test cases -test_cases = [ - ("correct_yaml_path", None), - ("wrong_yaml_path", ValueError), -] - - -# Tests -@pytest.mark.parametrize(("yaml_type", "error"), test_cases) -def test_split_yaml( - request: pytest.FixtureRequest, - snapshot: Callable[[], Any], - yaml_type: str, - error: Exception | None, -) -> None: - """Tests the CLI command with correct and wrong YAML files.""" - yaml_path = request.getfixturevalue(yaml_type) - tmpdir = tempfile.gettempdir() - if error: - with pytest.raises(error): # type: ignore[call-overload] - main(yaml_path, tmpdir) - else: - main(yaml_path, tmpdir) # main() returns None, no need to assert - files = os.listdir(tmpdir) - test_out = [f for f in files if f.startswith("test_")] - hashes = [] - for f in test_out: - with open(os.path.join(tmpdir, f)) as file: - hashes.append(hashlib.md5(file.read().encode()).hexdigest()) # noqa: S324 - assert sorted(hashes) == snapshot # sorted ensures that the order of the hashes does not matter From 10092ee079e9b77c55b1b6e27c07eae5373e754e Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Tue, 18 Feb 2025 15:46:02 +0100 Subject: [PATCH 34/35] FIX: src/stimulus/loaders.py -> initialize_splitter_from_config uses a YamlSplitConfigDict as input now --- src/stimulus/data/loaders.py | 74 +++++++++++++++++++++++++++--------- 1 file changed, 55 insertions(+), 19 deletions(-) diff --git a/src/stimulus/data/loaders.py b/src/stimulus/data/loaders.py index d962ac94..4bea4a1b 100644 --- a/src/stimulus/data/loaders.py +++ b/src/stimulus/data/loaders.py @@ -29,14 +29,17 @@ def __init__(self, seed: Optional[float] = None) -> None: """ self.seed = seed - def initialize_column_encoders_from_config(self, column_config: yaml_data.YamlColumns) -> None: + def initialize_column_encoders_from_config( + self, column_config: yaml_data.YamlColumns + ) -> None: """Build the loader from a config dictionary. Args: column_config (yaml_data.YamlColumns): Configuration dictionary containing field names (column_name) and their encoder specifications. """ for field in column_config: - encoder = self.get_encoder(field.encoder[0].name, field.encoder[0].params) + encoder = self.get_encoder( + field.encoder[0].name, field.encoder[0].params) self.set_encoder_as_attribute(field.column_name, encoder) def get_function_encode_all(self, field_name: str) -> Any: @@ -50,7 +53,9 @@ def get_function_encode_all(self, field_name: str) -> Any: """ return getattr(self, field_name).encode_all - def get_encoder(self, encoder_name: str, encoder_params: Optional[dict] = None) -> Any: + def get_encoder( + self, encoder_name: str, encoder_params: Optional[dict] = None + ) -> Any: """Gets an encoder object from the encoders module and initializes it with the given parameters. Args: @@ -63,7 +68,9 @@ def get_encoder(self, encoder_name: str, encoder_params: Optional[dict] = None) try: return getattr(encoders, encoder_name)(**encoder_params) except AttributeError: - logging.exception(f"Encoder '{encoder_name}' not found in the encoders module.") + logging.exception( + f"Encoder '{encoder_name}' not found in the encoders module." + ) logging.exception( f"Available encoders: {[name for name, obj in encoders.__dict__.items() if isinstance(obj, type) and name not in ('ABC', 'Any')]}", ) @@ -72,13 +79,17 @@ def get_encoder(self, encoder_name: str, encoder_params: Optional[dict] = None) except TypeError: if encoder_params is None: return getattr(encoders, encoder_name)() - logging.exception(f"Encoder '{encoder_name}' has incorrect parameters: {encoder_params}") + logging.exception( + f"Encoder '{encoder_name}' has incorrect parameters: {encoder_params}" + ) logging.exception( f"Expected parameters for '{encoder_name}': {inspect.signature(getattr(encoders, encoder_name))}", ) raise - def set_encoder_as_attribute(self, field_name: str, encoder: encoders.AbstractEncoder) -> None: + def set_encoder_as_attribute( + self, field_name: str, encoder: encoders.AbstractEncoder + ) -> None: """Sets the encoder as an attribute of the loader. Args: @@ -99,7 +110,9 @@ def __init__(self, seed: Optional[float] = None) -> None: """ self.seed = seed - def get_data_transformer(self, transformation_name: str, transformation_params: Optional[dict] = None) -> Any: + def get_data_transformer( + self, transformation_name: str, transformation_params: Optional[dict] = None + ) -> Any: """Gets a transformer object from the transformers module. Args: @@ -110,9 +123,13 @@ def get_data_transformer(self, transformation_name: str, transformation_params: Any: The transformer function for the specified transformation """ try: - return getattr(data_transformation_generators, transformation_name)(**transformation_params) + return getattr(data_transformation_generators, transformation_name)( + **transformation_params + ) except AttributeError: - logging.exception(f"Transformer '{transformation_name}' not found in the transformers module.") + logging.exception( + f"Transformer '{transformation_name}' not found in the transformers module." + ) logging.exception( f"Available transformers: {[name for name, obj in data_transformation_generators.__dict__.items() if isinstance(obj, type) and name not in ('ABC', 'Any')]}", ) @@ -121,13 +138,17 @@ def get_data_transformer(self, transformation_name: str, transformation_params: except TypeError: if transformation_params is None: return getattr(data_transformation_generators, transformation_name)() - logging.exception(f"Transformer '{transformation_name}' has incorrect parameters: {transformation_params}") + logging.exception( + f"Transformer '{transformation_name}' has incorrect parameters: {transformation_params}" + ) logging.exception( f"Expected parameters for '{transformation_name}': {inspect.signature(getattr(data_transformation_generators, transformation_name))}", ) raise - def set_data_transformer_as_attribute(self, field_name: str, data_transformer: Any) -> None: + def set_data_transformer_as_attribute( + self, field_name: str, data_transformer: Any + ) -> None: """Sets the data transformer as an attribute of the loader. Args: @@ -136,12 +157,18 @@ def set_data_transformer_as_attribute(self, field_name: str, data_transformer: A """ # check if the field already exists, if it does not, initialize it to an empty dict if not hasattr(self, field_name): - setattr(self, field_name, {data_transformer.__class__.__name__: data_transformer}) + setattr( + self, + field_name, + {data_transformer.__class__.__name__: data_transformer}, + ) else: field_value = getattr(self, field_name) field_value[data_transformer.__class__.__name__] = data_transformer - def initialize_column_data_transformers_from_config(self, transform_config: yaml_data.YamlTransform) -> None: + def initialize_column_data_transformers_from_config( + self, transform_config: yaml_data.YamlTransform + ) -> None: """Build the loader from a config dictionary. Args: @@ -174,7 +201,9 @@ def initialize_column_data_transformers_from_config(self, transform_config: yaml for column in transform_config.columns: col_name = column.column_name for transform_spec in column.transformations: - transformer = self.get_data_transformer(transform_spec.name, transform_spec.params) + transformer = self.get_data_transformer( + transform_spec.name, transform_spec.params + ) self.set_data_transformer_as_attribute(col_name, transformer) @@ -206,7 +235,9 @@ def get_function_split(self) -> Any: ) return self.split.get_split_indexes - def get_splitter(self, splitter_name: str, splitter_params: Optional[dict] = None) -> Any: + def get_splitter( + self, splitter_name: str, splitter_params: Optional[dict] = None + ) -> Any: """Gets a splitter object from the splitters module. Args: @@ -221,7 +252,9 @@ def get_splitter(self, splitter_name: str, splitter_params: Optional[dict] = Non except TypeError: if splitter_params is None: return getattr(splitters, splitter_name)() - logging.exception(f"Splitter '{splitter_name}' has incorrect parameters: {splitter_params}") + logging.exception( + f"Splitter '{splitter_name}' has incorrect parameters: {splitter_params}" + ) logging.exception( f"Expected parameters for '{splitter_name}': {inspect.signature(getattr(splitters, splitter_name))}", ) @@ -235,11 +268,14 @@ def set_splitter_as_attribute(self, splitter: Any) -> None: """ self.split = splitter - def initialize_splitter_from_config(self, split_config: yaml_data.YamlSplit) -> None: + def initialize_splitter_from_config( + self, split_config: yaml_data.YamlSplitConfigDict + ) -> None: """Build the loader from a config dictionary. Args: - split_config (yaml_data.YamlSplit): Configuration dictionary containing split configurations. + split_config (yaml_data.YamlSplitConfigDict): Configuration dictionary containing split configurations. """ - splitter = self.get_splitter(split_config.split_method, split_config.params) + splitter = self.get_splitter( + split_config.split_method, split_config.params) self.set_splitter_as_attribute(splitter) From 0012c727c69be896766538cf937c4efaa35751d5 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Tue, 18 Feb 2025 18:11:20 +0100 Subject: [PATCH 35/35] FIX: src/stimulus/data/{data_handlers.py,handlertorch.py} -> Change the docstring to match the right parameters --- src/stimulus/data/data_handlers.py | 2 +- src/stimulus/data/handlertorch.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/stimulus/data/data_handlers.py b/src/stimulus/data/data_handlers.py index 559e7d53..4af2f786 100644 --- a/src/stimulus/data/data_handlers.py +++ b/src/stimulus/data/data_handlers.py @@ -293,7 +293,7 @@ def __init__( """Initialize the DatasetHandler with required config. Args: - config_path (str): Path to the dataset configuration file. + data_config (yaml_data.YamlSplitTransformDict): A YamlSplitTransformDict object holding the config. csv_path (str): Path to the CSV data file. """ self.dataset_manager = DatasetManager(data_config) diff --git a/src/stimulus/data/handlertorch.py b/src/stimulus/data/handlertorch.py index 6d8e0641..3b89140a 100644 --- a/src/stimulus/data/handlertorch.py +++ b/src/stimulus/data/handlertorch.py @@ -21,7 +21,7 @@ def __init__( """Initialize the TorchDataset. Args: - config_path: Path to the configuration file + data_config: A YamlSplitTransformDict holding the configuration. csv_path: Path to the CSV data file encoder_loader: Encoder loader instance split: Optional tuple containing split information