Skip to content

Commit

Permalink
refactor(raytune_learner): moved raytune_learner to new paradigm.
Browse files Browse the repository at this point in the history
  • Loading branch information
mathysgrapotte committed Feb 20, 2025
1 parent 96024a3 commit 3efda09
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 82 deletions.
21 changes: 3 additions & 18 deletions src/stimulus/data/data_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,29 +321,14 @@ class TorchDataset(torch.utils.data.Dataset):

def __init__(
self,
encoders: dict[str, encoders_module.AbstractEncoder],
input_columns: list[str],
label_columns: list[str],
meta_columns: list[str],
csv_path: str,
split: Optional[int] = None,
loader: DatasetLoader,
) -> None:
"""Initialize the TorchDataset.
Args:
data_config: A YamlSplitTransformDict holding the configuration.
csv_path: Path to the CSV data file
encoder_loader: Encoder loader instance
split: Optional tuple containing split information
loader: A DatasetLoader instance
"""
self.loader = DatasetLoader(
encoders=encoders,
input_columns=input_columns,
label_columns=label_columns,
meta_columns=meta_columns,
csv_path=csv_path,
split=split,
)
self.loader = loader

def __len__(self) -> int:
return len(self.loader)
Expand Down
35 changes: 11 additions & 24 deletions src/stimulus/learner/raytune_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from torch.utils.data import DataLoader

from stimulus.data.data_handlers import TorchDataset
from stimulus.data.interface.data_config_schema import SplitTransformDict
from stimulus.learner.predict import PredictWrapper
from stimulus.typing import DatasetLoader
from stimulus.utils.generic_utils import set_general_seeds
from stimulus.utils.yaml_model_schema import RayTuneModel

Expand All @@ -35,10 +35,9 @@ class TuneWrapper:
def __init__(
self,
model_config: RayTuneModel,
data_config: SplitTransformDict,
model_class: nn.Module,
data_path: str,
# encoder_loader: EncoderLoader,
train_loader: DatasetLoader,
validation_loader: DatasetLoader,
seed: int,
ray_results_dir: Optional[str] = None,
tune_run_name: Optional[str] = None,
Expand Down Expand Up @@ -84,11 +83,6 @@ def __init__(
stop=model_config.tune.run_params.stop,
)

# add the data path to the config
if not os.path.exists(data_path):
raise ValueError("Data path does not exist. Given path:" + data_path)
self.config["data_path"] = os.path.abspath(data_path)

# Set up tune_run path
if ray_results_dir is None:
ray_results_dir = os.environ.get("HOME", "")
Expand All @@ -103,24 +97,23 @@ def __init__(
)
self.config["_debug"] = debug
self.config["model"] = model_class
self.config["encoder_loader"] = encoder_loader
self.config["train_loader"] = train_loader
self.config["validation_loader"] = validation_loader
self.config["ray_worker_seed"] = tune.randint(0, 1000)

self.gpu_per_trial = model_config.tune.gpu_per_trial
self.cpu_per_trial = model_config.tune.cpu_per_trial

self.tuner = self.tuner_initialization(
data_config=data_config,
data_path=data_path,
# encoder_loader=encoder_loader,
train_loader=train_loader,
validation_loader=validation_loader,
autoscaler=autoscaler,
)

def tuner_initialization(
self,
data_config: SplitTransformDict,
data_path: str,
# encoder_loader: EncoderLoader,
train_loader: DatasetLoader,
validation_loader: DatasetLoader,
*,
autoscaler: bool = False,
) -> tune.Tuner:
Expand Down Expand Up @@ -152,16 +145,10 @@ def tuner_initialization(
# Pre-load and encode datasets once, then put them in Ray's object store

training = TorchDataset(
data_config=data_config,
csv_path=data_path,
# encoder_loader=encoder_loader,
split=0,
loader=train_loader,
)
validation = TorchDataset(
data_config=data_config,
csv_path=data_path,
# encoder_loader=encoder_loader,
split=1,
loader=validation_loader,
)

# log to debug the names of the columns and shapes of tensors for a batch of training
Expand Down
11 changes: 9 additions & 2 deletions src/stimulus/typing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"""
# ruff: noqa: F401

from typing import TypeAlias
from typing import TypeAlias, TypeVar, Any

# these imports mostly alias everything
from stimulus.data.data_handlers import (
Expand Down Expand Up @@ -38,7 +38,6 @@
from stimulus.data.splitting import AbstractSplitter as Splitter
from stimulus.data.transforming.transforms import AbstractTransform as Transform
from stimulus.learner.predict import PredictWrapper
from stimulus.learner.raytune_learner import CheckpointDict, TuneModel, TuneWrapper
from stimulus.learner.raytune_parser import (
RayTuneMetrics,
RayTuneOptimizer,
Expand Down Expand Up @@ -78,3 +77,11 @@
| TransformColumns
| TransformColumnsTransformation
)

# Replace these problematic imports
# from stimulus.learner.raytune_learner import CheckpointDict, TuneModel, TuneWrapper

# Replace with type aliases if needed
CheckpointDict = dict[str, Any]
TuneModel = TypeVar('TuneModel')
TuneWrapper = TypeVar('TuneWrapper')
7 changes: 4 additions & 3 deletions tests/data/test_data_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,14 +200,14 @@ def test_torch_dataset_init(
dummy_encoders: dict,
) -> None:
"""Test initialization of TorchDataset."""
dataset = TorchDataset(
loader = DatasetLoader(
encoders=dummy_encoders,
input_columns=["age", "fare"],
label_columns=["survived"],
meta_columns=["passenger_id"],
csv_path=titanic_csv_path,
)

dataset = TorchDataset(loader)
assert len(dataset) > 0
assert isinstance(dataset[0], tuple)
assert len(dataset[0]) == 3
Expand All @@ -221,13 +221,14 @@ def test_torch_dataset_get_item(
dummy_encoders: dict,
) -> None:
"""Test getting item from TorchDataset."""
dataset = TorchDataset(
loader = DatasetLoader(
encoders=dummy_encoders,
input_columns=["age", "fare"],
label_columns=["survived"],
meta_columns=["passenger_id"],
csv_path=titanic_csv_path,
)
dataset = TorchDataset(loader)

inputs, labels, meta = dataset[0]
assert isinstance(inputs, dict)
Expand Down
97 changes: 62 additions & 35 deletions tests/learner/test_raytune_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
import pytest
import ray
import yaml
import torch

from stimulus.data.handlertorch import TorchDataset
from stimulus.data.loaders import EncoderLoader
from stimulus.data.data_handlers import TorchDataset, DatasetLoader
from stimulus.data.encoding import encoders as encoders_module
from stimulus.learner.raytune_learner import TuneWrapper
from stimulus.utils.yaml_data import YamlSplitTransformDict
from stimulus.utils.yaml_model_schema import Model, RayTuneModel, YamlRayConfigLoader
from stimulus.utils.yaml_model_schema import Model, RayTuneModel, RayConfigLoader
from tests.test_model import titanic_model


Expand All @@ -20,35 +20,69 @@ def ray_config_loader() -> RayTuneModel:
"""Load the RayTuneModel configuration."""
with open("tests/test_model/titanic_model_cpu.yaml") as file:
model_config = yaml.safe_load(file)
return YamlRayConfigLoader(Model(**model_config)).get_config()

return RayConfigLoader(Model(**model_config)).get_config()

@pytest.fixture
def encoder_loader() -> EncoderLoader:
def get_encoders() -> dict[str, encoders_module.AbstractEncoder]:
"""Load the EncoderLoader configuration."""
with open("tests/test_data/titanic/titanic_sub_config.yaml") as file:
data_config = yaml.safe_load(file)
encoder_loader = EncoderLoader()
encoder_loader.initialize_column_encoders_from_config(
YamlSplitTransformDict(**data_config).columns,
)
return encoder_loader
encoders = {
"passenger_id": encoders_module.NumericEncoder(dtype=torch.int64),
"survived": encoders_module.NumericEncoder(dtype=torch.int64),
"pclass": encoders_module.NumericEncoder(dtype=torch.int64),
"sex": encoders_module.StrClassificationEncoder(),
"age": encoders_module.NumericEncoder(dtype=torch.float32),
"sibsp": encoders_module.NumericEncoder(dtype=torch.int64),
"parch": encoders_module.NumericEncoder(dtype=torch.int64),
"fare": encoders_module.NumericEncoder(dtype=torch.float32),
"embarked": encoders_module.StrClassificationEncoder(),
}
return encoders

@pytest.fixture
def get_input_columns() -> list[str]:
"""Get the input columns."""
return ["embarked", "pclass", "sex", "age", "sibsp", "parch", "fare"]

@pytest.fixture
def get_label_columns() -> list[str]:
"""Get the label columns."""
return ["survived"]

@pytest.fixture
def get_meta_columns() -> list[str]:
"""Get the meta columns."""
return ["passenger_id"]

@pytest.fixture
def titanic_dataset(encoder_loader: EncoderLoader) -> TorchDataset:
def titanic_dataset() -> TorchDataset:
"""Create a TorchDataset instance for testing."""
return TorchDataset(
csv_path="tests/test_data/titanic/titanic_stimulus_split.csv",
config_path="tests/test_data/titanic/titanic_sub_config.yaml",
encoder_loader=encoder_loader,
split=0,
)
return "tests/test_data/titanic/titanic_stimulus_split.csv"

@pytest.fixture
def get_train_loader(titanic_dataset: str, get_encoders: dict[str, encoders_module.AbstractEncoder], get_input_columns: list[str], get_label_columns: list[str], get_meta_columns: list[str]) -> DatasetLoader:
"""Get the DatasetLoader."""
return DatasetLoader(csv_path=titanic_dataset,
encoders=get_encoders,
input_columns=get_input_columns,
label_columns=get_label_columns,
meta_columns=get_meta_columns,
split=0)

@pytest.fixture
def get_validation_loader(titanic_dataset: str, get_encoders: dict[str, encoders_module.AbstractEncoder], get_input_columns: list[str], get_label_columns: list[str], get_meta_columns: list[str]) -> DatasetLoader:
"""Get the DatasetLoader."""
return DatasetLoader(csv_path=titanic_dataset,
encoders=get_encoders,
input_columns=get_input_columns,
label_columns=get_label_columns,
meta_columns=get_meta_columns,
split=1)


def test_tunewrapper_init(
ray_config_loader: RayTuneModel,
encoder_loader: EncoderLoader,
get_train_loader: DatasetLoader,
get_validation_loader: DatasetLoader,
) -> None:
"""Test the initialization of the TuneWrapper class."""
# Filter ResourceWarning during Ray shutdown
Expand All @@ -58,16 +92,12 @@ def test_tunewrapper_init(
ray.init(ignore_reinit_error=True)

try:
data_config: YamlSplitTransformDict
with open("tests/test_data/titanic/titanic_sub_config.yaml") as f:
data_config = YamlSplitTransformDict(**yaml.safe_load(f))

tune_wrapper = TuneWrapper(
model_config=ray_config_loader,
model_class=titanic_model.ModelTitanic,
data_path="tests/test_data/titanic/titanic_stimulus_split.csv",
data_config=data_config,
encoder_loader=encoder_loader,
train_loader=get_train_loader,
validation_loader=get_validation_loader,
seed=42,
ray_results_dir=os.path.abspath("tests/test_data/titanic/ray_results"),
tune_run_name="test_run",
Expand All @@ -88,7 +118,8 @@ def test_tunewrapper_init(

def test_tune_wrapper_tune(
ray_config_loader: RayTuneModel,
encoder_loader: EncoderLoader,
get_train_loader: DatasetLoader,
get_validation_loader: DatasetLoader,
) -> None:
"""Test the tune method of TuneWrapper class."""
# Filter ResourceWarning during Ray shutdown
Expand All @@ -98,16 +129,12 @@ def test_tune_wrapper_tune(
ray.init(ignore_reinit_error=True)

try:
data_config: YamlSplitTransformDict
with open("tests/test_data/titanic/titanic_sub_config.yaml") as f:
data_config = YamlSplitTransformDict(**yaml.safe_load(f))

tune_wrapper = TuneWrapper(
model_config=ray_config_loader,
model_class=titanic_model.ModelTitanic,
data_path="tests/test_data/titanic/titanic_stimulus_split.csv",
data_config=data_config,
encoder_loader=encoder_loader,
train_loader=get_train_loader,
validation_loader=get_validation_loader,
seed=42,
ray_results_dir=os.path.abspath("tests/test_data/titanic/ray_results"),
tune_run_name="test_run",
Expand Down

0 comments on commit 3efda09

Please sign in to comment.