Skip to content

Commit

Permalink
Merge pull request #120 from mathysgrapotte/major-refactor-manager
Browse files Browse the repository at this point in the history
Major refactor manager
  • Loading branch information
mathysgrapotte authored Feb 21, 2025
2 parents ed69b4e + 973e6f6 commit a3abc79
Show file tree
Hide file tree
Showing 30 changed files with 1,096 additions and 2,023 deletions.
4 changes: 2 additions & 2 deletions src/stimulus/cli/check_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from stimulus.data import handlertorch, loaders
from stimulus.learner import raytune_learner
from stimulus.utils import launch_utils, yaml_data, yaml_model_schema
from stimulus.utils import model_file_interface, yaml_data, yaml_model_schema

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -134,7 +134,7 @@ def main(

logger.info("Dataset loaded successfully.")

model_class = launch_utils.import_class_from_file(model_path)
model_class = model_file_interface.import_class_from_file(model_path)

logger.info("Model class loaded successfully.")

Expand Down
4 changes: 2 additions & 2 deletions src/stimulus/cli/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
import torch
from torch.utils.data import DataLoader

from stimulus.data.handlertorch import TorchDataset
from stimulus.data.data_handlers import TorchDataset
from stimulus.learner.predict import PredictWrapper
from stimulus.utils.launch_utils import get_experiment, import_class_from_file
from stimulus.utils.model_file_interface import get_experiment, import_class_from_file


def get_args() -> argparse.Namespace:
Expand Down
4 changes: 2 additions & 2 deletions src/stimulus/cli/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from stimulus.data import loaders
from stimulus.learner import raytune_learner, raytune_parser
from stimulus.utils import launch_utils, yaml_data, yaml_model_schema
from stimulus.utils import model_file_interface, yaml_data, yaml_model_schema

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -186,7 +186,7 @@ def main(
column_config=data_config.columns,
)

model_class = launch_utils.import_class_from_file(model_path)
model_class = model_file_interface.import_class_from_file(model_path)

ray_config_loader = yaml_model_schema.YamlRayConfigLoader(model=model_config)
ray_config_model = ray_config_loader.get_config()
Expand Down
Loading

0 comments on commit a3abc79

Please sign in to comment.