Skip to content

Commit

Permalink
refactor(raytune_learner): gets a torchdataset as input.
Browse files Browse the repository at this point in the history
  • Loading branch information
mathysgrapotte committed Feb 20, 2025
1 parent 2213c5d commit 973e6f6
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 47 deletions.
30 changes: 9 additions & 21 deletions src/stimulus/learner/raytune_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

from stimulus.data.data_handlers import TorchDataset
from stimulus.learner.predict import PredictWrapper
from stimulus.typing import DatasetLoader
from stimulus.utils.generic_utils import set_general_seeds
from stimulus.utils.yaml_model_schema import RayTuneModel

Expand All @@ -36,8 +35,8 @@ def __init__(
self,
model_config: RayTuneModel,
model_class: nn.Module,
train_loader: DatasetLoader,
validation_loader: DatasetLoader,
train_dataset: TorchDataset,
validation_dataset: TorchDataset,
seed: int,
ray_results_dir: Optional[str] = None,
tune_run_name: Optional[str] = None,
Expand Down Expand Up @@ -97,23 +96,21 @@ def __init__(
)
self.config["_debug"] = debug
self.config["model"] = model_class
self.config["train_loader"] = train_loader
self.config["validation_loader"] = validation_loader
self.config["ray_worker_seed"] = tune.randint(0, 1000)

self.gpu_per_trial = model_config.tune.gpu_per_trial
self.cpu_per_trial = model_config.tune.cpu_per_trial

self.tuner = self.tuner_initialization(
train_loader=train_loader,
validation_loader=validation_loader,
train_dataset=train_dataset,
validation_dataset=validation_dataset,
autoscaler=autoscaler,
)

def tuner_initialization(
self,
train_loader: DatasetLoader,
validation_loader: DatasetLoader,
train_dataset: TorchDataset,
validation_dataset: TorchDataset,
*,
autoscaler: bool = False,
) -> tune.Tuner:
Expand Down Expand Up @@ -142,18 +139,9 @@ def tuner_initialization(
f"PER_TRIAL resources -> GPU: {self.gpu_per_trial} CPU: {self.cpu_per_trial}",
)

# Pre-load and encode datasets once, then put them in Ray's object store

training = TorchDataset(
loader=train_loader,
)
validation = TorchDataset(
loader=validation_loader,
)

# log to debug the names of the columns and shapes of tensors for a batch of training
# Log shapes of encoded tensors for first batch of training data
inputs, labels, meta = training[0:10]
inputs, labels, meta = train_dataset[0:10]

logging.debug("Training data tensor shapes:")
for field, tensor in inputs.items():
Expand All @@ -165,8 +153,8 @@ def tuner_initialization(
for field, values in meta.items():
logging.debug(f"Meta field '{field}' length: {len(values)}")

training_ref = ray.put(training)
validation_ref = ray.put(validation)
training_ref = ray.put(train_dataset)
validation_ref = ray.put(validation_dataset)

self.config["_training_ref"] = training_ref
self.config["_validation_ref"] = validation_ref
Expand Down
56 changes: 30 additions & 26 deletions tests/learner/test_raytune_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,47 +64,51 @@ def titanic_dataset() -> TorchDataset:


@pytest.fixture
def get_train_loader(
def get_train_dataset(
titanic_dataset: str,
get_encoders: dict[str, encoders_module.AbstractEncoder],
get_input_columns: list[str],
get_label_columns: list[str],
get_meta_columns: list[str],
) -> DatasetLoader:
) -> TorchDataset:
"""Get the DatasetLoader."""
return DatasetLoader(
csv_path=titanic_dataset,
encoders=get_encoders,
input_columns=get_input_columns,
label_columns=get_label_columns,
meta_columns=get_meta_columns,
split=0,
return TorchDataset(
loader=DatasetLoader(
csv_path=titanic_dataset,
encoders=get_encoders,
input_columns=get_input_columns,
label_columns=get_label_columns,
meta_columns=get_meta_columns,
split=0,
),
)


@pytest.fixture
def get_validation_loader(
def get_validation_dataset(
titanic_dataset: str,
get_encoders: dict[str, encoders_module.AbstractEncoder],
get_input_columns: list[str],
get_label_columns: list[str],
get_meta_columns: list[str],
) -> DatasetLoader:
) -> TorchDataset:
"""Get the DatasetLoader."""
return DatasetLoader(
csv_path=titanic_dataset,
encoders=get_encoders,
input_columns=get_input_columns,
label_columns=get_label_columns,
meta_columns=get_meta_columns,
split=1,
return TorchDataset(
loader=DatasetLoader(
csv_path=titanic_dataset,
encoders=get_encoders,
input_columns=get_input_columns,
label_columns=get_label_columns,
meta_columns=get_meta_columns,
split=1,
),
)


def test_tunewrapper_init(
ray_config_loader: RayTuneModel,
get_train_loader: DatasetLoader,
get_validation_loader: DatasetLoader,
get_train_dataset: TorchDataset,
get_validation_dataset: TorchDataset,
) -> None:
"""Test the initialization of the TuneWrapper class."""
# Filter ResourceWarning during Ray shutdown
Expand All @@ -117,8 +121,8 @@ def test_tunewrapper_init(
tune_wrapper = TuneWrapper(
model_config=ray_config_loader,
model_class=titanic_model.ModelTitanic,
train_loader=get_train_loader,
validation_loader=get_validation_loader,
train_dataset=get_train_dataset,
validation_dataset=get_validation_dataset,
seed=42,
ray_results_dir=os.path.abspath("tests/test_data/titanic/ray_results"),
tune_run_name="test_run",
Expand All @@ -139,8 +143,8 @@ def test_tunewrapper_init(

def test_tune_wrapper_tune(
ray_config_loader: RayTuneModel,
get_train_loader: DatasetLoader,
get_validation_loader: DatasetLoader,
get_train_dataset: TorchDataset,
get_validation_dataset: TorchDataset,
) -> None:
"""Test the tune method of TuneWrapper class."""
# Filter ResourceWarning during Ray shutdown
Expand All @@ -153,8 +157,8 @@ def test_tune_wrapper_tune(
tune_wrapper = TuneWrapper(
model_config=ray_config_loader,
model_class=titanic_model.ModelTitanic,
train_loader=get_train_loader,
validation_loader=get_validation_loader,
train_dataset=get_train_dataset,
validation_dataset=get_validation_dataset,
seed=42,
ray_results_dir=os.path.abspath("tests/test_data/titanic/ray_results"),
tune_run_name="test_run",
Expand Down

0 comments on commit 973e6f6

Please sign in to comment.