diff --git a/src/stimulus/cli/check_model.py b/src/stimulus/cli/check_model.py index 74c3ea3e..b3ef0150 100755 --- a/src/stimulus/cli/check_model.py +++ b/src/stimulus/cli/check_model.py @@ -22,8 +22,22 @@ def get_args() -> argparse.Namespace: Parsed command line arguments. """ parser = argparse.ArgumentParser(description="Launch check_model.") - parser.add_argument("-d", "--data", type=str, required=True, metavar="FILE", help="Path to input csv file.") - parser.add_argument("-m", "--model", type=str, required=True, metavar="FILE", help="Path to model file.") + parser.add_argument( + "-d", + "--data", + type=str, + required=True, + metavar="FILE", + help="Path to input csv file.", + ) + parser.add_argument( + "-m", + "--model", + type=str, + required=True, + metavar="FILE", + help="Path to model file.", + ) parser.add_argument( "-e", "--data_config", @@ -107,14 +121,16 @@ def main( """ with open(data_config_path) as file: data_config = yaml.safe_load(file) - data_config = yaml_data.YamlSubConfigDict(**data_config) + data_config = yaml_data.YamlSplitTransformDict(**data_config) with open(model_config_path) as file: model_config = yaml.safe_load(file) model_config = yaml_model_schema.Model(**model_config) encoder_loader = loaders.EncoderLoader() - encoder_loader.initialize_column_encoders_from_config(column_config=data_config.columns) + encoder_loader.initialize_column_encoders_from_config( + column_config=data_config.columns + ) logger.info("Dataset loaded successfully.") @@ -122,7 +138,8 @@ def main( logger.info("Model class loaded successfully.") - ray_config_loader = yaml_model_schema.YamlRayConfigLoader(model=model_config) + ray_config_loader = yaml_model_schema.YamlRayConfigLoader( + model=model_config) ray_config_dict = ray_config_loader.get_config().model_dump() ray_config_model = ray_config_loader.get_config() @@ -140,7 +157,7 @@ def main( logger.info("Model instance loaded successfully.") torch_dataset = handlertorch.TorchDataset( - config_path=data_config_path, + data_config=data_config, csv_path=data_path, encoder_loader=encoder_loader, ) @@ -171,7 +188,7 @@ def main( tuner = raytune_learner.TuneWrapper( model_config=ray_config_model, - data_config_path=data_config_path, + data_config=data_config, model_class=model_class, data_path=data_path, encoder_loader=encoder_loader, diff --git a/src/stimulus/cli/split_csv.py b/src/stimulus/cli/split_csv.py index 04de0e56..44afc03d 100755 --- a/src/stimulus/cli/split_csv.py +++ b/src/stimulus/cli/split_csv.py @@ -7,7 +7,7 @@ from stimulus.data.data_handlers import DatasetProcessor, SplitManager from stimulus.data.loaders import SplitLoader -from stimulus.utils.yaml_data import YamlSubConfigDict +from stimulus.utils.yaml_data import YamlSplitConfigDict def get_args() -> argparse.Namespace: @@ -49,7 +49,9 @@ def get_args() -> argparse.Namespace: return parser.parse_args() -def main(data_csv: str, config_yaml: str, out_path: str, *, force: bool = False) -> None: +def main( + data_csv: str, config_yaml: str, out_path: str, *, force: bool = False +) -> None: """Connect CSV and YAML configuration and handle sanity checks. Args: @@ -64,7 +66,7 @@ def main(data_csv: str, config_yaml: str, out_path: str, *, force: bool = False) # create a split manager from the config split_config = processor.dataset_manager.config.split with open(config_yaml) as f: - yaml_config = YamlSubConfigDict(**yaml.safe_load(f)) + yaml_config = YamlSplitConfigDict(**yaml.safe_load(f)) split_loader = SplitLoader(seed=yaml_config.global_params.seed) split_loader.initialize_splitter_from_config(split_config) split_manager = SplitManager(split_loader) diff --git a/src/stimulus/cli/split_yaml.py b/src/stimulus/cli/split_split.py similarity index 73% rename from src/stimulus/cli/split_yaml.py rename to src/stimulus/cli/split_split.py index 69c166a7..63e265d9 100755 --- a/src/stimulus/cli/split_yaml.py +++ b/src/stimulus/cli/split_split.py @@ -1,8 +1,8 @@ #!/usr/bin/env python3 -"""CLI module for splitting YAML configuration files. +"""CLI module for splitting YAML configuration files into unique files for each split. This module provides functionality to split a single YAML configuration file into multiple -YAML files, each containing a specific combination of data transformations and splits. +YAML files, each containing a unique split. The resulting YAML files can be used as input configurations for the stimulus package. """ @@ -13,9 +13,10 @@ from stimulus.utils.yaml_data import ( YamlConfigDict, + YamlSplitConfigDict, check_yaml_schema, dump_yaml_list_into_files, - generate_data_configs, + generate_split_configs, ) @@ -39,6 +40,7 @@ def get_args() -> argparse.Namespace: const="./", default="./", metavar="DIR", + # TODO: Change the output name help="The output dir where all the YAMLs are written to. Output YAML will be called split-#[number].yaml transform-#[number].yaml. Default -> ./", ) @@ -46,32 +48,32 @@ def get_args() -> argparse.Namespace: def main(config_yaml: str, out_dir_path: str) -> None: - """Reads a YAML config file and generates all possible data configurations. + """Reads a YAML config file and generates a file per unique split. This script reads a YAML with a defined structure and creates all the YAML files ready to be passed to the stimulus package. The structure of the YAML is described here -> TODO paste here link to documentation. - This YAML and it's structure summarize how to generate all the transform - split and respective parameter combinations. - Each resulting YAML will hold only one combination of the above three things. + This YAML and its structure summarize how to generate unique splits and all the transformations associated to this split. This script will always generate at least one YAML file that represent the combination that does not touch the data (no transform) and uses the default split behavior. """ - # read the yaml experiment config and load it to dictionary + # read the yaml experiment config and load its to dictionary yaml_config: dict[str, Any] = {} with open(config_yaml) as conf_file: yaml_config = yaml.safe_load(conf_file) yaml_config_dict: YamlConfigDict = YamlConfigDict(**yaml_config) # check if the yaml schema is correct + # FIXME: isn't it redundant to check and already class with pydantic ? check_yaml_schema(yaml_config_dict) - # generate all the YAML configs - data_configs = generate_data_configs(yaml_config_dict) + # generate the yaml files per split + split_configs: list[YamlSplitConfigDict] = generate_split_configs(yaml_config_dict) # dump all the YAML configs into files - dump_yaml_list_into_files(data_configs, out_dir_path, "test") + dump_yaml_list_into_files(split_configs, out_dir_path, "test_split") def run() -> None: diff --git a/src/stimulus/cli/split_transforms.py b/src/stimulus/cli/split_transforms.py new file mode 100644 index 00000000..f3e57717 --- /dev/null +++ b/src/stimulus/cli/split_transforms.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python3 +"""CLI module for splitting YAML configuration files into unique files for each transform. + +This module provides functionality to split a single YAML configuration file into multiple +YAML files, each containing a unique transform associated to a unique split. +The resulting YAML files can be used as input configurations for the stimulus package. +""" + +import argparse +from typing import Any + +import yaml + +from stimulus.utils.yaml_data import ( + YamlSplitConfigDict, + YamlSplitTransformDict, + dump_yaml_list_into_files, + generate_split_transform_configs, +) + + +def get_args() -> argparse.Namespace: + """Get the arguments when using the command line.""" + parser = argparse.ArgumentParser(description="") + parser.add_argument( + "-j", + "--yaml", + type=str, + required=True, + metavar="FILE", + help="The YAML config file that hold all the transform per split parameter info", + ) + parser.add_argument( + "-d", + "--out-dir", + type=str, + required=False, + nargs="?", + const="./", + default="./", + metavar="DIR", + help="The output dir where all the YAMLs are written to. Output YAML will be called split_transform-#[number].yaml. Default -> ./", + ) + + return parser.parse_args() + + +def main(config_yaml: str, out_dir_path: str) -> None: + """Reads a YAML config and generates files for all split - transform possible combinations. + + This script reads a YAML with a defined structure and creates all the YAML files ready to be passed to the stimulus package. + + The structure of the YAML is described here -> TODO: paste here the link to documentation + This YAML and its structure summarize how to generate all the transform for the split and respective parameter combinations. + + This script will always generate at least one YAML file that represent the combination that does not touch the data (no transform). + """ + # read the yaml experiment config and load its dictionnary + yaml_config: dict[str, Any] = {} + with open(config_yaml) as conf_file: + yaml_config = yaml.safe_load(conf_file) + + yaml_config_dict: YamlSplitConfigDict = YamlSplitConfigDict(**yaml_config) + + # Generate the yaml files for each transform + split_transform_configs: list[YamlSplitTransformDict] = ( + generate_split_transform_configs(yaml_config_dict) + ) + + # Dump all the YAML configs into files + dump_yaml_list_into_files(split_transform_configs, out_dir_path, "test_transforms") + + +if __name__ == "__main__": + args = get_args() + main(args.yaml, args.out_dir) diff --git a/src/stimulus/cli/transform_csv.py b/src/stimulus/cli/transform_csv.py index 2e2ff5fd..55b25882 100755 --- a/src/stimulus/cli/transform_csv.py +++ b/src/stimulus/cli/transform_csv.py @@ -7,12 +7,14 @@ from stimulus.data.data_handlers import DatasetProcessor, TransformManager from stimulus.data.loaders import TransformLoader -from stimulus.utils.yaml_data import YamlSubConfigDict +from stimulus.utils.yaml_data import YamlSplitConfigDict def get_args() -> argparse.Namespace: """Get the arguments when using from the commandline.""" - parser = argparse.ArgumentParser(description="CLI for transforming CSV data files using YAML configuration.") + parser = argparse.ArgumentParser( + description="CLI for transforming CSV data files using YAML configuration." + ) parser.add_argument( "-c", "--csv", @@ -53,8 +55,9 @@ def main(data_csv: str, config_yaml: str, out_path: str) -> None: # initialize the transform manager transform_config = processor.dataset_manager.config.transforms with open(config_yaml) as f: - yaml_config = YamlSubConfigDict(**yaml.safe_load(f)) + yaml_config = YamlSplitConfigDict(**yaml.safe_load(f)) transform_loader = TransformLoader(seed=yaml_config.global_params.seed) + print(transform_config) transform_loader.initialize_column_data_transformers_from_config(transform_config) transform_manager = TransformManager(transform_loader) diff --git a/src/stimulus/cli/tuning.py b/src/stimulus/cli/tuning.py index 19369b37..5beecdb4 100755 --- a/src/stimulus/cli/tuning.py +++ b/src/stimulus/cli/tuning.py @@ -29,8 +29,22 @@ def get_args() -> argparse.Namespace: Parsed command line arguments. """ parser = argparse.ArgumentParser(description="Launch check_model.") - parser.add_argument("-d", "--data", type=str, required=True, metavar="FILE", help="Path to input csv file.") - parser.add_argument("-m", "--model", type=str, required=True, metavar="FILE", help="Path to model file.") + parser.add_argument( + "-d", + "--data", + type=str, + required=True, + metavar="FILE", + help="Path to input csv file.", + ) + parser.add_argument( + "-m", + "--model", + type=str, + required=True, + metavar="FILE", + help="Path to model file.", + ) parser.add_argument( "-e", "--data_config", @@ -137,7 +151,7 @@ def get_args() -> argparse.Namespace: def main( model_path: str, data_path: str, - data_config_path: str, + data_config: yaml_data.YamlSplitTransformDict, model_config_path: str, initial_weights: str | None = None, # noqa: ARG001 ray_results_dirpath: str | None = None, @@ -153,7 +167,7 @@ def main( Args: data_path: Path to input data file. model_path: Path to model file. - data_config_path: Path to data config file. + data_config: A YamlSplitTransformObject model_config_path: Path to model config file. initial_weights: Optional path to initial weights. ray_results_dirpath: Directory for ray results. @@ -163,26 +177,25 @@ def main( best_metrics_path: Path to write the best metrics to. best_config_path: Path to write the best config to. """ - # Convert data config to proper type - with open(data_config_path) as file: - data_config_dict: dict[str, Any] = yaml.safe_load(file) - data_config: yaml_data.YamlSubConfigDict = yaml_data.YamlSubConfigDict(**data_config_dict) - with open(model_config_path) as file: model_config_dict: dict[str, Any] = yaml.safe_load(file) - model_config: yaml_model_schema.Model = yaml_model_schema.Model(**model_config_dict) + model_config: yaml_model_schema.Model = yaml_model_schema.Model( + **model_config_dict) encoder_loader = loaders.EncoderLoader() - encoder_loader.initialize_column_encoders_from_config(column_config=data_config.columns) + encoder_loader.initialize_column_encoders_from_config( + column_config=data_config.columns + ) model_class = launch_utils.import_class_from_file(model_path) - ray_config_loader = yaml_model_schema.YamlRayConfigLoader(model=model_config) + ray_config_loader = yaml_model_schema.YamlRayConfigLoader( + model=model_config) ray_config_model = ray_config_loader.get_config() tuner = raytune_learner.TuneWrapper( model_config=ray_config_model, - data_config_path=data_config_path, + data_config=data_config, model_class=model_class, data_path=data_path, encoder_loader=encoder_loader, @@ -228,10 +241,14 @@ def run() -> None: """Run the model checking script.""" ray.init(address="auto", ignore_reinit_error=True) args = get_args() + # Try to convert the configuration file to a YamlSplitTransformDict + config_dict: yaml_data.YamlSplitTransformDict + with open(args.data_config) as f: + config_dict = yaml_data.YamlSplitTransformDict(**yaml.safe_load(f)) main( data_path=args.data, model_path=args.model, - data_config_path=args.data_config, + data_config=config_dict, model_config_path=args.model_config, initial_weights=args.initial_weights, ray_results_dirpath=args.ray_results_dirpath, diff --git a/src/stimulus/data/data_handlers.py b/src/stimulus/data/data_handlers.py index e7d6fc9f..4af2f786 100644 --- a/src/stimulus/data/data_handlers.py +++ b/src/stimulus/data/data_handlers.py @@ -50,10 +50,11 @@ class DatasetManager: def __init__( self, - config_path: str, + config_dict: yaml_data.YamlSplitConfigDict, ) -> None: """Initialize the DatasetManager.""" - self.config = self._load_config(config_path) + # self.config = self._load_config(config_path) + self.config: yaml_data.YamlSplitTransformDict = config_dict self.column_categories = self.categorize_columns_by_type() def categorize_columns_by_type(self) -> dict: @@ -93,7 +94,8 @@ def categorize_columns_by_type(self) -> dict: return {"input": input_columns, "label": label_columns, "meta": meta_columns} - def _load_config(self, config_path: str) -> yaml_data.YamlConfigDict: + # TODO: Remove or change this function as the config is now preloaded + def _load_config(self, config_path: str) -> yaml_data.YamlSplitConfigDict: """Loads and parses a YAML configuration file. Args: @@ -108,8 +110,11 @@ def _load_config(self, config_path: str) -> yaml_data.YamlConfigDict: >>> print(config["columns"][0]["column_name"]) 'hello' """ + with open(config_path) as file: - return yaml_data.YamlSubConfigDict(**yaml.safe_load(file)) + # FIXME: cette fonction est appellé pour test_shuffle_csv et test_tune + return yaml_data.YamlSplitConfigDict(**yaml.safe_load(file)) + return yaml_data.YamlSplitTransformDict(**yaml.safe_load(file)) def get_split_columns(self) -> list[str]: """Get the columns that are used for splitting.""" @@ -185,7 +190,8 @@ def encode_column(self, column_name: str, column_data: list) -> torch.Tensor: >>> print(encoded.shape) torch.Size([2, 4, 4]) # 2 sequences, length 4, one-hot encoded """ - encode_all_function = self.encoder_loader.get_function_encode_all(column_name) + encode_all_function = self.encoder_loader.get_function_encode_all( + column_name) return encode_all_function(column_data) def encode_columns(self, column_data: dict) -> dict: @@ -207,11 +213,16 @@ def encode_columns(self, column_data: dict) -> dict: >>> print(encoded["dna_seq"].shape) torch.Size([2, 4, 4]) # 2 sequences, length 4, one-hot encoded """ - return {col: self.encode_column(col, values) for col, values in column_data.items()} + return { + col: self.encode_column(col, values) for col, values in column_data.items() + } def encode_dataframe(self, dataframe: pl.DataFrame) -> dict[str, torch.Tensor]: """Encode the dataframe using the encoders.""" - return {col: self.encode_column(col, dataframe[col].to_list()) for col in dataframe.columns} + return { + col: self.encode_column(col, dataframe[col].to_list()) + for col in dataframe.columns + } class TransformManager: @@ -224,7 +235,9 @@ def __init__( """Initialize the TransformManager.""" self.transform_loader = transform_loader - def transform_column(self, column_name: str, transform_name: str, column_data: list) -> tuple[list, bool]: + def transform_column( + self, column_name: str, transform_name: str, column_data: list + ) -> tuple[list, bool]: """Transform a column of data using the specified transformation. Args: @@ -236,7 +249,9 @@ def transform_column(self, column_name: str, transform_name: str, column_data: l list: The transformed data. bool: Whether the transformation added new rows to the data. """ - transformer = self.transform_loader.__getattribute__(column_name)[transform_name] + transformer = self.transform_loader.__getattribute__(column_name)[ + transform_name + ] return transformer.transform_all(column_data), transformer.add_row @@ -250,7 +265,9 @@ def __init__( """Initialize the SplitManager.""" self.split_loader = split_loader - def get_split_indices(self, data: dict) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + def get_split_indices( + self, data: dict + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """Get the indices for train, validation, and test splits.""" return self.split_loader.get_function_split()(data) @@ -270,16 +287,16 @@ class DatasetHandler: def __init__( self, - config_path: str, + data_config: yaml_data.YamlSplitTransformDict, csv_path: str, ) -> None: """Initialize the DatasetHandler with required config. Args: - config_path (str): Path to the dataset configuration file. + data_config (yaml_data.YamlSplitTransformDict): A YamlSplitTransformDict object holding the config. csv_path (str): Path to the CSV data file. """ - self.dataset_manager = DatasetManager(config_path) + self.dataset_manager = DatasetManager(data_config) self.columns = self.read_csv_header(csv_path) self.data = self.load_csv(csv_path) @@ -353,7 +370,8 @@ def add_split(self, split_manager: SplitManager, *, force: bool = False) -> None split_input_data = self.select_columns(split_columns) # get the split indices - train, validation, test = split_manager.get_split_indices(split_input_data) + train, validation, test = split_manager.get_split_indices( + split_input_data) # add the split column to the data split_column = np.full(len(self.data), -1).astype(int) @@ -367,17 +385,25 @@ def add_split(self, split_manager: SplitManager, *, force: bool = False) -> None def apply_transformation_group(self, transform_manager: TransformManager) -> None: """Apply the transformation group to the data.""" - for column_name, transform_name, _params in self.dataset_manager.get_transform_logic()["transformations"]: + for ( + column_name, + transform_name, + _params, + ) in self.dataset_manager.get_transform_logic()["transformations"]: transformed_data, add_row = transform_manager.transform_column( column_name, transform_name, self.data[column_name], ) if add_row: - new_rows = self.data.with_columns(pl.Series(column_name, transformed_data)) + new_rows = self.data.with_columns( + pl.Series(column_name, transformed_data) + ) self.data = pl.vstack(self.data, new_rows) else: - self.data = self.data.with_columns(pl.Series(column_name, transformed_data)) + self.data = self.data.with_columns( + pl.Series(column_name, transformed_data) + ) def shuffle_labels(self, seed: Optional[float] = None) -> None: """Shuffles the labels in the data.""" @@ -386,7 +412,9 @@ def shuffle_labels(self, seed: Optional[float] = None) -> None: label_keys = self.dataset_manager.column_categories["label"] for key in label_keys: - self.data = self.data.with_columns(pl.Series(key, np.random.permutation(list(self.data[key])))) + self.data = self.data.with_columns( + pl.Series(key, np.random.permutation(list(self.data[key]))) + ) class DatasetLoader(DatasetHandler): @@ -394,15 +422,19 @@ class DatasetLoader(DatasetHandler): def __init__( self, - config_path: str, + data_config: yaml_data.YamlSplitTransformDict, csv_path: str, encoder_loader: loaders.EncoderLoader, split: Union[int, None] = None, ) -> None: """Initialize the DatasetLoader.""" - super().__init__(config_path, csv_path) + super().__init__(data_config, csv_path) self.encoder_manager = EncodeManager(encoder_loader) - self.data = self.load_csv_per_split(csv_path, split) if split is not None else self.load_csv(csv_path) + self.data = ( + self.load_csv_per_split(csv_path, split) + if split is not None + else self.load_csv(csv_path) + ) def get_all_items(self) -> tuple[dict, dict, dict]: """Get the full dataset as three separate dictionaries for inputs, labels and metadata. @@ -428,8 +460,10 @@ def get_all_items(self) -> tuple[dict, dict, dict]: self.dataset_manager.column_categories["label"], self.dataset_manager.column_categories["meta"], ) - input_data = self.encoder_manager.encode_dataframe(self.data[input_columns]) - label_data = self.encoder_manager.encode_dataframe(self.data[label_columns]) + input_data = self.encoder_manager.encode_dataframe( + self.data[input_columns]) + label_data = self.encoder_manager.encode_dataframe( + self.data[label_columns]) meta_data = {key: self.data[key].to_list() for key in meta_columns} return input_data, label_data, meta_data @@ -447,16 +481,21 @@ def load_csv_per_split(self, csv_path: str, split: int) -> pl.DataFrame: we are gonna load only the relevant data for it. """ if "split" not in self.columns: - raise ValueError("The category split is not present in the csv file") + raise ValueError( + "The category split is not present in the csv file") if split not in [0, 1, 2]: - raise ValueError(f"The split value should be 0, 1 or 2. The specified split value is {split}") + raise ValueError( + f"The split value should be 0, 1 or 2. The specified split value is {split}" + ) return pl.scan_csv(csv_path).filter(pl.col("split") == split).collect() def __len__(self) -> int: """Return the length of the first list in input, assumes that all are the same length.""" return len(self.data) - def __getitem__(self, idx: Any) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor], dict[str, list]]: + def __getitem__( + self, idx: Any + ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor], dict[str, list]]: """Get the data at a given index, and encodes the input and label, leaving meta as it is. Args: @@ -474,17 +513,24 @@ def __getitem__(self, idx: Any) -> tuple[dict[str, torch.Tensor], dict[str, torc data_at_index = self.data.slice(start, stop - start) # Process DataFrame - input_data = self.encoder_manager.encode_dataframe(data_at_index[input_columns]) - label_data = self.encoder_manager.encode_dataframe(data_at_index[label_columns]) - meta_data = {key: data_at_index[key].to_list() for key in meta_columns} + input_data = self.encoder_manager.encode_dataframe( + data_at_index[input_columns] + ) + label_data = self.encoder_manager.encode_dataframe( + data_at_index[label_columns] + ) + meta_data = {key: data_at_index[key].to_list() + for key in meta_columns} elif isinstance(idx, int): # For single row, convert to dict with column names as keys row_dict = dict(zip(self.data.columns, self.data.row(idx))) # Create single-row DataFrames for encoding - input_df = pl.DataFrame({col: [row_dict[col]] for col in input_columns}) - label_df = pl.DataFrame({col: [row_dict[col]] for col in label_columns}) + input_df = pl.DataFrame( + {col: [row_dict[col]] for col in input_columns}) + label_df = pl.DataFrame( + {col: [row_dict[col]] for col in label_columns}) input_data = self.encoder_manager.encode_dataframe(input_df) label_data = self.encoder_manager.encode_dataframe(label_df) @@ -494,8 +540,13 @@ def __getitem__(self, idx: Any) -> tuple[dict[str, torch.Tensor], dict[str, torc data_at_index = self.data.select(idx) # Process DataFrame - input_data = self.encoder_manager.encode_dataframe(data_at_index[input_columns]) - label_data = self.encoder_manager.encode_dataframe(data_at_index[label_columns]) - meta_data = {key: data_at_index[key].to_list() for key in meta_columns} + input_data = self.encoder_manager.encode_dataframe( + data_at_index[input_columns] + ) + label_data = self.encoder_manager.encode_dataframe( + data_at_index[label_columns] + ) + meta_data = {key: data_at_index[key].to_list() + for key in meta_columns} return input_data, label_data, meta_data diff --git a/src/stimulus/data/handlertorch.py b/src/stimulus/data/handlertorch.py index e47f38c9..3b89140a 100644 --- a/src/stimulus/data/handlertorch.py +++ b/src/stimulus/data/handlertorch.py @@ -5,6 +5,7 @@ from torch.utils.data import Dataset from stimulus.data import data_handlers, loaders +from stimulus.utils.yaml_data import YamlSplitTransformDict class TorchDataset(Dataset): @@ -12,7 +13,7 @@ class TorchDataset(Dataset): def __init__( self, - config_path: str, + data_config: YamlSplitTransformDict, csv_path: str, encoder_loader: loaders.EncoderLoader, split: Optional[int] = None, @@ -20,13 +21,13 @@ def __init__( """Initialize the TorchDataset. Args: - config_path: Path to the configuration file + data_config: A YamlSplitTransformDict holding the configuration. csv_path: Path to the CSV data file encoder_loader: Encoder loader instance split: Optional tuple containing split information """ self.loader = data_handlers.DatasetLoader( - config_path=config_path, + data_config=data_config, csv_path=csv_path, encoder_loader=encoder_loader, split=split, diff --git a/src/stimulus/data/loaders.py b/src/stimulus/data/loaders.py index d962ac94..4bea4a1b 100644 --- a/src/stimulus/data/loaders.py +++ b/src/stimulus/data/loaders.py @@ -29,14 +29,17 @@ def __init__(self, seed: Optional[float] = None) -> None: """ self.seed = seed - def initialize_column_encoders_from_config(self, column_config: yaml_data.YamlColumns) -> None: + def initialize_column_encoders_from_config( + self, column_config: yaml_data.YamlColumns + ) -> None: """Build the loader from a config dictionary. Args: column_config (yaml_data.YamlColumns): Configuration dictionary containing field names (column_name) and their encoder specifications. """ for field in column_config: - encoder = self.get_encoder(field.encoder[0].name, field.encoder[0].params) + encoder = self.get_encoder( + field.encoder[0].name, field.encoder[0].params) self.set_encoder_as_attribute(field.column_name, encoder) def get_function_encode_all(self, field_name: str) -> Any: @@ -50,7 +53,9 @@ def get_function_encode_all(self, field_name: str) -> Any: """ return getattr(self, field_name).encode_all - def get_encoder(self, encoder_name: str, encoder_params: Optional[dict] = None) -> Any: + def get_encoder( + self, encoder_name: str, encoder_params: Optional[dict] = None + ) -> Any: """Gets an encoder object from the encoders module and initializes it with the given parameters. Args: @@ -63,7 +68,9 @@ def get_encoder(self, encoder_name: str, encoder_params: Optional[dict] = None) try: return getattr(encoders, encoder_name)(**encoder_params) except AttributeError: - logging.exception(f"Encoder '{encoder_name}' not found in the encoders module.") + logging.exception( + f"Encoder '{encoder_name}' not found in the encoders module." + ) logging.exception( f"Available encoders: {[name for name, obj in encoders.__dict__.items() if isinstance(obj, type) and name not in ('ABC', 'Any')]}", ) @@ -72,13 +79,17 @@ def get_encoder(self, encoder_name: str, encoder_params: Optional[dict] = None) except TypeError: if encoder_params is None: return getattr(encoders, encoder_name)() - logging.exception(f"Encoder '{encoder_name}' has incorrect parameters: {encoder_params}") + logging.exception( + f"Encoder '{encoder_name}' has incorrect parameters: {encoder_params}" + ) logging.exception( f"Expected parameters for '{encoder_name}': {inspect.signature(getattr(encoders, encoder_name))}", ) raise - def set_encoder_as_attribute(self, field_name: str, encoder: encoders.AbstractEncoder) -> None: + def set_encoder_as_attribute( + self, field_name: str, encoder: encoders.AbstractEncoder + ) -> None: """Sets the encoder as an attribute of the loader. Args: @@ -99,7 +110,9 @@ def __init__(self, seed: Optional[float] = None) -> None: """ self.seed = seed - def get_data_transformer(self, transformation_name: str, transformation_params: Optional[dict] = None) -> Any: + def get_data_transformer( + self, transformation_name: str, transformation_params: Optional[dict] = None + ) -> Any: """Gets a transformer object from the transformers module. Args: @@ -110,9 +123,13 @@ def get_data_transformer(self, transformation_name: str, transformation_params: Any: The transformer function for the specified transformation """ try: - return getattr(data_transformation_generators, transformation_name)(**transformation_params) + return getattr(data_transformation_generators, transformation_name)( + **transformation_params + ) except AttributeError: - logging.exception(f"Transformer '{transformation_name}' not found in the transformers module.") + logging.exception( + f"Transformer '{transformation_name}' not found in the transformers module." + ) logging.exception( f"Available transformers: {[name for name, obj in data_transformation_generators.__dict__.items() if isinstance(obj, type) and name not in ('ABC', 'Any')]}", ) @@ -121,13 +138,17 @@ def get_data_transformer(self, transformation_name: str, transformation_params: except TypeError: if transformation_params is None: return getattr(data_transformation_generators, transformation_name)() - logging.exception(f"Transformer '{transformation_name}' has incorrect parameters: {transformation_params}") + logging.exception( + f"Transformer '{transformation_name}' has incorrect parameters: {transformation_params}" + ) logging.exception( f"Expected parameters for '{transformation_name}': {inspect.signature(getattr(data_transformation_generators, transformation_name))}", ) raise - def set_data_transformer_as_attribute(self, field_name: str, data_transformer: Any) -> None: + def set_data_transformer_as_attribute( + self, field_name: str, data_transformer: Any + ) -> None: """Sets the data transformer as an attribute of the loader. Args: @@ -136,12 +157,18 @@ def set_data_transformer_as_attribute(self, field_name: str, data_transformer: A """ # check if the field already exists, if it does not, initialize it to an empty dict if not hasattr(self, field_name): - setattr(self, field_name, {data_transformer.__class__.__name__: data_transformer}) + setattr( + self, + field_name, + {data_transformer.__class__.__name__: data_transformer}, + ) else: field_value = getattr(self, field_name) field_value[data_transformer.__class__.__name__] = data_transformer - def initialize_column_data_transformers_from_config(self, transform_config: yaml_data.YamlTransform) -> None: + def initialize_column_data_transformers_from_config( + self, transform_config: yaml_data.YamlTransform + ) -> None: """Build the loader from a config dictionary. Args: @@ -174,7 +201,9 @@ def initialize_column_data_transformers_from_config(self, transform_config: yaml for column in transform_config.columns: col_name = column.column_name for transform_spec in column.transformations: - transformer = self.get_data_transformer(transform_spec.name, transform_spec.params) + transformer = self.get_data_transformer( + transform_spec.name, transform_spec.params + ) self.set_data_transformer_as_attribute(col_name, transformer) @@ -206,7 +235,9 @@ def get_function_split(self) -> Any: ) return self.split.get_split_indexes - def get_splitter(self, splitter_name: str, splitter_params: Optional[dict] = None) -> Any: + def get_splitter( + self, splitter_name: str, splitter_params: Optional[dict] = None + ) -> Any: """Gets a splitter object from the splitters module. Args: @@ -221,7 +252,9 @@ def get_splitter(self, splitter_name: str, splitter_params: Optional[dict] = Non except TypeError: if splitter_params is None: return getattr(splitters, splitter_name)() - logging.exception(f"Splitter '{splitter_name}' has incorrect parameters: {splitter_params}") + logging.exception( + f"Splitter '{splitter_name}' has incorrect parameters: {splitter_params}" + ) logging.exception( f"Expected parameters for '{splitter_name}': {inspect.signature(getattr(splitters, splitter_name))}", ) @@ -235,11 +268,14 @@ def set_splitter_as_attribute(self, splitter: Any) -> None: """ self.split = splitter - def initialize_splitter_from_config(self, split_config: yaml_data.YamlSplit) -> None: + def initialize_splitter_from_config( + self, split_config: yaml_data.YamlSplitConfigDict + ) -> None: """Build the loader from a config dictionary. Args: - split_config (yaml_data.YamlSplit): Configuration dictionary containing split configurations. + split_config (yaml_data.YamlSplitConfigDict): Configuration dictionary containing split configurations. """ - splitter = self.get_splitter(split_config.split_method, split_config.params) + splitter = self.get_splitter( + split_config.split_method, split_config.params) self.set_splitter_as_attribute(splitter) diff --git a/src/stimulus/learner/raytune_learner.py b/src/stimulus/learner/raytune_learner.py index 4d86ca47..95749b57 100644 --- a/src/stimulus/learner/raytune_learner.py +++ b/src/stimulus/learner/raytune_learner.py @@ -21,6 +21,7 @@ from stimulus.learner.predict import PredictWrapper from stimulus.utils.generic_utils import set_general_seeds from stimulus.utils.yaml_model_schema import RayTuneModel +from stimulus.utils.yaml_data import YamlSplitTransformDict class CheckpointDict(TypedDict): @@ -35,7 +36,7 @@ class TuneWrapper: def __init__( self, model_config: RayTuneModel, - data_config_path: str, + data_config: YamlSplitTransformDict, model_class: nn.Module, data_path: str, encoder_loader: EncoderLoader, @@ -75,7 +76,10 @@ def __init__( self.run_config = train.RunConfig( name=tune_run_name if tune_run_name is not None - else "TuneModel_" + datetime.datetime.now(tz=datetime.timezone.utc).strftime("%Y-%m-%d_%H-%M-%S"), + else "TuneModel_" + + datetime.datetime.now(tz=datetime.timezone.utc).strftime( + "%Y-%m-%d_%H-%M-%S" + ), storage_path=ray_results_dir, checkpoint_config=train.CheckpointConfig(checkpoint_at_end=True), stop=model_config.tune.run_params.stop, @@ -83,7 +87,8 @@ def __init__( # add the data path to the config if not os.path.exists(data_path): - raise ValueError("Data path does not exist. Given path:" + data_path) + raise ValueError( + "Data path does not exist. Given path:" + data_path) self.config["data_path"] = os.path.abspath(data_path) # Set up tune_run path @@ -93,7 +98,10 @@ def __init__( ray_results_dir, tune_run_name if tune_run_name is not None - else "TuneModel_" + datetime.datetime.now(tz=datetime.timezone.utc).strftime("%Y-%m-%d_%H-%M-%S"), + else "TuneModel_" + + datetime.datetime.now(tz=datetime.timezone.utc).strftime( + "%Y-%m-%d_%H-%M-%S" + ), ) self.config["_debug"] = debug self.config["model"] = model_class @@ -104,7 +112,7 @@ def __init__( self.cpu_per_trial = model_config.tune.cpu_per_trial self.tuner = self.tuner_initialization( - data_config_path=data_config_path, + data_config=data_config, data_path=data_path, encoder_loader=encoder_loader, autoscaler=autoscaler, @@ -112,7 +120,7 @@ def __init__( def tuner_initialization( self, - data_config_path: str, + data_config: YamlSplitTransformDict, data_path: str, encoder_loader: EncoderLoader, *, @@ -130,25 +138,29 @@ def tuner_initialization( "GPU per trial is more than what is available in the cluster, set autoscaler to True to allow for autoscaler to be used.", ) except KeyError as err: - logging.warning(f"KeyError: {err}, no GPU resources available in the cluster: {cluster_res}") + logging.warning( + f"KeyError: {err}, no GPU resources available in the cluster: {cluster_res}" + ) if self.cpu_per_trial > cluster_res["CPU"] and not autoscaler: raise ValueError( "CPU per trial is more than what is available in the cluster, set autoscaler to True to allow for autoscaler to be used.", ) - logging.info(f"PER_TRIAL resources -> GPU: {self.gpu_per_trial} CPU: {self.cpu_per_trial}") + logging.info( + f"PER_TRIAL resources -> GPU: {self.gpu_per_trial} CPU: {self.cpu_per_trial}" + ) # Pre-load and encode datasets once, then put them in Ray's object store training = TorchDataset( - config_path=data_config_path, + data_config=data_config, csv_path=data_path, encoder_loader=encoder_loader, split=0, ) validation = TorchDataset( - config_path=data_config_path, + data_config=data_config, csv_path=data_path, encoder_loader=encoder_loader, split=1, @@ -182,7 +194,12 @@ def tuner_initialization( resources={"cpu": self.cpu_per_trial, "gpu": self.gpu_per_trial}, ) - return tune.Tuner(trainable, tune_config=self.tune_config, param_space=self.config, run_config=self.run_config) + return tune.Tuner( + trainable, + tune_config=self.tune_config, + param_space=self.config, + run_config=self.run_config, + ) def tune(self) -> ray.tune.ResultGrid: """Run the tuning process.""" @@ -221,7 +238,10 @@ def setup(self, config: dict[Any, Any]) -> None: self.step_size = config["tune"]["step_size"] # Get datasets from Ray's object store - training, validation = ray.get(self.config["_training_ref"]), ray.get(self.config["_validation_ref"]) + training, validation = ( + ray.get(self.config["_training_ref"]), + ray.get(self.config["_validation_ref"]), + ) # use dataloader on training/validation data self.batch_size = config["data_params"]["batch_size"] @@ -269,7 +289,8 @@ def step(self) -> dict: for _step_size in range(self.step_size): for x, y, _meta in self.training: # the loss dict could be unpacked with ** and the function declaration handle it differently like **kwargs. to be decided, personally find this more clean and understable. - self.model.batch(x=x, y=y, optimizer=self.optimizer, **self.loss_dict) + self.model.batch( + x=x, y=y, optimizer=self.optimizer, **self.loss_dict) return self.objective() def objective(self) -> dict[str, float]: @@ -284,29 +305,48 @@ def objective(self) -> dict[str, float]: "recall", "spearmanr", ] # TODO maybe we report only a subset of metrics, given certain criteria (eg. if classification or regression) - predict_val = PredictWrapper(self.model, self.validation, loss_dict=self.loss_dict) - predict_train = PredictWrapper(self.model, self.training, loss_dict=self.loss_dict) + predict_val = PredictWrapper( + self.model, self.validation, loss_dict=self.loss_dict + ) + predict_train = PredictWrapper( + self.model, self.training, loss_dict=self.loss_dict + ) return { - **{"val_" + metric: value for metric, value in predict_val.compute_metrics(metrics).items()}, - **{"train_" + metric: value for metric, value in predict_train.compute_metrics(metrics).items()}, + **{ + "val_" + metric: value + for metric, value in predict_val.compute_metrics(metrics).items() + }, + **{ + "train_" + metric: value + for metric, value in predict_train.compute_metrics(metrics).items() + }, } - def export_model(self, export_dir: str | None = None) -> None: # type: ignore[override] + # type: ignore[override] + def export_model(self, export_dir: str | None = None) -> None: """Export model to safetensors format.""" if export_dir is None: return - safe_save_model(self.model, os.path.join(export_dir, "model.safetensors")) + safe_save_model(self.model, os.path.join( + export_dir, "model.safetensors")) def load_checkpoint(self, checkpoint: dict[Any, Any] | None) -> None: """Load model and optimizer state from checkpoint.""" if checkpoint is None: return checkpoint_dir = checkpoint["checkpoint_dir"] - self.model = safe_load_model(self.model, os.path.join(checkpoint_dir, "model.safetensors")) - self.optimizer.load_state_dict(torch.load(os.path.join(checkpoint_dir, "optimizer.pt"))) + self.model = safe_load_model( + self.model, os.path.join(checkpoint_dir, "model.safetensors") + ) + self.optimizer.load_state_dict( + torch.load(os.path.join(checkpoint_dir, "optimizer.pt")) + ) def save_checkpoint(self, checkpoint_dir: str) -> dict[Any, Any]: """Save model and optimizer state to checkpoint.""" - safe_save_model(self.model, os.path.join(checkpoint_dir, "model.safetensors")) - torch.save(self.optimizer.state_dict(), os.path.join(checkpoint_dir, "optimizer.pt")) + safe_save_model(self.model, os.path.join( + checkpoint_dir, "model.safetensors")) + torch.save( + self.optimizer.state_dict(), os.path.join(checkpoint_dir, "optimizer.pt") + ) return {"checkpoint_dir": checkpoint_dir} diff --git a/src/stimulus/typing/__init__.py b/src/stimulus/typing/__init__.py index cc59e9e2..ed970e73 100644 --- a/src/stimulus/typing/__init__.py +++ b/src/stimulus/typing/__init__.py @@ -23,10 +23,17 @@ from stimulus.data.handlertorch import TorchDataset from stimulus.data.loaders import EncoderLoader, SplitLoader, TransformLoader from stimulus.data.splitters.splitters import AbstractSplitter as Splitter -from stimulus.data.transform.data_transformation_generators import AbstractDataTransformer as Transform +from stimulus.data.transform.data_transformation_generators import ( + AbstractDataTransformer as Transform, +) from stimulus.learner.predict import PredictWrapper from stimulus.learner.raytune_learner import CheckpointDict, TuneModel, TuneWrapper -from stimulus.learner.raytune_parser import RayTuneMetrics, RayTuneOptimizer, RayTuneResult, TuneParser +from stimulus.learner.raytune_parser import ( + RayTuneMetrics, + RayTuneOptimizer, + RayTuneResult, + TuneParser, +) from stimulus.utils.performance import Performance from stimulus.utils.yaml_data import ( YamlColumns, @@ -35,7 +42,8 @@ YamlGlobalParams, YamlSchema, YamlSplit, - YamlSubConfigDict, + YamlSplitConfigDict, + YamlSplitTransformDict, YamlTransform, YamlTransformColumns, YamlTransformColumnsTransformation, @@ -56,7 +64,9 @@ # data/data_handlers.py -DataManager: TypeAlias = DatasetManager | EncodeManager | SplitManager | TransformManager +DataManager: TypeAlias = ( + DatasetManager | EncodeManager | SplitManager | TransformManager +) # data/experiments.py @@ -75,7 +85,7 @@ | YamlGlobalParams | YamlSchema | YamlSplit - | YamlSubConfigDict + | YamlSplitConfigDict | YamlTransform | YamlTransformColumns | YamlTransformColumnsTransformation diff --git a/src/stimulus/utils/yaml_data.py b/src/stimulus/utils/yaml_data.py index cfb550fc..9ed7c516 100644 --- a/src/stimulus/utils/yaml_data.py +++ b/src/stimulus/utils/yaml_data.py @@ -16,7 +16,9 @@ class YamlColumnsEncoder(BaseModel): """Model for column encoder configuration.""" name: str - params: Optional[dict[str, Union[str, list[Any]]]] # Allow both string and list values + params: Optional[ + dict[str, Union[str, list[Any]]] + ] # Allow both string and list values class YamlColumns(BaseModel): @@ -32,7 +34,9 @@ class YamlTransformColumnsTransformation(BaseModel): """Model for column transformation configuration.""" name: str - params: Optional[dict[str, Union[list[Any], float]]] # Allow both list and float values + params: Optional[ + dict[str, Union[list[Any], float]] + ] # Allow both list and float values class YamlTransformColumns(BaseModel): @@ -50,7 +54,9 @@ class YamlTransform(BaseModel): @field_validator("columns") @classmethod - def validate_param_lists_across_columns(cls, columns: list[YamlTransformColumns]) -> list[YamlTransformColumns]: + def validate_param_lists_across_columns( + cls, columns: list[YamlTransformColumns] + ) -> list[YamlTransformColumns]: """Validate that parameter lists across columns have consistent lengths. Args: @@ -105,7 +111,17 @@ class YamlConfigDict(BaseModel): split: list[YamlSplit] -class YamlSubConfigDict(BaseModel): +# TODO: Rename this class to SplitConfigDict +class YamlSplitConfigDict(BaseModel): + """Model for sub-configuration generated from main config.""" + + global_params: YamlGlobalParams + columns: list[YamlColumns] + transforms: list[YamlTransform] + split: YamlSplit + + +class YamlSplitTransformDict(BaseModel): """Model for sub-configuration generated from main config.""" global_params: YamlGlobalParams @@ -120,7 +136,15 @@ class YamlSchema(BaseModel): yaml_conf: YamlConfigDict -def extract_transform_parameters_at_index(transform: YamlTransform, index: int = 0) -> YamlTransform: +class YamlSplitSchema(BaseModel): + """Model for validating a Split YAML schema.""" + + yaml_conf: YamlSplitConfigDict + + +def extract_transform_parameters_at_index( + transform: YamlTransform, index: int = 0 +) -> YamlTransform: """Get a transform with parameters at the specified index. Args: @@ -149,7 +173,9 @@ def extract_transform_parameters_at_index(transform: YamlTransform, index: int = return new_transform -def expand_transform_parameter_combinations(transform: YamlTransform) -> list[YamlTransform]: +def expand_transform_parameter_combinations( + transform: YamlTransform, +) -> list[YamlTransform]: """Get all possible transforms by extracting parameters at each valid index. For a transform with parameter lists, creates multiple new transforms, each containing @@ -167,9 +193,15 @@ def expand_transform_parameter_combinations(transform: YamlTransform) -> list[Ya for column in transform.columns: for transformation in column.transformations: if transformation.params: - list_lengths = [len(v) for v in transformation.params.values() if isinstance(v, list) and len(v) > 1] + list_lengths = [ + len(v) + for v in transformation.params.values() + if isinstance(v, list) and len(v) > 1 + ] if list_lengths: - max_length = list_lengths[0] # All lists have same length due to validator + max_length = list_lengths[ + 0 + ] # All lists have same length due to validator break # Generate a transform for each index @@ -180,7 +212,9 @@ def expand_transform_parameter_combinations(transform: YamlTransform) -> list[Ya return transforms -def expand_transform_list_combinations(transform_list: list[YamlTransform]) -> list[YamlTransform]: +def expand_transform_list_combinations( + transform_list: list[YamlTransform], +) -> list[YamlTransform]: """Expands a list of transforms into all possible parameter combinations. Takes a list of transforms where each transform may contain parameter lists, @@ -204,17 +238,23 @@ def expand_transform_list_combinations(transform_list: list[YamlTransform]) -> l return sub_transforms -def generate_data_configs(yaml_config: YamlConfigDict) -> list[YamlSubConfigDict]: - """Generates all possible data configurations from a YAML config. +def generate_split_configs(yaml_config: YamlConfigDict) -> list[YamlSplitConfigDict]: + """Generates all possible split configuration from a YAML config. Takes a YAML configuration that may contain parameter lists and splits, - and generates all possible combinations of parameters and splits into - separate data configurations. + and generates all unique splits into separate data configurations. For example, if the config has: - - A transform with parameters [0.1, 0.2] + - Two transforms with parameters [0.1, 0.2], [0.3, 0.4] - Two splits [0.7/0.3] and [0.8/0.2] - This will generate 4 configs, 2 for each split. + This will generate 2 configs, 2 for each split. + config_1: + transform: [[0.1, 0.2], [0.3, 0.4]] + split: [0.7, 0.3] + + config_2: + transform: [[0.1, 0.2], [0.3, 0.4]] + split: [0.8, 0.2] Args: yaml_config: The source YAML configuration containing transforms with @@ -222,31 +262,78 @@ def generate_data_configs(yaml_config: YamlConfigDict) -> list[YamlSubConfigDict Returns: list[YamlSubConfigDict]: A list of data configurations, where each - config has single parameter values and one split configuration. The + config has a list of parameters and one split configuration. The length will be the product of the number of parameter combinations and the number of splits. """ if isinstance(yaml_config, dict) and not isinstance(yaml_config, YamlConfigDict): raise TypeError("Input must be a YamlConfigDict object") - sub_transforms = expand_transform_list_combinations(yaml_config.transforms) sub_splits = yaml_config.split sub_configs = [] for split in sub_splits: - for transform in sub_transforms: - sub_configs.append( - YamlSubConfigDict( - global_params=yaml_config.global_params, - columns=yaml_config.columns, - transforms=transform, - split=split, - ), - ) + sub_configs.append( + YamlSplitConfigDict( + global_params=yaml_config.global_params, + columns=yaml_config.columns, + transforms=yaml_config.transforms, + split=split, + ), + ) return sub_configs +def generate_split_transform_configs( + yaml_config: YamlSplitConfigDict, +) -> list[YamlSplitTransformDict]: + """Generates all the transform configuration for a given split + + Takes a YAML configuration that may contain a transform or a list of transform, + and generates all unique transform for a split into separate data configurations. + + For example, if the config has: + - Two transforms with parameters [0.1, 0.2], [0.3, 0.4] + - A split [0.7, 0.3] + This will generate 2 configs, 2 for each split. + transform_config_1: + transform: [0.1, 0.2] + split: [0.7, 0.3] + + transform_config_2: + transform: [0.3, 0.4] + split: [0.7, 0.3] + + Args: + yaml_config: The source YAML configuration containing each + a split with transforms with parameters lists + + Returns: + list[YamlSubConfigTransformDict]: A list of data configurations, where each + config has a list of parameters and one split configuration. The + length will be the product of the number of parameter combinations + and the number of splits. + """ + if isinstance(yaml_config, dict) and not isinstance( + yaml_config, YamlSplitConfigDict + ): + raise TypeError("Input must be a list of YamlSubConfigDict") + + sub_transforms = expand_transform_list_combinations(yaml_config.transforms) + split_transform_config: list[YamlSplitTransformDict] = [] + for transform in sub_transforms: + split_transform_config.append( + YamlSplitTransformDict( + global_params=yaml_config.global_params, + columns=yaml_config.columns, + transforms=transform, + split=yaml_config.split, + ) + ) + return split_transform_config + + def dump_yaml_list_into_files( - yaml_list: list[YamlSubConfigDict], + yaml_list: list[YamlSplitConfigDict], directory_path: str, base_name: str, ) -> None: @@ -264,9 +351,13 @@ def custom_representer(dumper: yaml.Dumper, data: Any) -> yaml.Node: if len(data) == 0: return dumper.represent_scalar("tag:yaml.org,2002:null", "") if isinstance(data[0], dict): - return dumper.represent_sequence("tag:yaml.org,2002:seq", data, flow_style=False) + return dumper.represent_sequence( + "tag:yaml.org,2002:seq", data, flow_style=False + ) if isinstance(data[0], list): - return dumper.represent_sequence("tag:yaml.org,2002:seq", data, flow_style=True) + return dumper.represent_sequence( + "tag:yaml.org,2002:seq", data, flow_style=True + ) return dumper.represent_sequence("tag:yaml.org,2002:seq", data, flow_style=True) class CustomDumper(yaml.Dumper): @@ -282,7 +373,9 @@ def write_line_break(self, _data: Any = None) -> None: if len(self.indents) <= 1: # At root level super().write_line_break(_data) - def increase_indent(self, *, flow: bool = False, indentless: bool = False) -> None: # type: ignore[override] + def increase_indent( + self, *, flow: bool = False, indentless: bool = False + ) -> None: # type: ignore[override] """Ensure consistent indentation by preventing indentless sequences.""" return super().increase_indent( flow=flow, @@ -305,21 +398,30 @@ def fix_params(input_dict: dict[str, Any]) -> dict[str, Any]: processed_dict[key] = [] for encoder in value: processed_encoder = dict(encoder) - if "params" not in processed_encoder or not processed_encoder["params"]: + if ( + "params" not in processed_encoder + or not processed_encoder["params"] + ): processed_encoder["params"] = {} processed_dict[key].append(processed_encoder) elif key == "transformations" and isinstance(value, list): processed_dict[key] = [] for transformation in value: processed_transformation = dict(transformation) - if "params" not in processed_transformation or not processed_transformation["params"]: + if ( + "params" not in processed_transformation + or not processed_transformation["params"] + ): processed_transformation["params"] = {} processed_dict[key].append(processed_transformation) elif isinstance(value, dict): processed_dict[key] = fix_params(value) elif isinstance(value, list): processed_dict[key] = [ - fix_params(list_item) if isinstance(list_item, dict) else list_item for list_item in value + fix_params(list_item) + if isinstance(list_item, dict) + else list_item + for list_item in value ] else: processed_dict[key] = value diff --git a/tests/cli/__snapshots__/test_split_splits.ambr b/tests/cli/__snapshots__/test_split_splits.ambr new file mode 100644 index 00000000..7ab227de --- /dev/null +++ b/tests/cli/__snapshots__/test_split_splits.ambr @@ -0,0 +1,6 @@ +# serializer version: 1 +# name: test_split_split[correct_yaml_path-None] + list([ + '42139ca7745259e09d1e56e24570d2c7', + ]) +# --- diff --git a/tests/cli/__snapshots__/test_split_transforms.ambr b/tests/cli/__snapshots__/test_split_transforms.ambr new file mode 100644 index 00000000..2743eb3f --- /dev/null +++ b/tests/cli/__snapshots__/test_split_transforms.ambr @@ -0,0 +1,7 @@ +# serializer version: 1 +# name: test_split_transforms[correct_yaml_path-None] + list([ + '0e43b7cdcd8d458cc4e6ff80e06ba7ea', + 'e213d72c1df7eda1e0fdee6ccb4bec7f', + ]) +# --- diff --git a/tests/cli/__snapshots__/test_split_yaml.ambr b/tests/cli/__snapshots__/test_split_yaml.ambr deleted file mode 100644 index e4e7731c..00000000 --- a/tests/cli/__snapshots__/test_split_yaml.ambr +++ /dev/null @@ -1,8 +0,0 @@ -# serializer version: 1 -# name: test_split_yaml[correct_yaml_path-None] - list([ - '0e43b7cdcd8d458cc4e6ff80e06ba7ea', - '43a7f9fbac5c32f51fa51680c7679a57', - 'edf8dd2d39b74619d17b298e3b010c77', - ]) -# --- diff --git a/tests/cli/test_check_model.py b/tests/cli/test_check_model.py index 7115e809..273ab63b 100644 --- a/tests/cli/test_check_model.py +++ b/tests/cli/test_check_model.py @@ -8,18 +8,29 @@ import ray from stimulus.cli import check_model +from src.stimulus.utils.yaml_data import YamlSplitTransformDict @pytest.fixture def data_path() -> str: """Get path to test data CSV file.""" - return str(Path(__file__).parent.parent / "test_data" / "titanic" / "titanic_stimulus_split.csv") + return str( + Path(__file__).parent.parent + / "test_data" + / "titanic" + / "titanic_stimulus_split.csv" + ) @pytest.fixture def data_config() -> str: """Get path to test data config YAML.""" - return str(Path(__file__).parent.parent / "test_data" / "titanic" / "titanic_sub_config.yaml") + return str( + Path(__file__).parent.parent + / "test_data" + / "titanic" + / "titanic_sub_config.yaml" + ) @pytest.fixture @@ -34,7 +45,9 @@ def model_config() -> str: return str(Path(__file__).parent.parent / "test_model" / "titanic_model_cpu.yaml") -def test_check_model_main(data_path: str, data_config: str, model_path: str, model_config: str) -> None: +def test_check_model_main( + data_path: str, data_config: str, model_path: str, model_config: str +) -> None: """Test that check_model.main runs without errors. Args: @@ -50,11 +63,14 @@ def test_check_model_main(data_path: str, data_config: str, model_path: str, mod ray.init(ignore_reinit_error=True) # Verify all required files exist assert os.path.exists(data_path), f"Data file not found at {data_path}" - assert os.path.exists(data_config), f"Data config not found at {data_config}" + assert os.path.exists( + data_config), f"Data config not found at {data_config}" assert os.path.exists(model_path), f"Model file not found at {model_path}" - assert os.path.exists(model_config), f"Model config not found at {model_config}" + assert os.path.exists( + model_config), f"Model config not found at {model_config}" try: + config_dict: yam # Run main function - should complete without errors check_model.main( model_path=model_path, diff --git a/tests/cli/test_shuffle_csv.py b/tests/cli/test_shuffle_csv.py index 8fef803c..2c95d7f0 100644 --- a/tests/cli/test_shuffle_csv.py +++ b/tests/cli/test_shuffle_csv.py @@ -1,5 +1,6 @@ """Tests for the shuffle_csv CLI command.""" +import yaml import hashlib import pathlib import tempfile @@ -8,6 +9,7 @@ import pytest from src.stimulus.cli.shuffle_csv import main +from src.stimulus.utils.yaml_data import YamlSplitTransformDict # Fixtures @@ -42,11 +44,18 @@ def test_shuffle_csv( csv_path = request.getfixturevalue(csv_type) yaml_path = request.getfixturevalue(yaml_type) tmpdir = pathlib.Path(tempfile.gettempdir()) - if error: - with pytest.raises(error): # type: ignore[call-overload] - main(csv_path, yaml_path, str(tmpdir / "test.csv")) - else: - main(csv_path, yaml_path, str(tmpdir / "test.csv")) - with open(tmpdir / "test.csv") as file: - hash = hashlib.md5(file.read().encode()).hexdigest() # noqa: S324 - assert hash == snapshot + with open(yaml_path) as f: + if error: + with pytest.raises(error): # type: ignore[call-overload] + config_dict: YamlSplitTransformDict = YamlSplitTransformDict( + **yaml.safe_load(f) + ) + main(csv_path, config_dict, str(tmpdir / "test.csv")) + else: + config_dict: YamlSplitTransformDict = YamlSplitTransformDict( + **yaml.safe_load(f) + ) + main(csv_path, config_dict, str(tmpdir / "test.csv")) + with open(tmpdir / "test.csv") as file: + hash = hashlib.md5(file.read().encode()).hexdigest() # noqa: S324 + assert hash == snapshot diff --git a/tests/cli/test_split_yaml.py b/tests/cli/test_split_splits.py similarity index 74% rename from tests/cli/test_split_yaml.py rename to tests/cli/test_split_splits.py index ad56f2b7..e0cc445c 100644 --- a/tests/cli/test_split_yaml.py +++ b/tests/cli/test_split_splits.py @@ -1,4 +1,4 @@ -"""Tests for the split_yaml CLI command.""" +"""Tests for the split_split CLI command.""" import hashlib import os @@ -7,7 +7,7 @@ import pytest -from src.stimulus.cli.split_yaml import main +from src.stimulus.cli import split_split # Fixtures @@ -32,24 +32,27 @@ def wrong_yaml_path() -> str: # Tests @pytest.mark.parametrize(("yaml_type", "error"), test_cases) -def test_split_yaml( +def test_split_split( request: pytest.FixtureRequest, snapshot: Callable[[], Any], yaml_type: str, error: Exception | None, + tmp_path, # Pytest tmp file system ) -> None: """Tests the CLI command with correct and wrong YAML files.""" yaml_path = request.getfixturevalue(yaml_type) - tmpdir = tempfile.gettempdir() + tmpdir = tmp_path if error: with pytest.raises(error): # type: ignore[call-overload] - main(yaml_path, tmpdir) + split_split.main(yaml_path, tmpdir) else: - main(yaml_path, tmpdir) # main() returns None, no need to assert + split_split.main(yaml_path, tmpdir) # main() returns None, no need to assert files = os.listdir(tmpdir) test_out = [f for f in files if f.startswith("test_")] hashes = [] for f in test_out: with open(os.path.join(tmpdir, f)) as file: hashes.append(hashlib.md5(file.read().encode()).hexdigest()) # noqa: S324 - assert sorted(hashes) == snapshot # sorted ensures that the order of the hashes does not matter + assert ( + sorted(hashes) == snapshot + ) # sorted ensures that the order of the hashes does not matter diff --git a/tests/cli/test_split_transforms.py b/tests/cli/test_split_transforms.py new file mode 100644 index 00000000..770ec042 --- /dev/null +++ b/tests/cli/test_split_transforms.py @@ -0,0 +1,54 @@ +"""Test for the split_transforms CLI command""" + +import hashlib +import os +from typing import Any, Callable + +import pytest + +from src.stimulus.cli import split_transforms + + +# Fixtures +@pytest.fixture +def correct_yaml_path() -> str: + """Fixture that returns the path to a correct YAML file with one split only""" + return "tests/test_data/titanic/titanic_unique_split.yaml" + + +@pytest.fixture +def wrong_yaml_path() -> str: + """Fixture that returns the path to a wrong YAML file""" + return "tests/test_data/yaml_files/wrong_field_type.yaml" + + +# Test cases +test_cases = [("correct_yaml_path", None), ("wrong_yaml_path", ValueError)] + + +# Tests +@pytest.mark.parametrize(("yaml_type", "error"), test_cases) +def test_split_transforms( + request: pytest.FixtureRequest, + snapshot: Callable[[], Any], + yaml_type: str, + error: Exception | None, + tmp_path, # Pytest tmp file system +) -> None: + """Tests the CLI command with correct and wrong YAML files.""" + yaml_path: str = request.getfixturevalue(yaml_type) + tmpdir = tmp_path + if error: + with pytest.raises(error): + split_transforms.main(yaml_path, tmpdir) + else: + split_transforms.main(yaml_path, tmpdir) + files = os.listdir(tmpdir) + test_out = [f for f in files if f.startswith("test_")] + hashes = [] + for f in test_out: + with open(os.path.join(tmpdir, f)) as file: + hashes.append(hashlib.md5(file.read().encode()).hexdigest()) + assert ( + sorted(hashes) == snapshot + ) # Sorted ensures that the order of the hashes does not matter diff --git a/tests/cli/test_tuning.py b/tests/cli/test_tuning.py index e05d6213..1716de65 100644 --- a/tests/cli/test_tuning.py +++ b/tests/cli/test_tuning.py @@ -12,14 +12,18 @@ import ray import yaml -from stimulus.cli import tuning +from src.stimulus.cli import tuning +from src.stimulus.utils.yaml_data import YamlSplitTransformDict @pytest.fixture def data_path() -> str: """Get path to test data CSV file.""" return str( - Path(__file__).parent.parent / "test_data" / "titanic" / "titanic_stimulus_split.csv", + Path(__file__).parent.parent + / "test_data" + / "titanic" + / "titanic_stimulus_split.csv", ) @@ -27,7 +31,10 @@ def data_path() -> str: def data_config() -> str: """Get path to test data config YAML.""" return str( - Path(__file__).parent.parent / "test_data" / "titanic" / "titanic_sub_config.yaml", + Path(__file__).parent.parent + / "test_data" + / "titanic" + / "titanic_sub_config.yaml", ) @@ -115,6 +122,10 @@ def test_tuning_main( assert os.path.exists(model_config), f"Model config not found at {model_config}" try: + config_dict: YamlSplitTransformDict + with open(data_config) as f: + config_dict = YamlSplitTransformDict(**yaml.safe_load(f)) + results_dir = Path("tests/test_data/titanic/test_results/").resolve() results_dir.mkdir(parents=True, exist_ok=True) @@ -122,10 +133,11 @@ def test_tuning_main( tuning.main( model_path=model_path, data_path=data_path, - data_config_path=data_config, + data_config=config_dict, model_config_path=model_config, initial_weights=None, - ray_results_dirpath=str(results_dir), # Directory path without URI scheme + # Directory path without URI scheme + ray_results_dirpath=str(results_dir), output_path=str(results_dir / "best_model.safetensors"), best_optimizer_path=str(results_dir / "best_optimizer.pt"), best_metrics_path=str(results_dir / "best_metrics.csv"), diff --git a/tests/data/test_data_handlers.py b/tests/data/test_data_handlers.py index 725f275b..fd765192 100644 --- a/tests/data/test_data_handlers.py +++ b/tests/data/test_data_handlers.py @@ -14,15 +14,18 @@ ) from stimulus.utils.yaml_data import ( YamlConfigDict, + YamlSplitConfigDict, + YamlSplitTransformDict, YamlTransform, YamlTransformColumns, YamlTransformColumnsTransformation, - generate_data_configs, + generate_split_configs, + generate_split_transform_configs, ) # Fixtures -## Data fixtures +# Data fixtures @pytest.fixture def titanic_csv_path() -> str: """Get path to test Titanic CSV file. @@ -67,20 +70,29 @@ def generate_sub_configs(base_config: YamlConfigDict) -> list[YamlConfigDict]: Returns: list[YamlConfigDict]: List of generated configurations """ - return generate_data_configs(base_config) + split_configs: list[YamlSplitConfigDict] = generate_split_configs( + base_config) + split_transform_list: list[YamlSplitTransformDict] = [] + for split in split_configs: + split_transform_list.extend(generate_split_transform_configs(split)) + return split_transform_list @pytest.fixture -def dump_single_split_config_to_disk() -> str: +def dump_single_split_config_to_disk() -> YamlSplitTransformDict: """Get path for dumping single split config. Returns: str: Path to dump config file """ - return "tests/test_data/titanic/titanic_sub_config.yaml" + config_dict: YamlSplitTransformDict + path: str = "tests/test_data/titanic/titanic_sub_config.yaml" + with open(path) as f: + config_dict = YamlSplitTransformDict(**yaml.safe_load(f)) + return config_dict -## Loader fixtures +# Loader fixtures @pytest.fixture def encoder_loader(generate_sub_configs: list[YamlConfigDict]) -> loaders.EncoderLoader: """Create encoder loader with initialized encoders. @@ -92,12 +104,15 @@ def encoder_loader(generate_sub_configs: list[YamlConfigDict]) -> loaders.Encode experiments.EncoderLoader: Initialized encoder loader """ loader = loaders.EncoderLoader() - loader.initialize_column_encoders_from_config(generate_sub_configs[0].columns) + loader.initialize_column_encoders_from_config( + generate_sub_configs[0].columns) return loader @pytest.fixture -def transform_loader(generate_sub_configs: list[YamlConfigDict]) -> loaders.TransformLoader: +def transform_loader( + generate_sub_configs: list[YamlConfigDict], +) -> loaders.TransformLoader: """Create transform loader with initialized transformers. Args: @@ -107,7 +122,9 @@ def transform_loader(generate_sub_configs: list[YamlConfigDict]) -> loaders.Tran experiments.TransformLoader: Initialized transform loader """ loader = loaders.TransformLoader() - loader.initialize_column_data_transformers_from_config(generate_sub_configs[0].transforms) + loader.initialize_column_data_transformers_from_config( + generate_sub_configs[0].transforms + ) return loader @@ -127,14 +144,18 @@ def split_loader(generate_sub_configs: list[YamlConfigDict]) -> loaders.SplitLoa # Test DatasetManager -def test_dataset_manager_init(dump_single_split_config_to_disk: str) -> None: +def test_dataset_manager_init( + dump_single_split_config_to_disk: YamlSplitTransformDict, +) -> None: """Test initialization of DatasetManager.""" manager = DatasetManager(dump_single_split_config_to_disk) assert hasattr(manager, "config") assert hasattr(manager, "column_categories") -def test_dataset_manager_organize_columns(dump_single_split_config_to_disk: str) -> None: +def test_dataset_manager_organize_columns( + dump_single_split_config_to_disk: str, +) -> None: """Test column organization by type.""" manager = DatasetManager(dump_single_split_config_to_disk) categories = manager.categorize_columns_by_type() @@ -146,7 +167,9 @@ def test_dataset_manager_organize_columns(dump_single_split_config_to_disk: str) assert "passenger_id" in categories["meta"] -def test_dataset_manager_organize_transforms(dump_single_split_config_to_disk: str) -> None: +def test_dataset_manager_organize_transforms( + dump_single_split_config_to_disk: str, +) -> None: """Test transform organization.""" manager = DatasetManager(dump_single_split_config_to_disk) categories = manager.categorize_columns_by_type() @@ -155,7 +178,9 @@ def test_dataset_manager_organize_transforms(dump_single_split_config_to_disk: s assert all(key in categories for key in ["input", "label", "meta"]) -def test_dataset_manager_get_transform_logic(dump_single_split_config_to_disk: str) -> None: +def test_dataset_manager_get_transform_logic( + dump_single_split_config_to_disk: str, +) -> None: """Test getting transform logic from config.""" manager = DatasetManager(dump_single_split_config_to_disk) transform_logic = manager.get_transform_logic() @@ -221,10 +246,12 @@ def test_transform_manager_transform_column() -> None: ), ], ) - transform_loader.initialize_column_data_transformers_from_config(dummy_config) + transform_loader.initialize_column_data_transformers_from_config( + dummy_config) manager = TransformManager(transform_loader) data = [1, 2, 3] - transformed, added_row = manager.transform_column("test_col", "GaussianNoise", data) + transformed, added_row = manager.transform_column( + "test_col", "GaussianNoise", data) assert len(transformed) == len(data) assert added_row is False @@ -303,15 +330,32 @@ def test_dataset_processor_apply_transformation_group( ) processor_control.data = processor_control.load_csv(titanic_csv_path) - processor.apply_transformation_group(transform_manager=TransformManager(transform_loader)) + processor.apply_transformation_group( + transform_manager=TransformManager(transform_loader) + ) - assert processor.data["age"].to_list() != processor_control.data["age"].to_list() - assert processor.data["fare"].to_list() != processor_control.data["fare"].to_list() - assert processor.data["parch"].to_list() == processor_control.data["parch"].to_list() - assert processor.data["sibsp"].to_list() == processor_control.data["sibsp"].to_list() - assert processor.data["pclass"].to_list() == processor_control.data["pclass"].to_list() - assert processor.data["embarked"].to_list() == processor_control.data["embarked"].to_list() - assert processor.data["sex"].to_list() == processor_control.data["sex"].to_list() + assert processor.data["age"].to_list( + ) != processor_control.data["age"].to_list() + assert processor.data["fare"].to_list( + ) != processor_control.data["fare"].to_list() + assert ( + processor.data["parch"].to_list( + ) == processor_control.data["parch"].to_list() + ) + assert ( + processor.data["sibsp"].to_list( + ) == processor_control.data["sibsp"].to_list() + ) + assert ( + processor.data["pclass"].to_list( + ) == processor_control.data["pclass"].to_list() + ) + assert ( + processor.data["embarked"].to_list() + == processor_control.data["embarked"].to_list() + ) + assert processor.data["sex"].to_list( + ) == processor_control.data["sex"].to_list() # Test DatasetLoader @@ -322,7 +366,7 @@ def test_dataset_loader_init( ) -> None: """Test initialization of DatasetLoader.""" loader = DatasetLoader( - config_path=dump_single_split_config_to_disk, + data_config=dump_single_split_config_to_disk, csv_path=titanic_csv_path, encoder_loader=encoder_loader, ) @@ -340,7 +384,7 @@ def test_dataset_loader_get_dataset( ) -> None: """Test getting dataset from loader.""" loader = DatasetLoader( - config_path=dump_single_split_config_to_disk, + data_config=dump_single_split_config_to_disk, csv_path=titanic_csv_path, encoder_loader=encoder_loader, ) diff --git a/tests/data/test_experiment.py b/tests/data/test_experiment.py index 15954217..aa261dc9 100644 --- a/tests/data/test_experiment.py +++ b/tests/data/test_experiment.py @@ -21,7 +21,9 @@ def dna_experiment_config_path() -> str: @pytest.fixture -def dna_experiment_sub_yaml(dna_experiment_config_path: str) -> yaml_data.YamlConfigDict: +def dna_experiment_sub_yaml( + dna_experiment_config_path: str, +) -> yaml_data.YamlConfigDict: """Get a sub-configuration from the DNA experiment config. Args: @@ -34,8 +36,11 @@ def dna_experiment_sub_yaml(dna_experiment_config_path: str) -> yaml_data.YamlCo yaml_dict = yaml.safe_load(f) yaml_config = yaml_data.YamlConfigDict(**yaml_dict) - yaml_configs = yaml_data.generate_data_configs(yaml_config) - return yaml_configs[0] + yaml_split_configs = yaml_data.generate_split_configs(yaml_config) + yaml_split_transform_configs = yaml_data.generate_split_transform_configs( + yaml_split_configs[0] + ) + return yaml_split_transform_configs[0] @pytest.fixture @@ -80,7 +85,9 @@ def test_get_encoder(text_onehot_encoder_params: tuple[str, dict[str, str]]) -> assert isinstance(encoder, AbstractEncoder) -def test_set_encoder_as_attribute(text_onehot_encoder_params: tuple[str, dict[str, str]]) -> None: +def test_set_encoder_as_attribute( + text_onehot_encoder_params: tuple[str, dict[str, str]], +) -> None: """Test the set_encoder_as_attribute method of the AbstractExperiment class. Args: @@ -95,7 +102,9 @@ def test_set_encoder_as_attribute(text_onehot_encoder_params: tuple[str, dict[st assert experiment.get_function_encode_all("ciao") == encoder.encode_all -def test_build_experiment_class_encoder_dict(dna_experiment_sub_yaml: yaml_data.YamlConfigDict) -> None: +def test_build_experiment_class_encoder_dict( + dna_experiment_sub_yaml: yaml_data.YamlConfigDict, +) -> None: """Test the build_experiment_class_encoder_dict method. Args: @@ -115,7 +124,8 @@ def test_get_data_transformer() -> None: """Test the get_data_transformer method of the TransformLoader class.""" experiment = loaders.TransformLoader() transformer = experiment.get_data_transformer("ReverseComplement") - assert isinstance(transformer, data_transformation_generators.ReverseComplement) + assert isinstance( + transformer, data_transformation_generators.ReverseComplement) def test_set_data_transformer_as_attribute() -> None: @@ -141,7 +151,10 @@ def test_initialize_column_data_transformers_from_config( assert hasattr(experiment, "col1") column_transformers = experiment.col1 - assert any(isinstance(t, data_transformation_generators.ReverseComplement) for t in column_transformers.values()) + assert any( + isinstance(t, data_transformation_generators.ReverseComplement) + for t in column_transformers.values() + ) def test_initialize_splitter_from_config( diff --git a/tests/data/test_handlertorch.py b/tests/data/test_handlertorch.py index d121ed39..dff24909 100644 --- a/tests/data/test_handlertorch.py +++ b/tests/data/test_handlertorch.py @@ -40,11 +40,13 @@ def titanic_yaml_config(titanic_config_path: str) -> dict: dict: Loaded YAML configuration """ with open(titanic_config_path) as file: - return yaml_data.YamlSubConfigDict(**yaml.safe_load(file)) + return yaml_data.YamlSplitTransformDict(**yaml.safe_load(file)) @pytest.fixture -def titanic_encoder_loader(titanic_yaml_config: yaml_data.YamlSubConfigDict) -> loaders.EncoderLoader: +def titanic_encoder_loader( + titanic_yaml_config: yaml_data.YamlSplitTransformDict, +) -> loaders.EncoderLoader: """Get Titanic encoder loader.""" loader = loaders.EncoderLoader() loader.initialize_column_encoders_from_config(titanic_yaml_config.columns) @@ -57,8 +59,11 @@ def test_init_handlertorch( titanic_encoder_loader: loaders.EncoderLoader, ) -> None: """Test TorchDataset initialization.""" + data_config: yaml_data.YamlSplitTransformDict + with open(titanic_config_path) as f: + data_config = yaml_data.YamlSplitTransformDict(**yaml.safe_load(f)) handlertorch.TorchDataset( - config_path=titanic_config_path, + data_config=data_config, csv_path=titanic_csv_path, encoder_loader=titanic_encoder_loader, ) @@ -76,8 +81,11 @@ def test_len_handlertorch( titanic_csv_path: Path to CSV file titanic_encoder_loader: Encoder loader instance """ + data_config: yaml_data.YamlSplitTransformDict + with open(titanic_config_path) as f: + data_config = yaml_data.YamlSplitTransformDict(**yaml.safe_load(f)) dataset = handlertorch.TorchDataset( - config_path=titanic_config_path, + data_config=data_config, csv_path=titanic_csv_path, encoder_loader=titanic_encoder_loader, ) @@ -96,8 +104,11 @@ def test_getitem_handlertorch_slice( titanic_csv_path: Path to CSV file titanic_encoder_loader: Encoder loader instance """ + data_config: yaml_data.YamlSplitTransformDict + with open(titanic_config_path) as f: + data_config = yaml_data.YamlSplitTransformDict(**yaml.safe_load(f)) dataset = handlertorch.TorchDataset( - config_path=titanic_config_path, + data_config=data_config, csv_path=titanic_csv_path, encoder_loader=titanic_encoder_loader, ) @@ -117,8 +128,11 @@ def test_getitem_handlertorch_int( titanic_csv_path: Path to CSV file titanic_encoder_loader: Encoder loader instance """ + data_config: yaml_data.YamlSplitTransformDict + with open(titanic_config_path) as f: + data_config = yaml_data.YamlSplitTransformDict(**yaml.safe_load(f)) dataset = handlertorch.TorchDataset( - config_path=titanic_config_path, + data_config=data_config, csv_path=titanic_csv_path, encoder_loader=titanic_encoder_loader, ) diff --git a/tests/data/transform/test_data_transformers.py b/tests/data/transform/test_data_transformers.py index 91215a89..116b37ea 100644 --- a/tests/data/transform/test_data_transformers.py +++ b/tests/data/transform/test_data_transformers.py @@ -13,7 +13,10 @@ ReverseComplement, UniformTextMasker, ) -from stimulus.utils.yaml_data import dump_yaml_list_into_files, generate_data_configs +from stimulus.utils.yaml_data import ( + dump_yaml_list_into_files, + generate_split_transform_configs, +) class DataTransformerTest: @@ -74,7 +77,11 @@ def gaussian_noise() -> DataTransformerTest: single_input = 5.0 expected_single_output = 5.4967141530112327 multiple_inputs = [1.0, 2.0, 3.0] - expected_multiple_outputs = [1.4967141530112327, 1.8617356988288154, 3.6476885381006925] + expected_multiple_outputs = [ + 1.4967141530112327, + 1.8617356988288154, + 3.6476885381006925, + ] return DataTransformerTest( transformer=transformer, params=params, @@ -134,7 +141,9 @@ def test_transform_single(self, request: Any, test_data_name: str) -> None: test_data: The test data to use. """ test_data = request.getfixturevalue(test_data_name) - transformed_data = test_data.transformer.transform(test_data.single_input, **test_data.params) + transformed_data = test_data.transformer.transform( + test_data.single_input, **test_data.params + ) assert isinstance(transformed_data, str) assert transformed_data == test_data.expected_single_output @@ -142,7 +151,10 @@ def test_transform_single(self, request: Any, test_data_name: str) -> None: def test_transform_multiple(self, request: Any, test_data_name: str) -> None: """Test masking multiple strings.""" test_data = request.getfixturevalue(test_data_name) - transformed_data = [test_data.transformer.transform(x, **test_data.params) for x in test_data.multiple_inputs] + transformed_data = [ + test_data.transformer.transform(x, **test_data.params) + for x in test_data.multiple_inputs + ] assert isinstance(transformed_data, list) for item in transformed_data: assert isinstance(item, str) @@ -156,20 +168,30 @@ class TestGaussianNoise: def test_transform_single(self, request: Any, test_data_name: str) -> None: """Test transforming a single float.""" test_data = request.getfixturevalue(test_data_name) - transformed_data = test_data.transformer.transform(test_data.single_input, **test_data.params) + transformed_data = test_data.transformer.transform( + test_data.single_input, **test_data.params + ) assert isinstance(transformed_data, float) - assert round(transformed_data, 7) == round(test_data.expected_single_output, 7) + assert round(transformed_data, 7) == round( + test_data.expected_single_output, 7) @pytest.mark.parametrize("test_data_name", ["gaussian_noise"]) - def test_transform_multiple(self, request: Any, test_data_name: DataTransformerTest) -> None: + def test_transform_multiple( + self, request: Any, test_data_name: DataTransformerTest + ) -> None: """Test transforming multiple floats.""" test_data = request.getfixturevalue(test_data_name) - transformed_data = test_data.transformer.transform_all(test_data.multiple_inputs, **test_data.params) + transformed_data = test_data.transformer.transform_all( + test_data.multiple_inputs, **test_data.params + ) assert isinstance(transformed_data, list) for item in transformed_data: assert isinstance(item, float) - assert len(transformed_data) == len(test_data.expected_multiple_outputs) - for item, expected in zip(transformed_data, test_data.expected_multiple_outputs): + assert len(transformed_data) == len( + test_data.expected_multiple_outputs) + for item, expected in zip( + transformed_data, test_data.expected_multiple_outputs + ): assert round(item, 7) == round(expected, 7) @@ -180,7 +202,8 @@ class TestGaussianChunk: def test_transform_single(self, request: Any, test_data_name: str) -> None: """Test transforming a single string.""" test_data = request.getfixturevalue(test_data_name) - transformed_data = test_data.transformer.transform(test_data.single_input) + transformed_data = test_data.transformer.transform( + test_data.single_input) assert isinstance(transformed_data, str) assert len(transformed_data) == 2 @@ -188,7 +211,9 @@ def test_transform_single(self, request: Any, test_data_name: str) -> None: def test_transform_multiple(self, request: Any, test_data_name: str) -> None: """Test transforming multiple strings.""" test_data = request.getfixturevalue(test_data_name) - transformed_data = [test_data.transformer.transform(x) for x in test_data.multiple_inputs] + transformed_data = [ + test_data.transformer.transform(x) for x in test_data.multiple_inputs + ] assert isinstance(transformed_data, list) for item in transformed_data: assert isinstance(item, str) @@ -200,7 +225,9 @@ def test_chunk_size_excessive(self, request: Any, test_data_name: str) -> None: """Test that the transform fails if chunk size is greater than the length of the input string.""" test_data = request.getfixturevalue(test_data_name) transformer = GaussianChunk(chunk_size=100) - with pytest.raises(ValueError, match="The input data is shorter than the chunk size"): + with pytest.raises( + ValueError, match="The input data is shorter than the chunk size" + ): transformer.transform(test_data.single_input) @@ -211,7 +238,9 @@ class TestReverseComplement: def test_transform_single(self, request: Any, test_data_name: str) -> None: """Test transforming a single string.""" test_data = request.getfixturevalue(test_data_name) - transformed_data = test_data.transformer.transform(test_data.single_input, **test_data.params) + transformed_data = test_data.transformer.transform( + test_data.single_input, **test_data.params + ) assert isinstance(transformed_data, str) assert transformed_data == test_data.expected_single_output @@ -219,7 +248,9 @@ def test_transform_single(self, request: Any, test_data_name: str) -> None: def test_transform_multiple(self, request: Any, test_data_name: str) -> None: """Test transforming multiple strings.""" test_data = request.getfixturevalue(test_data_name) - transformed_data = test_data.transformer.transform_all(test_data.multiple_inputs, **test_data.params) + transformed_data = test_data.transformer.transform_all( + test_data.multiple_inputs, **test_data.params + ) assert isinstance(transformed_data, list) for item in transformed_data: assert isinstance(item, str) @@ -232,7 +263,9 @@ def titanic_config_path(base_config: dict) -> str: config_path = "tests/test_data/titanic/titanic_sub_config_0.yaml" if not os.path.exists(config_path): - configs = generate_data_configs(base_config) - dump_yaml_list_into_files([configs[0]], "tests/test_data/titanic/", "titanic_sub_config") + configs = generate_split_transform_configs(base_config) + dump_yaml_list_into_files( + [configs[0]], "tests/test_data/titanic/", "titanic_sub_config" + ) return os.path.abspath(config_path) diff --git a/tests/learner/test_raytune_learner.py b/tests/learner/test_raytune_learner.py index 54ef6e27..8ade0911 100644 --- a/tests/learner/test_raytune_learner.py +++ b/tests/learner/test_raytune_learner.py @@ -10,7 +10,7 @@ from stimulus.data.handlertorch import TorchDataset from stimulus.data.loaders import EncoderLoader from stimulus.learner.raytune_learner import TuneWrapper -from stimulus.utils.yaml_data import YamlSubConfigDict +from stimulus.utils.yaml_data import YamlSplitConfigDict, YamlSplitTransformDict from stimulus.utils.yaml_model_schema import Model, RayTuneModel, YamlRayConfigLoader from tests.test_model import titanic_model @@ -29,7 +29,9 @@ def encoder_loader() -> EncoderLoader: with open("tests/test_data/titanic/titanic_sub_config.yaml") as file: data_config = yaml.safe_load(file) encoder_loader = EncoderLoader() - encoder_loader.initialize_column_encoders_from_config(YamlSubConfigDict(**data_config).columns) + encoder_loader.initialize_column_encoders_from_config( + YamlSplitTransformDict(**data_config).columns + ) return encoder_loader @@ -44,7 +46,9 @@ def titanic_dataset(encoder_loader: EncoderLoader) -> TorchDataset: ) -def test_tunewrapper_init(ray_config_loader: RayTuneModel, encoder_loader: EncoderLoader) -> None: +def test_tunewrapper_init( + ray_config_loader: RayTuneModel, encoder_loader: EncoderLoader +) -> None: """Test the initialization of the TuneWrapper class.""" # Filter ResourceWarning during Ray shutdown warnings.filterwarnings("ignore", category=ResourceWarning) @@ -53,14 +57,19 @@ def test_tunewrapper_init(ray_config_loader: RayTuneModel, encoder_loader: Encod ray.init(ignore_reinit_error=True) try: + data_config: YamlSplitTransformDict + with open("tests/test_data/titanic/titanic_sub_config.yaml") as f: + data_config = YamlSplitTransformDict(**yaml.safe_load(f)) + tune_wrapper = TuneWrapper( model_config=ray_config_loader, model_class=titanic_model.ModelTitanic, data_path="tests/test_data/titanic/titanic_stimulus_split.csv", - data_config_path="tests/test_data/titanic/titanic_sub_config.yaml", + data_config=data_config, encoder_loader=encoder_loader, seed=42, - ray_results_dir=os.path.abspath("tests/test_data/titanic/ray_results"), + ray_results_dir=os.path.abspath( + "tests/test_data/titanic/ray_results"), tune_run_name="test_run", debug=False, autoscaler=False, @@ -74,10 +83,13 @@ def test_tunewrapper_init(ray_config_loader: RayTuneModel, encoder_loader: Encod if os.path.exists("tests/test_data/titanic/ray_results"): import shutil - shutil.rmtree("tests/test_data/titanic/ray_results", ignore_errors=True) + shutil.rmtree("tests/test_data/titanic/ray_results", + ignore_errors=True) -def test_tune_wrapper_tune(ray_config_loader: RayTuneModel, encoder_loader: EncoderLoader) -> None: +def test_tune_wrapper_tune( + ray_config_loader: RayTuneModel, encoder_loader: EncoderLoader +) -> None: """Test the tune method of TuneWrapper class.""" # Filter ResourceWarning during Ray shutdown warnings.filterwarnings("ignore", category=ResourceWarning) @@ -86,14 +98,19 @@ def test_tune_wrapper_tune(ray_config_loader: RayTuneModel, encoder_loader: Enco ray.init(ignore_reinit_error=True) try: + data_config: YamlSplitTransformDict + with open("tests/test_data/titanic/titanic_sub_config.yaml") as f: + data_config = YamlSplitTransformDict(**yaml.safe_load(f)) + tune_wrapper = TuneWrapper( model_config=ray_config_loader, model_class=titanic_model.ModelTitanic, data_path="tests/test_data/titanic/titanic_stimulus_split.csv", - data_config_path="tests/test_data/titanic/titanic_sub_config.yaml", + data_config=data_config, encoder_loader=encoder_loader, seed=42, - ray_results_dir=os.path.abspath("tests/test_data/titanic/ray_results"), + ray_results_dir=os.path.abspath( + "tests/test_data/titanic/ray_results"), tune_run_name="test_run", debug=False, autoscaler=False, @@ -108,4 +125,5 @@ def test_tune_wrapper_tune(ray_config_loader: RayTuneModel, encoder_loader: Enco if os.path.exists("tests/test_data/titanic/ray_results"): import shutil - shutil.rmtree("tests/test_data/titanic/ray_results", ignore_errors=True) + shutil.rmtree("tests/test_data/titanic/ray_results", + ignore_errors=True) diff --git a/tests/test_data/titanic/titanic_unique_split.yaml b/tests/test_data/titanic/titanic_unique_split.yaml new file mode 100644 index 00000000..3ef7d17d --- /dev/null +++ b/tests/test_data/titanic/titanic_unique_split.yaml @@ -0,0 +1,91 @@ + +global_params: + seed: 42 + +columns: + - column_name: passenger_id + column_type: meta + data_type: int + encoder: + - name: NumericEncoder + params: {} + - column_name: survived + column_type: label + data_type: int + encoder: + - name: NumericEncoder + params: {} + - column_name: pclass + column_type: input + data_type: int + encoder: + - name: NumericEncoder + params: {} + - column_name: sex + column_type: input + data_type: str + encoder: + - name: StrClassificationEncoder + params: {} + - column_name: age + column_type: input + data_type: float + encoder: + - name: NumericEncoder + params: {} + - column_name: sibsp + column_type: input + data_type: int + encoder: + - name: NumericEncoder + params: {} + - column_name: parch + column_type: input + data_type: int + encoder: + - name: NumericEncoder + params: {} + - column_name: fare + column_type: input + data_type: float + encoder: + - name: NumericEncoder + params: {} + - column_name: embarked + column_type: input + data_type: str + encoder: + - name: StrClassificationEncoder + params: {} + +transforms: + - transformation_name: noise + columns: + - column_name: age + transformations: + - name: GaussianNoise + params: + std: 0.1 + - column_name: fare + transformations: + - name: GaussianNoise + params: + std: 0.1 + - transformation_name: noise2 + columns: + - column_name: age + transformations: + - name: GaussianNoise + params: + std: 0.1 + - column_name: fare + transformations: + - name: GaussianNoise + params: + std: 0.1 + +split: + split_method: RandomSplit + params: + split: [0.7, 0.15, 0.15] + split_input_columns: [age] diff --git a/tests/typing/test_typing.py b/tests/typing/test_typing.py index 22bc8577..d2be37ff 100644 --- a/tests/typing/test_typing.py +++ b/tests/typing/test_typing.py @@ -49,7 +49,8 @@ def test_yaml_data_types() -> None: YamlGlobalParams, YamlSchema, YamlSplit, - YamlSubConfigDict, + YamlSplitConfigDict, + YamlSplitTransformDict, YamlTransform, YamlTransformColumns, YamlTransformColumnsTransformation, diff --git a/tests/utils/test_data_yaml.py b/tests/utils/test_data_yaml.py index e7398d44..a96b5458 100644 --- a/tests/utils/test_data_yaml.py +++ b/tests/utils/test_data_yaml.py @@ -6,8 +6,10 @@ from src.stimulus.utils import yaml_data from src.stimulus.utils.yaml_data import ( YamlConfigDict, - YamlSubConfigDict, - generate_data_configs, + YamlSplitConfigDict, + YamlSplitTransformDict, + generate_split_configs, + generate_split_transform_configs, ) @@ -25,10 +27,20 @@ def load_titanic_yaml_from_file() -> YamlConfigDict: return YamlConfigDict(**yaml_dict) +@pytest.fixture +def load_split_config_yaml_from_file() -> YamlSplitConfigDict: + """Fixture that loads a test unique split YAML configuration file.""" + with open("tests/test_data/titanic/titanic_unique_split.yaml") as f: + yaml_dict = yaml.safe_load(f) + return YamlSplitConfigDict(**yaml_dict) + + @pytest.fixture def load_yaml_from_file() -> YamlConfigDict: """Fixture that loads a test YAML configuration file.""" - with open("tests/test_data/dna_experiment/dna_experiment_config_template.yaml") as f: + with open( + "tests/test_data/dna_experiment/dna_experiment_config_template.yaml" + ) as f: yaml_dict = yaml.safe_load(f) return YamlConfigDict(**yaml_dict) @@ -40,12 +52,20 @@ def load_wrong_type_yaml() -> dict: return yaml.safe_load(f) +def test_split_config_validation(load_titanic_yaml_from_file: YamlConfigDict) -> None: + """Test split configuration validation.""" + split_config = generate_split_configs(load_titanic_yaml_from_file)[0] + YamlSplitConfigDict.model_validate(split_config) + + def test_sub_config_validation( - load_titanic_yaml_from_file: YamlConfigDict, + load_split_config_yaml_from_file: YamlConfigDict, ) -> None: """Test sub-config validation.""" - sub_config = generate_data_configs(load_titanic_yaml_from_file)[0] - YamlSubConfigDict.model_validate(sub_config) + split_config = generate_split_transform_configs( + load_split_config_yaml_from_file)[0] + print(f"{split_config=}") + YamlSplitTransformDict.model_validate(split_config) def test_expand_transform_parameter_combinations( @@ -56,15 +76,20 @@ def test_expand_transform_parameter_combinations( transform = load_yaml_from_file.transforms[0] results = yaml_data.expand_transform_parameter_combinations(transform) assert len(results) == 1 # Only one transform returned - assert isinstance(results[0], yaml_data.YamlTransform) # Should return YamlTransform objects + assert isinstance( + results[0], yaml_data.YamlTransform + ) # Should return YamlTransform objects def test_expand_transform_list_combinations( load_yaml_from_file: YamlConfigDict, ) -> None: """Tests expanding a list of transforms into all parameter combinations.""" - results = yaml_data.expand_transform_list_combinations(load_yaml_from_file.transforms) - assert len(results) == 8 # 4 combinations from first transform x 2 from second + results = yaml_data.expand_transform_list_combinations( + load_yaml_from_file.transforms + ) + # 4 combinations from first transform x 2 from second + assert len(results) == 8 # Each result should be a YamlTransform for result in results: assert isinstance(result, yaml_data.YamlTransform) @@ -76,14 +101,18 @@ def test_generate_data_configs( load_yaml_from_file: YamlConfigDict, ) -> None: """Tests generating all possible data configurations.""" - configs = yaml_data.generate_data_configs(load_yaml_from_file) + split_configs = yaml_data.generate_split_configs(load_yaml_from_file) + configs: list[YamlSplitTransformDict] = [] + for s_conf in split_configs: + configs.extend(generate_split_transform_configs(s_conf)) + assert len(configs) == 16 # 8 transform combinations x 2 splits # Check each config individually to help debug for i, config in enumerate(configs): assert isinstance( config, - yaml_data.YamlSubConfigDict, + yaml_data.YamlSplitTransformDict, ), f"Config {i} is type {type(config)}, expected YamlSubConfigDict" @@ -98,7 +127,9 @@ def test_check_yaml_schema( """Tests the Pydantic schema validation.""" data = request.getfixturevalue(test_input[0]) if test_input[1]: - with pytest.raises(ValueError, match="Wrong type on a field, see the pydantic report above"): + with pytest.raises( + ValueError, match="Wrong type on a field, see the pydantic report above" + ): yaml_data.check_yaml_schema(data) else: yaml_data.check_yaml_schema(data)