Skip to content

Commit

Permalink
FIX: src/stimulus/loaders.py -> initialize_splitter_from_config uses …
Browse files Browse the repository at this point in the history
…a YamlSplitConfigDict as input now
  • Loading branch information
Julien Raynal authored and Julien Raynal committed Feb 18, 2025
1 parent cb7d55f commit 10092ee
Showing 1 changed file with 55 additions and 19 deletions.
74 changes: 55 additions & 19 deletions src/stimulus/data/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,17 @@ def __init__(self, seed: Optional[float] = None) -> None:
"""
self.seed = seed

def initialize_column_encoders_from_config(self, column_config: yaml_data.YamlColumns) -> None:
def initialize_column_encoders_from_config(
self, column_config: yaml_data.YamlColumns
) -> None:
"""Build the loader from a config dictionary.
Args:
column_config (yaml_data.YamlColumns): Configuration dictionary containing field names (column_name) and their encoder specifications.
"""
for field in column_config:
encoder = self.get_encoder(field.encoder[0].name, field.encoder[0].params)
encoder = self.get_encoder(
field.encoder[0].name, field.encoder[0].params)
self.set_encoder_as_attribute(field.column_name, encoder)

def get_function_encode_all(self, field_name: str) -> Any:
Expand All @@ -50,7 +53,9 @@ def get_function_encode_all(self, field_name: str) -> Any:
"""
return getattr(self, field_name).encode_all

def get_encoder(self, encoder_name: str, encoder_params: Optional[dict] = None) -> Any:
def get_encoder(
self, encoder_name: str, encoder_params: Optional[dict] = None
) -> Any:
"""Gets an encoder object from the encoders module and initializes it with the given parameters.
Args:
Expand All @@ -63,7 +68,9 @@ def get_encoder(self, encoder_name: str, encoder_params: Optional[dict] = None)
try:
return getattr(encoders, encoder_name)(**encoder_params)
except AttributeError:
logging.exception(f"Encoder '{encoder_name}' not found in the encoders module.")
logging.exception(
f"Encoder '{encoder_name}' not found in the encoders module."
)
logging.exception(
f"Available encoders: {[name for name, obj in encoders.__dict__.items() if isinstance(obj, type) and name not in ('ABC', 'Any')]}",
)
Expand All @@ -72,13 +79,17 @@ def get_encoder(self, encoder_name: str, encoder_params: Optional[dict] = None)
except TypeError:
if encoder_params is None:
return getattr(encoders, encoder_name)()
logging.exception(f"Encoder '{encoder_name}' has incorrect parameters: {encoder_params}")
logging.exception(
f"Encoder '{encoder_name}' has incorrect parameters: {encoder_params}"
)
logging.exception(
f"Expected parameters for '{encoder_name}': {inspect.signature(getattr(encoders, encoder_name))}",
)
raise

def set_encoder_as_attribute(self, field_name: str, encoder: encoders.AbstractEncoder) -> None:
def set_encoder_as_attribute(
self, field_name: str, encoder: encoders.AbstractEncoder
) -> None:
"""Sets the encoder as an attribute of the loader.
Args:
Expand All @@ -99,7 +110,9 @@ def __init__(self, seed: Optional[float] = None) -> None:
"""
self.seed = seed

def get_data_transformer(self, transformation_name: str, transformation_params: Optional[dict] = None) -> Any:
def get_data_transformer(
self, transformation_name: str, transformation_params: Optional[dict] = None
) -> Any:
"""Gets a transformer object from the transformers module.
Args:
Expand All @@ -110,9 +123,13 @@ def get_data_transformer(self, transformation_name: str, transformation_params:
Any: The transformer function for the specified transformation
"""
try:
return getattr(data_transformation_generators, transformation_name)(**transformation_params)
return getattr(data_transformation_generators, transformation_name)(
**transformation_params
)
except AttributeError:
logging.exception(f"Transformer '{transformation_name}' not found in the transformers module.")
logging.exception(
f"Transformer '{transformation_name}' not found in the transformers module."
)
logging.exception(
f"Available transformers: {[name for name, obj in data_transformation_generators.__dict__.items() if isinstance(obj, type) and name not in ('ABC', 'Any')]}",
)
Expand All @@ -121,13 +138,17 @@ def get_data_transformer(self, transformation_name: str, transformation_params:
except TypeError:
if transformation_params is None:
return getattr(data_transformation_generators, transformation_name)()
logging.exception(f"Transformer '{transformation_name}' has incorrect parameters: {transformation_params}")
logging.exception(
f"Transformer '{transformation_name}' has incorrect parameters: {transformation_params}"
)
logging.exception(
f"Expected parameters for '{transformation_name}': {inspect.signature(getattr(data_transformation_generators, transformation_name))}",
)
raise

def set_data_transformer_as_attribute(self, field_name: str, data_transformer: Any) -> None:
def set_data_transformer_as_attribute(
self, field_name: str, data_transformer: Any
) -> None:
"""Sets the data transformer as an attribute of the loader.
Args:
Expand All @@ -136,12 +157,18 @@ def set_data_transformer_as_attribute(self, field_name: str, data_transformer: A
"""
# check if the field already exists, if it does not, initialize it to an empty dict
if not hasattr(self, field_name):
setattr(self, field_name, {data_transformer.__class__.__name__: data_transformer})
setattr(
self,
field_name,
{data_transformer.__class__.__name__: data_transformer},
)
else:
field_value = getattr(self, field_name)
field_value[data_transformer.__class__.__name__] = data_transformer

def initialize_column_data_transformers_from_config(self, transform_config: yaml_data.YamlTransform) -> None:
def initialize_column_data_transformers_from_config(
self, transform_config: yaml_data.YamlTransform
) -> None:
"""Build the loader from a config dictionary.
Args:
Expand Down Expand Up @@ -174,7 +201,9 @@ def initialize_column_data_transformers_from_config(self, transform_config: yaml
for column in transform_config.columns:
col_name = column.column_name
for transform_spec in column.transformations:
transformer = self.get_data_transformer(transform_spec.name, transform_spec.params)
transformer = self.get_data_transformer(
transform_spec.name, transform_spec.params
)
self.set_data_transformer_as_attribute(col_name, transformer)


Expand Down Expand Up @@ -206,7 +235,9 @@ def get_function_split(self) -> Any:
)
return self.split.get_split_indexes

def get_splitter(self, splitter_name: str, splitter_params: Optional[dict] = None) -> Any:
def get_splitter(
self, splitter_name: str, splitter_params: Optional[dict] = None
) -> Any:
"""Gets a splitter object from the splitters module.
Args:
Expand All @@ -221,7 +252,9 @@ def get_splitter(self, splitter_name: str, splitter_params: Optional[dict] = Non
except TypeError:
if splitter_params is None:
return getattr(splitters, splitter_name)()
logging.exception(f"Splitter '{splitter_name}' has incorrect parameters: {splitter_params}")
logging.exception(
f"Splitter '{splitter_name}' has incorrect parameters: {splitter_params}"
)
logging.exception(
f"Expected parameters for '{splitter_name}': {inspect.signature(getattr(splitters, splitter_name))}",
)
Expand All @@ -235,11 +268,14 @@ def set_splitter_as_attribute(self, splitter: Any) -> None:
"""
self.split = splitter

def initialize_splitter_from_config(self, split_config: yaml_data.YamlSplit) -> None:
def initialize_splitter_from_config(
self, split_config: yaml_data.YamlSplitConfigDict
) -> None:
"""Build the loader from a config dictionary.
Args:
split_config (yaml_data.YamlSplit): Configuration dictionary containing split configurations.
split_config (yaml_data.YamlSplitConfigDict): Configuration dictionary containing split configurations.
"""
splitter = self.get_splitter(split_config.split_method, split_config.params)
splitter = self.get_splitter(
split_config.split_method, split_config.params)
self.set_splitter_as_attribute(splitter)

0 comments on commit 10092ee

Please sign in to comment.