Skip to content

Commit

Permalink
(change)config object implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
ved1beta committed Feb 22, 2025
1 parent 48b7ea9 commit 7d7ead8
Showing 1 changed file with 32 additions and 35 deletions.
67 changes: 32 additions & 35 deletions src/stimulus/data/data_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,24 +36,27 @@
class DatasetManager:
"""Class for managing the dataset.
This class handles loading and organizing dataset configuration from YAML files.
This class handles organizing dataset configuration from YAML data.
It manages column categorization into input, label and meta types based on the config.
Attributes:
config (dict): The loaded configuration dictionary from YAML
config (dict): The configuration dictionary
column_categories (dict): Dictionary mapping column types to lists of column names
Methods:
_load_config(config_path: str) -> dict: Loads the config from a YAML file.
categorize_columns_by_type() -> dict: Organizes the columns into input, label, meta based on the config.
"""

def __init__(
self,
config_path: str,
config: yaml_data.YamlConfigDict,
) -> None:
"""Initialize the DatasetManager."""
self.config = self._load_config(config_path)
"""Initialize the DatasetManager.
Args:
config: The configuration dictionary
"""
self.config = config
self.column_categories = self.categorize_columns_by_type()

def categorize_columns_by_type(self) -> dict:
Expand All @@ -71,7 +74,7 @@ def categorize_columns_by_type(self) -> dict:
}
Example:
>>> manager = DatasetManager("config.yaml")
>>> manager = DatasetManager(config)
>>> categories = manager.categorize_columns_by_type()
>>> print(categories)
{
Expand All @@ -93,24 +96,6 @@ def categorize_columns_by_type(self) -> dict:

return {"input": input_columns, "label": label_columns, "meta": meta_columns}

def _load_config(self, config_path: str) -> yaml_data.YamlConfigDict:
"""Loads and parses a YAML configuration file.
Args:
config_path (str): Path to the YAML config file
Returns:
dict: Parsed configuration dictionary
Example:
>>> manager = DatasetManager()
>>> config = manager._load_config("config.yaml")
>>> print(config["columns"][0]["column_name"])
'hello'
"""
with open(config_path) as file:
return yaml_data.YamlSubConfigDict(**yaml.safe_load(file))

def get_split_columns(self) -> list[str]:
"""Get the columns that are used for splitting."""
return self.config.split.split_input_columns
Expand Down Expand Up @@ -270,16 +255,16 @@ class DatasetHandler:

def __init__(
self,
config_path: str,
config: yaml_data.YamlConfigDict,
csv_path: str,
) -> None:
"""Initialize the DatasetHandler with required config.
Args:
config_path (str): Path to the dataset configuration file.
csv_path (str): Path to the CSV data file.
config: The dataset configuration dictionary
csv_path: Path to the CSV data file
"""
self.dataset_manager = DatasetManager(config_path)
self.dataset_manager = DatasetManager(config)
self.columns = self.read_csv_header(csv_path)
self.data = self.load_csv(csv_path)

Expand Down Expand Up @@ -331,9 +316,14 @@ def save(self, path: str) -> None:
class DatasetProcessor(DatasetHandler):
"""Class for loading dataset, applying transformations and splitting."""

def __init__(self, config_path: str, csv_path: str) -> None:
"""Initialize the DatasetProcessor."""
super().__init__(config_path, csv_path)
def __init__(self, config: yaml_data.YamlConfigDict, csv_path: str) -> None:
"""Initialize the DatasetProcessor.
Args:
config: The dataset configuration dictionary
csv_path: Path to the CSV data file
"""
super().__init__(config, csv_path)

def add_split(self, split_manager: SplitManager, *, force: bool = False) -> None:
"""Add a column specifying the train, validation, test splits of the data.
Expand Down Expand Up @@ -394,13 +384,20 @@ class DatasetLoader(DatasetHandler):

def __init__(
self,
config_path: str,
config: yaml_data.YamlConfigDict,
csv_path: str,
encoder_loader: loaders.EncoderLoader,
split: Union[int, None] = None,
) -> None:
"""Initialize the DatasetLoader."""
super().__init__(config_path, csv_path)
"""Initialize the DatasetLoader.
Args:
config: The dataset configuration dictionary
csv_path: Path to the CSV data file
encoder_loader: Loader for handling data encoding
split: Optional split index (0=train, 1=val, 2=test)
"""
super().__init__(config, csv_path)
self.encoder_manager = EncodeManager(encoder_loader)
self.data = self.load_csv_per_split(csv_path, split) if split is not None else self.load_csv(csv_path)

Expand Down

0 comments on commit 7d7ead8

Please sign in to comment.