From d5fb8e81b81f8f2e4373e42038163556ce816b3b Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 11:14:12 +0100 Subject: [PATCH 01/81] REPLACE: src/stimulus/typing/__init__ -> changed YamlGlobalConfig to GlobalConfig --- src/stimulus/typing/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/stimulus/typing/__init__.py b/src/stimulus/typing/__init__.py index ed970e73..2038a5c3 100644 --- a/src/stimulus/typing/__init__.py +++ b/src/stimulus/typing/__init__.py @@ -39,7 +39,7 @@ YamlColumns, YamlColumnsEncoder, YamlConfigDict, - YamlGlobalParams, + GlobalParams, YamlSchema, YamlSplit, YamlSplitConfigDict, @@ -82,7 +82,7 @@ YamlColumns | YamlColumnsEncoder | YamlConfigDict - | YamlGlobalParams + | GlobalParams | YamlSchema | YamlSplit | YamlSplitConfigDict From 89b5c26d4619a8baeb2be1511935fa159d946ee8 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 11:14:37 +0100 Subject: [PATCH 02/81] REPLACE: src/stimulus/utils/yaml_data.py -> changed YamlGlobalConfig to GlobalConfig --- src/stimulus/utils/yaml_data.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/stimulus/utils/yaml_data.py b/src/stimulus/utils/yaml_data.py index 9ed7c516..251e0c59 100644 --- a/src/stimulus/utils/yaml_data.py +++ b/src/stimulus/utils/yaml_data.py @@ -6,7 +6,7 @@ from pydantic import BaseModel, ValidationError, field_validator -class YamlGlobalParams(BaseModel): +class GlobalParams(BaseModel): """Model for global parameters in YAML configuration.""" seed: int @@ -105,7 +105,7 @@ class YamlSplit(BaseModel): class YamlConfigDict(BaseModel): """Model for main YAML configuration.""" - global_params: YamlGlobalParams + global_params: GlobalParams columns: list[YamlColumns] transforms: list[YamlTransform] split: list[YamlSplit] @@ -115,7 +115,7 @@ class YamlConfigDict(BaseModel): class YamlSplitConfigDict(BaseModel): """Model for sub-configuration generated from main config.""" - global_params: YamlGlobalParams + global_params: GlobalParams columns: list[YamlColumns] transforms: list[YamlTransform] split: YamlSplit @@ -124,7 +124,7 @@ class YamlSplitConfigDict(BaseModel): class YamlSplitTransformDict(BaseModel): """Model for sub-configuration generated from main config.""" - global_params: YamlGlobalParams + global_params: GlobalParams columns: list[YamlColumns] transforms: YamlTransform split: YamlSplit From c6a38e7e5914cc3bbe2fc61422bb35fb1592f02b Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 11:15:15 +0100 Subject: [PATCH 03/81] REPLACE: tests/typing/test_typing.py -> changed YamlGlobalConfig to GlobalConfig --- tests/typing/test_typing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/typing/test_typing.py b/tests/typing/test_typing.py index d2be37ff..3ca5cfa7 100644 --- a/tests/typing/test_typing.py +++ b/tests/typing/test_typing.py @@ -46,7 +46,7 @@ def test_yaml_data_types() -> None: YamlColumns, YamlColumnsEncoder, YamlConfigDict, - YamlGlobalParams, + GlobalParams, YamlSchema, YamlSplit, YamlSplitConfigDict, From 5ba13f00a9f4cd936db68d7f0ac254cc05f617f3 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 11:20:19 +0100 Subject: [PATCH 04/81] REPLACE: src/stimulus/typing/__init__.py -> Changed YamlColumnsEncoder to ColumnsEncoder --- src/stimulus/typing/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/stimulus/typing/__init__.py b/src/stimulus/typing/__init__.py index 2038a5c3..71034199 100644 --- a/src/stimulus/typing/__init__.py +++ b/src/stimulus/typing/__init__.py @@ -37,7 +37,7 @@ from stimulus.utils.performance import Performance from stimulus.utils.yaml_data import ( YamlColumns, - YamlColumnsEncoder, + ColumnsEncoder, YamlConfigDict, GlobalParams, YamlSchema, @@ -80,7 +80,7 @@ YamlData: TypeAlias = ( YamlColumns - | YamlColumnsEncoder + | ColumnsEncoder | YamlConfigDict | GlobalParams | YamlSchema From 1863ced36f194c57517f8a5567fdae9b74eed67a Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 11:20:52 +0100 Subject: [PATCH 05/81] REPLACE: src/stimulus/utils/yaml_data.py -> Changed YamlColumnsEncoder to ColumnsEncoder --- src/stimulus/utils/yaml_data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/stimulus/utils/yaml_data.py b/src/stimulus/utils/yaml_data.py index 251e0c59..5cd47186 100644 --- a/src/stimulus/utils/yaml_data.py +++ b/src/stimulus/utils/yaml_data.py @@ -12,7 +12,7 @@ class GlobalParams(BaseModel): seed: int -class YamlColumnsEncoder(BaseModel): +class ColumnsEncoder(BaseModel): """Model for column encoder configuration.""" name: str @@ -27,7 +27,7 @@ class YamlColumns(BaseModel): column_name: str column_type: str data_type: str - encoder: list[YamlColumnsEncoder] + encoder: list[ColumnsEncoder] class YamlTransformColumnsTransformation(BaseModel): From e9c40ebd68e4309aca7fad73635f21166994591c Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 11:21:22 +0100 Subject: [PATCH 06/81] REPLACE: tests/typing/test_typing.py -> Changed YamlColumnsEncoder to ColumnsEncoder --- tests/typing/test_typing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/typing/test_typing.py b/tests/typing/test_typing.py index 3ca5cfa7..2b27e9f3 100644 --- a/tests/typing/test_typing.py +++ b/tests/typing/test_typing.py @@ -44,7 +44,7 @@ def test_yaml_data_types() -> None: try: from stimulus.typing import ( YamlColumns, - YamlColumnsEncoder, + ColumnsEncoder, YamlConfigDict, GlobalParams, YamlSchema, From f55b7f68ab5152b77f889f099f850cc8bfd6d9ed Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 11:22:50 +0100 Subject: [PATCH 07/81] REPLACE: tests/typing/test_typing.py -> Changed YamlColumns to Columns --- tests/typing/test_typing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/typing/test_typing.py b/tests/typing/test_typing.py index 2b27e9f3..d1ccde9e 100644 --- a/tests/typing/test_typing.py +++ b/tests/typing/test_typing.py @@ -43,7 +43,7 @@ def test_yaml_data_types() -> None: """Test the YAML data types.""" try: from stimulus.typing import ( - YamlColumns, + Columns, ColumnsEncoder, YamlConfigDict, GlobalParams, From 11062b4c0ec3e0006b75b78e2ef880809b3a143f Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 11:28:17 +0100 Subject: [PATCH 08/81] REPLACE: tests/data/loaders.py -> Changed YamlColumns to Columns --- src/stimulus/data/loaders.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/stimulus/data/loaders.py b/src/stimulus/data/loaders.py index 4bea4a1b..b9cc0bcb 100644 --- a/src/stimulus/data/loaders.py +++ b/src/stimulus/data/loaders.py @@ -30,16 +30,15 @@ def __init__(self, seed: Optional[float] = None) -> None: self.seed = seed def initialize_column_encoders_from_config( - self, column_config: yaml_data.YamlColumns + self, column_config: yaml_data.Columns ) -> None: """Build the loader from a config dictionary. Args: - column_config (yaml_data.YamlColumns): Configuration dictionary containing field names (column_name) and their encoder specifications. + column_config (yaml_data.Columns): Configuration dictionary containing field names (column_name) and their encoder specifications. """ for field in column_config: - encoder = self.get_encoder( - field.encoder[0].name, field.encoder[0].params) + encoder = self.get_encoder(field.encoder[0].name, field.encoder[0].params) self.set_encoder_as_attribute(field.column_name, encoder) def get_function_encode_all(self, field_name: str) -> Any: @@ -276,6 +275,5 @@ def initialize_splitter_from_config( Args: split_config (yaml_data.YamlSplitConfigDict): Configuration dictionary containing split configurations. """ - splitter = self.get_splitter( - split_config.split_method, split_config.params) + splitter = self.get_splitter(split_config.split_method, split_config.params) self.set_splitter_as_attribute(splitter) From c370891772125addf85a00857da6959ee19632c5 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 11:28:58 +0100 Subject: [PATCH 09/81] REPLACE: tests/typing/__init__.py -> Changed YamlColumns to Columns --- src/stimulus/typing/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/stimulus/typing/__init__.py b/src/stimulus/typing/__init__.py index 71034199..6cd035f1 100644 --- a/src/stimulus/typing/__init__.py +++ b/src/stimulus/typing/__init__.py @@ -36,7 +36,7 @@ ) from stimulus.utils.performance import Performance from stimulus.utils.yaml_data import ( - YamlColumns, + Columns, ColumnsEncoder, YamlConfigDict, GlobalParams, @@ -79,7 +79,7 @@ # utils/yaml_data.py YamlData: TypeAlias = ( - YamlColumns + Columns | ColumnsEncoder | YamlConfigDict | GlobalParams From 61e578771143b774e6cf6163db8dab0836cabe67 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 11:29:31 +0100 Subject: [PATCH 10/81] REPLACE: src/stimulus/utils/yaml_data.py -> Changed YamlColumns to Columns --- src/stimulus/utils/yaml_data.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/stimulus/utils/yaml_data.py b/src/stimulus/utils/yaml_data.py index 5cd47186..11904766 100644 --- a/src/stimulus/utils/yaml_data.py +++ b/src/stimulus/utils/yaml_data.py @@ -21,7 +21,7 @@ class ColumnsEncoder(BaseModel): ] # Allow both string and list values -class YamlColumns(BaseModel): +class Columns(BaseModel): """Model for column configuration.""" column_name: str @@ -106,7 +106,7 @@ class YamlConfigDict(BaseModel): """Model for main YAML configuration.""" global_params: GlobalParams - columns: list[YamlColumns] + columns: list[Columns] transforms: list[YamlTransform] split: list[YamlSplit] @@ -116,7 +116,7 @@ class YamlSplitConfigDict(BaseModel): """Model for sub-configuration generated from main config.""" global_params: GlobalParams - columns: list[YamlColumns] + columns: list[Columns] transforms: list[YamlTransform] split: YamlSplit @@ -125,7 +125,7 @@ class YamlSplitTransformDict(BaseModel): """Model for sub-configuration generated from main config.""" global_params: GlobalParams - columns: list[YamlColumns] + columns: list[Columns] transforms: YamlTransform split: YamlSplit From 3c45094917125a03ff5f1cfaae34421800d16dfc Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 11:35:35 +0100 Subject: [PATCH 11/81] REPLACE: src/stimulus/utils/yaml_data.py -> Changed YamlTransformColumnsTransformation to TransformColumnsTransformation --- src/stimulus/utils/yaml_data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/stimulus/utils/yaml_data.py b/src/stimulus/utils/yaml_data.py index 11904766..a1c58015 100644 --- a/src/stimulus/utils/yaml_data.py +++ b/src/stimulus/utils/yaml_data.py @@ -30,7 +30,7 @@ class Columns(BaseModel): encoder: list[ColumnsEncoder] -class YamlTransformColumnsTransformation(BaseModel): +class TransformColumnsTransformation(BaseModel): """Model for column transformation configuration.""" name: str @@ -43,7 +43,7 @@ class YamlTransformColumns(BaseModel): """Model for transform columns configuration.""" column_name: str - transformations: list[YamlTransformColumnsTransformation] + transformations: list[TransformColumnsTransformation] class YamlTransform(BaseModel): From ec66cdf8c39e5ba01da2d603b25ca8c1eb846938 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 11:36:05 +0100 Subject: [PATCH 12/81] REPLACE: src/stimulus/typing/__init__.py -> Changed YamlTransformColumnsTransformation to TransformColumnsTransformation --- src/stimulus/typing/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/stimulus/typing/__init__.py b/src/stimulus/typing/__init__.py index 6cd035f1..2cb506ad 100644 --- a/src/stimulus/typing/__init__.py +++ b/src/stimulus/typing/__init__.py @@ -46,7 +46,7 @@ YamlSplitTransformDict, YamlTransform, YamlTransformColumns, - YamlTransformColumnsTransformation, + TransformColumnsTransformation, ) from stimulus.utils.yaml_model_schema import ( CustomTunableParameter, @@ -88,5 +88,5 @@ | YamlSplitConfigDict | YamlTransform | YamlTransformColumns - | YamlTransformColumnsTransformation + | TransformColumnsTransformation ) From 4746659d91551263ed3f47b7e39b517f2100caae Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 11:36:45 +0100 Subject: [PATCH 13/81] REPLACE: tests/data/test_data_handlers.py -> Changed YamlTransformColumnsTransformation to TransformColumnsTransformation --- tests/data/test_data_handlers.py | 34 +++++++++++--------------------- 1 file changed, 12 insertions(+), 22 deletions(-) diff --git a/tests/data/test_data_handlers.py b/tests/data/test_data_handlers.py index fd765192..0e7672e5 100644 --- a/tests/data/test_data_handlers.py +++ b/tests/data/test_data_handlers.py @@ -18,7 +18,7 @@ YamlSplitTransformDict, YamlTransform, YamlTransformColumns, - YamlTransformColumnsTransformation, + TransformColumnsTransformation, generate_split_configs, generate_split_transform_configs, ) @@ -70,8 +70,7 @@ def generate_sub_configs(base_config: YamlConfigDict) -> list[YamlConfigDict]: Returns: list[YamlConfigDict]: List of generated configurations """ - split_configs: list[YamlSplitConfigDict] = generate_split_configs( - base_config) + split_configs: list[YamlSplitConfigDict] = generate_split_configs(base_config) split_transform_list: list[YamlSplitTransformDict] = [] for split in split_configs: split_transform_list.extend(generate_split_transform_configs(split)) @@ -104,8 +103,7 @@ def encoder_loader(generate_sub_configs: list[YamlConfigDict]) -> loaders.Encode experiments.EncoderLoader: Initialized encoder loader """ loader = loaders.EncoderLoader() - loader.initialize_column_encoders_from_config( - generate_sub_configs[0].columns) + loader.initialize_column_encoders_from_config(generate_sub_configs[0].columns) return loader @@ -238,7 +236,7 @@ def test_transform_manager_transform_column() -> None: YamlTransformColumns( column_name="test_col", transformations=[ - YamlTransformColumnsTransformation( + TransformColumnsTransformation( name="GaussianNoise", params={"std": 0.1}, ), @@ -246,12 +244,10 @@ def test_transform_manager_transform_column() -> None: ), ], ) - transform_loader.initialize_column_data_transformers_from_config( - dummy_config) + transform_loader.initialize_column_data_transformers_from_config(dummy_config) manager = TransformManager(transform_loader) data = [1, 2, 3] - transformed, added_row = manager.transform_column( - "test_col", "GaussianNoise", data) + transformed, added_row = manager.transform_column("test_col", "GaussianNoise", data) assert len(transformed) == len(data) assert added_row is False @@ -334,28 +330,22 @@ def test_dataset_processor_apply_transformation_group( transform_manager=TransformManager(transform_loader) ) - assert processor.data["age"].to_list( - ) != processor_control.data["age"].to_list() - assert processor.data["fare"].to_list( - ) != processor_control.data["fare"].to_list() + assert processor.data["age"].to_list() != processor_control.data["age"].to_list() + assert processor.data["fare"].to_list() != processor_control.data["fare"].to_list() assert ( - processor.data["parch"].to_list( - ) == processor_control.data["parch"].to_list() + processor.data["parch"].to_list() == processor_control.data["parch"].to_list() ) assert ( - processor.data["sibsp"].to_list( - ) == processor_control.data["sibsp"].to_list() + processor.data["sibsp"].to_list() == processor_control.data["sibsp"].to_list() ) assert ( - processor.data["pclass"].to_list( - ) == processor_control.data["pclass"].to_list() + processor.data["pclass"].to_list() == processor_control.data["pclass"].to_list() ) assert ( processor.data["embarked"].to_list() == processor_control.data["embarked"].to_list() ) - assert processor.data["sex"].to_list( - ) == processor_control.data["sex"].to_list() + assert processor.data["sex"].to_list() == processor_control.data["sex"].to_list() # Test DatasetLoader From 1f99221f8b086573903a132b0dc5ad4823b45438 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 11:37:47 +0100 Subject: [PATCH 14/81] REPLACE: tests/typing/test_typing.py -> Changed YamlTransformColumnsTransformation to TransformColumnsTransformation --- tests/typing/test_typing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/typing/test_typing.py b/tests/typing/test_typing.py index d1ccde9e..30f68e7f 100644 --- a/tests/typing/test_typing.py +++ b/tests/typing/test_typing.py @@ -53,7 +53,7 @@ def test_yaml_data_types() -> None: YamlSplitTransformDict, YamlTransform, YamlTransformColumns, - YamlTransformColumnsTransformation, + TransformColumnsTransformation, ) except ImportError: pytest.fail("Failed to import YAML Data types") From 3ffe7104cb9dd70ebc4ecdffa2e878632ebdfcff Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 11:41:48 +0100 Subject: [PATCH 15/81] REPLACE: src/stimulus/utils/yaml_data.py -> Changed YamlTransformColumns to TransformColumns --- src/stimulus/utils/yaml_data.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/stimulus/utils/yaml_data.py b/src/stimulus/utils/yaml_data.py index a1c58015..dca4baf3 100644 --- a/src/stimulus/utils/yaml_data.py +++ b/src/stimulus/utils/yaml_data.py @@ -39,7 +39,7 @@ class TransformColumnsTransformation(BaseModel): ] # Allow both list and float values -class YamlTransformColumns(BaseModel): +class TransformColumns(BaseModel): """Model for transform columns configuration.""" column_name: str @@ -50,13 +50,13 @@ class YamlTransform(BaseModel): """Model for transform configuration.""" transformation_name: str - columns: list[YamlTransformColumns] + columns: list[TransformColumns] @field_validator("columns") @classmethod def validate_param_lists_across_columns( - cls, columns: list[YamlTransformColumns] - ) -> list[YamlTransformColumns]: + cls, columns: list[TransformColumns] + ) -> list[TransformColumns]: """Validate that parameter lists across columns have consistent lengths. Args: From 5d0cfa24a74be2631719a1cf6cd63eeae9b82526 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 11:42:22 +0100 Subject: [PATCH 16/81] REPLACE: src/stimulus/typing/__init__.py -> Changed YamlTransformColumns to TransformColumns --- src/stimulus/typing/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/stimulus/typing/__init__.py b/src/stimulus/typing/__init__.py index 2cb506ad..eb79f502 100644 --- a/src/stimulus/typing/__init__.py +++ b/src/stimulus/typing/__init__.py @@ -45,7 +45,7 @@ YamlSplitConfigDict, YamlSplitTransformDict, YamlTransform, - YamlTransformColumns, + TransformColumns, TransformColumnsTransformation, ) from stimulus.utils.yaml_model_schema import ( @@ -87,6 +87,6 @@ | YamlSplit | YamlSplitConfigDict | YamlTransform - | YamlTransformColumns + | TransformColumns | TransformColumnsTransformation ) From edede13649360e947c1570601aeb785f50c3c143 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 11:42:56 +0100 Subject: [PATCH 17/81] REPLACE: tests/data/test_data_handlers.py -> Changed YamlTransformColumns to TransformColumns --- tests/data/test_data_handlers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/data/test_data_handlers.py b/tests/data/test_data_handlers.py index 0e7672e5..2bd7c596 100644 --- a/tests/data/test_data_handlers.py +++ b/tests/data/test_data_handlers.py @@ -17,7 +17,7 @@ YamlSplitConfigDict, YamlSplitTransformDict, YamlTransform, - YamlTransformColumns, + TransformColumns, TransformColumnsTransformation, generate_split_configs, generate_split_transform_configs, @@ -233,7 +233,7 @@ def test_transform_manager_transform_column() -> None: dummy_config = YamlTransform( transformation_name="GaussianNoise", columns=[ - YamlTransformColumns( + TransformColumns( column_name="test_col", transformations=[ TransformColumnsTransformation( From 53fe62e1dcac45fbd3622d9162c0a0cca12b4e6a Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 11:43:25 +0100 Subject: [PATCH 18/81] REPLACE: tests/typing/test_typing.py -> Changed YamlTransformColumns to TransformColumns --- tests/typing/test_typing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/typing/test_typing.py b/tests/typing/test_typing.py index 30f68e7f..6b849d18 100644 --- a/tests/typing/test_typing.py +++ b/tests/typing/test_typing.py @@ -52,7 +52,7 @@ def test_yaml_data_types() -> None: YamlSplitConfigDict, YamlSplitTransformDict, YamlTransform, - YamlTransformColumns, + TransformColumns, TransformColumnsTransformation, ) except ImportError: From fd8f92a3a3931b9c925941d86a326e7898ff76a6 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 13:03:18 +0100 Subject: [PATCH 19/81] REPLACE: src/stimulus/utils/yaml_data.py -> Changed YamlTransform to Transform --- src/stimulus/utils/yaml_data.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/stimulus/utils/yaml_data.py b/src/stimulus/utils/yaml_data.py index dca4baf3..d7d139c0 100644 --- a/src/stimulus/utils/yaml_data.py +++ b/src/stimulus/utils/yaml_data.py @@ -46,7 +46,7 @@ class TransformColumns(BaseModel): transformations: list[TransformColumnsTransformation] -class YamlTransform(BaseModel): +class Transform(BaseModel): """Model for transform configuration.""" transformation_name: str @@ -107,7 +107,7 @@ class YamlConfigDict(BaseModel): global_params: GlobalParams columns: list[Columns] - transforms: list[YamlTransform] + transforms: list[Transform] split: list[YamlSplit] @@ -117,7 +117,7 @@ class YamlSplitConfigDict(BaseModel): global_params: GlobalParams columns: list[Columns] - transforms: list[YamlTransform] + transforms: list[Transform] split: YamlSplit @@ -126,7 +126,7 @@ class YamlSplitTransformDict(BaseModel): global_params: GlobalParams columns: list[Columns] - transforms: YamlTransform + transforms: Transform split: YamlSplit @@ -143,8 +143,8 @@ class YamlSplitSchema(BaseModel): def extract_transform_parameters_at_index( - transform: YamlTransform, index: int = 0 -) -> YamlTransform: + transform: Transform, index: int = 0 +) -> Transform: """Get a transform with parameters at the specified index. Args: @@ -155,7 +155,7 @@ def extract_transform_parameters_at_index( A new transform with single parameter values at the specified index """ # Create a copy of the transform - new_transform = YamlTransform(**transform.model_dump()) + new_transform = Transform(**transform.model_dump()) # Process each column and transformation for column in new_transform.columns: @@ -174,8 +174,8 @@ def extract_transform_parameters_at_index( def expand_transform_parameter_combinations( - transform: YamlTransform, -) -> list[YamlTransform]: + transform: Transform, +) -> list[Transform]: """Get all possible transforms by extracting parameters at each valid index. For a transform with parameter lists, creates multiple new transforms, each containing @@ -213,8 +213,8 @@ def expand_transform_parameter_combinations( def expand_transform_list_combinations( - transform_list: list[YamlTransform], -) -> list[YamlTransform]: + transform_list: list[Transform], +) -> list[Transform]: """Expands a list of transforms into all possible parameter combinations. Takes a list of transforms where each transform may contain parameter lists, @@ -223,11 +223,11 @@ def expand_transform_list_combinations( create two transforms: one with 0.1/1 and another with 0.2/2. Args: - transform_list: A list of YamlTransform objects containing parameter lists + transform_list: A list of Transform objects containing parameter lists that need to be expanded into individual transforms. Returns: - list[YamlTransform]: A flattened list of transforms where each transform + list[Transform]: A flattened list of transforms where each transform has single parameter values instead of parameter lists. The length of the returned list will be the sum of the number of parameter combinations for each input transform. From 93bfe2e39b46f12ef458bb8a86b48c6877486be8 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 13:03:49 +0100 Subject: [PATCH 20/81] REPLACE: src/stimulus/data/loaders.py -> Changed YamlTransform to Transform --- src/stimulus/data/loaders.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/stimulus/data/loaders.py b/src/stimulus/data/loaders.py index b9cc0bcb..6861bcfe 100644 --- a/src/stimulus/data/loaders.py +++ b/src/stimulus/data/loaders.py @@ -166,12 +166,12 @@ def set_data_transformer_as_attribute( field_value[data_transformer.__class__.__name__] = data_transformer def initialize_column_data_transformers_from_config( - self, transform_config: yaml_data.YamlTransform + self, transform_config: yaml_data.Transform ) -> None: """Build the loader from a config dictionary. Args: - transform_config (yaml_data.YamlTransform): Configuration dictionary containing transforms configurations. + transform_config (yaml_data.Transform): Configuration dictionary containing transforms configurations. Example: Given a YAML config like: From 0f7da0d63e9dfdd4e0a29757daa805e29d3d532c Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 13:04:20 +0100 Subject: [PATCH 21/81] REPLACE: src/stimulus/typing/__init__.py -> Changed YamlTransform to Transform --- src/stimulus/typing/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/stimulus/typing/__init__.py b/src/stimulus/typing/__init__.py index eb79f502..77ca6ba3 100644 --- a/src/stimulus/typing/__init__.py +++ b/src/stimulus/typing/__init__.py @@ -44,7 +44,7 @@ YamlSplit, YamlSplitConfigDict, YamlSplitTransformDict, - YamlTransform, + Transform, TransformColumns, TransformColumnsTransformation, ) @@ -86,7 +86,7 @@ | YamlSchema | YamlSplit | YamlSplitConfigDict - | YamlTransform + | Transform | TransformColumns | TransformColumnsTransformation ) From b9f1b3a2183b1297516df3f792dbcdd97806be1d Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 13:04:51 +0100 Subject: [PATCH 22/81] REPLACE: tests/data/test_data_handlers.py -> Changed YamlTransform to Transform --- tests/data/test_data_handlers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/data/test_data_handlers.py b/tests/data/test_data_handlers.py index 2bd7c596..4c24c8ed 100644 --- a/tests/data/test_data_handlers.py +++ b/tests/data/test_data_handlers.py @@ -16,7 +16,7 @@ YamlConfigDict, YamlSplitConfigDict, YamlSplitTransformDict, - YamlTransform, + Transform, TransformColumns, TransformColumnsTransformation, generate_split_configs, @@ -230,7 +230,7 @@ def test_transform_manager_initialize_transforms() -> None: def test_transform_manager_transform_column() -> None: """Test column transformation.""" transform_loader = loaders.TransformLoader() - dummy_config = YamlTransform( + dummy_config = Transform( transformation_name="GaussianNoise", columns=[ TransformColumns( From ac3aa009d94bb751f3ca95320c77373563bca025 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 13:07:39 +0100 Subject: [PATCH 23/81] REPLACE: tests/typing/test_typing.py -> Changed YamlTransform to Transform --- tests/typing/test_typing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/typing/test_typing.py b/tests/typing/test_typing.py index 6b849d18..b0db3cd2 100644 --- a/tests/typing/test_typing.py +++ b/tests/typing/test_typing.py @@ -51,7 +51,7 @@ def test_yaml_data_types() -> None: YamlSplit, YamlSplitConfigDict, YamlSplitTransformDict, - YamlTransform, + Transform, TransformColumns, TransformColumnsTransformation, ) From b1cc3aabaf3714bdd720b7228b37d495530d8cea Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 13:08:02 +0100 Subject: [PATCH 24/81] REPLACE: tests/utils/test_data_yaml.py -> Changed YamlTransform to Transform --- tests/utils/test_data_yaml.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/utils/test_data_yaml.py b/tests/utils/test_data_yaml.py index a96b5458..9b255728 100644 --- a/tests/utils/test_data_yaml.py +++ b/tests/utils/test_data_yaml.py @@ -62,8 +62,7 @@ def test_sub_config_validation( load_split_config_yaml_from_file: YamlConfigDict, ) -> None: """Test sub-config validation.""" - split_config = generate_split_transform_configs( - load_split_config_yaml_from_file)[0] + split_config = generate_split_transform_configs(load_split_config_yaml_from_file)[0] print(f"{split_config=}") YamlSplitTransformDict.model_validate(split_config) @@ -77,8 +76,8 @@ def test_expand_transform_parameter_combinations( results = yaml_data.expand_transform_parameter_combinations(transform) assert len(results) == 1 # Only one transform returned assert isinstance( - results[0], yaml_data.YamlTransform - ) # Should return YamlTransform objects + results[0], yaml_data.Transform + ) # Should return Transform objects def test_expand_transform_list_combinations( @@ -90,9 +89,9 @@ def test_expand_transform_list_combinations( ) # 4 combinations from first transform x 2 from second assert len(results) == 8 - # Each result should be a YamlTransform + # Each result should be a Transform for result in results: - assert isinstance(result, yaml_data.YamlTransform) + assert isinstance(result, yaml_data.Transform) assert isinstance(result.transformation_name, str) assert isinstance(result.columns, list) From 1c7e996c8c41db85f5cdcfb3e75bbdc0cf7801ea Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 13:57:35 +0100 Subject: [PATCH 25/81] REPLACE: src/stimulus/cli/check_model.py -> Changed YamlSplit* to Split* --- src/stimulus/cli/check_model.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/stimulus/cli/check_model.py b/src/stimulus/cli/check_model.py index b3ef0150..70c1d0a1 100755 --- a/src/stimulus/cli/check_model.py +++ b/src/stimulus/cli/check_model.py @@ -121,7 +121,7 @@ def main( """ with open(data_config_path) as file: data_config = yaml.safe_load(file) - data_config = yaml_data.YamlSplitTransformDict(**data_config) + data_config = yaml_data.SplitTransformDict(**data_config) with open(model_config_path) as file: model_config = yaml.safe_load(file) @@ -138,8 +138,7 @@ def main( logger.info("Model class loaded successfully.") - ray_config_loader = yaml_model_schema.YamlRayConfigLoader( - model=model_config) + ray_config_loader = yaml_model_schema.YamlRayConfigLoader(model=model_config) ray_config_dict = ray_config_loader.get_config().model_dump() ray_config_model = ray_config_loader.get_config() From 940bb59382c986fc13dfd6dbf66e9f1af7e1cabd Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 13:58:02 +0100 Subject: [PATCH 26/81] REPLACE: src/stimulus/cli/split_csv.py -> Changed YamlSplit* to Split* --- src/stimulus/cli/split_csv.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/stimulus/cli/split_csv.py b/src/stimulus/cli/split_csv.py index 44afc03d..756f55b6 100755 --- a/src/stimulus/cli/split_csv.py +++ b/src/stimulus/cli/split_csv.py @@ -7,7 +7,7 @@ from stimulus.data.data_handlers import DatasetProcessor, SplitManager from stimulus.data.loaders import SplitLoader -from stimulus.utils.yaml_data import YamlSplitConfigDict +from stimulus.utils.yaml_data import SplitConfigDict def get_args() -> argparse.Namespace: @@ -66,7 +66,7 @@ def main( # create a split manager from the config split_config = processor.dataset_manager.config.split with open(config_yaml) as f: - yaml_config = YamlSplitConfigDict(**yaml.safe_load(f)) + yaml_config = SplitConfigDict(**yaml.safe_load(f)) split_loader = SplitLoader(seed=yaml_config.global_params.seed) split_loader.initialize_splitter_from_config(split_config) split_manager = SplitManager(split_loader) From 9e4f3686cc1db8eb18912acbf6de9b1aaf8ce2f2 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 13:58:16 +0100 Subject: [PATCH 27/81] REPLACE: src/stimulus/cli/split_split.py -> Changed YamlSplit* to Split* --- src/stimulus/cli/split_split.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/stimulus/cli/split_split.py b/src/stimulus/cli/split_split.py index 63e265d9..fa53b4ef 100755 --- a/src/stimulus/cli/split_split.py +++ b/src/stimulus/cli/split_split.py @@ -13,7 +13,7 @@ from stimulus.utils.yaml_data import ( YamlConfigDict, - YamlSplitConfigDict, + SplitConfigDict, check_yaml_schema, dump_yaml_list_into_files, generate_split_configs, @@ -70,7 +70,7 @@ def main(config_yaml: str, out_dir_path: str) -> None: check_yaml_schema(yaml_config_dict) # generate the yaml files per split - split_configs: list[YamlSplitConfigDict] = generate_split_configs(yaml_config_dict) + split_configs: list[SplitConfigDict] = generate_split_configs(yaml_config_dict) # dump all the YAML configs into files dump_yaml_list_into_files(split_configs, out_dir_path, "test_split") From a4a9e6bb7e1da3eddad26d02a45675df2be3110f Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 13:58:34 +0100 Subject: [PATCH 28/81] REPLACE: src/stimulus/cli/split_transforms.py -> Changed YamlSplit* to Split* --- src/stimulus/cli/split_transforms.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/stimulus/cli/split_transforms.py b/src/stimulus/cli/split_transforms.py index f3e57717..10a05c3c 100644 --- a/src/stimulus/cli/split_transforms.py +++ b/src/stimulus/cli/split_transforms.py @@ -12,8 +12,8 @@ import yaml from stimulus.utils.yaml_data import ( - YamlSplitConfigDict, - YamlSplitTransformDict, + SplitConfigDict, + SplitTransformDict, dump_yaml_list_into_files, generate_split_transform_configs, ) @@ -60,10 +60,10 @@ def main(config_yaml: str, out_dir_path: str) -> None: with open(config_yaml) as conf_file: yaml_config = yaml.safe_load(conf_file) - yaml_config_dict: YamlSplitConfigDict = YamlSplitConfigDict(**yaml_config) + yaml_config_dict: SplitConfigDict = SplitConfigDict(**yaml_config) # Generate the yaml files for each transform - split_transform_configs: list[YamlSplitTransformDict] = ( + split_transform_configs: list[SplitTransformDict] = ( generate_split_transform_configs(yaml_config_dict) ) From d9baee106eb94670aa2ea8f836cbe9bc6bfebf30 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 13:59:09 +0100 Subject: [PATCH 29/81] REPLACE: src/stimulus/cli/transform_csv.py -> Changed YamlSplit* to Split* --- src/stimulus/cli/transform_csv.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/stimulus/cli/transform_csv.py b/src/stimulus/cli/transform_csv.py index 55b25882..15c06e68 100755 --- a/src/stimulus/cli/transform_csv.py +++ b/src/stimulus/cli/transform_csv.py @@ -7,7 +7,7 @@ from stimulus.data.data_handlers import DatasetProcessor, TransformManager from stimulus.data.loaders import TransformLoader -from stimulus.utils.yaml_data import YamlSplitConfigDict +from stimulus.utils.yaml_data import SplitConfigDict def get_args() -> argparse.Namespace: @@ -55,7 +55,7 @@ def main(data_csv: str, config_yaml: str, out_path: str) -> None: # initialize the transform manager transform_config = processor.dataset_manager.config.transforms with open(config_yaml) as f: - yaml_config = YamlSplitConfigDict(**yaml.safe_load(f)) + yaml_config = SplitConfigDict(**yaml.safe_load(f)) transform_loader = TransformLoader(seed=yaml_config.global_params.seed) print(transform_config) transform_loader.initialize_column_data_transformers_from_config(transform_config) From 7a99a54a01d5c4b364638edb408690e18e4afac9 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 13:59:26 +0100 Subject: [PATCH 30/81] REPLACE: src/stimulus/cli/tuning.py -> Changed YamlSplit* to Split* --- src/stimulus/cli/tuning.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/src/stimulus/cli/tuning.py b/src/stimulus/cli/tuning.py index 5beecdb4..75957e25 100755 --- a/src/stimulus/cli/tuning.py +++ b/src/stimulus/cli/tuning.py @@ -151,7 +151,7 @@ def get_args() -> argparse.Namespace: def main( model_path: str, data_path: str, - data_config: yaml_data.YamlSplitTransformDict, + data_config: yaml_data.SplitTransformDict, model_config_path: str, initial_weights: str | None = None, # noqa: ARG001 ray_results_dirpath: str | None = None, @@ -167,7 +167,7 @@ def main( Args: data_path: Path to input data file. model_path: Path to model file. - data_config: A YamlSplitTransformObject + data_config: A SplitTransformObject model_config_path: Path to model config file. initial_weights: Optional path to initial weights. ray_results_dirpath: Directory for ray results. @@ -179,8 +179,7 @@ def main( """ with open(model_config_path) as file: model_config_dict: dict[str, Any] = yaml.safe_load(file) - model_config: yaml_model_schema.Model = yaml_model_schema.Model( - **model_config_dict) + model_config: yaml_model_schema.Model = yaml_model_schema.Model(**model_config_dict) encoder_loader = loaders.EncoderLoader() encoder_loader.initialize_column_encoders_from_config( @@ -189,8 +188,7 @@ def main( model_class = launch_utils.import_class_from_file(model_path) - ray_config_loader = yaml_model_schema.YamlRayConfigLoader( - model=model_config) + ray_config_loader = yaml_model_schema.YamlRayConfigLoader(model=model_config) ray_config_model = ray_config_loader.get_config() tuner = raytune_learner.TuneWrapper( @@ -241,10 +239,10 @@ def run() -> None: """Run the model checking script.""" ray.init(address="auto", ignore_reinit_error=True) args = get_args() - # Try to convert the configuration file to a YamlSplitTransformDict - config_dict: yaml_data.YamlSplitTransformDict + # Try to convert the configuration file to a SplitTransformDict + config_dict: yaml_data.SplitTransformDict with open(args.data_config) as f: - config_dict = yaml_data.YamlSplitTransformDict(**yaml.safe_load(f)) + config_dict = yaml_data.SplitTransformDict(**yaml.safe_load(f)) main( data_path=args.data, model_path=args.model, From 5bdeaa2a63097592a6f9bbd439a8486e1ab63e1b Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 14:00:21 +0100 Subject: [PATCH 31/81] REPLACE: src/stimulus/data/data_handlers.py -> Changed YamlSplit* to Split* --- src/stimulus/data/data_handlers.py | 43 ++++++++++++------------------ 1 file changed, 17 insertions(+), 26 deletions(-) diff --git a/src/stimulus/data/data_handlers.py b/src/stimulus/data/data_handlers.py index 4af2f786..bf314303 100644 --- a/src/stimulus/data/data_handlers.py +++ b/src/stimulus/data/data_handlers.py @@ -50,11 +50,11 @@ class DatasetManager: def __init__( self, - config_dict: yaml_data.YamlSplitConfigDict, + config_dict: yaml_data.SplitConfigDict, ) -> None: """Initialize the DatasetManager.""" # self.config = self._load_config(config_path) - self.config: yaml_data.YamlSplitTransformDict = config_dict + self.config: yaml_data.SplitTransformDict = config_dict self.column_categories = self.categorize_columns_by_type() def categorize_columns_by_type(self) -> dict: @@ -95,7 +95,7 @@ def categorize_columns_by_type(self) -> dict: return {"input": input_columns, "label": label_columns, "meta": meta_columns} # TODO: Remove or change this function as the config is now preloaded - def _load_config(self, config_path: str) -> yaml_data.YamlSplitConfigDict: + def _load_config(self, config_path: str) -> yaml_data.SplitConfigDict: """Loads and parses a YAML configuration file. Args: @@ -113,8 +113,8 @@ def _load_config(self, config_path: str) -> yaml_data.YamlSplitConfigDict: with open(config_path) as file: # FIXME: cette fonction est appellé pour test_shuffle_csv et test_tune - return yaml_data.YamlSplitConfigDict(**yaml.safe_load(file)) - return yaml_data.YamlSplitTransformDict(**yaml.safe_load(file)) + return yaml_data.SplitConfigDict(**yaml.safe_load(file)) + return yaml_data.SplitTransformDict(**yaml.safe_load(file)) def get_split_columns(self) -> list[str]: """Get the columns that are used for splitting.""" @@ -190,8 +190,7 @@ def encode_column(self, column_name: str, column_data: list) -> torch.Tensor: >>> print(encoded.shape) torch.Size([2, 4, 4]) # 2 sequences, length 4, one-hot encoded """ - encode_all_function = self.encoder_loader.get_function_encode_all( - column_name) + encode_all_function = self.encoder_loader.get_function_encode_all(column_name) return encode_all_function(column_data) def encode_columns(self, column_data: dict) -> dict: @@ -287,13 +286,13 @@ class DatasetHandler: def __init__( self, - data_config: yaml_data.YamlSplitTransformDict, + data_config: yaml_data.SplitTransformDict, csv_path: str, ) -> None: """Initialize the DatasetHandler with required config. Args: - data_config (yaml_data.YamlSplitTransformDict): A YamlSplitTransformDict object holding the config. + data_config (yaml_data.SplitTransformDict): A SplitTransformDict object holding the config. csv_path (str): Path to the CSV data file. """ self.dataset_manager = DatasetManager(data_config) @@ -370,8 +369,7 @@ def add_split(self, split_manager: SplitManager, *, force: bool = False) -> None split_input_data = self.select_columns(split_columns) # get the split indices - train, validation, test = split_manager.get_split_indices( - split_input_data) + train, validation, test = split_manager.get_split_indices(split_input_data) # add the split column to the data split_column = np.full(len(self.data), -1).astype(int) @@ -422,7 +420,7 @@ class DatasetLoader(DatasetHandler): def __init__( self, - data_config: yaml_data.YamlSplitTransformDict, + data_config: yaml_data.SplitTransformDict, csv_path: str, encoder_loader: loaders.EncoderLoader, split: Union[int, None] = None, @@ -460,10 +458,8 @@ def get_all_items(self) -> tuple[dict, dict, dict]: self.dataset_manager.column_categories["label"], self.dataset_manager.column_categories["meta"], ) - input_data = self.encoder_manager.encode_dataframe( - self.data[input_columns]) - label_data = self.encoder_manager.encode_dataframe( - self.data[label_columns]) + input_data = self.encoder_manager.encode_dataframe(self.data[input_columns]) + label_data = self.encoder_manager.encode_dataframe(self.data[label_columns]) meta_data = {key: self.data[key].to_list() for key in meta_columns} return input_data, label_data, meta_data @@ -481,8 +477,7 @@ def load_csv_per_split(self, csv_path: str, split: int) -> pl.DataFrame: we are gonna load only the relevant data for it. """ if "split" not in self.columns: - raise ValueError( - "The category split is not present in the csv file") + raise ValueError("The category split is not present in the csv file") if split not in [0, 1, 2]: raise ValueError( f"The split value should be 0, 1 or 2. The specified split value is {split}" @@ -519,18 +514,15 @@ def __getitem__( label_data = self.encoder_manager.encode_dataframe( data_at_index[label_columns] ) - meta_data = {key: data_at_index[key].to_list() - for key in meta_columns} + meta_data = {key: data_at_index[key].to_list() for key in meta_columns} elif isinstance(idx, int): # For single row, convert to dict with column names as keys row_dict = dict(zip(self.data.columns, self.data.row(idx))) # Create single-row DataFrames for encoding - input_df = pl.DataFrame( - {col: [row_dict[col]] for col in input_columns}) - label_df = pl.DataFrame( - {col: [row_dict[col]] for col in label_columns}) + input_df = pl.DataFrame({col: [row_dict[col]] for col in input_columns}) + label_df = pl.DataFrame({col: [row_dict[col]] for col in label_columns}) input_data = self.encoder_manager.encode_dataframe(input_df) label_data = self.encoder_manager.encode_dataframe(label_df) @@ -546,7 +538,6 @@ def __getitem__( label_data = self.encoder_manager.encode_dataframe( data_at_index[label_columns] ) - meta_data = {key: data_at_index[key].to_list() - for key in meta_columns} + meta_data = {key: data_at_index[key].to_list() for key in meta_columns} return input_data, label_data, meta_data From 843ed2d6488f8a32ab3b284e724130e13583ae5b Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 14:00:42 +0100 Subject: [PATCH 32/81] REPLACE: src/stimulus/data/handlertorch.py -> Changed YamlSplit* to Split* --- src/stimulus/data/handlertorch.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/stimulus/data/handlertorch.py b/src/stimulus/data/handlertorch.py index 3b89140a..4160bae8 100644 --- a/src/stimulus/data/handlertorch.py +++ b/src/stimulus/data/handlertorch.py @@ -5,7 +5,7 @@ from torch.utils.data import Dataset from stimulus.data import data_handlers, loaders -from stimulus.utils.yaml_data import YamlSplitTransformDict +from stimulus.utils.yaml_data import SplitTransformDict class TorchDataset(Dataset): @@ -13,7 +13,7 @@ class TorchDataset(Dataset): def __init__( self, - data_config: YamlSplitTransformDict, + data_config: SplitTransformDict, csv_path: str, encoder_loader: loaders.EncoderLoader, split: Optional[int] = None, @@ -21,7 +21,7 @@ def __init__( """Initialize the TorchDataset. Args: - data_config: A YamlSplitTransformDict holding the configuration. + data_config: A SplitTransformDict holding the configuration. csv_path: Path to the CSV data file encoder_loader: Encoder loader instance split: Optional tuple containing split information From ac93ac7858ea80f1f920950e2bfa4cd5d668c849 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 14:03:09 +0100 Subject: [PATCH 33/81] REPLACE: src/stimulus/data/loaders.py -> Changed YamlSplit* to Split* --- src/stimulus/data/loaders.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/stimulus/data/loaders.py b/src/stimulus/data/loaders.py index 6861bcfe..affa1168 100644 --- a/src/stimulus/data/loaders.py +++ b/src/stimulus/data/loaders.py @@ -268,12 +268,12 @@ def set_splitter_as_attribute(self, splitter: Any) -> None: self.split = splitter def initialize_splitter_from_config( - self, split_config: yaml_data.YamlSplitConfigDict + self, split_config: yaml_data.SplitConfigDict ) -> None: """Build the loader from a config dictionary. Args: - split_config (yaml_data.YamlSplitConfigDict): Configuration dictionary containing split configurations. + split_config (yaml_data.SplitConfigDict): Configuration dictionary containing split configurations. """ splitter = self.get_splitter(split_config.split_method, split_config.params) self.set_splitter_as_attribute(splitter) From ac3b64df1890d2c8b8a14087753b18a5d929f083 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 14:03:32 +0100 Subject: [PATCH 34/81] REPLACE: src/stimulus/data/raytune_learner.py -> Changed YamlSplit* to Split* --- src/stimulus/learner/raytune_learner.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/src/stimulus/learner/raytune_learner.py b/src/stimulus/learner/raytune_learner.py index 95749b57..76e24150 100644 --- a/src/stimulus/learner/raytune_learner.py +++ b/src/stimulus/learner/raytune_learner.py @@ -21,7 +21,7 @@ from stimulus.learner.predict import PredictWrapper from stimulus.utils.generic_utils import set_general_seeds from stimulus.utils.yaml_model_schema import RayTuneModel -from stimulus.utils.yaml_data import YamlSplitTransformDict +from stimulus.utils.yaml_data import SplitTransformDict class CheckpointDict(TypedDict): @@ -36,7 +36,7 @@ class TuneWrapper: def __init__( self, model_config: RayTuneModel, - data_config: YamlSplitTransformDict, + data_config: SplitTransformDict, model_class: nn.Module, data_path: str, encoder_loader: EncoderLoader, @@ -87,8 +87,7 @@ def __init__( # add the data path to the config if not os.path.exists(data_path): - raise ValueError( - "Data path does not exist. Given path:" + data_path) + raise ValueError("Data path does not exist. Given path:" + data_path) self.config["data_path"] = os.path.abspath(data_path) # Set up tune_run path @@ -120,7 +119,7 @@ def __init__( def tuner_initialization( self, - data_config: YamlSplitTransformDict, + data_config: SplitTransformDict, data_path: str, encoder_loader: EncoderLoader, *, @@ -289,8 +288,7 @@ def step(self) -> dict: for _step_size in range(self.step_size): for x, y, _meta in self.training: # the loss dict could be unpacked with ** and the function declaration handle it differently like **kwargs. to be decided, personally find this more clean and understable. - self.model.batch( - x=x, y=y, optimizer=self.optimizer, **self.loss_dict) + self.model.batch(x=x, y=y, optimizer=self.optimizer, **self.loss_dict) return self.objective() def objective(self) -> dict[str, float]: @@ -327,8 +325,7 @@ def export_model(self, export_dir: str | None = None) -> None: """Export model to safetensors format.""" if export_dir is None: return - safe_save_model(self.model, os.path.join( - export_dir, "model.safetensors")) + safe_save_model(self.model, os.path.join(export_dir, "model.safetensors")) def load_checkpoint(self, checkpoint: dict[Any, Any] | None) -> None: """Load model and optimizer state from checkpoint.""" @@ -344,8 +341,7 @@ def load_checkpoint(self, checkpoint: dict[Any, Any] | None) -> None: def save_checkpoint(self, checkpoint_dir: str) -> dict[Any, Any]: """Save model and optimizer state to checkpoint.""" - safe_save_model(self.model, os.path.join( - checkpoint_dir, "model.safetensors")) + safe_save_model(self.model, os.path.join(checkpoint_dir, "model.safetensors")) torch.save( self.optimizer.state_dict(), os.path.join(checkpoint_dir, "optimizer.pt") ) From cb49240806d3b05a55c575eb3264ee8e45af42a8 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 14:06:43 +0100 Subject: [PATCH 35/81] REPLACE: src/stimulus/typing/__init.py -> Changed YamlSplit* to Split* --- src/stimulus/typing/__init__.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/stimulus/typing/__init__.py b/src/stimulus/typing/__init__.py index 77ca6ba3..811a7cf0 100644 --- a/src/stimulus/typing/__init__.py +++ b/src/stimulus/typing/__init__.py @@ -41,9 +41,9 @@ YamlConfigDict, GlobalParams, YamlSchema, - YamlSplit, - YamlSplitConfigDict, - YamlSplitTransformDict, + Split, + SplitConfigDict, + SplitTransformDict, Transform, TransformColumns, TransformColumnsTransformation, @@ -84,8 +84,8 @@ | YamlConfigDict | GlobalParams | YamlSchema - | YamlSplit - | YamlSplitConfigDict + | Split + | SplitConfigDict | Transform | TransformColumns | TransformColumnsTransformation From 6aff6b50c321790f3e078dbb9b6299d44acfc4a1 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 14:07:11 +0100 Subject: [PATCH 36/81] REPLACE: src/stimulus/utils/yaml_data.py -> Changed YamlSplit* to Split* --- src/stimulus/utils/yaml_data.py | 34 ++++++++++++++++----------------- 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/src/stimulus/utils/yaml_data.py b/src/stimulus/utils/yaml_data.py index d7d139c0..ce71c5ce 100644 --- a/src/stimulus/utils/yaml_data.py +++ b/src/stimulus/utils/yaml_data.py @@ -94,7 +94,7 @@ def validate_param_lists_across_columns( return columns -class YamlSplit(BaseModel): +class Split(BaseModel): """Model for split configuration.""" split_method: str @@ -108,26 +108,26 @@ class YamlConfigDict(BaseModel): global_params: GlobalParams columns: list[Columns] transforms: list[Transform] - split: list[YamlSplit] + split: list[Split] # TODO: Rename this class to SplitConfigDict -class YamlSplitConfigDict(BaseModel): +class SplitConfigDict(BaseModel): """Model for sub-configuration generated from main config.""" global_params: GlobalParams columns: list[Columns] transforms: list[Transform] - split: YamlSplit + split: Split -class YamlSplitTransformDict(BaseModel): +class SplitTransformDict(BaseModel): """Model for sub-configuration generated from main config.""" global_params: GlobalParams columns: list[Columns] transforms: Transform - split: YamlSplit + split: Split class YamlSchema(BaseModel): @@ -136,10 +136,10 @@ class YamlSchema(BaseModel): yaml_conf: YamlConfigDict -class YamlSplitSchema(BaseModel): +class SplitSchema(BaseModel): """Model for validating a Split YAML schema.""" - yaml_conf: YamlSplitConfigDict + yaml_conf: SplitConfigDict def extract_transform_parameters_at_index( @@ -238,7 +238,7 @@ def expand_transform_list_combinations( return sub_transforms -def generate_split_configs(yaml_config: YamlConfigDict) -> list[YamlSplitConfigDict]: +def generate_split_configs(yaml_config: YamlConfigDict) -> list[SplitConfigDict]: """Generates all possible split configuration from a YAML config. Takes a YAML configuration that may contain parameter lists and splits, @@ -273,7 +273,7 @@ def generate_split_configs(yaml_config: YamlConfigDict) -> list[YamlSplitConfigD sub_configs = [] for split in sub_splits: sub_configs.append( - YamlSplitConfigDict( + SplitConfigDict( global_params=yaml_config.global_params, columns=yaml_config.columns, transforms=yaml_config.transforms, @@ -284,8 +284,8 @@ def generate_split_configs(yaml_config: YamlConfigDict) -> list[YamlSplitConfigD def generate_split_transform_configs( - yaml_config: YamlSplitConfigDict, -) -> list[YamlSplitTransformDict]: + yaml_config: SplitConfigDict, +) -> list[SplitTransformDict]: """Generates all the transform configuration for a given split Takes a YAML configuration that may contain a transform or a list of transform, @@ -313,16 +313,14 @@ def generate_split_transform_configs( length will be the product of the number of parameter combinations and the number of splits. """ - if isinstance(yaml_config, dict) and not isinstance( - yaml_config, YamlSplitConfigDict - ): + if isinstance(yaml_config, dict) and not isinstance(yaml_config, SplitConfigDict): raise TypeError("Input must be a list of YamlSubConfigDict") sub_transforms = expand_transform_list_combinations(yaml_config.transforms) - split_transform_config: list[YamlSplitTransformDict] = [] + split_transform_config: list[SplitTransformDict] = [] for transform in sub_transforms: split_transform_config.append( - YamlSplitTransformDict( + SplitTransformDict( global_params=yaml_config.global_params, columns=yaml_config.columns, transforms=transform, @@ -333,7 +331,7 @@ def generate_split_transform_configs( def dump_yaml_list_into_files( - yaml_list: list[YamlSplitConfigDict], + yaml_list: list[SplitConfigDict], directory_path: str, base_name: str, ) -> None: From 1f3307268b39b0c073198d9f6762d14214055822 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 14:07:52 +0100 Subject: [PATCH 37/81] REPLACE: tests/cli/test_check_model.py -> Changed YamlSplit* to Split* --- tests/cli/test_check_model.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/cli/test_check_model.py b/tests/cli/test_check_model.py index 273ab63b..3f5a3014 100644 --- a/tests/cli/test_check_model.py +++ b/tests/cli/test_check_model.py @@ -8,7 +8,7 @@ import ray from stimulus.cli import check_model -from src.stimulus.utils.yaml_data import YamlSplitTransformDict +from src.stimulus.utils.yaml_data import SplitTransformDict @pytest.fixture @@ -63,11 +63,9 @@ def test_check_model_main( ray.init(ignore_reinit_error=True) # Verify all required files exist assert os.path.exists(data_path), f"Data file not found at {data_path}" - assert os.path.exists( - data_config), f"Data config not found at {data_config}" + assert os.path.exists(data_config), f"Data config not found at {data_config}" assert os.path.exists(model_path), f"Model file not found at {model_path}" - assert os.path.exists( - model_config), f"Model config not found at {model_config}" + assert os.path.exists(model_config), f"Model config not found at {model_config}" try: config_dict: yam From d661d30923c05a912babb90eabb0fabd854f5ab5 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 14:08:16 +0100 Subject: [PATCH 38/81] REPLACE: tests/cli/test_shuffle_csv.py -> Changed YamlSplit* to Split* --- tests/cli/test_shuffle_csv.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/cli/test_shuffle_csv.py b/tests/cli/test_shuffle_csv.py index 2c95d7f0..5d1dbbd3 100644 --- a/tests/cli/test_shuffle_csv.py +++ b/tests/cli/test_shuffle_csv.py @@ -9,7 +9,7 @@ import pytest from src.stimulus.cli.shuffle_csv import main -from src.stimulus.utils.yaml_data import YamlSplitTransformDict +from src.stimulus.utils.yaml_data import SplitTransformDict # Fixtures @@ -47,14 +47,12 @@ def test_shuffle_csv( with open(yaml_path) as f: if error: with pytest.raises(error): # type: ignore[call-overload] - config_dict: YamlSplitTransformDict = YamlSplitTransformDict( + config_dict: SplitTransformDict = SplitTransformDict( **yaml.safe_load(f) ) main(csv_path, config_dict, str(tmpdir / "test.csv")) else: - config_dict: YamlSplitTransformDict = YamlSplitTransformDict( - **yaml.safe_load(f) - ) + config_dict: SplitTransformDict = SplitTransformDict(**yaml.safe_load(f)) main(csv_path, config_dict, str(tmpdir / "test.csv")) with open(tmpdir / "test.csv") as file: hash = hashlib.md5(file.read().encode()).hexdigest() # noqa: S324 From a4f0595b81271f93f588f6460ec0a8a06d868387 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 14:08:30 +0100 Subject: [PATCH 39/81] REPLACE: tests/cli/test_tuning.py -> Changed YamlSplit* to Split* --- tests/cli/test_tuning.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/cli/test_tuning.py b/tests/cli/test_tuning.py index 1716de65..2a11e3ac 100644 --- a/tests/cli/test_tuning.py +++ b/tests/cli/test_tuning.py @@ -13,7 +13,7 @@ import yaml from src.stimulus.cli import tuning -from src.stimulus.utils.yaml_data import YamlSplitTransformDict +from src.stimulus.utils.yaml_data import SplitTransformDict @pytest.fixture @@ -122,9 +122,9 @@ def test_tuning_main( assert os.path.exists(model_config), f"Model config not found at {model_config}" try: - config_dict: YamlSplitTransformDict + config_dict: SplitTransformDict with open(data_config) as f: - config_dict = YamlSplitTransformDict(**yaml.safe_load(f)) + config_dict = SplitTransformDict(**yaml.safe_load(f)) results_dir = Path("tests/test_data/titanic/test_results/").resolve() results_dir.mkdir(parents=True, exist_ok=True) From 01e71026bd0108ca23fffbf33db70a6d26d5f689 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 14:09:14 +0100 Subject: [PATCH 40/81] REPLACE: tests/data/test_data_handlers.py -> Changed YamlSplit* to Split* --- tests/data/test_data_handlers.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/data/test_data_handlers.py b/tests/data/test_data_handlers.py index 4c24c8ed..c75d8d03 100644 --- a/tests/data/test_data_handlers.py +++ b/tests/data/test_data_handlers.py @@ -14,8 +14,8 @@ ) from stimulus.utils.yaml_data import ( YamlConfigDict, - YamlSplitConfigDict, - YamlSplitTransformDict, + SplitConfigDict, + SplitTransformDict, Transform, TransformColumns, TransformColumnsTransformation, @@ -70,24 +70,24 @@ def generate_sub_configs(base_config: YamlConfigDict) -> list[YamlConfigDict]: Returns: list[YamlConfigDict]: List of generated configurations """ - split_configs: list[YamlSplitConfigDict] = generate_split_configs(base_config) - split_transform_list: list[YamlSplitTransformDict] = [] + split_configs: list[SplitConfigDict] = generate_split_configs(base_config) + split_transform_list: list[SplitTransformDict] = [] for split in split_configs: split_transform_list.extend(generate_split_transform_configs(split)) return split_transform_list @pytest.fixture -def dump_single_split_config_to_disk() -> YamlSplitTransformDict: +def dump_single_split_config_to_disk() -> SplitTransformDict: """Get path for dumping single split config. Returns: str: Path to dump config file """ - config_dict: YamlSplitTransformDict + config_dict: SplitTransformDict path: str = "tests/test_data/titanic/titanic_sub_config.yaml" with open(path) as f: - config_dict = YamlSplitTransformDict(**yaml.safe_load(f)) + config_dict = SplitTransformDict(**yaml.safe_load(f)) return config_dict @@ -143,7 +143,7 @@ def split_loader(generate_sub_configs: list[YamlConfigDict]) -> loaders.SplitLoa # Test DatasetManager def test_dataset_manager_init( - dump_single_split_config_to_disk: YamlSplitTransformDict, + dump_single_split_config_to_disk: SplitTransformDict, ) -> None: """Test initialization of DatasetManager.""" manager = DatasetManager(dump_single_split_config_to_disk) From 60d146701d5ab33edd1277a385a1009656d06f0c Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 14:09:49 +0100 Subject: [PATCH 41/81] REPLACE: tests/data/test_handlertorch.py -> Changed YamlSplit* to Split* --- tests/data/test_handlertorch.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/data/test_handlertorch.py b/tests/data/test_handlertorch.py index dff24909..0e60d1a1 100644 --- a/tests/data/test_handlertorch.py +++ b/tests/data/test_handlertorch.py @@ -40,12 +40,12 @@ def titanic_yaml_config(titanic_config_path: str) -> dict: dict: Loaded YAML configuration """ with open(titanic_config_path) as file: - return yaml_data.YamlSplitTransformDict(**yaml.safe_load(file)) + return yaml_data.SplitTransformDict(**yaml.safe_load(file)) @pytest.fixture def titanic_encoder_loader( - titanic_yaml_config: yaml_data.YamlSplitTransformDict, + titanic_yaml_config: yaml_data.SplitTransformDict, ) -> loaders.EncoderLoader: """Get Titanic encoder loader.""" loader = loaders.EncoderLoader() @@ -59,9 +59,9 @@ def test_init_handlertorch( titanic_encoder_loader: loaders.EncoderLoader, ) -> None: """Test TorchDataset initialization.""" - data_config: yaml_data.YamlSplitTransformDict + data_config: yaml_data.SplitTransformDict with open(titanic_config_path) as f: - data_config = yaml_data.YamlSplitTransformDict(**yaml.safe_load(f)) + data_config = yaml_data.SplitTransformDict(**yaml.safe_load(f)) handlertorch.TorchDataset( data_config=data_config, csv_path=titanic_csv_path, @@ -81,9 +81,9 @@ def test_len_handlertorch( titanic_csv_path: Path to CSV file titanic_encoder_loader: Encoder loader instance """ - data_config: yaml_data.YamlSplitTransformDict + data_config: yaml_data.SplitTransformDict with open(titanic_config_path) as f: - data_config = yaml_data.YamlSplitTransformDict(**yaml.safe_load(f)) + data_config = yaml_data.SplitTransformDict(**yaml.safe_load(f)) dataset = handlertorch.TorchDataset( data_config=data_config, csv_path=titanic_csv_path, @@ -104,9 +104,9 @@ def test_getitem_handlertorch_slice( titanic_csv_path: Path to CSV file titanic_encoder_loader: Encoder loader instance """ - data_config: yaml_data.YamlSplitTransformDict + data_config: yaml_data.SplitTransformDict with open(titanic_config_path) as f: - data_config = yaml_data.YamlSplitTransformDict(**yaml.safe_load(f)) + data_config = yaml_data.SplitTransformDict(**yaml.safe_load(f)) dataset = handlertorch.TorchDataset( data_config=data_config, csv_path=titanic_csv_path, @@ -128,9 +128,9 @@ def test_getitem_handlertorch_int( titanic_csv_path: Path to CSV file titanic_encoder_loader: Encoder loader instance """ - data_config: yaml_data.YamlSplitTransformDict + data_config: yaml_data.SplitTransformDict with open(titanic_config_path) as f: - data_config = yaml_data.YamlSplitTransformDict(**yaml.safe_load(f)) + data_config = yaml_data.SplitTransformDict(**yaml.safe_load(f)) dataset = handlertorch.TorchDataset( data_config=data_config, csv_path=titanic_csv_path, From f12638a40e314545b917097640d8630be1f3a05c Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 14:10:27 +0100 Subject: [PATCH 42/81] REPLACE: tests/learner/test_raytune_learner.py -> Changed YamlSplit* to Split* --- tests/learner/test_raytune_learner.py | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/tests/learner/test_raytune_learner.py b/tests/learner/test_raytune_learner.py index 8ade0911..0e8dac17 100644 --- a/tests/learner/test_raytune_learner.py +++ b/tests/learner/test_raytune_learner.py @@ -10,7 +10,7 @@ from stimulus.data.handlertorch import TorchDataset from stimulus.data.loaders import EncoderLoader from stimulus.learner.raytune_learner import TuneWrapper -from stimulus.utils.yaml_data import YamlSplitConfigDict, YamlSplitTransformDict +from stimulus.utils.yaml_data import SplitConfigDict, SplitTransformDict from stimulus.utils.yaml_model_schema import Model, RayTuneModel, YamlRayConfigLoader from tests.test_model import titanic_model @@ -30,7 +30,7 @@ def encoder_loader() -> EncoderLoader: data_config = yaml.safe_load(file) encoder_loader = EncoderLoader() encoder_loader.initialize_column_encoders_from_config( - YamlSplitTransformDict(**data_config).columns + SplitTransformDict(**data_config).columns ) return encoder_loader @@ -57,9 +57,9 @@ def test_tunewrapper_init( ray.init(ignore_reinit_error=True) try: - data_config: YamlSplitTransformDict + data_config: SplitTransformDict with open("tests/test_data/titanic/titanic_sub_config.yaml") as f: - data_config = YamlSplitTransformDict(**yaml.safe_load(f)) + data_config = SplitTransformDict(**yaml.safe_load(f)) tune_wrapper = TuneWrapper( model_config=ray_config_loader, @@ -68,8 +68,7 @@ def test_tunewrapper_init( data_config=data_config, encoder_loader=encoder_loader, seed=42, - ray_results_dir=os.path.abspath( - "tests/test_data/titanic/ray_results"), + ray_results_dir=os.path.abspath("tests/test_data/titanic/ray_results"), tune_run_name="test_run", debug=False, autoscaler=False, @@ -83,8 +82,7 @@ def test_tunewrapper_init( if os.path.exists("tests/test_data/titanic/ray_results"): import shutil - shutil.rmtree("tests/test_data/titanic/ray_results", - ignore_errors=True) + shutil.rmtree("tests/test_data/titanic/ray_results", ignore_errors=True) def test_tune_wrapper_tune( @@ -98,9 +96,9 @@ def test_tune_wrapper_tune( ray.init(ignore_reinit_error=True) try: - data_config: YamlSplitTransformDict + data_config: SplitTransformDict with open("tests/test_data/titanic/titanic_sub_config.yaml") as f: - data_config = YamlSplitTransformDict(**yaml.safe_load(f)) + data_config = SplitTransformDict(**yaml.safe_load(f)) tune_wrapper = TuneWrapper( model_config=ray_config_loader, @@ -109,8 +107,7 @@ def test_tune_wrapper_tune( data_config=data_config, encoder_loader=encoder_loader, seed=42, - ray_results_dir=os.path.abspath( - "tests/test_data/titanic/ray_results"), + ray_results_dir=os.path.abspath("tests/test_data/titanic/ray_results"), tune_run_name="test_run", debug=False, autoscaler=False, @@ -125,5 +122,4 @@ def test_tune_wrapper_tune( if os.path.exists("tests/test_data/titanic/ray_results"): import shutil - shutil.rmtree("tests/test_data/titanic/ray_results", - ignore_errors=True) + shutil.rmtree("tests/test_data/titanic/ray_results", ignore_errors=True) From 8d7030dbac52da3b636c49aae133ec822ecfdd15 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 14:10:56 +0100 Subject: [PATCH 43/81] REPLACE: tests/typing/test_typing.py -> Changed YamlSplit* to Split* --- tests/typing/test_typing.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/typing/test_typing.py b/tests/typing/test_typing.py index b0db3cd2..9774506f 100644 --- a/tests/typing/test_typing.py +++ b/tests/typing/test_typing.py @@ -48,9 +48,9 @@ def test_yaml_data_types() -> None: YamlConfigDict, GlobalParams, YamlSchema, - YamlSplit, - YamlSplitConfigDict, - YamlSplitTransformDict, + Split, + SplitConfigDict, + SplitTransformDict, Transform, TransformColumns, TransformColumnsTransformation, From 721b7fadcc5a7080bbdbf8ecbb16fb013b9b1cda Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 14:11:20 +0100 Subject: [PATCH 44/81] REPLACE: tests/utils/test_data_yaml.py -> Changed YamlSplit* to Split* --- tests/utils/test_data_yaml.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/utils/test_data_yaml.py b/tests/utils/test_data_yaml.py index 9b255728..ab7a36c3 100644 --- a/tests/utils/test_data_yaml.py +++ b/tests/utils/test_data_yaml.py @@ -6,8 +6,8 @@ from src.stimulus.utils import yaml_data from src.stimulus.utils.yaml_data import ( YamlConfigDict, - YamlSplitConfigDict, - YamlSplitTransformDict, + SplitConfigDict, + SplitTransformDict, generate_split_configs, generate_split_transform_configs, ) @@ -28,11 +28,11 @@ def load_titanic_yaml_from_file() -> YamlConfigDict: @pytest.fixture -def load_split_config_yaml_from_file() -> YamlSplitConfigDict: +def load_split_config_yaml_from_file() -> SplitConfigDict: """Fixture that loads a test unique split YAML configuration file.""" with open("tests/test_data/titanic/titanic_unique_split.yaml") as f: yaml_dict = yaml.safe_load(f) - return YamlSplitConfigDict(**yaml_dict) + return SplitConfigDict(**yaml_dict) @pytest.fixture @@ -55,7 +55,7 @@ def load_wrong_type_yaml() -> dict: def test_split_config_validation(load_titanic_yaml_from_file: YamlConfigDict) -> None: """Test split configuration validation.""" split_config = generate_split_configs(load_titanic_yaml_from_file)[0] - YamlSplitConfigDict.model_validate(split_config) + SplitConfigDict.model_validate(split_config) def test_sub_config_validation( @@ -64,7 +64,7 @@ def test_sub_config_validation( """Test sub-config validation.""" split_config = generate_split_transform_configs(load_split_config_yaml_from_file)[0] print(f"{split_config=}") - YamlSplitTransformDict.model_validate(split_config) + SplitTransformDict.model_validate(split_config) def test_expand_transform_parameter_combinations( @@ -101,7 +101,7 @@ def test_generate_data_configs( ) -> None: """Tests generating all possible data configurations.""" split_configs = yaml_data.generate_split_configs(load_yaml_from_file) - configs: list[YamlSplitTransformDict] = [] + configs: list[SplitTransformDict] = [] for s_conf in split_configs: configs.extend(generate_split_transform_configs(s_conf)) @@ -111,7 +111,7 @@ def test_generate_data_configs( for i, config in enumerate(configs): assert isinstance( config, - yaml_data.YamlSplitTransformDict, + yaml_data.SplitTransformDict, ), f"Config {i} is type {type(config)}, expected YamlSubConfigDict" From 116a3f5c38a0ccb6f41d1b99b0be233f8b49e9ab Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 14:15:37 +0100 Subject: [PATCH 45/81] REPLACE: src/stimulus/utils/yaml_data.py -> Changed YamlConfigDict to ConfigDict --- src/stimulus/utils/yaml_data.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/stimulus/utils/yaml_data.py b/src/stimulus/utils/yaml_data.py index ce71c5ce..78f503fe 100644 --- a/src/stimulus/utils/yaml_data.py +++ b/src/stimulus/utils/yaml_data.py @@ -102,7 +102,7 @@ class Split(BaseModel): split_input_columns: list[str] -class YamlConfigDict(BaseModel): +class ConfigDict(BaseModel): """Model for main YAML configuration.""" global_params: GlobalParams @@ -133,7 +133,7 @@ class SplitTransformDict(BaseModel): class YamlSchema(BaseModel): """Model for validating YAML schema.""" - yaml_conf: YamlConfigDict + yaml_conf: ConfigDict class SplitSchema(BaseModel): @@ -238,7 +238,7 @@ def expand_transform_list_combinations( return sub_transforms -def generate_split_configs(yaml_config: YamlConfigDict) -> list[SplitConfigDict]: +def generate_split_configs(yaml_config: ConfigDict) -> list[SplitConfigDict]: """Generates all possible split configuration from a YAML config. Takes a YAML configuration that may contain parameter lists and splits, @@ -266,8 +266,8 @@ def generate_split_configs(yaml_config: YamlConfigDict) -> list[SplitConfigDict] length will be the product of the number of parameter combinations and the number of splits. """ - if isinstance(yaml_config, dict) and not isinstance(yaml_config, YamlConfigDict): - raise TypeError("Input must be a YamlConfigDict object") + if isinstance(yaml_config, dict) and not isinstance(yaml_config, ConfigDict): + raise TypeError("Input must be a ConfigDict object") sub_splits = yaml_config.split sub_configs = [] @@ -440,14 +440,14 @@ def fix_params(input_dict: dict[str, Any]) -> dict[str, Any]: ) -def check_yaml_schema(config_yaml: YamlConfigDict) -> str: +def check_yaml_schema(config_yaml: ConfigDict) -> str: """Validate YAML configuration fields have correct types. If the children field is specific to a parent, the children fields class is hosted in the parent fields class. If any field in not the right type, the function prints an error message explaining the problem and exits the python code. Args: - config_yaml: The YamlConfigDict containing the fields of the yaml configuration file + config_yaml: The ConfigDict containing the fields of the yaml configuration file Returns: str: Empty string if validation succeeds From d75b82f098e29365a0e040f9211900ff581a2fe0 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 14:16:19 +0100 Subject: [PATCH 46/81] REPLACE: src/stimulus/cli/split_split.py -> Changed YamlConfigDict to ConfigDict --- src/stimulus/cli/split_split.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/stimulus/cli/split_split.py b/src/stimulus/cli/split_split.py index fa53b4ef..94a760c0 100755 --- a/src/stimulus/cli/split_split.py +++ b/src/stimulus/cli/split_split.py @@ -12,7 +12,7 @@ import yaml from stimulus.utils.yaml_data import ( - YamlConfigDict, + ConfigDict, SplitConfigDict, check_yaml_schema, dump_yaml_list_into_files, @@ -64,7 +64,7 @@ def main(config_yaml: str, out_dir_path: str) -> None: with open(config_yaml) as conf_file: yaml_config = yaml.safe_load(conf_file) - yaml_config_dict: YamlConfigDict = YamlConfigDict(**yaml_config) + yaml_config_dict: ConfigDict = ConfigDict(**yaml_config) # check if the yaml schema is correct # FIXME: isn't it redundant to check and already class with pydantic ? check_yaml_schema(yaml_config_dict) From b13bd9f20ce8ed7c503e03e9d00a02f8a71dac00 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 14:16:52 +0100 Subject: [PATCH 47/81] REPLACE: src/stimulus/typing/__init__.py -> Changed YamlConfigDict to ConfigDict --- src/stimulus/typing/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/stimulus/typing/__init__.py b/src/stimulus/typing/__init__.py index 811a7cf0..b3a0105f 100644 --- a/src/stimulus/typing/__init__.py +++ b/src/stimulus/typing/__init__.py @@ -38,7 +38,7 @@ from stimulus.utils.yaml_data import ( Columns, ColumnsEncoder, - YamlConfigDict, + ConfigDict, GlobalParams, YamlSchema, Split, @@ -81,7 +81,7 @@ YamlData: TypeAlias = ( Columns | ColumnsEncoder - | YamlConfigDict + | ConfigDict | GlobalParams | YamlSchema | Split From 11978da4edcf7c088e75799690cb609d88de271f Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 14:17:28 +0100 Subject: [PATCH 48/81] REPLACE: tests/data/test_data_handlers.py -> Changed YamlConfigDict to ConfigDict --- tests/data/test_data_handlers.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/data/test_data_handlers.py b/tests/data/test_data_handlers.py index c75d8d03..a5f6c288 100644 --- a/tests/data/test_data_handlers.py +++ b/tests/data/test_data_handlers.py @@ -13,7 +13,7 @@ TransformManager, ) from stimulus.utils.yaml_data import ( - YamlConfigDict, + ConfigDict, SplitConfigDict, SplitTransformDict, Transform, @@ -47,28 +47,28 @@ def config_path() -> str: @pytest.fixture -def base_config(config_path: str) -> YamlConfigDict: +def base_config(config_path: str) -> ConfigDict: """Load base configuration from YAML file. Args: config_path: Path to config file Returns: - YamlConfigDict: Loaded configuration + ConfigDict: Loaded configuration """ with open(config_path) as f: - return YamlConfigDict(**yaml.safe_load(f)) + return ConfigDict(**yaml.safe_load(f)) @pytest.fixture -def generate_sub_configs(base_config: YamlConfigDict) -> list[YamlConfigDict]: +def generate_sub_configs(base_config: ConfigDict) -> list[ConfigDict]: """Generate all possible configurations from base config. Args: base_config: Base configuration to generate from Returns: - list[YamlConfigDict]: List of generated configurations + list[ConfigDict]: List of generated configurations """ split_configs: list[SplitConfigDict] = generate_split_configs(base_config) split_transform_list: list[SplitTransformDict] = [] @@ -93,7 +93,7 @@ def dump_single_split_config_to_disk() -> SplitTransformDict: # Loader fixtures @pytest.fixture -def encoder_loader(generate_sub_configs: list[YamlConfigDict]) -> loaders.EncoderLoader: +def encoder_loader(generate_sub_configs: list[ConfigDict]) -> loaders.EncoderLoader: """Create encoder loader with initialized encoders. Args: @@ -109,7 +109,7 @@ def encoder_loader(generate_sub_configs: list[YamlConfigDict]) -> loaders.Encode @pytest.fixture def transform_loader( - generate_sub_configs: list[YamlConfigDict], + generate_sub_configs: list[ConfigDict], ) -> loaders.TransformLoader: """Create transform loader with initialized transformers. @@ -127,7 +127,7 @@ def transform_loader( @pytest.fixture -def split_loader(generate_sub_configs: list[YamlConfigDict]) -> loaders.SplitLoader: +def split_loader(generate_sub_configs: list[ConfigDict]) -> loaders.SplitLoader: """Create split loader with initialized splitter. Args: From 6f3c4e9c9a44b2810c729e1f434e83df40be16e6 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 14:17:50 +0100 Subject: [PATCH 49/81] REPLACE: tests/data/test_experiment.py -> Changed YamlConfigDict to ConfigDict --- tests/data/test_experiment.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/tests/data/test_experiment.py b/tests/data/test_experiment.py index aa261dc9..32244bcd 100644 --- a/tests/data/test_experiment.py +++ b/tests/data/test_experiment.py @@ -23,18 +23,18 @@ def dna_experiment_config_path() -> str: @pytest.fixture def dna_experiment_sub_yaml( dna_experiment_config_path: str, -) -> yaml_data.YamlConfigDict: +) -> yaml_data.ConfigDict: """Get a sub-configuration from the DNA experiment config. Args: dna_experiment_config_path: Path to the DNA experiment config file Returns: - yaml_data.YamlConfigDict: First generated sub-configuration + yaml_data.ConfigDict: First generated sub-configuration """ with open(dna_experiment_config_path) as f: yaml_dict = yaml.safe_load(f) - yaml_config = yaml_data.YamlConfigDict(**yaml_dict) + yaml_config = yaml_data.ConfigDict(**yaml_dict) yaml_split_configs = yaml_data.generate_split_configs(yaml_config) yaml_split_transform_configs = yaml_data.generate_split_transform_configs( @@ -103,7 +103,7 @@ def test_set_encoder_as_attribute( def test_build_experiment_class_encoder_dict( - dna_experiment_sub_yaml: yaml_data.YamlConfigDict, + dna_experiment_sub_yaml: yaml_data.ConfigDict, ) -> None: """Test the build_experiment_class_encoder_dict method. @@ -124,8 +124,7 @@ def test_get_data_transformer() -> None: """Test the get_data_transformer method of the TransformLoader class.""" experiment = loaders.TransformLoader() transformer = experiment.get_data_transformer("ReverseComplement") - assert isinstance( - transformer, data_transformation_generators.ReverseComplement) + assert isinstance(transformer, data_transformation_generators.ReverseComplement) def test_set_data_transformer_as_attribute() -> None: @@ -138,7 +137,7 @@ def test_set_data_transformer_as_attribute() -> None: def test_initialize_column_data_transformers_from_config( - dna_experiment_sub_yaml: yaml_data.YamlConfigDict, + dna_experiment_sub_yaml: yaml_data.ConfigDict, ) -> None: """Test initializing column data transformers from config. @@ -158,7 +157,7 @@ def test_initialize_column_data_transformers_from_config( def test_initialize_splitter_from_config( - dna_experiment_sub_yaml: yaml_data.YamlConfigDict, + dna_experiment_sub_yaml: yaml_data.ConfigDict, ) -> None: """Test initializing splitter from configuration. From b22b12477660f1b80e0f6aeb95781cc2eab257bd Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 14:18:16 +0100 Subject: [PATCH 50/81] REPLACE: tests/typing/test_typing.py -> Changed YamlConfigDict to ConfigDict --- tests/typing/test_typing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/typing/test_typing.py b/tests/typing/test_typing.py index 9774506f..1bb0debe 100644 --- a/tests/typing/test_typing.py +++ b/tests/typing/test_typing.py @@ -45,7 +45,7 @@ def test_yaml_data_types() -> None: from stimulus.typing import ( Columns, ColumnsEncoder, - YamlConfigDict, + ConfigDict, GlobalParams, YamlSchema, Split, From fc402470549bc4c22dea8fa856ca32d6e58dfb18 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 14:18:44 +0100 Subject: [PATCH 51/81] REPLACE: tests/utils/test_data_yaml.py -> Changed YamlConfigDict to ConfigDict --- tests/utils/test_data_yaml.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/utils/test_data_yaml.py b/tests/utils/test_data_yaml.py index ab7a36c3..32b516e0 100644 --- a/tests/utils/test_data_yaml.py +++ b/tests/utils/test_data_yaml.py @@ -5,7 +5,7 @@ from src.stimulus.utils import yaml_data from src.stimulus.utils.yaml_data import ( - YamlConfigDict, + ConfigDict, SplitConfigDict, SplitTransformDict, generate_split_configs, @@ -20,11 +20,11 @@ def titanic_csv_path() -> str: @pytest.fixture -def load_titanic_yaml_from_file() -> YamlConfigDict: +def load_titanic_yaml_from_file() -> ConfigDict: """Fixture that loads a test YAML configuration file.""" with open("tests/test_data/titanic/titanic.yaml") as f: yaml_dict = yaml.safe_load(f) - return YamlConfigDict(**yaml_dict) + return ConfigDict(**yaml_dict) @pytest.fixture @@ -36,13 +36,13 @@ def load_split_config_yaml_from_file() -> SplitConfigDict: @pytest.fixture -def load_yaml_from_file() -> YamlConfigDict: +def load_yaml_from_file() -> ConfigDict: """Fixture that loads a test YAML configuration file.""" with open( "tests/test_data/dna_experiment/dna_experiment_config_template.yaml" ) as f: yaml_dict = yaml.safe_load(f) - return YamlConfigDict(**yaml_dict) + return ConfigDict(**yaml_dict) @pytest.fixture @@ -52,14 +52,14 @@ def load_wrong_type_yaml() -> dict: return yaml.safe_load(f) -def test_split_config_validation(load_titanic_yaml_from_file: YamlConfigDict) -> None: +def test_split_config_validation(load_titanic_yaml_from_file: ConfigDict) -> None: """Test split configuration validation.""" split_config = generate_split_configs(load_titanic_yaml_from_file)[0] SplitConfigDict.model_validate(split_config) def test_sub_config_validation( - load_split_config_yaml_from_file: YamlConfigDict, + load_split_config_yaml_from_file: ConfigDict, ) -> None: """Test sub-config validation.""" split_config = generate_split_transform_configs(load_split_config_yaml_from_file)[0] @@ -68,7 +68,7 @@ def test_sub_config_validation( def test_expand_transform_parameter_combinations( - load_yaml_from_file: YamlConfigDict, + load_yaml_from_file: ConfigDict, ) -> None: """Tests expanding transforms with parameter lists into individual transforms.""" # Test transform with multiple parameter lists @@ -81,7 +81,7 @@ def test_expand_transform_parameter_combinations( def test_expand_transform_list_combinations( - load_yaml_from_file: YamlConfigDict, + load_yaml_from_file: ConfigDict, ) -> None: """Tests expanding a list of transforms into all parameter combinations.""" results = yaml_data.expand_transform_list_combinations( @@ -97,7 +97,7 @@ def test_expand_transform_list_combinations( def test_generate_data_configs( - load_yaml_from_file: YamlConfigDict, + load_yaml_from_file: ConfigDict, ) -> None: """Tests generating all possible data configurations.""" split_configs = yaml_data.generate_split_configs(load_yaml_from_file) From 306750deb20a2cf231e819f0c161dca697c2b598 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 14:25:39 +0100 Subject: [PATCH 52/81] REPLACE: src/stimulus/utils/yaml_data.py -> Changed YamlSchema to Schema --- src/stimulus/utils/yaml_data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/stimulus/utils/yaml_data.py b/src/stimulus/utils/yaml_data.py index 78f503fe..d31ed051 100644 --- a/src/stimulus/utils/yaml_data.py +++ b/src/stimulus/utils/yaml_data.py @@ -130,7 +130,7 @@ class SplitTransformDict(BaseModel): split: Split -class YamlSchema(BaseModel): +class Schema(BaseModel): """Model for validating YAML schema.""" yaml_conf: ConfigDict @@ -456,7 +456,7 @@ def check_yaml_schema(config_yaml: ConfigDict) -> str: ValueError: If validation fails """ try: - YamlSchema(yaml_conf=config_yaml) + Schema(yaml_conf=config_yaml) except ValidationError as e: # Use logging instead of print for error handling raise ValueError("Wrong type on a field, see the pydantic report above") from e From 4a5a1e3a098bb644aef9399b6c5ccffac184415a Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 14:26:05 +0100 Subject: [PATCH 53/81] REPLACE: src/stimulus/typing/__init__.py -> Changed YamlSchema to Schema --- src/stimulus/typing/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/stimulus/typing/__init__.py b/src/stimulus/typing/__init__.py index b3a0105f..198a1d41 100644 --- a/src/stimulus/typing/__init__.py +++ b/src/stimulus/typing/__init__.py @@ -40,7 +40,7 @@ ColumnsEncoder, ConfigDict, GlobalParams, - YamlSchema, + Schema, Split, SplitConfigDict, SplitTransformDict, @@ -83,7 +83,7 @@ | ColumnsEncoder | ConfigDict | GlobalParams - | YamlSchema + | Schema | Split | SplitConfigDict | Transform From ff2c7f2808c9d1aa801f95f823b3ede33825a3aa Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 14:26:45 +0100 Subject: [PATCH 54/81] REPLACE: tests/typing/test_typing.py -> Changed YamlSchema to Schema --- tests/typing/test_typing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/typing/test_typing.py b/tests/typing/test_typing.py index 1bb0debe..7365a131 100644 --- a/tests/typing/test_typing.py +++ b/tests/typing/test_typing.py @@ -47,7 +47,7 @@ def test_yaml_data_types() -> None: ColumnsEncoder, ConfigDict, GlobalParams, - YamlSchema, + Schema, Split, SplitConfigDict, SplitTransformDict, From fff0bb0641a7485c0bda2e4fd719ae4e0b8d27d0 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 15:07:29 +0100 Subject: [PATCH 55/81] CHANGE: src/stimulus/yaml_data.py -> Changed some variables to remove yaml --- src/stimulus/utils/yaml_data.py | 44 ++++++++++++++++----------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/src/stimulus/utils/yaml_data.py b/src/stimulus/utils/yaml_data.py index d31ed051..65bc24f7 100644 --- a/src/stimulus/utils/yaml_data.py +++ b/src/stimulus/utils/yaml_data.py @@ -133,13 +133,13 @@ class SplitTransformDict(BaseModel): class Schema(BaseModel): """Model for validating YAML schema.""" - yaml_conf: ConfigDict + conf: ConfigDict class SplitSchema(BaseModel): """Model for validating a Split YAML schema.""" - yaml_conf: SplitConfigDict + conf: SplitConfigDict def extract_transform_parameters_at_index( @@ -238,7 +238,7 @@ def expand_transform_list_combinations( return sub_transforms -def generate_split_configs(yaml_config: ConfigDict) -> list[SplitConfigDict]: +def generate_split_configs(config: ConfigDict) -> list[SplitConfigDict]: """Generates all possible split configuration from a YAML config. Takes a YAML configuration that may contain parameter lists and splits, @@ -257,7 +257,7 @@ def generate_split_configs(yaml_config: ConfigDict) -> list[SplitConfigDict]: split: [0.8, 0.2] Args: - yaml_config: The source YAML configuration containing transforms with + config: The source YAML configuration containing transforms with parameter lists and multiple splits. Returns: @@ -266,17 +266,17 @@ def generate_split_configs(yaml_config: ConfigDict) -> list[SplitConfigDict]: length will be the product of the number of parameter combinations and the number of splits. """ - if isinstance(yaml_config, dict) and not isinstance(yaml_config, ConfigDict): + if isinstance(config, dict) and not isinstance(yaml_config, ConfigDict): raise TypeError("Input must be a ConfigDict object") - sub_splits = yaml_config.split + sub_splits = config.split sub_configs = [] for split in sub_splits: sub_configs.append( SplitConfigDict( - global_params=yaml_config.global_params, - columns=yaml_config.columns, - transforms=yaml_config.transforms, + global_params=config.global_params, + columns=config.columns, + transforms=config.transforms, split=split, ), ) @@ -284,7 +284,7 @@ def generate_split_configs(yaml_config: ConfigDict) -> list[SplitConfigDict]: def generate_split_transform_configs( - yaml_config: SplitConfigDict, + config: SplitConfigDict, ) -> list[SplitTransformDict]: """Generates all the transform configuration for a given split @@ -304,7 +304,7 @@ def generate_split_transform_configs( split: [0.7, 0.3] Args: - yaml_config: The source YAML configuration containing each + config: The source YAML configuration containing each a split with transforms with parameters lists Returns: @@ -313,25 +313,25 @@ def generate_split_transform_configs( length will be the product of the number of parameter combinations and the number of splits. """ - if isinstance(yaml_config, dict) and not isinstance(yaml_config, SplitConfigDict): + if isinstance(config, dict) and not isinstance(yaml_config, SplitConfigDict): raise TypeError("Input must be a list of YamlSubConfigDict") - sub_transforms = expand_transform_list_combinations(yaml_config.transforms) + sub_transforms = expand_transform_list_combinations(config.transforms) split_transform_config: list[SplitTransformDict] = [] for transform in sub_transforms: split_transform_config.append( SplitTransformDict( - global_params=yaml_config.global_params, - columns=yaml_config.columns, + global_params=config.global_params, + columns=config.columns, transforms=transform, - split=yaml_config.split, + split=config.split, ) ) return split_transform_config def dump_yaml_list_into_files( - yaml_list: list[SplitConfigDict], + config_list: list[SplitConfigDict], directory_path: str, base_name: str, ) -> None: @@ -384,8 +384,8 @@ def increase_indent( yaml.add_representer(type(None), represent_none, Dumper=CustomDumper) yaml.add_representer(list, custom_representer, Dumper=CustomDumper) - for i, yaml_dict in enumerate(yaml_list): - dict_data = yaml_dict.model_dump(exclude_none=True) + for i, config_dict in enumerate(config_list): + dict_data = config_dict.model_dump(exclude_none=True) def fix_params(input_dict: dict[str, Any]) -> dict[str, Any]: """Recursively process dictionary to properly handle params fields.""" @@ -440,14 +440,14 @@ def fix_params(input_dict: dict[str, Any]) -> dict[str, Any]: ) -def check_yaml_schema(config_yaml: ConfigDict) -> str: +def check_yaml_schema(config: ConfigDict) -> str: """Validate YAML configuration fields have correct types. If the children field is specific to a parent, the children fields class is hosted in the parent fields class. If any field in not the right type, the function prints an error message explaining the problem and exits the python code. Args: - config_yaml: The ConfigDict containing the fields of the yaml configuration file + config: The ConfigDict containing the fields of the yaml configuration file Returns: str: Empty string if validation succeeds @@ -456,7 +456,7 @@ def check_yaml_schema(config_yaml: ConfigDict) -> str: ValueError: If validation fails """ try: - Schema(yaml_conf=config_yaml) + Schema(conf=config) except ValidationError as e: # Use logging instead of print for error handling raise ValueError("Wrong type on a field, see the pydantic report above") from e From 982e83f8f6f218cf4f262f9215a1e7dde24405db Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 15:09:52 +0100 Subject: [PATCH 56/81] CHANGE: src/stimulus/yaml_data.py -> Changed some comments to remove --- src/stimulus/utils/yaml_data.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/stimulus/utils/yaml_data.py b/src/stimulus/utils/yaml_data.py index 65bc24f7..60e7b47f 100644 --- a/src/stimulus/utils/yaml_data.py +++ b/src/stimulus/utils/yaml_data.py @@ -1,4 +1,4 @@ -"""Utility module for handling YAML configuration files and their validation.""" +"""Utility module for handling configuration files and their validation.""" from typing import Any, Optional, Union @@ -7,7 +7,7 @@ class GlobalParams(BaseModel): - """Model for global parameters in YAML configuration.""" + """Model for global parameters in configuration.""" seed: int @@ -103,7 +103,7 @@ class Split(BaseModel): class ConfigDict(BaseModel): - """Model for main YAML configuration.""" + """Model for main configuration.""" global_params: GlobalParams columns: list[Columns] @@ -131,13 +131,13 @@ class SplitTransformDict(BaseModel): class Schema(BaseModel): - """Model for validating YAML schema.""" + """Model for validating schema.""" conf: ConfigDict class SplitSchema(BaseModel): - """Model for validating a Split YAML schema.""" + """Model for validating a Split schema.""" conf: SplitConfigDict @@ -239,9 +239,9 @@ def expand_transform_list_combinations( def generate_split_configs(config: ConfigDict) -> list[SplitConfigDict]: - """Generates all possible split configuration from a YAML config. + """Generates all possible split configuration from a config. - Takes a YAML configuration that may contain parameter lists and splits, + Takes a configuration that may contain parameter lists and splits, and generates all unique splits into separate data configurations. For example, if the config has: @@ -257,7 +257,7 @@ def generate_split_configs(config: ConfigDict) -> list[SplitConfigDict]: split: [0.8, 0.2] Args: - config: The source YAML configuration containing transforms with + config: The source configuration containing transforms with parameter lists and multiple splits. Returns: @@ -288,7 +288,7 @@ def generate_split_transform_configs( ) -> list[SplitTransformDict]: """Generates all the transform configuration for a given split - Takes a YAML configuration that may contain a transform or a list of transform, + Takes a configuration that may contain a transform or a list of transform, and generates all unique transform for a split into separate data configurations. For example, if the config has: @@ -304,7 +304,7 @@ def generate_split_transform_configs( split: [0.7, 0.3] Args: - config: The source YAML configuration containing each + config: The source configuration containing each a split with transforms with parameters lists Returns: @@ -441,7 +441,7 @@ def fix_params(input_dict: dict[str, Any]) -> dict[str, Any]: def check_yaml_schema(config: ConfigDict) -> str: - """Validate YAML configuration fields have correct types. + """Validate configuration fields have correct types. If the children field is specific to a parent, the children fields class is hosted in the parent fields class. If any field in not the right type, the function prints an error message explaining the problem and exits the python code. From df5d4db09ed8e124c316ec97af4518368691b11a Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 15:14:00 +0100 Subject: [PATCH 57/81] CHANGE: src/stimulus/yaml_data.py -> Changed function check_yaml_schema to check_schema --- src/stimulus/utils/yaml_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/stimulus/utils/yaml_data.py b/src/stimulus/utils/yaml_data.py index 60e7b47f..d22748a4 100644 --- a/src/stimulus/utils/yaml_data.py +++ b/src/stimulus/utils/yaml_data.py @@ -440,7 +440,7 @@ def fix_params(input_dict: dict[str, Any]) -> dict[str, Any]: ) -def check_yaml_schema(config: ConfigDict) -> str: +def check_schema(config: ConfigDict) -> str: """Validate configuration fields have correct types. If the children field is specific to a parent, the children fields class is hosted in the parent fields class. From 36978bbb420fdec11d350fbbf93768c30bf6cbc2 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 15:15:15 +0100 Subject: [PATCH 58/81] CHANGE: src/stimulus/yaml_data.py -> Removed Yaml in classes in comments --- src/stimulus/utils/yaml_data.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/stimulus/utils/yaml_data.py b/src/stimulus/utils/yaml_data.py index d22748a4..db96508c 100644 --- a/src/stimulus/utils/yaml_data.py +++ b/src/stimulus/utils/yaml_data.py @@ -261,7 +261,7 @@ def generate_split_configs(config: ConfigDict) -> list[SplitConfigDict]: parameter lists and multiple splits. Returns: - list[YamlSubConfigDict]: A list of data configurations, where each + list[SubConfigDict]: A list of data configurations, where each config has a list of parameters and one split configuration. The length will be the product of the number of parameter combinations and the number of splits. @@ -308,13 +308,13 @@ def generate_split_transform_configs( a split with transforms with parameters lists Returns: - list[YamlSubConfigTransformDict]: A list of data configurations, where each + list[SubConfigTransformDict]: A list of data configurations, where each config has a list of parameters and one split configuration. The length will be the product of the number of parameter combinations and the number of splits. """ if isinstance(config, dict) and not isinstance(yaml_config, SplitConfigDict): - raise TypeError("Input must be a list of YamlSubConfigDict") + raise TypeError("Input must be a list of SubConfigDict") sub_transforms = expand_transform_list_combinations(config.transforms) split_transform_config: list[SplitTransformDict] = [] From e809194cf4278388205b19c5f30ef74ccb2ad096 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 15:16:35 +0100 Subject: [PATCH 59/81] CHANGE: src/stimulus/yaml_data.py -> Changed left variables to remove yaml mentions --- src/stimulus/utils/yaml_data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/stimulus/utils/yaml_data.py b/src/stimulus/utils/yaml_data.py index db96508c..5b4db4d1 100644 --- a/src/stimulus/utils/yaml_data.py +++ b/src/stimulus/utils/yaml_data.py @@ -266,7 +266,7 @@ def generate_split_configs(config: ConfigDict) -> list[SplitConfigDict]: length will be the product of the number of parameter combinations and the number of splits. """ - if isinstance(config, dict) and not isinstance(yaml_config, ConfigDict): + if isinstance(config, dict) and not isinstance(config, ConfigDict): raise TypeError("Input must be a ConfigDict object") sub_splits = config.split @@ -313,7 +313,7 @@ def generate_split_transform_configs( length will be the product of the number of parameter combinations and the number of splits. """ - if isinstance(config, dict) and not isinstance(yaml_config, SplitConfigDict): + if isinstance(config, dict) and not isinstance(config, SplitConfigDict): raise TypeError("Input must be a list of SubConfigDict") sub_transforms = expand_transform_list_combinations(config.transforms) From 3027d373c8b99f8b6a07814d9ab9ffa52e4f3798 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 15:28:23 +0100 Subject: [PATCH 60/81] FORMATTING: src/stimulus/cli/check_model.py --- src/stimulus/cli/check_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/stimulus/cli/check_model.py b/src/stimulus/cli/check_model.py index 70c1d0a1..c61e7211 100755 --- a/src/stimulus/cli/check_model.py +++ b/src/stimulus/cli/check_model.py @@ -129,7 +129,7 @@ def main( encoder_loader = loaders.EncoderLoader() encoder_loader.initialize_column_encoders_from_config( - column_config=data_config.columns + column_config=data_config.columns, ) logger.info("Dataset loaded successfully.") From 5418c006048af3b0c4a8bbfa6638f0dfb959c068 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 15:28:51 +0100 Subject: [PATCH 61/81] FORMATTING: src/stimulus/cli/split_csv.py --- src/stimulus/cli/split_csv.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/stimulus/cli/split_csv.py b/src/stimulus/cli/split_csv.py index 756f55b6..1e7a3774 100755 --- a/src/stimulus/cli/split_csv.py +++ b/src/stimulus/cli/split_csv.py @@ -50,7 +50,11 @@ def get_args() -> argparse.Namespace: def main( - data_csv: str, config_yaml: str, out_path: str, *, force: bool = False + data_csv: str, + config_yaml: str, + out_path: str, + *, + force: bool = False, ) -> None: """Connect CSV and YAML configuration and handle sanity checks. From 7dea223baccac200711526281be5d0e61e7c82c3 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 15:29:14 +0100 Subject: [PATCH 62/81] FORMATTING: src/stimulus/cli/split_transforms.py --- src/stimulus/cli/split_transforms.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/stimulus/cli/split_transforms.py b/src/stimulus/cli/split_transforms.py index 10a05c3c..d7c4ba23 100644 --- a/src/stimulus/cli/split_transforms.py +++ b/src/stimulus/cli/split_transforms.py @@ -63,9 +63,7 @@ def main(config_yaml: str, out_dir_path: str) -> None: yaml_config_dict: SplitConfigDict = SplitConfigDict(**yaml_config) # Generate the yaml files for each transform - split_transform_configs: list[SplitTransformDict] = ( - generate_split_transform_configs(yaml_config_dict) - ) + split_transform_configs: list[SplitTransformDict] = generate_split_transform_configs(yaml_config_dict) # Dump all the YAML configs into files dump_yaml_list_into_files(split_transform_configs, out_dir_path, "test_transforms") From 129bd046b67deb0a28b35976f3d895f8f498ece6 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 15:29:33 +0100 Subject: [PATCH 63/81] FORMATTING: src/stimulus/cli/transform_csv.py --- src/stimulus/cli/transform_csv.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/stimulus/cli/transform_csv.py b/src/stimulus/cli/transform_csv.py index 15c06e68..106b8a23 100755 --- a/src/stimulus/cli/transform_csv.py +++ b/src/stimulus/cli/transform_csv.py @@ -13,7 +13,7 @@ def get_args() -> argparse.Namespace: """Get the arguments when using from the commandline.""" parser = argparse.ArgumentParser( - description="CLI for transforming CSV data files using YAML configuration." + description="CLI for transforming CSV data files using YAML configuration.", ) parser.add_argument( "-c", From 9f583c17691c8c6a68c3ee230f05baaaab2afcb4 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 15:29:50 +0100 Subject: [PATCH 64/81] FORMATTING: src/stimulus/cli/tuning.py --- src/stimulus/cli/tuning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/stimulus/cli/tuning.py b/src/stimulus/cli/tuning.py index 75957e25..f16d0938 100755 --- a/src/stimulus/cli/tuning.py +++ b/src/stimulus/cli/tuning.py @@ -183,7 +183,7 @@ def main( encoder_loader = loaders.EncoderLoader() encoder_loader.initialize_column_encoders_from_config( - column_config=data_config.columns + column_config=data_config.columns, ) model_class = launch_utils.import_class_from_file(model_path) From d52aae54b3214f0ace75386dc9f37241c3f28b95 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 15:30:10 +0100 Subject: [PATCH 65/81] FORMATTING: src/stimulus/data/data_handlers.py --- src/stimulus/data/data_handlers.py | 47 +++++++++++++----------------- 1 file changed, 20 insertions(+), 27 deletions(-) diff --git a/src/stimulus/data/data_handlers.py b/src/stimulus/data/data_handlers.py index bf314303..5e8067b8 100644 --- a/src/stimulus/data/data_handlers.py +++ b/src/stimulus/data/data_handlers.py @@ -110,7 +110,6 @@ def _load_config(self, config_path: str) -> yaml_data.SplitConfigDict: >>> print(config["columns"][0]["column_name"]) 'hello' """ - with open(config_path) as file: # FIXME: cette fonction est appellé pour test_shuffle_csv et test_tune return yaml_data.SplitConfigDict(**yaml.safe_load(file)) @@ -212,16 +211,11 @@ def encode_columns(self, column_data: dict) -> dict: >>> print(encoded["dna_seq"].shape) torch.Size([2, 4, 4]) # 2 sequences, length 4, one-hot encoded """ - return { - col: self.encode_column(col, values) for col, values in column_data.items() - } + return {col: self.encode_column(col, values) for col, values in column_data.items()} def encode_dataframe(self, dataframe: pl.DataFrame) -> dict[str, torch.Tensor]: """Encode the dataframe using the encoders.""" - return { - col: self.encode_column(col, dataframe[col].to_list()) - for col in dataframe.columns - } + return {col: self.encode_column(col, dataframe[col].to_list()) for col in dataframe.columns} class TransformManager: @@ -235,7 +229,10 @@ def __init__( self.transform_loader = transform_loader def transform_column( - self, column_name: str, transform_name: str, column_data: list + self, + column_name: str, + transform_name: str, + column_data: list, ) -> tuple[list, bool]: """Transform a column of data using the specified transformation. @@ -248,9 +245,7 @@ def transform_column( list: The transformed data. bool: Whether the transformation added new rows to the data. """ - transformer = self.transform_loader.__getattribute__(column_name)[ - transform_name - ] + transformer = self.transform_loader.__getattribute__(column_name)[transform_name] return transformer.transform_all(column_data), transformer.add_row @@ -265,7 +260,8 @@ def __init__( self.split_loader = split_loader def get_split_indices( - self, data: dict + self, + data: dict, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """Get the indices for train, validation, and test splits.""" return self.split_loader.get_function_split()(data) @@ -395,12 +391,12 @@ def apply_transformation_group(self, transform_manager: TransformManager) -> Non ) if add_row: new_rows = self.data.with_columns( - pl.Series(column_name, transformed_data) + pl.Series(column_name, transformed_data), ) self.data = pl.vstack(self.data, new_rows) else: self.data = self.data.with_columns( - pl.Series(column_name, transformed_data) + pl.Series(column_name, transformed_data), ) def shuffle_labels(self, seed: Optional[float] = None) -> None: @@ -411,7 +407,7 @@ def shuffle_labels(self, seed: Optional[float] = None) -> None: label_keys = self.dataset_manager.column_categories["label"] for key in label_keys: self.data = self.data.with_columns( - pl.Series(key, np.random.permutation(list(self.data[key]))) + pl.Series(key, np.random.permutation(list(self.data[key]))), ) @@ -428,11 +424,7 @@ def __init__( """Initialize the DatasetLoader.""" super().__init__(data_config, csv_path) self.encoder_manager = EncodeManager(encoder_loader) - self.data = ( - self.load_csv_per_split(csv_path, split) - if split is not None - else self.load_csv(csv_path) - ) + self.data = self.load_csv_per_split(csv_path, split) if split is not None else self.load_csv(csv_path) def get_all_items(self) -> tuple[dict, dict, dict]: """Get the full dataset as three separate dictionaries for inputs, labels and metadata. @@ -480,7 +472,7 @@ def load_csv_per_split(self, csv_path: str, split: int) -> pl.DataFrame: raise ValueError("The category split is not present in the csv file") if split not in [0, 1, 2]: raise ValueError( - f"The split value should be 0, 1 or 2. The specified split value is {split}" + f"The split value should be 0, 1 or 2. The specified split value is {split}", ) return pl.scan_csv(csv_path).filter(pl.col("split") == split).collect() @@ -489,7 +481,8 @@ def __len__(self) -> int: return len(self.data) def __getitem__( - self, idx: Any + self, + idx: Any, ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor], dict[str, list]]: """Get the data at a given index, and encodes the input and label, leaving meta as it is. @@ -509,10 +502,10 @@ def __getitem__( # Process DataFrame input_data = self.encoder_manager.encode_dataframe( - data_at_index[input_columns] + data_at_index[input_columns], ) label_data = self.encoder_manager.encode_dataframe( - data_at_index[label_columns] + data_at_index[label_columns], ) meta_data = {key: data_at_index[key].to_list() for key in meta_columns} @@ -533,10 +526,10 @@ def __getitem__( # Process DataFrame input_data = self.encoder_manager.encode_dataframe( - data_at_index[input_columns] + data_at_index[input_columns], ) label_data = self.encoder_manager.encode_dataframe( - data_at_index[label_columns] + data_at_index[label_columns], ) meta_data = {key: data_at_index[key].to_list() for key in meta_columns} From 1d096febd329436989f74f7eb0ca6518b608609e Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 15:31:54 +0100 Subject: [PATCH 66/81] FORMATTING: src/stimulus/data/loaders.py --- src/stimulus/data/loaders.py | 44 ++++++++++++++++++++++++------------ 1 file changed, 29 insertions(+), 15 deletions(-) diff --git a/src/stimulus/data/loaders.py b/src/stimulus/data/loaders.py index affa1168..b007950c 100644 --- a/src/stimulus/data/loaders.py +++ b/src/stimulus/data/loaders.py @@ -30,7 +30,8 @@ def __init__(self, seed: Optional[float] = None) -> None: self.seed = seed def initialize_column_encoders_from_config( - self, column_config: yaml_data.Columns + self, + column_config: yaml_data.Columns, ) -> None: """Build the loader from a config dictionary. @@ -53,7 +54,9 @@ def get_function_encode_all(self, field_name: str) -> Any: return getattr(self, field_name).encode_all def get_encoder( - self, encoder_name: str, encoder_params: Optional[dict] = None + self, + encoder_name: str, + encoder_params: Optional[dict] = None, ) -> Any: """Gets an encoder object from the encoders module and initializes it with the given parameters. @@ -68,7 +71,7 @@ def get_encoder( return getattr(encoders, encoder_name)(**encoder_params) except AttributeError: logging.exception( - f"Encoder '{encoder_name}' not found in the encoders module." + f"Encoder '{encoder_name}' not found in the encoders module.", ) logging.exception( f"Available encoders: {[name for name, obj in encoders.__dict__.items() if isinstance(obj, type) and name not in ('ABC', 'Any')]}", @@ -79,7 +82,7 @@ def get_encoder( if encoder_params is None: return getattr(encoders, encoder_name)() logging.exception( - f"Encoder '{encoder_name}' has incorrect parameters: {encoder_params}" + f"Encoder '{encoder_name}' has incorrect parameters: {encoder_params}", ) logging.exception( f"Expected parameters for '{encoder_name}': {inspect.signature(getattr(encoders, encoder_name))}", @@ -87,7 +90,9 @@ def get_encoder( raise def set_encoder_as_attribute( - self, field_name: str, encoder: encoders.AbstractEncoder + self, + field_name: str, + encoder: encoders.AbstractEncoder, ) -> None: """Sets the encoder as an attribute of the loader. @@ -110,7 +115,9 @@ def __init__(self, seed: Optional[float] = None) -> None: self.seed = seed def get_data_transformer( - self, transformation_name: str, transformation_params: Optional[dict] = None + self, + transformation_name: str, + transformation_params: Optional[dict] = None, ) -> Any: """Gets a transformer object from the transformers module. @@ -123,11 +130,11 @@ def get_data_transformer( """ try: return getattr(data_transformation_generators, transformation_name)( - **transformation_params + **transformation_params, ) except AttributeError: logging.exception( - f"Transformer '{transformation_name}' not found in the transformers module." + f"Transformer '{transformation_name}' not found in the transformers module.", ) logging.exception( f"Available transformers: {[name for name, obj in data_transformation_generators.__dict__.items() if isinstance(obj, type) and name not in ('ABC', 'Any')]}", @@ -138,7 +145,7 @@ def get_data_transformer( if transformation_params is None: return getattr(data_transformation_generators, transformation_name)() logging.exception( - f"Transformer '{transformation_name}' has incorrect parameters: {transformation_params}" + f"Transformer '{transformation_name}' has incorrect parameters: {transformation_params}", ) logging.exception( f"Expected parameters for '{transformation_name}': {inspect.signature(getattr(data_transformation_generators, transformation_name))}", @@ -146,7 +153,9 @@ def get_data_transformer( raise def set_data_transformer_as_attribute( - self, field_name: str, data_transformer: Any + self, + field_name: str, + data_transformer: Any, ) -> None: """Sets the data transformer as an attribute of the loader. @@ -166,7 +175,8 @@ def set_data_transformer_as_attribute( field_value[data_transformer.__class__.__name__] = data_transformer def initialize_column_data_transformers_from_config( - self, transform_config: yaml_data.Transform + self, + transform_config: yaml_data.Transform, ) -> None: """Build the loader from a config dictionary. @@ -201,7 +211,8 @@ def initialize_column_data_transformers_from_config( col_name = column.column_name for transform_spec in column.transformations: transformer = self.get_data_transformer( - transform_spec.name, transform_spec.params + transform_spec.name, + transform_spec.params, ) self.set_data_transformer_as_attribute(col_name, transformer) @@ -235,7 +246,9 @@ def get_function_split(self) -> Any: return self.split.get_split_indexes def get_splitter( - self, splitter_name: str, splitter_params: Optional[dict] = None + self, + splitter_name: str, + splitter_params: Optional[dict] = None, ) -> Any: """Gets a splitter object from the splitters module. @@ -252,7 +265,7 @@ def get_splitter( if splitter_params is None: return getattr(splitters, splitter_name)() logging.exception( - f"Splitter '{splitter_name}' has incorrect parameters: {splitter_params}" + f"Splitter '{splitter_name}' has incorrect parameters: {splitter_params}", ) logging.exception( f"Expected parameters for '{splitter_name}': {inspect.signature(getattr(splitters, splitter_name))}", @@ -268,7 +281,8 @@ def set_splitter_as_attribute(self, splitter: Any) -> None: self.split = splitter def initialize_splitter_from_config( - self, split_config: yaml_data.SplitConfigDict + self, + split_config: yaml_data.SplitConfigDict, ) -> None: """Build the loader from a config dictionary. From 5c263a056e4f20c5f23bfade924f98aae56a4ccd Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 15:32:21 +0100 Subject: [PATCH 67/81] FORMATTING: src/stimulus/learner/raytune_learner.py --- src/stimulus/learner/raytune_learner.py | 36 ++++++++++++------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/src/stimulus/learner/raytune_learner.py b/src/stimulus/learner/raytune_learner.py index 76e24150..ec59beb2 100644 --- a/src/stimulus/learner/raytune_learner.py +++ b/src/stimulus/learner/raytune_learner.py @@ -20,8 +20,8 @@ from stimulus.data.loaders import EncoderLoader from stimulus.learner.predict import PredictWrapper from stimulus.utils.generic_utils import set_general_seeds -from stimulus.utils.yaml_model_schema import RayTuneModel from stimulus.utils.yaml_data import SplitTransformDict +from stimulus.utils.yaml_model_schema import RayTuneModel class CheckpointDict(TypedDict): @@ -78,7 +78,7 @@ def __init__( if tune_run_name is not None else "TuneModel_" + datetime.datetime.now(tz=datetime.timezone.utc).strftime( - "%Y-%m-%d_%H-%M-%S" + "%Y-%m-%d_%H-%M-%S", ), storage_path=ray_results_dir, checkpoint_config=train.CheckpointConfig(checkpoint_at_end=True), @@ -99,7 +99,7 @@ def __init__( if tune_run_name is not None else "TuneModel_" + datetime.datetime.now(tz=datetime.timezone.utc).strftime( - "%Y-%m-%d_%H-%M-%S" + "%Y-%m-%d_%H-%M-%S", ), ) self.config["_debug"] = debug @@ -138,7 +138,7 @@ def tuner_initialization( ) except KeyError as err: logging.warning( - f"KeyError: {err}, no GPU resources available in the cluster: {cluster_res}" + f"KeyError: {err}, no GPU resources available in the cluster: {cluster_res}", ) if self.cpu_per_trial > cluster_res["CPU"] and not autoscaler: @@ -147,7 +147,7 @@ def tuner_initialization( ) logging.info( - f"PER_TRIAL resources -> GPU: {self.gpu_per_trial} CPU: {self.cpu_per_trial}" + f"PER_TRIAL resources -> GPU: {self.gpu_per_trial} CPU: {self.cpu_per_trial}", ) # Pre-load and encode datasets once, then put them in Ray's object store @@ -304,20 +304,18 @@ def objective(self) -> dict[str, float]: "spearmanr", ] # TODO maybe we report only a subset of metrics, given certain criteria (eg. if classification or regression) predict_val = PredictWrapper( - self.model, self.validation, loss_dict=self.loss_dict + self.model, + self.validation, + loss_dict=self.loss_dict, ) predict_train = PredictWrapper( - self.model, self.training, loss_dict=self.loss_dict + self.model, + self.training, + loss_dict=self.loss_dict, ) return { - **{ - "val_" + metric: value - for metric, value in predict_val.compute_metrics(metrics).items() - }, - **{ - "train_" + metric: value - for metric, value in predict_train.compute_metrics(metrics).items() - }, + **{"val_" + metric: value for metric, value in predict_val.compute_metrics(metrics).items()}, + **{"train_" + metric: value for metric, value in predict_train.compute_metrics(metrics).items()}, } # type: ignore[override] @@ -333,16 +331,18 @@ def load_checkpoint(self, checkpoint: dict[Any, Any] | None) -> None: return checkpoint_dir = checkpoint["checkpoint_dir"] self.model = safe_load_model( - self.model, os.path.join(checkpoint_dir, "model.safetensors") + self.model, + os.path.join(checkpoint_dir, "model.safetensors"), ) self.optimizer.load_state_dict( - torch.load(os.path.join(checkpoint_dir, "optimizer.pt")) + torch.load(os.path.join(checkpoint_dir, "optimizer.pt")), ) def save_checkpoint(self, checkpoint_dir: str) -> dict[Any, Any]: """Save model and optimizer state to checkpoint.""" safe_save_model(self.model, os.path.join(checkpoint_dir, "model.safetensors")) torch.save( - self.optimizer.state_dict(), os.path.join(checkpoint_dir, "optimizer.pt") + self.optimizer.state_dict(), + os.path.join(checkpoint_dir, "optimizer.pt"), ) return {"checkpoint_dir": checkpoint_dir} From f377f07d51a6ea213b45e716093e706eb0deaf61 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 15:33:02 +0100 Subject: [PATCH 68/81] FORMATTING: src/stimulus/typing/__init__.py --- src/stimulus/typing/__init__.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/stimulus/typing/__init__.py b/src/stimulus/typing/__init__.py index 198a1d41..87220792 100644 --- a/src/stimulus/typing/__init__.py +++ b/src/stimulus/typing/__init__.py @@ -64,9 +64,7 @@ # data/data_handlers.py -DataManager: TypeAlias = ( - DatasetManager | EncodeManager | SplitManager | TransformManager -) +DataManager: TypeAlias = DatasetManager | EncodeManager | SplitManager | TransformManager # data/experiments.py From c5dbac0dda965fa63492abe492d12236c5f6c9a6 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 15:33:29 +0100 Subject: [PATCH 69/81] FORMATTING: src/stimulus/utils/yaml_data.py --- src/stimulus/utils/yaml_data.py | 56 ++++++++++++++------------------- 1 file changed, 23 insertions(+), 33 deletions(-) diff --git a/src/stimulus/utils/yaml_data.py b/src/stimulus/utils/yaml_data.py index 5b4db4d1..6ce67842 100644 --- a/src/stimulus/utils/yaml_data.py +++ b/src/stimulus/utils/yaml_data.py @@ -16,9 +16,7 @@ class ColumnsEncoder(BaseModel): """Model for column encoder configuration.""" name: str - params: Optional[ - dict[str, Union[str, list[Any]]] - ] # Allow both string and list values + params: Optional[dict[str, Union[str, list[Any]]]] # Allow both string and list values class Columns(BaseModel): @@ -34,9 +32,7 @@ class TransformColumnsTransformation(BaseModel): """Model for column transformation configuration.""" name: str - params: Optional[ - dict[str, Union[list[Any], float]] - ] # Allow both list and float values + params: Optional[dict[str, Union[list[Any], float]]] # Allow both list and float values class TransformColumns(BaseModel): @@ -55,7 +51,8 @@ class Transform(BaseModel): @field_validator("columns") @classmethod def validate_param_lists_across_columns( - cls, columns: list[TransformColumns] + cls, + columns: list[TransformColumns], ) -> list[TransformColumns]: """Validate that parameter lists across columns have consistent lengths. @@ -143,7 +140,8 @@ class SplitSchema(BaseModel): def extract_transform_parameters_at_index( - transform: Transform, index: int = 0 + transform: Transform, + index: int = 0, ) -> Transform: """Get a transform with parameters at the specified index. @@ -193,15 +191,9 @@ def expand_transform_parameter_combinations( for column in transform.columns: for transformation in column.transformations: if transformation.params: - list_lengths = [ - len(v) - for v in transformation.params.values() - if isinstance(v, list) and len(v) > 1 - ] + list_lengths = [len(v) for v in transformation.params.values() if isinstance(v, list) and len(v) > 1] if list_lengths: - max_length = list_lengths[ - 0 - ] # All lists have same length due to validator + max_length = list_lengths[0] # All lists have same length due to validator break # Generate a transform for each index @@ -325,7 +317,7 @@ def generate_split_transform_configs( columns=config.columns, transforms=transform, split=config.split, - ) + ), ) return split_transform_config @@ -335,7 +327,7 @@ def dump_yaml_list_into_files( directory_path: str, base_name: str, ) -> None: - """Dumps a list of YAML configurations into separate files with custom formatting.""" + """Dumps a list of configurations into separate files with custom formatting.""" # Create a new class attribute rather than assigning to the method # Remove this line since we'll add ignore_aliases to CustomDumper instead @@ -350,11 +342,15 @@ def custom_representer(dumper: yaml.Dumper, data: Any) -> yaml.Node: return dumper.represent_scalar("tag:yaml.org,2002:null", "") if isinstance(data[0], dict): return dumper.represent_sequence( - "tag:yaml.org,2002:seq", data, flow_style=False + "tag:yaml.org,2002:seq", + data, + flow_style=False, ) if isinstance(data[0], list): return dumper.represent_sequence( - "tag:yaml.org,2002:seq", data, flow_style=True + "tag:yaml.org,2002:seq", + data, + flow_style=True, ) return dumper.represent_sequence("tag:yaml.org,2002:seq", data, flow_style=True) @@ -372,7 +368,10 @@ def write_line_break(self, _data: Any = None) -> None: super().write_line_break(_data) def increase_indent( - self, *, flow: bool = False, indentless: bool = False + self, + *, + flow: bool = False, + indentless: bool = False, ) -> None: # type: ignore[override] """Ensure consistent indentation by preventing indentless sequences.""" return super().increase_indent( @@ -396,30 +395,21 @@ def fix_params(input_dict: dict[str, Any]) -> dict[str, Any]: processed_dict[key] = [] for encoder in value: processed_encoder = dict(encoder) - if ( - "params" not in processed_encoder - or not processed_encoder["params"] - ): + if "params" not in processed_encoder or not processed_encoder["params"]: processed_encoder["params"] = {} processed_dict[key].append(processed_encoder) elif key == "transformations" and isinstance(value, list): processed_dict[key] = [] for transformation in value: processed_transformation = dict(transformation) - if ( - "params" not in processed_transformation - or not processed_transformation["params"] - ): + if "params" not in processed_transformation or not processed_transformation["params"]: processed_transformation["params"] = {} processed_dict[key].append(processed_transformation) elif isinstance(value, dict): processed_dict[key] = fix_params(value) elif isinstance(value, list): processed_dict[key] = [ - fix_params(list_item) - if isinstance(list_item, dict) - else list_item - for list_item in value + fix_params(list_item) if isinstance(list_item, dict) else list_item for list_item in value ] else: processed_dict[key] = value From 34191e23b57189c3bf7544a628d6793915fc6a33 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 15:33:57 +0100 Subject: [PATCH 70/81] FORMATTING: tests/cli/test_check_model.py --- tests/cli/test_check_model.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/tests/cli/test_check_model.py b/tests/cli/test_check_model.py index 3f5a3014..d4143256 100644 --- a/tests/cli/test_check_model.py +++ b/tests/cli/test_check_model.py @@ -8,17 +8,13 @@ import ray from stimulus.cli import check_model -from src.stimulus.utils.yaml_data import SplitTransformDict @pytest.fixture def data_path() -> str: """Get path to test data CSV file.""" return str( - Path(__file__).parent.parent - / "test_data" - / "titanic" - / "titanic_stimulus_split.csv" + Path(__file__).parent.parent / "test_data" / "titanic" / "titanic_stimulus_split.csv", ) @@ -26,10 +22,7 @@ def data_path() -> str: def data_config() -> str: """Get path to test data config YAML.""" return str( - Path(__file__).parent.parent - / "test_data" - / "titanic" - / "titanic_sub_config.yaml" + Path(__file__).parent.parent / "test_data" / "titanic" / "titanic_sub_config.yaml", ) @@ -46,7 +39,10 @@ def model_config() -> str: def test_check_model_main( - data_path: str, data_config: str, model_path: str, model_config: str + data_path: str, + data_config: str, + model_path: str, + model_config: str, ) -> None: """Test that check_model.main runs without errors. From 7f973d41289741c4bd935a9f8642d28efe29c42e Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 15:34:19 +0100 Subject: [PATCH 71/81] FORMATTING: tests/cli/test_shuffle_csv.py --- tests/cli/test_shuffle_csv.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/cli/test_shuffle_csv.py b/tests/cli/test_shuffle_csv.py index 5d1dbbd3..885b102e 100644 --- a/tests/cli/test_shuffle_csv.py +++ b/tests/cli/test_shuffle_csv.py @@ -1,12 +1,12 @@ """Tests for the shuffle_csv CLI command.""" -import yaml import hashlib import pathlib import tempfile from typing import Any, Callable import pytest +import yaml from src.stimulus.cli.shuffle_csv import main from src.stimulus.utils.yaml_data import SplitTransformDict @@ -48,7 +48,7 @@ def test_shuffle_csv( if error: with pytest.raises(error): # type: ignore[call-overload] config_dict: SplitTransformDict = SplitTransformDict( - **yaml.safe_load(f) + **yaml.safe_load(f), ) main(csv_path, config_dict, str(tmpdir / "test.csv")) else: From 8aafd463b14ca5ee139cbe6a6a4bfa9db53af94d Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 15:34:35 +0100 Subject: [PATCH 72/81] FORMATTING: tests/cli/test_split_split.py --- tests/cli/test_split_splits.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/cli/test_split_splits.py b/tests/cli/test_split_splits.py index e0cc445c..41a3910d 100644 --- a/tests/cli/test_split_splits.py +++ b/tests/cli/test_split_splits.py @@ -2,7 +2,6 @@ import hashlib import os -import tempfile from typing import Any, Callable import pytest @@ -53,6 +52,4 @@ def test_split_split( for f in test_out: with open(os.path.join(tmpdir, f)) as file: hashes.append(hashlib.md5(file.read().encode()).hexdigest()) # noqa: S324 - assert ( - sorted(hashes) == snapshot - ) # sorted ensures that the order of the hashes does not matter + assert sorted(hashes) == snapshot # sorted ensures that the order of the hashes does not matter From f97bc24307e4411d471211df076b8567b4c6c795 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 15:34:56 +0100 Subject: [PATCH 73/81] FORMATTING: tests/cli/test_split_transforms.py --- tests/cli/test_split_transforms.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/cli/test_split_transforms.py b/tests/cli/test_split_transforms.py index 770ec042..90b856c5 100644 --- a/tests/cli/test_split_transforms.py +++ b/tests/cli/test_split_transforms.py @@ -49,6 +49,4 @@ def test_split_transforms( for f in test_out: with open(os.path.join(tmpdir, f)) as file: hashes.append(hashlib.md5(file.read().encode()).hexdigest()) - assert ( - sorted(hashes) == snapshot - ) # Sorted ensures that the order of the hashes does not matter + assert sorted(hashes) == snapshot # Sorted ensures that the order of the hashes does not matter From e9c9561a4dba15690e7812c99b0e5bc7e4c4f4c3 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 15:35:15 +0100 Subject: [PATCH 74/81] FORMATTING: tests/cli/test_tuning.py --- tests/cli/test_tuning.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/tests/cli/test_tuning.py b/tests/cli/test_tuning.py index 2a11e3ac..5c7162be 100644 --- a/tests/cli/test_tuning.py +++ b/tests/cli/test_tuning.py @@ -20,10 +20,7 @@ def data_path() -> str: """Get path to test data CSV file.""" return str( - Path(__file__).parent.parent - / "test_data" - / "titanic" - / "titanic_stimulus_split.csv", + Path(__file__).parent.parent / "test_data" / "titanic" / "titanic_stimulus_split.csv", ) @@ -31,10 +28,7 @@ def data_path() -> str: def data_config() -> str: """Get path to test data config YAML.""" return str( - Path(__file__).parent.parent - / "test_data" - / "titanic" - / "titanic_sub_config.yaml", + Path(__file__).parent.parent / "test_data" / "titanic" / "titanic_sub_config.yaml", ) From b80560a2f49e758735ba3eff9d842a713dccd234 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 15:35:42 +0100 Subject: [PATCH 75/81] FORMATTING: tests/data/test_data_handlers.py --- tests/data/test_data_handlers.py | 21 ++++++--------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/tests/data/test_data_handlers.py b/tests/data/test_data_handlers.py index a5f6c288..b5e5644c 100644 --- a/tests/data/test_data_handlers.py +++ b/tests/data/test_data_handlers.py @@ -121,7 +121,7 @@ def transform_loader( """ loader = loaders.TransformLoader() loader.initialize_column_data_transformers_from_config( - generate_sub_configs[0].transforms + generate_sub_configs[0].transforms, ) return loader @@ -327,24 +327,15 @@ def test_dataset_processor_apply_transformation_group( processor_control.data = processor_control.load_csv(titanic_csv_path) processor.apply_transformation_group( - transform_manager=TransformManager(transform_loader) + transform_manager=TransformManager(transform_loader), ) assert processor.data["age"].to_list() != processor_control.data["age"].to_list() assert processor.data["fare"].to_list() != processor_control.data["fare"].to_list() - assert ( - processor.data["parch"].to_list() == processor_control.data["parch"].to_list() - ) - assert ( - processor.data["sibsp"].to_list() == processor_control.data["sibsp"].to_list() - ) - assert ( - processor.data["pclass"].to_list() == processor_control.data["pclass"].to_list() - ) - assert ( - processor.data["embarked"].to_list() - == processor_control.data["embarked"].to_list() - ) + assert processor.data["parch"].to_list() == processor_control.data["parch"].to_list() + assert processor.data["sibsp"].to_list() == processor_control.data["sibsp"].to_list() + assert processor.data["pclass"].to_list() == processor_control.data["pclass"].to_list() + assert processor.data["embarked"].to_list() == processor_control.data["embarked"].to_list() assert processor.data["sex"].to_list() == processor_control.data["sex"].to_list() From 40cb0455c856f91810a817fab67dcde84e52effe Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 15:36:04 +0100 Subject: [PATCH 76/81] FORMATTING: tests/data/test_experiment.py --- tests/data/test_experiment.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/data/test_experiment.py b/tests/data/test_experiment.py index 32244bcd..9c37520b 100644 --- a/tests/data/test_experiment.py +++ b/tests/data/test_experiment.py @@ -38,7 +38,7 @@ def dna_experiment_sub_yaml( yaml_split_configs = yaml_data.generate_split_configs(yaml_config) yaml_split_transform_configs = yaml_data.generate_split_transform_configs( - yaml_split_configs[0] + yaml_split_configs[0], ) return yaml_split_transform_configs[0] @@ -150,10 +150,7 @@ def test_initialize_column_data_transformers_from_config( assert hasattr(experiment, "col1") column_transformers = experiment.col1 - assert any( - isinstance(t, data_transformation_generators.ReverseComplement) - for t in column_transformers.values() - ) + assert any(isinstance(t, data_transformation_generators.ReverseComplement) for t in column_transformers.values()) def test_initialize_splitter_from_config( From d1991efa6c059f2d8177c93738bd7c34d193f145 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 15:36:28 +0100 Subject: [PATCH 77/81] FORMATTING: tests/data/transform/test_data_transformers.py --- .../data/transform/test_data_transformers.py | 47 ++++++++++--------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/tests/data/transform/test_data_transformers.py b/tests/data/transform/test_data_transformers.py index 116b37ea..053a1719 100644 --- a/tests/data/transform/test_data_transformers.py +++ b/tests/data/transform/test_data_transformers.py @@ -142,7 +142,8 @@ def test_transform_single(self, request: Any, test_data_name: str) -> None: """ test_data = request.getfixturevalue(test_data_name) transformed_data = test_data.transformer.transform( - test_data.single_input, **test_data.params + test_data.single_input, + **test_data.params, ) assert isinstance(transformed_data, str) assert transformed_data == test_data.expected_single_output @@ -151,10 +152,7 @@ def test_transform_single(self, request: Any, test_data_name: str) -> None: def test_transform_multiple(self, request: Any, test_data_name: str) -> None: """Test masking multiple strings.""" test_data = request.getfixturevalue(test_data_name) - transformed_data = [ - test_data.transformer.transform(x, **test_data.params) - for x in test_data.multiple_inputs - ] + transformed_data = [test_data.transformer.transform(x, **test_data.params) for x in test_data.multiple_inputs] assert isinstance(transformed_data, list) for item in transformed_data: assert isinstance(item, str) @@ -169,28 +167,31 @@ def test_transform_single(self, request: Any, test_data_name: str) -> None: """Test transforming a single float.""" test_data = request.getfixturevalue(test_data_name) transformed_data = test_data.transformer.transform( - test_data.single_input, **test_data.params + test_data.single_input, + **test_data.params, ) assert isinstance(transformed_data, float) - assert round(transformed_data, 7) == round( - test_data.expected_single_output, 7) + assert round(transformed_data, 7) == round(test_data.expected_single_output, 7) @pytest.mark.parametrize("test_data_name", ["gaussian_noise"]) def test_transform_multiple( - self, request: Any, test_data_name: DataTransformerTest + self, + request: Any, + test_data_name: DataTransformerTest, ) -> None: """Test transforming multiple floats.""" test_data = request.getfixturevalue(test_data_name) transformed_data = test_data.transformer.transform_all( - test_data.multiple_inputs, **test_data.params + test_data.multiple_inputs, + **test_data.params, ) assert isinstance(transformed_data, list) for item in transformed_data: assert isinstance(item, float) - assert len(transformed_data) == len( - test_data.expected_multiple_outputs) + assert len(transformed_data) == len(test_data.expected_multiple_outputs) for item, expected in zip( - transformed_data, test_data.expected_multiple_outputs + transformed_data, + test_data.expected_multiple_outputs, ): assert round(item, 7) == round(expected, 7) @@ -202,8 +203,7 @@ class TestGaussianChunk: def test_transform_single(self, request: Any, test_data_name: str) -> None: """Test transforming a single string.""" test_data = request.getfixturevalue(test_data_name) - transformed_data = test_data.transformer.transform( - test_data.single_input) + transformed_data = test_data.transformer.transform(test_data.single_input) assert isinstance(transformed_data, str) assert len(transformed_data) == 2 @@ -211,9 +211,7 @@ def test_transform_single(self, request: Any, test_data_name: str) -> None: def test_transform_multiple(self, request: Any, test_data_name: str) -> None: """Test transforming multiple strings.""" test_data = request.getfixturevalue(test_data_name) - transformed_data = [ - test_data.transformer.transform(x) for x in test_data.multiple_inputs - ] + transformed_data = [test_data.transformer.transform(x) for x in test_data.multiple_inputs] assert isinstance(transformed_data, list) for item in transformed_data: assert isinstance(item, str) @@ -226,7 +224,8 @@ def test_chunk_size_excessive(self, request: Any, test_data_name: str) -> None: test_data = request.getfixturevalue(test_data_name) transformer = GaussianChunk(chunk_size=100) with pytest.raises( - ValueError, match="The input data is shorter than the chunk size" + ValueError, + match="The input data is shorter than the chunk size", ): transformer.transform(test_data.single_input) @@ -239,7 +238,8 @@ def test_transform_single(self, request: Any, test_data_name: str) -> None: """Test transforming a single string.""" test_data = request.getfixturevalue(test_data_name) transformed_data = test_data.transformer.transform( - test_data.single_input, **test_data.params + test_data.single_input, + **test_data.params, ) assert isinstance(transformed_data, str) assert transformed_data == test_data.expected_single_output @@ -249,7 +249,8 @@ def test_transform_multiple(self, request: Any, test_data_name: str) -> None: """Test transforming multiple strings.""" test_data = request.getfixturevalue(test_data_name) transformed_data = test_data.transformer.transform_all( - test_data.multiple_inputs, **test_data.params + test_data.multiple_inputs, + **test_data.params, ) assert isinstance(transformed_data, list) for item in transformed_data: @@ -265,7 +266,9 @@ def titanic_config_path(base_config: dict) -> str: if not os.path.exists(config_path): configs = generate_split_transform_configs(base_config) dump_yaml_list_into_files( - [configs[0]], "tests/test_data/titanic/", "titanic_sub_config" + [configs[0]], + "tests/test_data/titanic/", + "titanic_sub_config", ) return os.path.abspath(config_path) From 530644579d5640f2134bc0b3ec1b11c0fbcd8aa4 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 15:37:00 +0100 Subject: [PATCH 78/81] FORMATTING: tests/learmes/test_raytune_learner.py --- tests/learner/test_raytune_learner.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/learner/test_raytune_learner.py b/tests/learner/test_raytune_learner.py index 0e8dac17..918d3d2a 100644 --- a/tests/learner/test_raytune_learner.py +++ b/tests/learner/test_raytune_learner.py @@ -10,7 +10,7 @@ from stimulus.data.handlertorch import TorchDataset from stimulus.data.loaders import EncoderLoader from stimulus.learner.raytune_learner import TuneWrapper -from stimulus.utils.yaml_data import SplitConfigDict, SplitTransformDict +from stimulus.utils.yaml_data import SplitTransformDict from stimulus.utils.yaml_model_schema import Model, RayTuneModel, YamlRayConfigLoader from tests.test_model import titanic_model @@ -30,7 +30,7 @@ def encoder_loader() -> EncoderLoader: data_config = yaml.safe_load(file) encoder_loader = EncoderLoader() encoder_loader.initialize_column_encoders_from_config( - SplitTransformDict(**data_config).columns + SplitTransformDict(**data_config).columns, ) return encoder_loader @@ -47,7 +47,8 @@ def titanic_dataset(encoder_loader: EncoderLoader) -> TorchDataset: def test_tunewrapper_init( - ray_config_loader: RayTuneModel, encoder_loader: EncoderLoader + ray_config_loader: RayTuneModel, + encoder_loader: EncoderLoader, ) -> None: """Test the initialization of the TuneWrapper class.""" # Filter ResourceWarning during Ray shutdown @@ -86,7 +87,8 @@ def test_tunewrapper_init( def test_tune_wrapper_tune( - ray_config_loader: RayTuneModel, encoder_loader: EncoderLoader + ray_config_loader: RayTuneModel, + encoder_loader: EncoderLoader, ) -> None: """Test the tune method of TuneWrapper class.""" # Filter ResourceWarning during Ray shutdown From d7d565665df3d40dc4380925e272f1dc9cc03430 Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 15:37:24 +0100 Subject: [PATCH 79/81] FORMATTING: tests/utils/test_data_yaml.py --- tests/utils/test_data_yaml.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/utils/test_data_yaml.py b/tests/utils/test_data_yaml.py index 32b516e0..aaabbba8 100644 --- a/tests/utils/test_data_yaml.py +++ b/tests/utils/test_data_yaml.py @@ -39,7 +39,7 @@ def load_split_config_yaml_from_file() -> SplitConfigDict: def load_yaml_from_file() -> ConfigDict: """Fixture that loads a test YAML configuration file.""" with open( - "tests/test_data/dna_experiment/dna_experiment_config_template.yaml" + "tests/test_data/dna_experiment/dna_experiment_config_template.yaml", ) as f: yaml_dict = yaml.safe_load(f) return ConfigDict(**yaml_dict) @@ -76,7 +76,8 @@ def test_expand_transform_parameter_combinations( results = yaml_data.expand_transform_parameter_combinations(transform) assert len(results) == 1 # Only one transform returned assert isinstance( - results[0], yaml_data.Transform + results[0], + yaml_data.Transform, ) # Should return Transform objects @@ -85,7 +86,7 @@ def test_expand_transform_list_combinations( ) -> None: """Tests expanding a list of transforms into all parameter combinations.""" results = yaml_data.expand_transform_list_combinations( - load_yaml_from_file.transforms + load_yaml_from_file.transforms, ) # 4 combinations from first transform x 2 from second assert len(results) == 8 @@ -127,7 +128,8 @@ def test_check_yaml_schema( data = request.getfixturevalue(test_input[0]) if test_input[1]: with pytest.raises( - ValueError, match="Wrong type on a field, see the pydantic report above" + ValueError, + match="Wrong type on a field, see the pydantic report above", ): yaml_data.check_yaml_schema(data) else: From 26f8b3ab21acdbec409edf12f61a728b4f3f199b Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 16:30:05 +0100 Subject: [PATCH 80/81] FIX: references to old yaml functions --- src/stimulus/cli/split_split.py | 4 ++-- tests/utils/test_data_yaml.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/stimulus/cli/split_split.py b/src/stimulus/cli/split_split.py index 94a760c0..7ad398a2 100755 --- a/src/stimulus/cli/split_split.py +++ b/src/stimulus/cli/split_split.py @@ -14,7 +14,7 @@ from stimulus.utils.yaml_data import ( ConfigDict, SplitConfigDict, - check_yaml_schema, + check_schema, dump_yaml_list_into_files, generate_split_configs, ) @@ -67,7 +67,7 @@ def main(config_yaml: str, out_dir_path: str) -> None: yaml_config_dict: ConfigDict = ConfigDict(**yaml_config) # check if the yaml schema is correct # FIXME: isn't it redundant to check and already class with pydantic ? - check_yaml_schema(yaml_config_dict) + check_schema(yaml_config_dict) # generate the yaml files per split split_configs: list[SplitConfigDict] = generate_split_configs(yaml_config_dict) diff --git a/tests/utils/test_data_yaml.py b/tests/utils/test_data_yaml.py index aaabbba8..5b6467c1 100644 --- a/tests/utils/test_data_yaml.py +++ b/tests/utils/test_data_yaml.py @@ -120,7 +120,7 @@ def test_generate_data_configs( "test_input", [("load_yaml_from_file", False), ("load_wrong_type_yaml", True)], ) -def test_check_yaml_schema( +def test_check_schema( request: pytest.FixtureRequest, test_input: tuple[str, bool], ) -> None: @@ -131,6 +131,6 @@ def test_check_yaml_schema( ValueError, match="Wrong type on a field, see the pydantic report above", ): - yaml_data.check_yaml_schema(data) + yaml_data.check_schema(data) else: - yaml_data.check_yaml_schema(data) + yaml_data.check_schema(data) From 216298d3b3c5eb2873f0363e377fa13a196c673e Mon Sep 17 00:00:00 2001 From: Julien Raynal Date: Wed, 19 Feb 2025 17:32:53 +0100 Subject: [PATCH 81/81] FORMATTING: formatting some files --- src/stimulus/cli/split_transforms.py | 4 +--- src/stimulus/data/data_handlers.py | 23 +++++++---------------- src/stimulus/learner/raytune_learner.py | 14 ++++---------- src/stimulus/utils/yaml_data.py | 22 ++++------------------ tests/cli/test_shuffle_csv.py | 2 -- tests/data/test_data_handlers.py | 17 ++++------------- 6 files changed, 20 insertions(+), 62 deletions(-) diff --git a/src/stimulus/cli/split_transforms.py b/src/stimulus/cli/split_transforms.py index 10a05c3c..d7c4ba23 100755 --- a/src/stimulus/cli/split_transforms.py +++ b/src/stimulus/cli/split_transforms.py @@ -63,9 +63,7 @@ def main(config_yaml: str, out_dir_path: str) -> None: yaml_config_dict: SplitConfigDict = SplitConfigDict(**yaml_config) # Generate the yaml files for each transform - split_transform_configs: list[SplitTransformDict] = ( - generate_split_transform_configs(yaml_config_dict) - ) + split_transform_configs: list[SplitTransformDict] = generate_split_transform_configs(yaml_config_dict) # Dump all the YAML configs into files dump_yaml_list_into_files(split_transform_configs, out_dir_path, "test_transforms") diff --git a/src/stimulus/data/data_handlers.py b/src/stimulus/data/data_handlers.py index 3b9a11d9..e8454a06 100644 --- a/src/stimulus/data/data_handlers.py +++ b/src/stimulus/data/data_handlers.py @@ -211,16 +211,11 @@ def encode_columns(self, column_data: dict) -> dict: >>> print(encoded["dna_seq"].shape) torch.Size([2, 4, 4]) # 2 sequences, length 4, one-hot encoded """ - return { - col: self.encode_column(col, values) for col, values in column_data.items() - } + return {col: self.encode_column(col, values) for col, values in column_data.items()} def encode_dataframe(self, dataframe: pl.DataFrame) -> dict[str, torch.Tensor]: """Encode the dataframe using the encoders.""" - return { - col: self.encode_column(col, dataframe[col].to_list()) - for col in dataframe.columns - } + return {col: self.encode_column(col, dataframe[col].to_list()) for col in dataframe.columns} class TransformManager: @@ -250,9 +245,7 @@ def transform_column( list: The transformed data. bool: Whether the transformation added new rows to the data. """ - transformer = self.transform_loader.__getattribute__(column_name)[ - transform_name - ] + transformer = self.transform_loader.__getattribute__(column_name)[transform_name] return transformer.transform_all(column_data), transformer.add_row @@ -351,7 +344,9 @@ class DatasetProcessor(DatasetHandler): """Class for loading dataset, applying transformations and splitting.""" def __init__( - self, data_config: yaml_data.SplitTransformDict, csv_path: str + self, + data_config: yaml_data.SplitTransformDict, + csv_path: str, ) -> None: """Initialize the DatasetProcessor.""" super().__init__(data_config, csv_path) @@ -433,11 +428,7 @@ def __init__( """Initialize the DatasetLoader.""" super().__init__(data_config, csv_path) self.encoder_manager = EncodeManager(encoder_loader) - self.data = ( - self.load_csv_per_split(csv_path, split) - if split is not None - else self.load_csv(csv_path) - ) + self.data = self.load_csv_per_split(csv_path, split) if split is not None else self.load_csv(csv_path) def get_all_items(self) -> tuple[dict, dict, dict]: """Get the full dataset as three separate dictionaries for inputs, labels and metadata. diff --git a/src/stimulus/learner/raytune_learner.py b/src/stimulus/learner/raytune_learner.py index 5e9d9e3e..ec59beb2 100644 --- a/src/stimulus/learner/raytune_learner.py +++ b/src/stimulus/learner/raytune_learner.py @@ -1,7 +1,5 @@ """Ray Tune wrapper and trainable model classes for hyperparameter optimization.""" -from stimulus.utils.yaml_model_schema import RayTuneModel -from stimulus.utils.yaml_data import SplitTransformDict import datetime import logging import os @@ -22,6 +20,8 @@ from stimulus.data.loaders import EncoderLoader from stimulus.learner.predict import PredictWrapper from stimulus.utils.generic_utils import set_general_seeds +from stimulus.utils.yaml_data import SplitTransformDict +from stimulus.utils.yaml_model_schema import RayTuneModel class CheckpointDict(TypedDict): @@ -314,14 +314,8 @@ def objective(self) -> dict[str, float]: loss_dict=self.loss_dict, ) return { - **{ - "val_" + metric: value - for metric, value in predict_val.compute_metrics(metrics).items() - }, - **{ - "train_" + metric: value - for metric, value in predict_train.compute_metrics(metrics).items() - }, + **{"val_" + metric: value for metric, value in predict_val.compute_metrics(metrics).items()}, + **{"train_" + metric: value for metric, value in predict_train.compute_metrics(metrics).items()}, } # type: ignore[override] diff --git a/src/stimulus/utils/yaml_data.py b/src/stimulus/utils/yaml_data.py index 73fb6daa..aa7ee003 100644 --- a/src/stimulus/utils/yaml_data.py +++ b/src/stimulus/utils/yaml_data.py @@ -193,11 +193,7 @@ def expand_transform_parameter_combinations( for column in transform.columns: for transformation in column.transformations: if transformation.params: - list_lengths = [ - len(v) - for v in transformation.params.values() - if isinstance(v, list) and len(v) > 1 - ] + list_lengths = [len(v) for v in transformation.params.values() if isinstance(v, list) and len(v) > 1] if list_lengths: # All lists have same length due to validator max_length = list_lengths[0] @@ -312,7 +308,6 @@ def generate_split_transform_configs( length will be the product of the number of parameter combinations and the number of splits. """ - if isinstance(config, dict) and not isinstance(config, SplitConfigDict): raise TypeError("Input must be a list of SubConfigDict") @@ -403,30 +398,21 @@ def fix_params(input_dict: dict[str, Any]) -> dict[str, Any]: processed_dict[key] = [] for encoder in value: processed_encoder = dict(encoder) - if ( - "params" not in processed_encoder - or not processed_encoder["params"] - ): + if "params" not in processed_encoder or not processed_encoder["params"]: processed_encoder["params"] = {} processed_dict[key].append(processed_encoder) elif key == "transformations" and isinstance(value, list): processed_dict[key] = [] for transformation in value: processed_transformation = dict(transformation) - if ( - "params" not in processed_transformation - or not processed_transformation["params"] - ): + if "params" not in processed_transformation or not processed_transformation["params"]: processed_transformation["params"] = {} processed_dict[key].append(processed_transformation) elif isinstance(value, dict): processed_dict[key] = fix_params(value) elif isinstance(value, list): processed_dict[key] = [ - fix_params(list_item) - if isinstance(list_item, dict) - else list_item - for list_item in value + fix_params(list_item) if isinstance(list_item, dict) else list_item for list_item in value ] else: processed_dict[key] = value diff --git a/tests/cli/test_shuffle_csv.py b/tests/cli/test_shuffle_csv.py index 653287de..8fef803c 100644 --- a/tests/cli/test_shuffle_csv.py +++ b/tests/cli/test_shuffle_csv.py @@ -6,10 +6,8 @@ from typing import Any, Callable import pytest -import yaml from src.stimulus.cli.shuffle_csv import main -from src.stimulus.utils.yaml_data import SplitTransformDict # Fixtures diff --git a/tests/data/test_data_handlers.py b/tests/data/test_data_handlers.py index 23eb6585..2716ee32 100644 --- a/tests/data/test_data_handlers.py +++ b/tests/data/test_data_handlers.py @@ -332,19 +332,10 @@ def test_dataset_processor_apply_transformation_group( assert processor.data["age"].to_list() != processor_control.data["age"].to_list() assert processor.data["fare"].to_list() != processor_control.data["fare"].to_list() - assert ( - processor.data["parch"].to_list() == processor_control.data["parch"].to_list() - ) - assert ( - processor.data["sibsp"].to_list() == processor_control.data["sibsp"].to_list() - ) - assert ( - processor.data["pclass"].to_list() == processor_control.data["pclass"].to_list() - ) - assert ( - processor.data["embarked"].to_list() - == processor_control.data["embarked"].to_list() - ) + assert processor.data["parch"].to_list() == processor_control.data["parch"].to_list() + assert processor.data["sibsp"].to_list() == processor_control.data["sibsp"].to_list() + assert processor.data["pclass"].to_list() == processor_control.data["pclass"].to_list() + assert processor.data["embarked"].to_list() == processor_control.data["embarked"].to_list() assert processor.data["sex"].to_list() == processor_control.data["sex"].to_list()