Skip to content

Commit

Permalink
refactor(check_model): added refactoring for check-model cli.
Browse files Browse the repository at this point in the history
  • Loading branch information
mathysgrapotte committed Feb 21, 2025
1 parent 01d4532 commit 09044c7
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 126 deletions.
153 changes: 53 additions & 100 deletions src/stimulus/cli/check_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,112 +3,82 @@

import logging

import click
import ray
import yaml
from torch.utils.data import DataLoader

from stimulus.data import handlertorch, loaders
from stimulus.data import data_handlers
from stimulus.data.interface import data_config_parser
from stimulus.learner import raytune_learner
from stimulus.utils import model_file_interface, yaml_data, yaml_model_schema
from stimulus.utils import model_file_interface, yaml_model_schema

logger = logging.getLogger(__name__)


@click.command()
@click.option(
"-d",
"--data",
type=click.Path(exists=True),
required=True,
help="Path to input csv file.",
)
@click.option(
"-m",
"--model",
type=click.Path(exists=True),
required=True,
help="Path to model file.",
)
@click.option(
"-e",
"--data-config",
type=click.Path(exists=True),
required=True,
help="Path to data config file.",
)
@click.option(
"-c",
"--model-config",
type=click.Path(exists=True),
required=True,
help="Path to yaml config training file.",
)
@click.option(
"-w",
"--initial-weights",
type=click.Path(exists=True),
help="The path to the initial weights (optional).",
)
@click.option(
"-n",
"--num-samples",
type=int,
default=3,
help="Number of samples for tuning. Overwrites tune.tune_params.num_samples in config.",
)
@click.option(
"--ray-results-dirpath",
type=click.Path(),
help="Location where ray_results output dir should be written. If None, uses ~/ray_results.",
)
@click.option(
"--debug-mode",
is_flag=True,
help="Activate debug mode for tuning. Default false, no debug.",
)
def load_data_config_from_path(data_path: str, data_config_path: str, split: int) -> data_handlers.TorchDataset:
"""Load the data config from a path.
Args:
data_config_path: Path to the data config file.
Returns:
A tuple of the parsed configuration.
"""
with open(data_config_path) as file:
data_config_dict = yaml.safe_load(file)
data_config_obj = data_config_parser.SplitTransformDict(**data_config_dict)

encoders, input_columns, label_columns, meta_columns = data_config_parser.parse_split_transform_config(
data_config_obj,
)

return data_handlers.TorchDataset(
loader=data_handlers.DatasetLoader(
encoders=encoders,
input_columns=input_columns,
label_columns=label_columns,
meta_columns=meta_columns,
csv_path=data_path,
split=split,
),
)


def check_model(
data: str,
model: str,
data_config: str,
model_config: str,
data_path: str,
model_path: str,
data_config_path: str,
model_config_path: str,
initial_weights: str | None = None, # noqa: ARG001
num_samples: int = 3,
ray_results_dirpath: str | None = None,
*,
debug_mode: bool = False,
) -> None:
"""Run the main model checking pipeline.
Args:
data: Path to input data file.
model: Path to model file.
data_config: Path to data config file.
model_config: Path to model config file.
data_path: Path to input data file.
model_path: Path to model file.
data_config_path: Path to data config file.
model_config_path: Path to model config file.
initial_weights: Optional path to initial weights.
num_samples: Number of samples for tuning.
ray_results_dirpath: Directory for ray results.
debug_mode: Whether to run in debug mode.
"""
with open(data_config) as file:
data_config_dict = yaml.safe_load(file)
data_config_obj = yaml_data.YamlSplitTransformDict(**data_config_dict)

with open(model_config) as file:
model_config_dict = yaml.safe_load(file)
model_config_obj = yaml_model_schema.Model(**model_config_dict)

encoder_loader = loaders.EncoderLoader()
encoder_loader.initialize_column_encoders_from_config(
column_config=data_config_obj.columns,
)

train_dataset = load_data_config_from_path(data_path, data_config_path, split=0)
validation_dataset = load_data_config_from_path(data_path, data_config_path, split=1)
logger.info("Dataset loaded successfully.")

model_class = model_file_interface.import_class_from_file(model)
model_class = model_file_interface.import_class_from_file(model_path)

logger.info("Model class loaded successfully.")

ray_config_loader = yaml_model_schema.YamlRayConfigLoader(model=model_config_obj)
with open(model_config_path) as file:
model_config_content = yaml.safe_load(file)
model_config = yaml_model_schema.Model(**model_config_content)

ray_config_loader = yaml_model_schema.RayConfigLoader(model=model_config)
ray_config_dict = ray_config_loader.get_config().model_dump()
ray_config_model = ray_config_loader.get_config()

Expand All @@ -125,13 +95,7 @@ def check_model(

logger.info("Model instance loaded successfully.")

torch_dataset = handlertorch.TorchDataset(
data_config=data_config_obj,
csv_path=data,
encoder_loader=encoder_loader,
)

torch_dataloader = DataLoader(torch_dataset, batch_size=10, shuffle=True)
torch_dataloader = DataLoader(train_dataset, batch_size=10, shuffle=True)

logger.info("Torch dataloader loaded successfully.")

Expand All @@ -153,14 +117,13 @@ def check_model(
logger.info("Model checking single pass completed successfully.")

# override num_samples
model_config_obj.tune.tune_params.num_samples = num_samples
model_config.tune.tune_params.num_samples = num_samples

tuner = raytune_learner.TuneWrapper(
model_config=ray_config_model,
data_config=data_config_obj,
model_class=model_class,
data_path=data,
encoder_loader=encoder_loader,
train_dataset=train_dataset,
validation_dataset=validation_dataset,
seed=42,
ray_results_dir=ray_results_dirpath,
debug=debug_mode,
Expand All @@ -172,13 +135,3 @@ def check_model(

logger.info("Tuning completed successfully.")
logger.info("Checks complete")


def run() -> None:
"""Run the model checking script."""
ray.init(address="auto", ignore_reinit_error=True)
check_model()


if __name__ == "__main__":
run()
101 changes: 77 additions & 24 deletions src/stimulus/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,86 @@
import click
from importlib_metadata import version

@click.group(context_settings=dict(help_option_names=["-h", "--help"]))

@click.group(context_settings={"help_option_names": ["-h", "--help"]})
@click.version_option(version("stimulus-py"), "-v", "--version")
def cli():
def cli() -> None:
"""Stimulus is an open-science framework for data processing and model training."""
pass


@cli.command()
def check_model():
"""Check model configuration and run initial tests.
check-model will connect to an existing ray cluster. Make sure you start a ray cluster before by running:
ray start --head
\b
Required Options:
-d, --data PATH Path to input csv file
-m, --model PATH Path to model file
-e, --data-config PATH Path to data config file
-c, --model-config PATH Path to yaml config training file
@click.option(
"-d",
"--data",
type=click.Path(exists=True),
required=True,
help="Path to input csv file",
)
@click.option(
"-m",
"--model",
type=click.Path(exists=True),
required=True,
help="Path to model file",
)
@click.option(
"-e",
"--data-config",
type=click.Path(exists=True),
required=True,
help="Path to data config file",
)
@click.option(
"-c",
"--model-config",
type=click.Path(exists=True),
required=True,
help="Path to yaml config training file",
)
@click.option(
"-w",
"--initial-weights",
type=click.Path(exists=True),
help="Path to initial weights",
)
@click.option(
"-n",
"--num-samples",
type=int,
default=3,
help="Number of samples for tuning [default: 3]",
)
@click.option(
"--ray-results-dirpath",
type=click.Path(),
help="Location for ray_results output dir",
)
@click.option(
"--debug-mode",
is_flag=True,
help="Activate debug mode for tuning",
)
def check_model(
data: str,
model: str,
data_config: str,
model_config: str,
initial_weights: str | None,
num_samples: int,
ray_results_dirpath: str | None,
*,
debug_mode: bool,
) -> None:
"""Check model configuration and run initial tests."""
from stimulus.cli.check_model import check_model as check_model_func

\b
Optional Options:
-w, --initial-weights PATH Path to initial weights
-n, --num-samples INTEGER Number of samples for tuning [default: 3]
--ray-results-dirpath PATH Location for ray_results output dir
--debug-mode Activate debug mode for tuning
"""
from stimulus.cli.check_model import main
main()
check_model_func(
data_path=data,
model_path=model,
data_config_path=data_config,
model_config_path=model_config,
initial_weights=initial_weights,
num_samples=num_samples,
ray_results_dirpath=ray_results_dirpath,
debug_mode=debug_mode,
)
Loading

0 comments on commit 09044c7

Please sign in to comment.