Skip to content

Commit

Permalink
Merge branch 'dev' into split_behaviour
Browse files Browse the repository at this point in the history
  • Loading branch information
Julien Raynal authored and Julien Raynal committed Feb 18, 2025
2 parents 10092ee + e5d1d1e commit 417aadd
Showing 1 changed file with 35 additions and 32 deletions.
67 changes: 35 additions & 32 deletions src/stimulus/utils/yaml_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
from pydantic import BaseModel, ValidationError, field_validator


class YamlGlobalParams(BaseModel):
class GlobalParams(BaseModel):
"""Model for global parameters in YAML configuration."""

seed: int


class YamlColumnsEncoder(BaseModel):
class ColumnsEncoder(BaseModel):
"""Model for column encoder configuration."""

name: str
Expand All @@ -21,16 +21,16 @@ class YamlColumnsEncoder(BaseModel):
] # Allow both string and list values


class YamlColumns(BaseModel):
class Columns(BaseModel):
"""Model for column configuration."""

column_name: str
column_type: str
data_type: str
encoder: list[YamlColumnsEncoder]
encoder: list[ColumnsEncoder]


class YamlTransformColumnsTransformation(BaseModel):
class TransformColumnsTransformation(BaseModel):
"""Model for column transformation configuration."""

name: str
Expand All @@ -39,18 +39,18 @@ class YamlTransformColumnsTransformation(BaseModel):
] # Allow both list and float values


class YamlTransformColumns(BaseModel):
class TransformColumns(BaseModel):
"""Model for transform columns configuration."""

column_name: str
transformations: list[YamlTransformColumnsTransformation]
transformations: list[TransformColumnsTransformation]


class YamlTransform(BaseModel):
class Transform(BaseModel):
"""Model for transform configuration."""

transformation_name: str
columns: list[YamlTransformColumns]
columns: list[TransformColumns]

@field_validator("columns")
@classmethod
Expand Down Expand Up @@ -94,21 +94,21 @@ def validate_param_lists_across_columns(
return columns


class YamlSplit(BaseModel):
class Split(BaseModel):
"""Model for split configuration."""

split_method: str
params: dict[str, list[float]] # More specific type for split parameters
split_input_columns: list[str]


class YamlConfigDict(BaseModel):
class ConfigDict(BaseModel):
"""Model for main YAML configuration."""

global_params: YamlGlobalParams
columns: list[YamlColumns]
transforms: list[YamlTransform]
split: list[YamlSplit]
global_params: GlobalParams
columns: list[Columns]
transforms: list[Transform]
split: list[Split]


# TODO: Rename this class to SplitConfigDict
Expand All @@ -124,16 +124,16 @@ class YamlSplitConfigDict(BaseModel):
class YamlSplitTransformDict(BaseModel):
"""Model for sub-configuration generated from main config."""

global_params: YamlGlobalParams
columns: list[YamlColumns]
transforms: YamlTransform
split: YamlSplit
global_params: GlobalParams
columns: list[Columns]
transforms: list[Transform]
split: Split


class YamlSchema(BaseModel):
class Schema(BaseModel):
"""Model for validating YAML schema."""

yaml_conf: YamlConfigDict
yaml_conf: ConfigDict


class YamlSplitSchema(BaseModel):
Expand All @@ -155,7 +155,7 @@ def extract_transform_parameters_at_index(
A new transform with single parameter values at the specified index
"""
# Create a copy of the transform
new_transform = YamlTransform(**transform.model_dump())
new_transform = Transform(**transform.model_dump())

# Process each column and transformation
for column in new_transform.columns:
Expand Down Expand Up @@ -223,18 +223,19 @@ def expand_transform_list_combinations(
create two transforms: one with 0.1/1 and another with 0.2/2.
Args:
transform_list: A list of YamlTransform objects containing parameter lists
transform_list: A list of Transform objects containing parameter lists
that need to be expanded into individual transforms.
Returns:
list[YamlTransform]: A flattened list of transforms where each transform
list[Transform]: A flattened list of transforms where each transform
has single parameter values instead of parameter lists. The length of
the returned list will be the sum of the number of parameter combinations
for each input transform.
"""
sub_transforms = []
for transform in transform_list:
sub_transforms.extend(expand_transform_parameter_combinations(transform))
sub_transforms.extend(
expand_transform_parameter_combinations(transform))
return sub_transforms


Expand Down Expand Up @@ -266,8 +267,8 @@ def generate_split_configs(yaml_config: YamlConfigDict) -> list[YamlSplitConfigD
length will be the product of the number of parameter combinations
and the number of splits.
"""
if isinstance(yaml_config, dict) and not isinstance(yaml_config, YamlConfigDict):
raise TypeError("Input must be a YamlConfigDict object")
if isinstance(yaml_config, dict) and not isinstance(yaml_config, ConfigDict):
raise TypeError("Input must be a ConfigDict object")

sub_splits = yaml_config.split
sub_configs = []
Expand Down Expand Up @@ -413,7 +414,8 @@ def fix_params(input_dict: dict[str, Any]) -> dict[str, Any]:
or not processed_transformation["params"]
):
processed_transformation["params"] = {}
processed_dict[key].append(processed_transformation)
processed_dict[key].append(
processed_transformation)
elif isinstance(value, dict):
processed_dict[key] = fix_params(value)
elif isinstance(value, list):
Expand Down Expand Up @@ -442,14 +444,14 @@ def fix_params(input_dict: dict[str, Any]) -> dict[str, Any]:
)


def check_yaml_schema(config_yaml: YamlConfigDict) -> str:
def check_yaml_schema(config_yaml: ConfigDict) -> str:
"""Validate YAML configuration fields have correct types.
If the children field is specific to a parent, the children fields class is hosted in the parent fields class.
If any field in not the right type, the function prints an error message explaining the problem and exits the python code.
Args:
config_yaml: The YamlConfigDict containing the fields of the yaml configuration file
config_yaml: The ConfigDict containing the fields of the yaml configuration file
Returns:
str: Empty string if validation succeeds
Expand All @@ -458,8 +460,9 @@ def check_yaml_schema(config_yaml: YamlConfigDict) -> str:
ValueError: If validation fails
"""
try:
YamlSchema(yaml_conf=config_yaml)
Schema(yaml_conf=config_yaml)
except ValidationError as e:
# Use logging instead of print for error handling
raise ValueError("Wrong type on a field, see the pydantic report above") from e
raise ValueError(
"Wrong type on a field, see the pydantic report above") from e
return ""

0 comments on commit 417aadd

Please sign in to comment.