diff --git a/src/stimulus/data/csv.py b/src/stimulus/data/csv.py index a1df1000..ad3a36cf 100644 --- a/src/stimulus/data/csv.py +++ b/src/stimulus/data/csv.py @@ -12,102 +12,269 @@ from functools import partial from typing import Any, Tuple, Union +from abc import ABC import numpy as np import polars as pl +import yaml +import stimulus.data.experiments as experiments +import stimulus.utils.yaml_data as yaml_data +import torch + +class DatasetManager: + """Class for managing the dataset. + + This class handles loading and organizing dataset configuration from YAML files. + It manages column categorization into input, label and meta types based on the config. + + Attributes: + config (dict): The loaded configuration dictionary from YAML + column_categories (dict): Dictionary mapping column types to lists of column names + + Methods: + _load_config(config_path: str) -> dict: Loads the config from a YAML file. + categorize_columns_by_type() -> dict: Organizes the columns into input, label, meta based on the config. + """ + def __init__(self, + config_path: str, + ) -> None: + self.config = self._load_config(config_path) + self.column_categories = self.categorize_columns_by_type() + + def categorize_columns_by_type(self) -> dict: + """Organizes columns from config into input, label, and meta categories. + + Reads the column definitions from the config and sorts them into categories + based on their column_type field. + + Returns: + dict: Dictionary containing lists of column names for each category: + { + "input": ["col1", "col2"], # Input columns + "label": ["target"], # Label/output columns + "meta": ["id"] # Metadata columns + } + + Example: + >>> manager = DatasetManager("config.yaml") + >>> categories = manager.categorize_columns_by_type() + >>> print(categories) + { + 'input': ['hello', 'bonjour'], + 'label': ['ciao'], + 'meta': ["id"] + } + """ + input_columns = [] + label_columns = [] + meta_columns = [] + for column in self.config.columns: + if column.column_type == "input": + input_columns.append(column.column_name) + elif column.column_type == "label": + label_columns.append(column.column_name) + elif column.column_type == "meta": + meta_columns.append(column.column_name) + + return {"input": input_columns, "label": label_columns, "meta": meta_columns} + + def _load_config(self, config_path: str) -> dict: + """Loads and parses a YAML configuration file. -class CsvHandler: - """Meta class for handling CSV files.""" + Args: + config_path (str): Path to the YAML config file - def __init__(self, experiment: Any, csv_path: str) -> None: - self.experiment = experiment - self.csv_path = csv_path - self.categories = self.check_and_get_categories() - self.check_compulsory_categories_exist() + Returns: + dict: Parsed configuration dictionary - def check_and_get_categories(self) -> list: - """Returns the categories contained in the csv file.""" - with open(self.csv_path) as f: - header = f.readline().strip().split(",") - categories = [] - for colname in header: - category = colname.split(":")[1].lower() - if category not in ["input", "label", "split", "meta"]: - raise ValueError( - f"Unknown category {category}, category (the second element of the csv column, seperated by ':') should be input, label, split or meta. The specified csv column is {colname}.", - ) - categories.append(category) - return categories + Example: + >>> manager = DatasetManager() + >>> config = manager._load_config("config.yaml") + >>> print(config["columns"][0]["column_name"]) + 'hello' + """ + with open(config_path, "r") as file: + return yaml_data.YamlConfigDict(**yaml.safe_load(file)) + + def get_split_columns(self) -> str: + """Get the columns that are used for splitting.""" + return self.config.split.split_input_columns + + +class EncodeManager: + """Manages the encoding of data columns using configured encoders. + + This class handles encoding of data columns based on the encoders specified in the + configuration. It uses an EncoderLoader to get the appropriate encoder for each column + and applies the encoding. + + Attributes: + encoder_loader (experiments.EncoderLoader): Loader that provides encoders based on config. + + Example: + >>> encoder_loader = EncoderLoader(config) + >>> encode_manager = EncodeManager(encoder_loader) + >>> data = ["ACGT", "TGCA", "GCTA"] + >>> encoded = encode_manager.encode_column("dna_seq", data) + >>> print(encoded.shape) + torch.Size([3, 4, 4]) # 3 sequences, length 4, one-hot encoded + """ - def update_categories(self) -> None: - """Updates the categories of the csv file. - Checks colnames in header and updates the categories that are present. + def __init__(self, + encoder_loader: experiments.EncoderLoader, + ) -> None: + """Initializes the EncodeManager. + + Args: + encoder_loader: Loader that provides encoders based on configuration. """ - for colname in self.data.columns: - category = colname.split(":")[1].lower() - if category not in self.categories: - self.categories.append(category) - - def extract_header(self) -> list: - """Extracts the header of the csv file.""" - with open(self.csv_path) as f: - header = f.readline().strip().split(",") - return header + self.encoder_loader = encoder_loader - def get_keys_from_header( - self, - header, - column_name: str = None, - category: str = None, - data_type: str = None, - ) -> list: - keys = [] - for key in header: - current_name, current_category, current_dtype = key.split(":") - if ( - (column_name is None or column_name == current_name) - and (category is None or category == current_category) - and (data_type is None or data_type == current_dtype) - ): - keys.append(key) - if len(keys) == 0: - raise ValueError( - f"No keys found with the specified column_name={column_name}, category={category}, data_type={data_type}", - ) - return keys - - def get_keys_based_on_name_category_dtype( - self, - column_name: str = None, - category: str = None, - data_type: str = None, - ) -> list: - """Returns the keys that are of a specific type, name or category. Or a combination of those.""" - if (column_name is None) and (category is None) and (data_type is None): - raise ValueError("At least one of the arguments column_name, category or data_type should be provided") - header = self.extract_header() - keys = self.get_keys_from_header(header, column_name, category, data_type) - return keys - - def check_compulsory_categories_exist(self) -> None: - """Checks if the compulsory categories exist in the csv file.""" - if "input" not in self.categories: - raise ValueError("The category input is not present in the csv file") - - def load_csv(self) -> pl.DataFrame: - """Loads the csv file into a polars dataframe.""" - return pl.read_csv(self.csv_path) + def encode_column(self, column_name: str, column_data: list) -> torch.Tensor: + """Encodes a column of data using the configured encoder. + Gets the appropriate encoder for the column from the encoder_loader and uses it + to encode all the data in the column. -class CsvProcessing(CsvHandler): - """Class to load the input csv data and add noise accordingly.""" + Args: + column_name: Name of the column to encode. + column_data: List of data values from the column to encode. - def __init__(self, experiment: Any, csv_path: str) -> None: - super().__init__(experiment, csv_path) - self.data = self.load_csv() + Returns: + Encoded data as a torch.Tensor. The exact shape depends on the encoder used. + + Example: + >>> data = ["ACGT", "TGCA"] + >>> encoded = encode_manager.encode_column("dna_seq", data) + >>> print(encoded.shape) + torch.Size([2, 4, 4]) # 2 sequences, length 4, one-hot encoded + """ + encode_all_function = self.encoder_loader.get_function_encode_all(column_name) + return encode_all_function(column_data) + + def encode_columns(self, column_data: dict) -> dict: + """Encodes multiple columns of data using the configured encoders. + + Gets the appropriate encoder for each column from the encoder_loader and encodes + all data values in those columns. + + Args: + column_data: Dict mapping column names to lists of data values to encode. + + Returns: + Dict mapping column names to their encoded tensors. The exact shape of each + tensor depends on the encoder used for that column. + + Example: + >>> data = { + ... "dna_seq": ["ACGT", "TGCA"], + ... "labels": ["1", "2"] + ... } + >>> encoded = encode_manager.encode_columns(data) + >>> print(encoded["dna_seq"].shape) + torch.Size([2, 4, 4]) # 2 sequences, length 4, one-hot encoded + """ + return {col: self.encode_column(col, values) for col, values in column_data.items()} + +class TransformManager: + """Class for managing the transformations.""" + + def __init__(self, + transform_loader: experiments.TransformLoader, + ) -> None: + self.transform_loader = transform_loader + +class SplitManager: + """Class for managing the splitting.""" + + def __init__(self, + split_loader: experiments.SplitLoader, + ) -> None: + self.split_loader = split_loader + + def get_split_indices(self, data: dict) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Get the indices for train, validation, and test splits.""" + return self.split_loader.get_function_split()(data) + +class DatasetHandler: + """Main class for handling dataset loading, encoding, transformation and splitting. + + This class coordinates the interaction between different managers to process + CSV datasets according to the provided configuration. + + Attributes: + encoder_manager (EncodeManager): Manager for handling data encoding operations. + transform_manager (TransformManager): Manager for handling data transformations. + split_manager (SplitManager): Manager for handling dataset splitting. + dataset_manager (DatasetManager): Manager for organizing dataset columns and config. + """ - def add_split(self, config: dict, force=False) -> None: + def __init__(self, + encoder_loader: experiments.EncoderLoader, + transform_loader: experiments.TransformLoader, + split_loader: experiments.SplitLoader, + config_path: str, + csv_path: str, + ) -> None: + """Initialize the DatasetHandler with required loaders and config. + + Args: + encoder_loader (experiments.EncoderLoader): Loader for getting column encoders. + transform_loader (experiments.TransformLoader): Loader for getting data transformations. + split_loader (experiments.SplitLoader): Loader for getting dataset split configurations. + config_path (str): Path to the dataset configuration file. + csv_path (str): Path to the CSV data file. + """ + self.encoder_manager = EncodeManager(encoder_loader) + self.transform_manager = TransformManager(transform_loader) + self.split_manager = SplitManager(split_loader) + self.dataset_manager = DatasetManager(config_path) + self.data = self.load_csv(csv_path) + self.columns = self.read_csv_header(csv_path) + + def read_csv_header(self, csv_path: str) -> list: + """Get the column names from the header of the CSV file. + + Args: + csv_path (str): Path to the CSV file to read headers from. + + Returns: + list: List of column names from the CSV header. + """ + with open(csv_path) as f: + header = f.readline().strip().split(",") + return header + + def load_csv(self, csv_path: str) -> pl.DataFrame: + """Load the CSV file into a polars DataFrame. + + Args: + csv_path (str): Path to the CSV file to load. + + Returns: + pl.DataFrame: Polars DataFrame containing the loaded CSV data. + """ + return pl.read_csv(csv_path) + + def select_columns(self, columns: list) -> dict: + """Select specific columns from the DataFrame and return as a dictionary. + + Args: + columns (list): List of column names to select. + + Returns: + dict: A dictionary where keys are column names and values are lists containing the column data. + + Example: + >>> handler = DatasetHandler(...) + >>> data_dict = handler.select_columns(["col1", "col2"]) + >>> # Returns {'col1': [1, 2, 3], 'col2': [4, 5, 6]} + """ + df = self.data.select(columns) + return {col: df[col].to_list() for col in columns} + + def add_split(self, force=False) -> None: """Add a column specifying the train, validation, test splits of the data. An error exception is raised if the split column is already present in the csv file. This behaviour can be overriden by setting force=True. @@ -117,24 +284,80 @@ def add_split(self, config: dict, force=False) -> None: "parameters" (dict) : the split_function specific optional parameters, passed here as a dict with keys named as in the split function definition. force (bool) : If True, the split column will be added even if it is already present in the csv file. """ - if ("split" in self.categories) and (not force): + if ("split" in self.columns) and (not force): raise ValueError( "The category split is already present in the csv file. If you want to still use this function, set force=True", ) + # get relevant split columns from the dataset_manager + split_columns = self.dataset_manager.get_split_columns() + + # if split_columns is none, build an empty dictionary + if split_columns is None: + split_input_data = {} + else: + split_input_data = self.select_columns(split_columns) - # set the split name method - split_method = config["name"] - - # get the indices for train, validation and test using the specified split method - train, validation, test = self.experiment.get_function_split(split_method)(self.data, **config["params"]) + # get the split indices + train, validation, test = self.split_manager.get_split_indices(split_input_data) # add the split column to the data split_column = np.full(len(self.data), -1).astype(int) split_column[train] = 0 split_column[validation] = 1 split_column[test] = 2 - self.data = self.data.with_columns(pl.Series("split:split:int", split_column)) - self.update_categories() + self.data = self.data.with_columns(pl.Series("split", split_column)) + + if "split" not in self.columns: + self.columns.append("split") + + def get_all_items(self) -> tuple[dict, dict, dict]: + """Get the full dataset as three separate dictionaries for inputs, labels and metadata. + + Returns: + tuple[dict, dict, dict]: Three dictionaries containing: + - Input dictionary mapping input column names to encoded input data + - Label dictionary mapping label column names to encoded label data + - Meta dictionary mapping meta column names to meta data + + Example: + >>> handler = DatasetHandler(...) + >>> input_dict, label_dict, meta_dict = handler.get_dataset() + >>> print(input_dict.keys()) + dict_keys(['age', 'fare']) + >>> print(label_dict.keys()) + dict_keys(['survived']) + >>> print(meta_dict.keys()) + dict_keys(['passenger_id']) + """ + # Get columns for each category from dataset manager + input_cols = self.dataset_manager.column_categories["input"] + label_cols = self.dataset_manager.column_categories["label"] + meta_cols = self.dataset_manager.column_categories["meta"] + + # Select and organize data by category + input_data = self.select_columns(input_cols) if input_cols else {} + label_data = self.select_columns(label_cols) if label_cols else {} + meta_data = self.select_columns(meta_cols) if meta_cols else {} + + # Encode input and label data + encoded_input = self.encoder_manager.encode_columns(input_data) if input_data else {} + encoded_label = self.encoder_manager.encode_columns(label_data) if label_data else {} + + return encoded_input, encoded_label, meta_data + +class CsvHandler: + """Meta class for handling CSV files.""" + + def __init__(self, experiment: Any, csv_path: str) -> None: + self.experiment = experiment + self.csv_path = csv_path + +class CsvProcessing(CsvHandler): + """Class to load the input csv data and add noise accordingly.""" + + def __init__(self, experiment: Any, csv_path: str) -> None: + super().__init__(experiment, csv_path) + self.data = self.load_csv() def transform(self, transformations: list) -> None: """Transforms the data using the specified configuration.""" diff --git a/src/stimulus/data/experiments.py b/src/stimulus/data/experiments.py index b42ba97d..efd4321b 100644 --- a/src/stimulus/data/experiments.py +++ b/src/stimulus/data/experiments.py @@ -1,83 +1,228 @@ -"""Experiments are classes parsed by CSV master classes to run experiments. -Conceptually, experiment classes contain data types, transformations etc and are used to duplicate the input data into many datasets. -Here we provide standard experiments as well as an absctract class for users to implement their own. +"""Loaders serve as interfaces between the CSV master class and custom methods. +Mainly, three types of custom methods are supported: +- Encoders: methods for encoding data before it is fed into the model +- Data transformers: methods for transforming data (i.e. augmenting, noising...) +- Splitters: methods for splitting data into train, validation and test sets -# TODO implement noise schemes and splitting schemes. +Loaders are built from an input config YAML file which format is described in the documentation, you can find an example here: tests/test_data/dna_experiment/dna_experiment_config_template.yaml """ from abc import ABC from typing import Any +from collections import defaultdict -from .encoding import encoders as encoders -from .splitters import splitters as splitters -from .transform import data_transformation_generators as data_transformation_generators +import inspect +import yaml -class AbstractExperiment(ABC): - """Abstract class for experiments. +from stimulus.data.encoding import encoders as encoders +from stimulus.data.splitters import splitters as splitters +from stimulus.data.transform import data_transformation_generators as data_transformation_generators +from stimulus.utils.yaml_data import YamlConfigDict - WARNING, DATA_TYPES ARGUMENT NAMES SHOULD BE ALL LOWERCASE, CHECK THE DATA_TYPES MODULE FOR THE TYPES THAT HAVE BEEN IMPLEMENTED. - """ +class AbstractLoader(ABC): + """Abstract base class for defining loaders.""" + + def get_config_from_yaml(self, yaml_path: str) -> dict: + """Loads experiment configuration from a YAML file. + + Args: + yaml_path (str): Path to the YAML config file + + Returns: + dict: The loaded configuration dictionary + """ + with open(yaml_path, "r") as file: + config = YamlConfigDict(**yaml.safe_load(file)) + return config + +class EncoderLoader(AbstractLoader): + """Class for loading encoders from a config file.""" + + def __init__(self, seed: float = None) -> None: + self.seed = seed + + def initialize_column_encoders_from_config(self, config: YamlConfigDict) -> None: + """Build the loader from a config dictionary. + + Args: + config (YamlConfigDict): Configuration dictionary containing field names (column_name) and their encoder specifications. + """ + for field in config: + encoder = self.get_encoder(field.encoder[0].name, field.encoder[0].params) + self.set_encoder_as_attribute(field.column_name, encoder) + + def get_function_encode_all(self, field_name: str) -> Any: + """Gets the encoding function for a specific field. + + Args: + field_name (str): The field name to get the encoder for + + Returns: + Any: The encode_all function for the specified field + """ + return getattr(self, field_name)["encoder"].encode_all + + def get_encoder(self, encoder_name: str, encoder_params: dict = None) -> Any: + """Gets an encoder object from the encoders module and initializes it with the given parametersß. + + Args: + encoder_name (str): The name of the encoder to get + encoder_params (dict): The parameters for the encoder + + Returns: + Any: The encoder function for the specified field and parameters + """ + + try: + return getattr(encoders, encoder_name)(**encoder_params) + except AttributeError: + print(f"Encoder '{encoder_name}' not found in the encoders module.") + print(f"Available encoders: {[name for name, obj in encoders.__dict__.items() if isinstance(obj, type) and name not in ('ABC', 'Any')]}") + raise + + except TypeError: + if encoder_params is None: + return getattr(encoders, encoder_name)() + else: + print(f"Encoder '{encoder_name}' has incorrect parameters: {encoder_params}") + print(f"Expected parameters for '{encoder_name}': {inspect.signature(getattr(encoders, encoder_name))}") + raise + + def set_encoder_as_attribute(self, field_name: str, encoder: encoders.AbstractEncoder) -> None: + """Sets the encoder as an attribute of the loader. + + Args: + field_name (str): The name of the field to set the encoder for + encoder (encoders.AbstractEncoder): The encoder to set + """ + setattr(self, field_name, {"encoder": encoder}) + +class TransformLoader(AbstractLoader): + """Class for loading transformations from a config file.""" def __init__(self, seed: float = None) -> None: - # allow ability to add a seed for reproducibility self.seed = seed - # added because if the user does not define this it does not crach the get_function_split, random split works for every class afteralll - self.split = {"RandomSplitter": splitters.RandomSplitter()} - - def get_function_encode_all(self, data_type: str) -> Any: - """This method gets the encoding function for a specific data type.""" - return getattr(self, data_type)["encoder"].encode_all - - def get_data_transformer(self, data_type: str, transformation_generator: str) -> Any: - """This method transforms the data (noising, data augmentation etc).""" - return getattr(self, data_type)["data_transformation_generators"][transformation_generator] - - def get_function_split(self, split_method: str) -> Any: - """This method returns the function for splitting the data.""" - return self.split[split_method].get_split_indexes - - -class DnaToFloatExperiment(AbstractExperiment): - """Class for dealing with DNA to float predictions (for instance regression from DNA sequence to CAGE value)""" - - def __init__(self) -> None: - super().__init__() - self.dna = { - "encoder": encoders.TextOneHotEncoder(alphabet="acgt"), - "data_transformation_generators": { - "UniformTextMasker": data_transformation_generators.UniformTextMasker(mask="N"), - "ReverseComplement": data_transformation_generators.ReverseComplement(), - "GaussianChunk": data_transformation_generators.GaussianChunk(), - }, - } - self.float = { - "encoder": encoders.FloatEncoder(), - "data_transformation_generators": {"GaussianNoise": data_transformation_generators.GaussianNoise()}, - } - self.split = {"RandomSplitter": splitters.RandomSplitter()} - - -class ProtDnaToFloatExperiment(DnaToFloatExperiment): - """Class for dealing with Protein and DNA to float predictions (for instance regression from Protein sequence + DNA sequence to binding score)""" - - def __init__(self) -> None: - super().__init__() - self.prot = { - "encoder": encoders.TextOneHotEncoder(alphabet="acdefghiklmnpqrstvwy"), - "data_transformation_generators": { - "UniformTextMasker": data_transformation_generators.UniformTextMasker(mask="X"), - }, - } - - -class TitanicExperiment(AbstractExperiment): - """Class for dealing with the Titanic dataset as a test format.""" - - def __init__(self) -> None: - super().__init__() - self.int_class = {"encoder": encoders.IntEncoder(), "data_transformation_generators": {}} - self.str_class = {"encoder": encoders.StrClassificationIntEncoder(), "data_transformation_generators": {}} - self.int_reg = {"encoder": encoders.IntRankEncoder(), "data_transformation_generators": {}} - self.float_rank = {"encoder": encoders.FloatRankEncoder(), "data_transformation_generators": {}} + + def get_data_transformer(self, transformation_name: str, transformation_params: dict = None) -> Any: + """Gets a transformer object from the transformers module. + + Args: + transformation_name (str): The name of the transformer to get + + Returns: + Any: The transformer function for the specified transformation + """ + try: + return getattr(data_transformation_generators, transformation_name)(**transformation_params) + except AttributeError: + print(f"Transformer '{transformation_name}' not found in the transformers module.") + print(f"Available transformers: {[name for name, obj in data_transformation_generators.__dict__.items() if isinstance(obj, type) and name not in ('ABC', 'Any')]}") + raise + + except TypeError: + if transformation_params is None: + return getattr(data_transformation_generators, transformation_name)() + else: + print(f"Transformer '{transformation_name}' has incorrect parameters: {transformation_params}") + print(f"Expected parameters for '{transformation_name}': {inspect.signature(getattr(data_transformation_generators, transformation_name))}") + raise + + def set_data_transformer_as_attribute(self, field_name: str, data_transformer: Any) -> None: + """Sets the data transformer as an attribute of the loader. + + Args: + field_name (str): The name of the field to set the data transformer for + data_transformer (Any): The data transformer to set + """ + setattr(self, field_name, {"data_transformation_generators": data_transformer}) + + def initialize_column_data_transformers_from_config(self, config: YamlConfigDict) -> None: + """Build the loader from a config dictionary. + + Args: + config (YamlConfigDict): Configuration dictionary containing transforms configurations. + Each transform can specify multiple columns and their transformations. + The method will organize transformers by column, ensuring each column + has all its required transformations. + """ + + # Use defaultdict to automatically initialize empty lists + column_transformers = defaultdict(list) + + # First pass: collect all transformations by column + for transform_group in config: + for column in transform_group.columns: + col_name = column.column_name + + # Process each transformation for this column + for transform_spec in column.transformations: + # Create transformer instance + transformer = self.get_data_transformer(transform_spec.name) + + # Get transformer class for comparison + transformer_type = type(transformer) + + # Add transformer if its type isn't already present + if not any(isinstance(existing, transformer_type) + for existing in column_transformers[col_name]): + column_transformers[col_name].append(transformer) + + # Second pass: set all collected transformers as attributes + for col_name, transformers in column_transformers.items(): + self.set_data_transformer_as_attribute(col_name, transformers) + +class SplitLoader(AbstractLoader): + """Class for loading splitters from a config file.""" + + def __init__(self, seed: float = None) -> None: + self.seed = seed + + def get_function_split(self) -> Any: + """Gets the function for splitting the data. + + Args: + split_method (str): Name of the split method to use + + Returns: + Any: The split function for the specified method + """ + return self.split.get_split_indexes + + def get_splitter(self, splitter_name: str, splitter_params: dict = None) -> Any: + """Gets a splitter object from the splitters module. + + Args: + splitter_name (str): The name of the splitter to get + + Returns: + Any: The splitter function for the specified splitter + """ + try: + return getattr(splitters, splitter_name)(**splitter_params) + except TypeError: + if splitter_params is None: + return getattr(splitters, splitter_name)() + else: + print(f"Splitter '{splitter_name}' has incorrect parameters: {splitter_params}") + print(f"Expected parameters for '{splitter_name}': {inspect.signature(getattr(splitters, splitter_name))}") + raise + + def set_splitter_as_attribute(self, splitter: Any) -> None: + """Sets the splitter as an attribute of the loader. + + Args: + field_name (str): The name of the field to set the splitter for + splitter (Any): The splitter to set + """ + setattr(self, "split", splitter) + + def initialize_splitter_from_config(self, config: YamlConfigDict, split_index: int = 0) -> None: + """Build the loader from a config dictionary. + + Args: + config (dict): Configuration dictionary containing split configurations. + """ + splitter = self.get_splitter(config.split[split_index].split_method, config.split[split_index].params) + self.set_splitter_as_attribute(splitter) diff --git a/src/stimulus/data/splitters/__init__.py b/src/stimulus/data/splitters/__init__.py index e69de29b..a8c5e4a4 100644 --- a/src/stimulus/data/splitters/__init__.py +++ b/src/stimulus/data/splitters/__init__.py @@ -0,0 +1,3 @@ +from .splitters import AbstractSplitter, RandomSplit + +__all__ = ["AbstractSplitter", "RandomSplit"] diff --git a/src/stimulus/data/splitters/splitters.py b/src/stimulus/data/splitters/splitters.py index 3b5b4098..e1f7b03e 100644 --- a/src/stimulus/data/splitters/splitters.py +++ b/src/stimulus/data/splitters/splitters.py @@ -17,15 +17,17 @@ class AbstractSplitter(ABC): distance: calculates the distance between two elements of the data """ + def __init__(self, seed: float = 42) -> None: + self.seed = seed + @abstractmethod - def get_split_indexes(self, data: pl.DataFrame, seed: float = None) -> list: + def get_split_indexes(self, data: pl.DataFrame) -> list: """Splits the data. Always return indices mapping to the original list. This is an abstract method that should be implemented by the child class. Args: data (pl.DataFrame): the data to be split - seed (float): the seed for reproducibility Returns: split_indices (list): the indices for train, validation, and test sets @@ -48,17 +50,21 @@ def distance(self, data_one: Any, data_two: Any) -> float: raise NotImplementedError -class RandomSplitter(AbstractSplitter): +class RandomSplit(AbstractSplitter): """This splitter randomly splits the data.""" - def __init__(self) -> None: + def __init__(self, split: list = [0.7, 0.2, 0.1], seed: float = None) -> None: super().__init__() + self.split = split + self.seed = seed + if len(self.split) != 3: + raise ValueError( + "The split argument should be a list with length 3 that contains the proportions for [train, validation, test] splits.", + ) def get_split_indexes( self, data: pl.DataFrame, - split: list = [0.7, 0.2, 0.1], - seed: float = None, ) -> tuple[list, list, list]: """Splits the data indices into train, validation, and test sets. @@ -66,37 +72,31 @@ def get_split_indexes( Args: data (pl.DataFrame): The data loaded with polars. - split (list): The proportions for [train, validation, test] splits. - seed (float): The seed for reproducibility. Returns: train (list): The indices for the training set. validation (list): The indices for the validation set. - test (list): he indices for the test set. + test (list): The indices for the test set. Raises: ValueError: If the split argument is not a list with length 3. ValueError: If the sum of the split proportions is not 1. """ - if len(split) != 3: - raise ValueError( - "The split argument should be a list with length 3 that contains the proportions for [train, validation, test] splits.", - ) # Use round to avoid errors due to floating point imprecisions - if round(sum(split), 3) < 1.0: - raise ValueError(f"The sum of the split proportions should be 1. Instead, it is {sum(split)}.") + if round(sum(self.split), 3) < 1.0: + raise ValueError(f"The sum of the split proportions should be 1. Instead, it is {sum(self.split)}.") # compute the length of the data length_of_data = len(data) # Generate a list of indices and shuffle it indices = np.arange(length_of_data) - np.random.seed(seed) + np.random.seed(self.seed) np.random.shuffle(indices) # Calculate the sizes of the train, validation, and test sets - train_size = int(split[0] * length_of_data) - validation_size = int(split[1] * length_of_data) + train_size = int(self.split[0] * length_of_data) + validation_size = int(self.split[1] * length_of_data) # Split the shuffled indices according to the calculated sizes train = indices[:train_size].tolist() diff --git a/src/stimulus/data/transform/data_transformation_generators.py b/src/stimulus/data/transform/data_transformation_generators.py index cf30764f..210d71d1 100644 --- a/src/stimulus/data/transform/data_transformation_generators.py +++ b/src/stimulus/data/transform/data_transformation_generators.py @@ -30,37 +30,36 @@ class AbstractDataTransformer(ABC): def __init__(self): self.add_row = None + self.seed = 42 @abstractmethod - def transform(self, data: Any, seed: float = None) -> Any: + def transform(self, data: Any) -> Any: """Transforms a single data point. This is an abstract method that should be implemented by the child class. Args: data (Any): the data to be transformed - seed (float): the seed for reproducibility Returns: transformed_data (Any): the transformed data """ - # np.random.seed(seed) + # np.random.seed(self.seed) raise NotImplementedError @abstractmethod - def transform_all(self, data: list, seed: float = None) -> list: + def transform_all(self, data: list) -> list: """Transforms a list of data points. This is an abstract method that should be implemented by the child class. Args: data (list): the data to be transformed - seed (float): the seed for reproducibility Returns: transformed_data (list): the transformed data """ - # np.random.seed(seed) + # np.random.seed(self.seed) raise NotImplementedError @@ -91,45 +90,41 @@ class UniformTextMasker(AbstractNoiseGenerator): This noise generators replace characters with a masking character with a given probability. - Attributes: - mask (str): the character to use for masking - Methods: transform: adds character masking to a single data point transform_all: adds character masking to a list of data points """ - def __init__(self, mask: str) -> None: + def __init__(self, probability: float = 0.1, mask: str = "*", seed: float = 42) -> None: super().__init__() + self.probability = probability self.mask = mask + self.seed = seed - def transform(self, data: str, probability: float = 0.1, seed: float = None) -> str: + def transform(self, data: str) -> str: """Adds character masking to the data. Args: data (str): the data to be transformed - probability (float): the probability of adding noise - seed (float): the seed for reproducibility Returns: transformed_data (str): the transformed data point """ - np.random.seed(seed) - return "".join([c if np.random.rand() > probability else self.mask for c in data]) + np.random.seed(self.seed) + return "".join([c if np.random.rand() > self.probability else self.mask for c in data]) - def transform_all(self, data: list, probability: float = 0.1, seed: float = None) -> list: + def transform_all(self, data: list) -> list: """Adds character masking to multiple data points using multiprocessing. Args: data (list): the data to be transformed - probability (float): the probability of adding noise - seed (float): the seed for reproducibility + Returns: transformed_data (list): the transformed data points """ with mp.Pool(mp.cpu_count()) as pool: - function_specific_input = [(item, probability, seed) for item in data] + function_specific_input = [(item) for item in data] return pool.starmap(self.transform, function_specific_input) @@ -143,35 +138,35 @@ class GaussianNoise(AbstractNoiseGenerator): transform_all: adds noise to a list of data points """ - def transform(self, data: float, mean: float = 0, std: float = 1, seed: float = None) -> float: + def __init__(self, mean: float = 0, std: float = 1, seed: float = 42) -> None: + super().__init__() + self.mean = mean + self.std = std + self.seed = seed + + def transform(self, data: float) -> float: """Adds Gaussian noise to a single point of data. Args: data (float): the data to be transformed - mean (float): the mean of the Gaussian distribution - std (float): the standard deviation of the Gaussian distribution - seed (float): the seed for reproducibility Returns: transformed_data (float): the transformed data point """ - np.random.seed(seed) - return data + np.random.normal(mean, std) + np.random.seed(self.seed) + return data + np.random.normal(self.mean, self.std) - def transform_all(self, data: list, mean: float = 0, std: float = 0, seed: float = None) -> np.array: + def transform_all(self, data: list) -> np.array: """Adds Gaussian noise to a list of data points Args: data (list): the data to be transformed - mean (float): the mean of the Gaussian distribution - std (float): the standard deviation of the Gaussian distribution - seed (float): the seed for reproducibility Returns: transformed_data (np.array): the transformed data points """ - np.random.seed(seed) - return np.array(np.array(data) + np.random.normal(mean, std, len(data))) + np.random.seed(self.seed) + return np.array(np.array(data) + np.random.normal(self.mean, self.std, len(data))) class ReverseComplement(AbstractAugmentationGenerator): @@ -237,7 +232,13 @@ class GaussianChunk(AbstractAugmentationGenerator): transform_all: chunks multiple lists """ - def transform(self, data: str, chunk_size: int, seed: float = None, std: float = 1) -> str: + def __init__(self, chunk_size: int, seed: float = 42, std: float = 1) -> None: + super().__init__() + self.chunk_size = chunk_size + self.seed = seed + self.std = std + + def transform(self, data: str) -> str: """Chunks a sequence of size chunk_size from the middle position +/- a value obtained through a gaussian distribution. Args: @@ -252,31 +253,31 @@ def transform(self, data: str, chunk_size: int, seed: float = None, std: float = Raises: AssertionError: if the input data is shorter than the chunk size """ - np.random.seed(seed) + np.random.seed(self.seed) # make sure that the data is longer than chunk_size otherwise raise an error - assert len(data) > chunk_size, "The input data is shorter than the chunk size" + assert len(data) > self.chunk_size, "The input data is shorter than the chunk size" # Get the middle position of the input sequence middle_position = len(data) // 2 # Change the middle position by a value obtained through a gaussian distribution - new_middle_position = int(middle_position + np.random.normal(0, std)) + new_middle_position = int(middle_position + np.random.normal(0, self.std)) # Get the start and end position of the chunk - start_position = new_middle_position - chunk_size // 2 - end_position = new_middle_position + chunk_size // 2 + start_position = new_middle_position - self.chunk_size // 2 + end_position = new_middle_position + self.chunk_size // 2 # if the start position is negative, set it to 0 start_position = max(start_position, 0) # Get the chunk of size chunk_size from the start position if the end position is smaller than the length of the data if end_position < len(data): - return data[start_position : start_position + chunk_size] + return data[start_position : start_position + self.chunk_size] # Otherwise return the chunk of the sequence from the end of the sequence of size chunk_size - return data[-chunk_size:] + return data[-self.chunk_size:] - def transform_all(self, data: list, chunk_size: int, seed: float = None, std: float = 1) -> list: + def transform_all(self, data: list) -> list: """Adds chunks to multiple lists using multiprocessing. Args: @@ -289,5 +290,5 @@ def transform_all(self, data: list, chunk_size: int, seed: float = None, std: fl transformed_data (list): the transformed sequences """ with mp.Pool(mp.cpu_count()) as pool: - function_specific_input = [(item, chunk_size, seed, std) for item in data] + function_specific_input = [(item) for item in data] return pool.starmap(self.transform, function_specific_input) diff --git a/src/stimulus/utils/yaml_data.py b/src/stimulus/utils/yaml_data.py new file mode 100644 index 00000000..a1e65211 --- /dev/null +++ b/src/stimulus/utils/yaml_data.py @@ -0,0 +1,326 @@ +import yaml +from pydantic import BaseModel, ValidationError, field_validator +from typing import List, Optional, Dict, Union, Any + +class YamlGlobalParams(BaseModel): + seed: int + +class YamlColumnsEncoder(BaseModel): + name: str + params: Optional[Dict[str, Union[str, list]]] # Allow both string and list values + +class YamlColumns(BaseModel): + column_name: str + column_type: str + data_type: str + encoder: List[YamlColumnsEncoder] + + +class YamlTransformColumnsTransformation(BaseModel): + name: str + params: Optional[Dict[str, Union[list, float]]] # Allow both list and float values + + +class YamlTransformColumns(BaseModel): + column_name: str + transformations: List[YamlTransformColumnsTransformation] + + +class YamlTransform(BaseModel): + transformation_name: str + columns: List[YamlTransformColumns] + + @field_validator('columns') + @classmethod + def validate_param_lists_across_columns(cls, columns) -> List[YamlTransformColumns]: + # Get all parameter list lengths across all columns and transformations + all_list_lengths = set() + + for column in columns: + for transformation in column.transformations: + if transformation.params: + for param_value in transformation.params.values(): + if isinstance(param_value, list): + if len(param_value) > 0: # Non-empty list + all_list_lengths.add(len(param_value)) + + # Skip validation if no lists found + if not all_list_lengths: + return columns + + # Check if all lists either have length 1, or all have the same length + all_list_lengths.discard(1) # Remove length 1 as it's always valid + if len(all_list_lengths) > 1: # Multiple different lengths found, since sets do not allow duplicates + raise ValueError("All parameter lists across columns must either contain one element or have the same length") + + return columns + + +class YamlSplit(BaseModel): + split_method: str + params: Dict[str, List[float]] # More specific type for split parameters + split_input_columns: Optional[List[str]] + +class YamlConfigDict(BaseModel): + global_params: YamlGlobalParams + columns: List[YamlColumns] + transforms: List[YamlTransform] + split: List[YamlSplit] + +class YamlSubConfigDict(BaseModel): + global_params: YamlGlobalParams + columns: List[YamlColumns] + transforms: YamlTransform + split: YamlSplit + +class YamlSchema(BaseModel): + yaml_conf: YamlConfigDict + +def extract_transform_parameters_at_index(transform: YamlTransform, index: int = 0) -> YamlTransform: + """Get a transform with parameters at the specified index. + + Args: + transform: The original transform containing parameter lists + index: Index to extract parameters from (default 0) + + Returns: + A new transform with single parameter values at the specified index + """ + # Create a copy of the transform + new_transform = YamlTransform(**transform.model_dump()) + + # Process each column and transformation + for column in new_transform.columns: + for transformation in column.transformations: + if transformation.params: + # Convert each parameter list to single value at index + new_params = {} + for param_name, param_value in transformation.params.items(): + if isinstance(param_value, list): + new_params[param_name] = param_value[index] + else: + new_params[param_name] = param_value + transformation.params = new_params + + return new_transform + +def expand_transform_parameter_combinations(transform: YamlTransform) -> list[YamlTransform]: + """Get all possible transforms by extracting parameters at each valid index. + + For a transform with parameter lists, creates multiple new transforms, each containing + single parameter values from the corresponding indices of the parameter lists. + + Args: + transform: The original transform containing parameter lists + + Returns: + A list of transforms, each with single parameter values from sequential indices + """ + # Find the length of parameter lists - we only need to check the first list we find + # since all lists must have the same length (enforced by pydantic validator) + max_length = 1 + for column in transform.columns: + for transformation in column.transformations: + if transformation.params: + list_lengths = [len(v) for v in transformation.params.values() + if isinstance(v, list) and len(v) > 1] + if list_lengths: + max_length = list_lengths[0] # All lists have same length due to validator + break + + # Generate a transform for each index + transforms = [] + for i in range(max_length): + transforms.append(extract_transform_parameters_at_index(transform, i)) + + return transforms + +def expand_transform_list_combinations(transform_list: list[YamlTransform]) -> list[YamlTransform]: + """Expands a list of transforms into all possible parameter combinations. + + Takes a list of transforms where each transform may contain parameter lists, + and expands them into separate transforms with single parameter values. + For example, if a transform has parameters [0.1, 0.2] and [1, 2], this will + create two transforms: one with 0.1/1 and another with 0.2/2. + + Args: + transform_list: A list of YamlTransform objects containing parameter lists + that need to be expanded into individual transforms. + + Returns: + list[YamlTransform]: A flattened list of transforms where each transform + has single parameter values instead of parameter lists. The length of + the returned list will be the sum of the number of parameter combinations + for each input transform. + """ + sub_transforms = [] + for transform in transform_list: + sub_transforms.extend(expand_transform_parameter_combinations(transform)) + return sub_transforms + +def generate_data_configs(yaml_config: YamlConfigDict) -> list[YamlSubConfigDict]: + """Generates all possible data configurations from a YAML config. + + Takes a YAML configuration that may contain parameter lists and splits, + and generates all possible combinations of parameters and splits into + separate data configurations. + + For example, if the config has: + - A transform with parameters [0.1, 0.2] + - Two splits [0.7/0.3] and [0.8/0.2] + This will generate 4 configs, 2 for each split. + + Args: + yaml_config: The source YAML configuration containing transforms with + parameter lists and multiple splits. + + Returns: + list[YamlSubConfigDict]: A list of data configurations, where each + config has single parameter values and one split configuration. The + length will be the product of the number of parameter combinations + and the number of splits. + """ + sub_transforms = expand_transform_list_combinations(yaml_config.transforms) + sub_splits = yaml_config.split + sub_configs = [] + for split in sub_splits: + for transform in sub_transforms: + sub_configs.append(YamlSubConfigDict( + global_params=yaml_config.global_params, + columns=yaml_config.columns, + transforms=transform, + split=split + )) + return sub_configs + +def dump_yaml_list_into_files( + yaml_list: list[YamlSubConfigDict], directory_path: str, base_name: str +) -> None: + """Dumps a list of YAML configurations into separate files with custom formatting. + + This function takes a list of YAML configurations and writes each one to a separate file, + applying custom formatting rules to ensure consistent and readable output. It handles + special cases like None values, nested lists, and proper indentation. + + Args: + yaml_list: List of YamlSubConfigDict objects to be written to files + directory_path: Directory path where the files should be created + base_name: Base name for the output files. Files will be named {base_name}_{index}.yaml + + The function applies several custom formatting rules: + - None values are represented as empty strings + - Nested lists use appropriate flow styles based on content type + - Extra newlines are added between root-level elements + - Proper indentation is maintained throughout + """ + # Disable YAML aliases to prevent reference-style output + yaml.Dumper.ignore_aliases = lambda *args : True + + def represent_none(dumper, _): + """Custom representer to format None values as empty strings in YAML output.""" + return dumper.represent_scalar('tag:yaml.org,2002:null', '') + + def custom_representer(dumper, data): + """Custom representer to handle different types of lists with appropriate formatting. + + Applies different flow styles based on content: + - Empty lists -> empty string + - Lists of dicts (e.g. columns) -> block style (vertical) + - Lists of lists (e.g. split params) -> flow style (inline) + - Other lists -> flow style + """ + if isinstance(data, list): + if len(data) == 0: + return dumper.represent_scalar('tag:yaml.org,2002:null', '') + if isinstance(data[0], dict): + # Use block style for structured data like columns + return dumper.represent_sequence('tag:yaml.org,2002:seq', data, flow_style=False) + elif isinstance(data[0], list): + # Use flow style for numeric data like split ratios + return dumper.represent_sequence('tag:yaml.org,2002:seq', data, flow_style=True) + return dumper.represent_sequence('tag:yaml.org,2002:seq', data, flow_style=True) + + class CustomDumper(yaml.Dumper): + """Custom YAML dumper that adds extra formatting controls.""" + + def write_line_break(self, data=None): + """Add extra newline after root-level elements.""" + super().write_line_break(data) + if len(self.indents) <= 1: # At root level + super().write_line_break(data) + + def increase_indent(self, flow=False, indentless=False): + """Ensure consistent indentation by preventing indentless sequences.""" + return super().increase_indent(flow, False) + + # Register the custom representers with our dumper + yaml.add_representer(type(None), represent_none, Dumper=CustomDumper) + yaml.add_representer(list, custom_representer, Dumper=CustomDumper) + + for i, yaml_dict in enumerate(yaml_list): + # Convert Pydantic model to dict, excluding None values + dict_data = yaml_dict.model_dump(exclude_none=True) + + def fix_params(input_dict): + """Recursively process dictionary to properly handle params fields. + + Special handling for: + - Empty params fields -> None + - Transformation params -> None if empty + - Nested dicts and lists -> recursive processing + """ + if isinstance(input_dict, dict): + processed_dict = {} + for key, value in input_dict.items(): + if key == 'params' and (value is None or value == {}): + processed_dict[key] = None # Convert empty params to None + elif key == 'transformations' and isinstance(value, list): + # Handle transformation params specially + processed_dict[key] = [] + for transformation in value: + processed_transformation = dict(transformation) + if 'params' not in processed_transformation or processed_transformation['params'] is None or processed_transformation['params'] == {}: + processed_transformation['params'] = None + processed_dict[key].append(processed_transformation) + elif isinstance(value, dict): + processed_dict[key] = fix_params(value) # Recurse into nested dicts + elif isinstance(value, list): + # Process lists, recursing into dict elements + processed_dict[key] = [fix_params(list_item) if isinstance(list_item, dict) else list_item for list_item in value] + else: + processed_dict[key] = value + return processed_dict + return input_dict + + dict_data = fix_params(dict_data) + + # Write the processed data to file with custom formatting + with open(f"{directory_path}/{base_name}_{i}.yaml", "w") as f: + yaml.dump( + dict_data, + f, + Dumper=CustomDumper, + sort_keys=False, + default_flow_style=False, + indent=2, + width=float("inf") # Prevent line wrapping + ) + +def check_yaml_schema(config_yaml: str) -> str: + """ + Using pydantic this function confirms that the fields have the correct input type + If the children field is specific to a parent, the children fields class is hosted in the parent fields class + + If any field in not the right type, the function prints an error message explaining the problem and exits the python code + + Args: + config_yaml (dict): The dict containing the fields of the yaml configuration file + + Returns: + None + """ + try: + YamlSchema(yaml_conf=config_yaml) + except ValidationError as e: + print(e) + raise ValueError("Wrong type on a field, see the pydantic report above") # Crashes in case of an error raised diff --git a/tests/data/test_csv.py b/tests/data/test_csv.py new file mode 100644 index 00000000..ad88f6a1 --- /dev/null +++ b/tests/data/test_csv.py @@ -0,0 +1,196 @@ +import pytest +import polars as pl +import numpy as np +from pathlib import Path +import yaml + +from stimulus.data.csv import DatasetHandler, DatasetManager, EncodeManager, TransformManager, SplitManager +from stimulus.utils.yaml_data import generate_data_configs, YamlConfigDict +from stimulus.data import experiments + +# Fixtures +@pytest.fixture +def titanic_csv_path(): + return "tests/test_data/titanic/titanic_stimulus.csv" + +@pytest.fixture +def config_path(): + return "tests/test_data/titanic/titanic.yaml" + +@pytest.fixture +def base_config(config_path): + with open(config_path, 'r') as f: + return YamlConfigDict(**yaml.safe_load(f)) + +@pytest.fixture +def split_configs(base_config): + """Generate all possible configurations from base config""" + return generate_data_configs(base_config) + + +# Test DatasetHandler Integration +@pytest.fixture +def encoder_loader(base_config): + loader = experiments.EncoderLoader() + loader.initialize_column_encoders_from_config(base_config.columns) + return loader + +@pytest.fixture +def transform_loader(base_config): + loader = experiments.TransformLoader() + if "transforms" in base_config: + loader.initialize_column_data_transformers_from_config(base_config["transforms"]) + return loader + +@pytest.fixture +def split_loader(base_config): + loader = experiments.SplitLoader() + if "split" in base_config: + # Get first split configuration + split_config = base_config["split"][0] + splitter = loader.get_splitter(split_config["split_method"]) + loader.set_splitter_as_attribute("split", splitter) + return loader + +# Test DatasetManager +def test_dataset_manager_init(config_path): + manager = DatasetManager(config_path) + assert hasattr(manager, "config") + assert hasattr(manager, "column_categories") + +def test_dataset_manager_organize_columns(config_path): + manager = DatasetManager(config_path) + categories = manager.categorize_columns_by_type() + + assert "pclass" in categories["input"] + assert "sex" in categories["input"] + assert "age" in categories["input"] + assert "survived" in categories["label"] + assert "passenger_id" in categories["meta"] + +def test_dataset_manager_organize_transforms(config_path): + manager = DatasetManager(config_path) + categories = manager.categorize_columns_by_type() + + assert len(categories) == 3 + assert all(key in categories for key in ["input", "label", "meta"]) + +# Test EncodeManager +def test_encode_manager_init(): + encoder_loader = experiments.EncoderLoader() + manager = EncodeManager(encoder_loader) + assert hasattr(manager, "encoder_loader") + +def test_encode_manager_initialize_encoders(): + encoder_loader = experiments.EncoderLoader() + manager = EncodeManager(encoder_loader) + assert hasattr(manager, "encoder_loader") + +def test_encode_manager_encode_numeric(): + encoder_loader = experiments.EncoderLoader() + intencoder = encoder_loader.get_encoder("IntEncoder") + encoder_loader.set_encoder_as_attribute("test_col", intencoder) + manager = EncodeManager(encoder_loader) + data = [1, 2, 3] + encoded = manager.encode_column("test_col", data) + assert encoded is not None + +# Test TransformManager +def test_transform_manager_init(): + transform_loader = experiments.TransformLoader() + manager = TransformManager(transform_loader) + assert hasattr(manager, "transform_loader") + +def test_transform_manager_initialize_transforms(): + transform_loader = experiments.TransformLoader() + manager = TransformManager(transform_loader) + assert hasattr(manager, "transform_loader") + +def test_transform_manager_apply_transforms(): + transform_loader = experiments.TransformLoader() + manager = TransformManager(transform_loader) + assert hasattr(manager, "transform_loader") + +# Test SplitManager +def test_split_manager_init(): + split_loader = experiments.SplitLoader() + manager = SplitManager(split_loader) + assert hasattr(manager, "split_loader") + +def test_split_manager_initialize_splits(): + split_loader = experiments.SplitLoader() + manager = SplitManager(split_loader) + assert hasattr(manager, "split_loader") + +def test_split_manager_apply_split(): + split_loader = experiments.SplitLoader(seed=42) + manager = SplitManager(split_loader) + data = pl.DataFrame({"col": range(100)}) + split_indices = manager.get_split_indices(data) + assert len(split_indices) == 100 + +def test_dataset_handler_init(config_path, titanic_csv_path, encoder_loader, transform_loader, split_loader): + handler = DatasetHandler( + config_path=config_path, + encoder_loader=encoder_loader, + transform_loader=transform_loader, + split_loader=split_loader, + csv_path=titanic_csv_path + ) + + assert isinstance(handler.encoder_manager, EncodeManager) + assert isinstance(handler.transform_manager, TransformManager) + assert isinstance(handler.split_manager, SplitManager) + +def test_dataset_handler_get_dataset(config_path, titanic_csv_path, encoder_loader): + transform_loader = experiments.TransformLoader() + split_loader = experiments.SplitLoader() + + handler = DatasetHandler( + config_path=config_path, + encoder_loader=encoder_loader, + transform_loader=transform_loader, + split_loader=split_loader, + csv_path=titanic_csv_path + ) + + dataset = handler.get_all_items() + assert isinstance(dataset, tuple) + +def test_dataset_handler_print_dataset_info(config_path, titanic_csv_path, encoder_loader): + transform_loader = experiments.TransformLoader() + split_loader = experiments.SplitLoader() + + handler = DatasetHandler( + config_path=config_path, + encoder_loader=encoder_loader, + transform_loader=transform_loader, + split_loader=split_loader, + csv_path=titanic_csv_path + ) + + input_dict, label_dict, meta_dict = handler.get_all_items() + # Print input dict keys and first 5 elements of each value + print("\nInput dictionary contents:") + for key, value in input_dict.items(): + print(f"\n{key}:") + if isinstance(value, np.ndarray): + print(value[:5]) # Print first 5 elements if numpy array + else: + print(value[:5]) # Print first 5 elements if list + +@pytest.mark.parametrize("config_idx", [0, 1]) # Test both split configs +def test_dataset_handler_different_configs(config_path, titanic_csv_path, config_idx, encoder_loader): + transform_loader = experiments.TransformLoader() + split_loader = experiments.SplitLoader() + + handler = DatasetHandler( + config_path=config_path, + encoder_loader=encoder_loader, + transform_loader=transform_loader, + split_loader=split_loader, + csv_path=titanic_csv_path + ) + + dataset = handler.get_all_items() + assert isinstance(dataset, tuple) diff --git a/tests/data/test_experiment.py b/tests/data/test_experiment.py new file mode 100644 index 00000000..0e6ea02e --- /dev/null +++ b/tests/data/test_experiment.py @@ -0,0 +1,129 @@ +import pytest + +import numpy as np + +from stimulus.data.transform import data_transformation_generators +from stimulus.data.encoding.encoders import AbstractEncoder +import stimulus.data.experiments as experiments +import stimulus.data.splitters as splitters +@pytest.fixture +def dna_experiment_config_path(): + """Fixture that provides the path to the DNA experiment config template YAML file. + + This fixture returns the path to a YAML configuration file containing DNA experiment + parameters, including column definitions and transformation specifications. + + Returns: + str: Path to the DNA experiment config template YAML file + """ + return "tests/test_data/dna_experiment/dna_experiment_config_template.yaml" + +@pytest.fixture +def titanic_yaml_path(): + return "tests/test_data/titanic/titanic.yaml" + +@pytest.fixture +def TextOneHotEncoder_name_and_params(): + return "TextOneHotEncoder", {"alphabet": "acgt"} + + +def test_get_config_from_yaml(dna_experiment_config_path): + """Test the get_config_from_yaml method of the AbstractExperiment class. + + This test checks if the get_config_from_yaml method correctly parses the YAML configuration file. + """ + experiment = experiments.AbstractLoader() + config = experiment.get_config_from_yaml(dna_experiment_config_path) + assert config is not None + +def test_get_encoder(TextOneHotEncoder_name_and_params): + """Test the get_encoder method of the AbstractExperiment class. + + This test checks if the get_encoder method correctly returns the encoder function. + """ + experiment = experiments.EncoderLoader() + encoder_name, encoder_params = TextOneHotEncoder_name_and_params + encoder = experiment.get_encoder(encoder_name, encoder_params) + assert isinstance(encoder, AbstractEncoder) + +def test_set_encoder_as_attribute(TextOneHotEncoder_name_and_params): + """Test the set_encoder_as_attribute method of the AbstractExperiment class. + + This test checks if the set_encoder_as_attribute method correctly sets the encoder as an attribute of the experiment class. + """ + experiment = experiments.EncoderLoader() + encoder_name, encoder_params = TextOneHotEncoder_name_and_params + encoder = experiment.get_encoder(encoder_name, encoder_params) + experiment.set_encoder_as_attribute("ciao", encoder) + assert hasattr(experiment, "ciao") + assert experiment.ciao["encoder"] == encoder + assert experiment.get_function_encode_all("ciao") == encoder.encode_all + +def test_build_experiment_class_encoder_dict(dna_experiment_config_path): + """Test the build_experiment_class_encoder_dict method of the AbstractExperiment class. + + This test checks if the build_experiment_class_encoder_dict method correctly builds the experiment class from a config dictionary. + """ + experiment = experiments.EncoderLoader() + config = experiment.get_config_from_yaml(dna_experiment_config_path).columns + experiment.initialize_column_encoders_from_config(config) + assert hasattr(experiment, "hello") + assert hasattr(experiment, "bonjour") + assert hasattr(experiment, "ciao") + + # call encoder from "hello", check that it completes successfully + assert experiment.hello["encoder"].encode_all(["a", "c", "g", "t"]) is not None + +def test_get_data_transformer(): + """Test the get_data_transformer method of the TransformLoader class. + + This test checks if the get_data_transformer method correctly returns the transformer function. + """ + experiment = experiments.TransformLoader() + transformer = experiment.get_data_transformer("ReverseComplement") + assert isinstance(transformer, data_transformation_generators.ReverseComplement) + +def test_set_data_transformer_as_attribute(): + """Test the set_data_transformer_as_attribute method of the TransformLoader class. + + This test checks if the set_data_transformer_as_attribute method correctly sets the transformer + as an attribute of the experiment class. + """ + experiment = experiments.TransformLoader() + transformer = experiment.get_data_transformer("ReverseComplement") + experiment.set_data_transformer_as_attribute("col1", transformer) + assert hasattr(experiment, "col1") + assert experiment.col1["data_transformation_generators"] == transformer + +def test_initialize_column_data_transformers_from_config(dna_experiment_config_path): + """Test the initialize_column_data_transformers_from_config method of the TransformLoader class. + + This test checks if the initialize_column_data_transformers_from_config method correctly builds + the experiment class from a config dictionary. + """ + experiment = experiments.TransformLoader() + config = experiment.get_config_from_yaml(dna_experiment_config_path).transforms + experiment.initialize_column_data_transformers_from_config(config) + + # Check columns have transformers set + assert hasattr(experiment, "col1") + assert hasattr(experiment, "col2") + + # Check transformers were properly initialized + col1_transformers = experiment.col1["data_transformation_generators"] + col2_transformers = experiment.col2["data_transformation_generators"] + + # Verify col1 has the expected transformers + assert any(isinstance(t, data_transformation_generators.ReverseComplement) for t in col1_transformers) + assert any(isinstance(t, data_transformation_generators.UniformTextMasker) for t in col1_transformers) + assert any(isinstance(t, data_transformation_generators.GaussianNoise) for t in col1_transformers) + + # Verify col2 has the expected transformer + assert any(isinstance(t, data_transformation_generators.GaussianNoise) for t in col2_transformers) + +def test_initialize_splitter_from_config(titanic_yaml_path): + experiment = experiments.SplitLoader() + config = experiment.get_config_from_yaml(titanic_yaml_path) + experiment.initialize_splitter_from_config(config) + assert hasattr(experiment, "split") + assert isinstance(experiment.split, splitters.RandomSplit) \ No newline at end of file diff --git a/tests/test_data/dna_experiment/dna_experiment_config_template.yaml b/tests/test_data/dna_experiment/dna_experiment_config_template.yaml new file mode 100644 index 00000000..793aa831 --- /dev/null +++ b/tests/test_data/dna_experiment/dna_experiment_config_template.yaml @@ -0,0 +1,68 @@ +global_params: + seed: 0 + +columns: + - column_name : "hello" + column_type : "input" + data_type : str + encoder: + - name: TextOneHotEncoder + params: + alphabet: "acgt" + + - column_name : "bonjour" + column_type : "input" + data_type : str + encoder: + - name: TextOneHotEncoder + params: + alphabet: "acgt" + + - column_name : "ciao" + column_type : "label" + data_type : int + encoder: + - name: IntEncoder + params: + +transforms: + - transformation_name: 'A' + columns: + - column_name: "col1" + transformations: + - name: ReverseComplement + params: + - transformation_name: 'B' + columns: + - column_name: "col1" + transformations: + - name: UniformTextMasker + params: + probability: [0.1, 0.2, 0.3] + - transformation_name: 'C' + columns: + - column_name: "col1" + transformations: + - name: ReverseComplement + params: + - name: UniformTextMasker + params: + probability: [0.1, 0.2, 0.3, 0.4] + - name: "GaussianNoise" + params: + std: [0.1, 0.2, 0.3, 0.4] + - column_name: "col2" + transformations: + - name: "GaussianNoise" + params: + std: [0.1, 0.2, 0.1, 0.2] + +split: + - split_method: random_split + split_input_columns: + params: + split: [0.6, 0.2, 0.2] + - split_method: random_split + split_input_columns: + params: + split: [0.7, 0.15, 0.15] \ No newline at end of file diff --git a/tests/test_data/titanic/titanic.yaml b/tests/test_data/titanic/titanic.yaml new file mode 100644 index 00000000..e6d1fef9 --- /dev/null +++ b/tests/test_data/titanic/titanic.yaml @@ -0,0 +1,85 @@ +global_params: + seed: 42 + +columns: + - column_name: "passenger_id" + column_type: "meta" + data_type: "int" + encoder: + - name: IntEncoder + params: + + - column_name: "survived" + column_type: "label" + data_type: "int" + encoder: + - name: IntEncoder + params: + + - column_name: "pclass" + column_type: "input" + data_type: "int" + encoder: + - name: IntEncoder + params: + + - column_name: "sex" + column_type: "input" + data_type: "str" + encoder: + - name: StrClassificationIntEncoder + params: + + - column_name: "age" + column_type: "input" + data_type: "float" + encoder: + - name: FloatRankEncoder + params: + + - column_name: "sibsp" + column_type: "input" + data_type: "int" + encoder: + - name: IntEncoder + params: + + - column_name: "parch" + column_type: "input" + data_type: "int" + encoder: + - name: IntEncoder + params: + + - column_name: "fare" + column_type: "input" + data_type: "float" + encoder: + - name: FloatRankEncoder + params: + + - column_name: "embarked" + column_type: "input" + data_type: "str" + encoder: + - name: StrClassificationIntEncoder + params: +transforms: + - transformation_name: "noise" + columns: + - column_name: "age" + transformations: + - name: GaussianNoise + params: + std: [0.1, 0.2, 0.3] + - column_name: "fare" + transformations: + - name: GaussianNoise + params: + std: [0.1, 0.2, 0.3] + +split: + - split_method: RandomSplit + split_input_columns: + params: + split: [0.7, 0.15, 0.15] diff --git a/tests/test_data/titanic/titanic_stimulus.csv b/tests/test_data/titanic/titanic_stimulus.csv index 5a127337..e47a2cdd 100644 --- a/tests/test_data/titanic/titanic_stimulus.csv +++ b/tests/test_data/titanic/titanic_stimulus.csv @@ -1,4 +1,4 @@ -passenger_id:meta:int,survived:label:int_class,pclass:input:int_class,sex:input:str_class,age:input:int_reg,sibsp:input:int_class,parch:input:int_class,fare:input:float_rank,embarked:input:str_class +passenger_id,survived,pclass,sex,age,sibsp,parch,fare,embarked 1,0,3,male,22.0,1,0,7.25,S 2,1,1,female,38.0,1,0,71.2833,C 3,1,3,female,26.0,0,0,7.925,S diff --git a/tests/test_data/yaml_files/wrong_field_type.yaml b/tests/test_data/yaml_files/wrong_field_type.yaml new file mode 100644 index 00000000..57642563 --- /dev/null +++ b/tests/test_data/yaml_files/wrong_field_type.yaml @@ -0,0 +1,33 @@ +global_params: + seed: 0 + +columns: + - column_name: "hi" + column_type: "type" + data_type: str + encoder: + - name: OneHotEncoder + params: + alphabet: 2 # Wrong + + - column_name: "Guten Tag" + column_type: "type" + data_type: 35.6 # Wrong + encoder: + - name: 9 # Wrong + params: + + +transforms: + - transformation_name: "D" + columns: + - column_name: "col2" + transformations: + - name: 4 # Wrong + params: + probability: "error" # Wrong + +split: + - split_method: 3 # Wrong + params: + split: [[0.6, 0.2, 0.2]] \ No newline at end of file diff --git a/tests/utils/test_data_yaml.py b/tests/utils/test_data_yaml.py new file mode 100644 index 00000000..e3c941af --- /dev/null +++ b/tests/utils/test_data_yaml.py @@ -0,0 +1,84 @@ +import pytest +import yaml +from stimulus.utils import yaml_data +from stimulus.utils.yaml_data import YamlConfigDict, YamlTransform, YamlSubConfigDict + +@pytest.fixture +def load_yaml_from_file() -> YamlConfigDict: + """Fixture that loads a test YAML configuration file.""" + with open("tests/test_data/dna_experiment/dna_experiment_config_template.yaml") as f: + yaml_dict = yaml.safe_load(f) + return YamlConfigDict(**yaml_dict) + +@pytest.fixture +def load_wrong_type_yaml() -> dict: + """Fixture that loads a YAML configuration file with wrong typing.""" + with open("tests/test_data/yaml_files/wrong_field_type.yaml") as f: + return yaml.safe_load(f) + +def test_extract_transform_parameters_at_index(load_yaml_from_file): + """Tests extracting parameters at specific indices from transforms.""" + # Test transform with parameter lists + transform = load_yaml_from_file.transforms[1] # Transform 'B' with probability list + + # Extract first parameter set + result = yaml_data.extract_transform_parameters_at_index(transform, 0) + assert result.columns[0].transformations[0].params["probability"] == 0.1 + + # Extract second parameter set + result = yaml_data.extract_transform_parameters_at_index(transform, 1) + assert result.columns[0].transformations[0].params["probability"] == 0.2 + +def test_expand_transform_parameter_combinations(load_yaml_from_file): + """Tests expanding transforms with parameter lists into individual transforms.""" + # Test transform with multiple parameter lists + transform = load_yaml_from_file.transforms[2] # Transform 'C' with multiple lists + + results = yaml_data.expand_transform_parameter_combinations(transform) + assert len(results) == 4 # Should create 4 transforms (longest parameter list length) + + # Check first and last transforms + assert results[0].columns[0].transformations[1].params["probability"] == 0.1 + assert results[3].columns[0].transformations[1].params["probability"] == 0.4 + +def test_expand_transform_list_combinations(load_yaml_from_file): + """Tests expanding a list of transforms into all parameter combinations.""" + results = yaml_data.expand_transform_list_combinations(load_yaml_from_file.transforms) + + # Count expected transforms: + # Transform A: 1 (no parameters) + # Transform B: 3 (probability list length 3) + # Transform C: 4 (probability and std lists length 4) + assert len(results) == 8 + +def test_generate_data_configs(load_yaml_from_file): + """Tests generating all possible data configurations.""" + configs = yaml_data.generate_data_configs(load_yaml_from_file) + + # Expected configs = (transforms combinations) × (number of splits) + # 8 transform combinations × 2 splits = 16 configs + assert len(configs) == 16 + + # Check that each config is a valid YamlSubConfigDict + for config in configs: + assert isinstance(config, YamlSubConfigDict) + assert config.global_params == load_yaml_from_file.global_params + assert config.columns == load_yaml_from_file.columns + +def test_dump_yaml_list_into_files(load_yaml_from_file): + """Tests dumping a list of YAML configurations into separate files.""" + configs = yaml_data.generate_data_configs(load_yaml_from_file) + yaml_data.dump_yaml_list_into_files(configs, "scratch/", "dna_experiment_config_template") + +@pytest.mark.parametrize("test_input", [("load_yaml_from_file", False), ("load_wrong_type_yaml", True)]) +def test_check_yaml_schema(request, test_input): + """Tests the Pydantic schema validation.""" + data = request.getfixturevalue(test_input[0]) + expect_value_error = test_input[1] + + if not expect_value_error: + yaml_data.check_yaml_schema(data) + assert True + else: + with pytest.raises(ValueError, match="Wrong type on a field, see the pydantic report above"): + yaml_data.check_yaml_schema(data)