Skip to content

Commit

Permalink
Merge pull request #30 from mathysgrapotte/yaml-refactor-auto-class-b…
Browse files Browse the repository at this point in the history
…uild

Yaml refactor auto class build
  • Loading branch information
mathysgrapotte authored Jan 15, 2025
2 parents 34016e7 + 65613eb commit 6e0e7b9
Show file tree
Hide file tree
Showing 13 changed files with 1,513 additions and 220 deletions.
403 changes: 313 additions & 90 deletions src/stimulus/data/csv.py

Large diffs are not rendered by default.

285 changes: 215 additions & 70 deletions src/stimulus/data/experiments.py
Original file line number Diff line number Diff line change
@@ -1,83 +1,228 @@
"""Experiments are classes parsed by CSV master classes to run experiments.
Conceptually, experiment classes contain data types, transformations etc and are used to duplicate the input data into many datasets.
Here we provide standard experiments as well as an absctract class for users to implement their own.
"""Loaders serve as interfaces between the CSV master class and custom methods.
Mainly, three types of custom methods are supported:
- Encoders: methods for encoding data before it is fed into the model
- Data transformers: methods for transforming data (i.e. augmenting, noising...)
- Splitters: methods for splitting data into train, validation and test sets
# TODO implement noise schemes and splitting schemes.
Loaders are built from an input config YAML file which format is described in the documentation, you can find an example here: tests/test_data/dna_experiment/dna_experiment_config_template.yaml
"""

from abc import ABC
from typing import Any
from collections import defaultdict

from .encoding import encoders as encoders
from .splitters import splitters as splitters
from .transform import data_transformation_generators as data_transformation_generators

import inspect
import yaml

class AbstractExperiment(ABC):
"""Abstract class for experiments.
from stimulus.data.encoding import encoders as encoders
from stimulus.data.splitters import splitters as splitters
from stimulus.data.transform import data_transformation_generators as data_transformation_generators
from stimulus.utils.yaml_data import YamlConfigDict

WARNING, DATA_TYPES ARGUMENT NAMES SHOULD BE ALL LOWERCASE, CHECK THE DATA_TYPES MODULE FOR THE TYPES THAT HAVE BEEN IMPLEMENTED.
"""
class AbstractLoader(ABC):
"""Abstract base class for defining loaders."""

def get_config_from_yaml(self, yaml_path: str) -> dict:
"""Loads experiment configuration from a YAML file.
Args:
yaml_path (str): Path to the YAML config file
Returns:
dict: The loaded configuration dictionary
"""
with open(yaml_path, "r") as file:
config = YamlConfigDict(**yaml.safe_load(file))
return config

class EncoderLoader(AbstractLoader):
"""Class for loading encoders from a config file."""

def __init__(self, seed: float = None) -> None:
self.seed = seed

def initialize_column_encoders_from_config(self, config: YamlConfigDict) -> None:
"""Build the loader from a config dictionary.
Args:
config (YamlConfigDict): Configuration dictionary containing field names (column_name) and their encoder specifications.
"""
for field in config:
encoder = self.get_encoder(field.encoder[0].name, field.encoder[0].params)
self.set_encoder_as_attribute(field.column_name, encoder)

def get_function_encode_all(self, field_name: str) -> Any:
"""Gets the encoding function for a specific field.
Args:
field_name (str): The field name to get the encoder for
Returns:
Any: The encode_all function for the specified field
"""
return getattr(self, field_name)["encoder"].encode_all

def get_encoder(self, encoder_name: str, encoder_params: dict = None) -> Any:
"""Gets an encoder object from the encoders module and initializes it with the given parametersß.
Args:
encoder_name (str): The name of the encoder to get
encoder_params (dict): The parameters for the encoder
Returns:
Any: The encoder function for the specified field and parameters
"""

try:
return getattr(encoders, encoder_name)(**encoder_params)
except AttributeError:
print(f"Encoder '{encoder_name}' not found in the encoders module.")
print(f"Available encoders: {[name for name, obj in encoders.__dict__.items() if isinstance(obj, type) and name not in ('ABC', 'Any')]}")
raise

except TypeError:
if encoder_params is None:
return getattr(encoders, encoder_name)()
else:
print(f"Encoder '{encoder_name}' has incorrect parameters: {encoder_params}")
print(f"Expected parameters for '{encoder_name}': {inspect.signature(getattr(encoders, encoder_name))}")
raise

def set_encoder_as_attribute(self, field_name: str, encoder: encoders.AbstractEncoder) -> None:
"""Sets the encoder as an attribute of the loader.
Args:
field_name (str): The name of the field to set the encoder for
encoder (encoders.AbstractEncoder): The encoder to set
"""
setattr(self, field_name, {"encoder": encoder})

class TransformLoader(AbstractLoader):
"""Class for loading transformations from a config file."""

def __init__(self, seed: float = None) -> None:
# allow ability to add a seed for reproducibility
self.seed = seed
# added because if the user does not define this it does not crach the get_function_split, random split works for every class afteralll
self.split = {"RandomSplitter": splitters.RandomSplitter()}

def get_function_encode_all(self, data_type: str) -> Any:
"""This method gets the encoding function for a specific data type."""
return getattr(self, data_type)["encoder"].encode_all

def get_data_transformer(self, data_type: str, transformation_generator: str) -> Any:
"""This method transforms the data (noising, data augmentation etc)."""
return getattr(self, data_type)["data_transformation_generators"][transformation_generator]

def get_function_split(self, split_method: str) -> Any:
"""This method returns the function for splitting the data."""
return self.split[split_method].get_split_indexes


class DnaToFloatExperiment(AbstractExperiment):
"""Class for dealing with DNA to float predictions (for instance regression from DNA sequence to CAGE value)"""

def __init__(self) -> None:
super().__init__()
self.dna = {
"encoder": encoders.TextOneHotEncoder(alphabet="acgt"),
"data_transformation_generators": {
"UniformTextMasker": data_transformation_generators.UniformTextMasker(mask="N"),
"ReverseComplement": data_transformation_generators.ReverseComplement(),
"GaussianChunk": data_transformation_generators.GaussianChunk(),
},
}
self.float = {
"encoder": encoders.FloatEncoder(),
"data_transformation_generators": {"GaussianNoise": data_transformation_generators.GaussianNoise()},
}
self.split = {"RandomSplitter": splitters.RandomSplitter()}


class ProtDnaToFloatExperiment(DnaToFloatExperiment):
"""Class for dealing with Protein and DNA to float predictions (for instance regression from Protein sequence + DNA sequence to binding score)"""

def __init__(self) -> None:
super().__init__()
self.prot = {
"encoder": encoders.TextOneHotEncoder(alphabet="acdefghiklmnpqrstvwy"),
"data_transformation_generators": {
"UniformTextMasker": data_transformation_generators.UniformTextMasker(mask="X"),
},
}


class TitanicExperiment(AbstractExperiment):
"""Class for dealing with the Titanic dataset as a test format."""

def __init__(self) -> None:
super().__init__()
self.int_class = {"encoder": encoders.IntEncoder(), "data_transformation_generators": {}}
self.str_class = {"encoder": encoders.StrClassificationIntEncoder(), "data_transformation_generators": {}}
self.int_reg = {"encoder": encoders.IntRankEncoder(), "data_transformation_generators": {}}
self.float_rank = {"encoder": encoders.FloatRankEncoder(), "data_transformation_generators": {}}

def get_data_transformer(self, transformation_name: str, transformation_params: dict = None) -> Any:
"""Gets a transformer object from the transformers module.
Args:
transformation_name (str): The name of the transformer to get
Returns:
Any: The transformer function for the specified transformation
"""
try:
return getattr(data_transformation_generators, transformation_name)(**transformation_params)
except AttributeError:
print(f"Transformer '{transformation_name}' not found in the transformers module.")
print(f"Available transformers: {[name for name, obj in data_transformation_generators.__dict__.items() if isinstance(obj, type) and name not in ('ABC', 'Any')]}")
raise

except TypeError:
if transformation_params is None:
return getattr(data_transformation_generators, transformation_name)()
else:
print(f"Transformer '{transformation_name}' has incorrect parameters: {transformation_params}")
print(f"Expected parameters for '{transformation_name}': {inspect.signature(getattr(data_transformation_generators, transformation_name))}")
raise

def set_data_transformer_as_attribute(self, field_name: str, data_transformer: Any) -> None:
"""Sets the data transformer as an attribute of the loader.
Args:
field_name (str): The name of the field to set the data transformer for
data_transformer (Any): The data transformer to set
"""
setattr(self, field_name, {"data_transformation_generators": data_transformer})

def initialize_column_data_transformers_from_config(self, config: YamlConfigDict) -> None:
"""Build the loader from a config dictionary.
Args:
config (YamlConfigDict): Configuration dictionary containing transforms configurations.
Each transform can specify multiple columns and their transformations.
The method will organize transformers by column, ensuring each column
has all its required transformations.
"""

# Use defaultdict to automatically initialize empty lists
column_transformers = defaultdict(list)

# First pass: collect all transformations by column
for transform_group in config:
for column in transform_group.columns:
col_name = column.column_name

# Process each transformation for this column
for transform_spec in column.transformations:
# Create transformer instance
transformer = self.get_data_transformer(transform_spec.name)

# Get transformer class for comparison
transformer_type = type(transformer)

# Add transformer if its type isn't already present
if not any(isinstance(existing, transformer_type)
for existing in column_transformers[col_name]):
column_transformers[col_name].append(transformer)

# Second pass: set all collected transformers as attributes
for col_name, transformers in column_transformers.items():
self.set_data_transformer_as_attribute(col_name, transformers)

class SplitLoader(AbstractLoader):
"""Class for loading splitters from a config file."""

def __init__(self, seed: float = None) -> None:
self.seed = seed

def get_function_split(self) -> Any:
"""Gets the function for splitting the data.
Args:
split_method (str): Name of the split method to use
Returns:
Any: The split function for the specified method
"""
return self.split.get_split_indexes

def get_splitter(self, splitter_name: str, splitter_params: dict = None) -> Any:
"""Gets a splitter object from the splitters module.
Args:
splitter_name (str): The name of the splitter to get
Returns:
Any: The splitter function for the specified splitter
"""
try:
return getattr(splitters, splitter_name)(**splitter_params)
except TypeError:
if splitter_params is None:
return getattr(splitters, splitter_name)()
else:
print(f"Splitter '{splitter_name}' has incorrect parameters: {splitter_params}")
print(f"Expected parameters for '{splitter_name}': {inspect.signature(getattr(splitters, splitter_name))}")
raise

def set_splitter_as_attribute(self, splitter: Any) -> None:
"""Sets the splitter as an attribute of the loader.
Args:
field_name (str): The name of the field to set the splitter for
splitter (Any): The splitter to set
"""
setattr(self, "split", splitter)

def initialize_splitter_from_config(self, config: YamlConfigDict, split_index: int = 0) -> None:
"""Build the loader from a config dictionary.
Args:
config (dict): Configuration dictionary containing split configurations.
"""
splitter = self.get_splitter(config.split[split_index].split_method, config.split[split_index].params)
self.set_splitter_as_attribute(splitter)
3 changes: 3 additions & 0 deletions src/stimulus/data/splitters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .splitters import AbstractSplitter, RandomSplit

__all__ = ["AbstractSplitter", "RandomSplit"]
36 changes: 18 additions & 18 deletions src/stimulus/data/splitters/splitters.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,17 @@ class AbstractSplitter(ABC):
distance: calculates the distance between two elements of the data
"""

def __init__(self, seed: float = 42) -> None:
self.seed = seed

@abstractmethod
def get_split_indexes(self, data: pl.DataFrame, seed: float = None) -> list:
def get_split_indexes(self, data: pl.DataFrame) -> list:
"""Splits the data. Always return indices mapping to the original list.
This is an abstract method that should be implemented by the child class.
Args:
data (pl.DataFrame): the data to be split
seed (float): the seed for reproducibility
Returns:
split_indices (list): the indices for train, validation, and test sets
Expand All @@ -48,55 +50,53 @@ def distance(self, data_one: Any, data_two: Any) -> float:
raise NotImplementedError


class RandomSplitter(AbstractSplitter):
class RandomSplit(AbstractSplitter):
"""This splitter randomly splits the data."""

def __init__(self) -> None:
def __init__(self, split: list = [0.7, 0.2, 0.1], seed: float = None) -> None:
super().__init__()
self.split = split
self.seed = seed
if len(self.split) != 3:
raise ValueError(
"The split argument should be a list with length 3 that contains the proportions for [train, validation, test] splits.",
)

def get_split_indexes(
self,
data: pl.DataFrame,
split: list = [0.7, 0.2, 0.1],
seed: float = None,
) -> tuple[list, list, list]:
"""Splits the data indices into train, validation, and test sets.
One can use these lists of indices to parse the data afterwards.
Args:
data (pl.DataFrame): The data loaded with polars.
split (list): The proportions for [train, validation, test] splits.
seed (float): The seed for reproducibility.
Returns:
train (list): The indices for the training set.
validation (list): The indices for the validation set.
test (list): he indices for the test set.
test (list): The indices for the test set.
Raises:
ValueError: If the split argument is not a list with length 3.
ValueError: If the sum of the split proportions is not 1.
"""
if len(split) != 3:
raise ValueError(
"The split argument should be a list with length 3 that contains the proportions for [train, validation, test] splits.",
)
# Use round to avoid errors due to floating point imprecisions
if round(sum(split), 3) < 1.0:
raise ValueError(f"The sum of the split proportions should be 1. Instead, it is {sum(split)}.")
if round(sum(self.split), 3) < 1.0:
raise ValueError(f"The sum of the split proportions should be 1. Instead, it is {sum(self.split)}.")

# compute the length of the data
length_of_data = len(data)

# Generate a list of indices and shuffle it
indices = np.arange(length_of_data)
np.random.seed(seed)
np.random.seed(self.seed)
np.random.shuffle(indices)

# Calculate the sizes of the train, validation, and test sets
train_size = int(split[0] * length_of_data)
validation_size = int(split[1] * length_of_data)
train_size = int(self.split[0] * length_of_data)
validation_size = int(self.split[1] * length_of_data)

# Split the shuffled indices according to the calculated sizes
train = indices[:train_size].tolist()
Expand Down
Loading

0 comments on commit 6e0e7b9

Please sign in to comment.