Skip to content

Commit

Permalink
refactor(typing): fix typing accross codebase.
Browse files Browse the repository at this point in the history
  • Loading branch information
mathysgrapotte committed Feb 20, 2025
1 parent 9e76451 commit ce0d5d4
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 72 deletions.
19 changes: 9 additions & 10 deletions src/stimulus/learner/raytune_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@
from torch import nn, optim
from torch.utils.data import DataLoader

from stimulus.data.handlertorch import TorchDataset
from stimulus.data.loaders import EncoderLoader
from stimulus.data.data_handlers import TorchDataset
from stimulus.learner.predict import PredictWrapper
from stimulus.utils.generic_utils import set_general_seeds
from stimulus.utils.yaml_data import YamlSplitTransformDict
from stimulus.data.interface.data_config_schema import SplitTransformDict
from stimulus.utils.yaml_model_schema import RayTuneModel


Expand All @@ -36,10 +35,10 @@ class TuneWrapper:
def __init__(
self,
model_config: RayTuneModel,
data_config: YamlSplitTransformDict,
data_config: SplitTransformDict,
model_class: nn.Module,
data_path: str,
encoder_loader: EncoderLoader,
#encoder_loader: EncoderLoader,
seed: int,
ray_results_dir: Optional[str] = None,
tune_run_name: Optional[str] = None,
Expand Down Expand Up @@ -113,15 +112,15 @@ def __init__(
self.tuner = self.tuner_initialization(
data_config=data_config,
data_path=data_path,
encoder_loader=encoder_loader,
#encoder_loader=encoder_loader,
autoscaler=autoscaler,
)

def tuner_initialization(
self,
data_config: YamlSplitTransformDict,
data_config: SplitTransformDict,
data_path: str,
encoder_loader: EncoderLoader,
#encoder_loader: EncoderLoader,
*,
autoscaler: bool = False,
) -> tune.Tuner:
Expand Down Expand Up @@ -155,13 +154,13 @@ def tuner_initialization(
training = TorchDataset(
data_config=data_config,
csv_path=data_path,
encoder_loader=encoder_loader,
#encoder_loader=encoder_loader,
split=0,
)
validation = TorchDataset(
data_config=data_config,
csv_path=data_path,
encoder_loader=encoder_loader,
#encoder_loader=encoder_loader,
split=1,
)

Expand Down
72 changes: 31 additions & 41 deletions src/stimulus/typing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,17 @@
from stimulus.data.data_handlers import (
DatasetHandler,
DatasetLoader,
DatasetManager,
DatasetProcessor,
EncodeManager,
SplitManager,
TransformManager,
TorchDataset,
)
from stimulus.data.encoding.encoders import AbstractEncoder as Encoder
from stimulus.data.handlertorch import TorchDataset
from stimulus.data.loaders import EncoderLoader, SplitLoader, TransformLoader
from stimulus.data.splitters.splitters import AbstractSplitter as Splitter
from stimulus.data.transform.data_transformation_generators import (
AbstractDataTransformer as Transform,
from stimulus.data.interface.data_config_parser import (
create_encoders,
create_splitter,
create_transforms,
)
from stimulus.data.splitting import AbstractSplitter as Splitter
from stimulus.data.transforming.transforms import AbstractTransform as Transform
from stimulus.learner.predict import PredictWrapper
from stimulus.learner.raytune_learner import CheckpointDict, TuneModel, TuneWrapper
from stimulus.learner.raytune_parser import (
Expand All @@ -35,18 +33,18 @@
TuneParser,
)
from stimulus.utils.performance import Performance
from stimulus.utils.yaml_data import (
YamlColumns,
YamlColumnsEncoder,
YamlConfigDict,
YamlGlobalParams,
YamlSchema,
YamlSplit,
YamlSplitConfigDict,
YamlSplitTransformDict,
YamlTransform,
YamlTransformColumns,
YamlTransformColumnsTransformation,
from stimulus.data.interface.data_config_schema import (
Columns,
ColumnsEncoder,
ConfigDict,
GlobalParams,
Schema,
Split,
SplitConfigDict,
SplitTransformDict,
Transform,
TransformColumns,
TransformColumnsTransformation,
)
from stimulus.utils.yaml_model_schema import (
CustomTunableParameter,
Expand All @@ -59,32 +57,24 @@
TunableParameter,
Tune,
TuneParams,
YamlRayConfigLoader,
RayConfigLoader,
)

# data/data_handlers.py

DataManager: TypeAlias = DatasetManager | EncodeManager | SplitManager | TransformManager

# data/experiments.py

Loader: TypeAlias = DatasetLoader | EncoderLoader | TransformLoader | SplitLoader

# learner/raytune_parser.py

RayTuneData: TypeAlias = RayTuneMetrics | RayTuneOptimizer | RayTuneResult

# utils/yaml_data.py

YamlData: TypeAlias = (
YamlColumns
| YamlColumnsEncoder
| YamlConfigDict
| YamlGlobalParams
| YamlSchema
| YamlSplit
| YamlSplitConfigDict
| YamlTransform
| YamlTransformColumns
| YamlTransformColumnsTransformation
Data: TypeAlias = (
Columns
| ColumnsEncoder
| ConfigDict
| GlobalParams
| Schema
| Split
| SplitConfigDict
| Transform
| TransformColumns
| TransformColumnsTransformation
)
2 changes: 1 addition & 1 deletion src/stimulus/utils/yaml_model_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class RayTuneModel(pydantic.BaseModel):
tune: Tune


class YamlRayConfigLoader:
class RayConfigLoader:
"""Load and convert YAML configurations to Ray Tune format.
This class handles loading model configurations and converting them into
Expand Down
38 changes: 18 additions & 20 deletions tests/typing/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,8 @@ def test_data_handlers_types() -> None:
from stimulus.typing import (
DatasetHandler,
DatasetLoader,
DatasetManager,
DatasetProcessor,
EncodeManager,
SplitManager,
TransformManager,
TorchDataset,
)
except ImportError:
pytest.fail("Failed to import Data Handlers types")
Expand All @@ -34,29 +31,30 @@ def test_learner_types() -> None:
TuneModel,
TuneParser,
TuneWrapper,
CheckpointDict,
)
except ImportError:
pytest.fail("Failed to import Learner types")


def test_yaml_data_types() -> None:
"""Test the YAML data types."""
def test_data_config_schema_types() -> None:
"""Test the data config schema types."""
try:
from stimulus.typing import (
YamlColumns,
YamlColumnsEncoder,
YamlConfigDict,
YamlGlobalParams,
YamlSchema,
YamlSplit,
YamlSplitConfigDict,
YamlSplitTransformDict,
YamlTransform,
YamlTransformColumns,
YamlTransformColumnsTransformation,
Columns,
ColumnsEncoder,
ConfigDict,
GlobalParams,
Schema,
Split,
SplitConfigDict,
SplitTransformDict,
Transform,
TransformColumns,
TransformColumnsTransformation,
)
except ImportError:
pytest.fail("Failed to import YAML Data types")
pytest.fail("Failed to import Data Config Schema types")


def test_yaml_model_schema_types() -> None:
Expand All @@ -73,7 +71,7 @@ def test_yaml_model_schema_types() -> None:
TunableParameter,
Tune,
TuneParams,
YamlRayConfigLoader,
RayConfigLoader,
)
except ImportError:
pytest.fail("Failed to import YAML Model Schema types")
Expand All @@ -82,6 +80,6 @@ def test_yaml_model_schema_types() -> None:
def test_type_aliases() -> None:
"""Test the type aliases."""
try:
from stimulus.typing import DataManager, Loader, RayTuneData, YamlData
from stimulus.typing import RayTuneData, Data
except ImportError:
pytest.fail("Failed to import Type Aliases")

0 comments on commit ce0d5d4

Please sign in to comment.