Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Major cli refactor #121

Draft
wants to merge 4 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,12 @@ dependencies = [
"scipy==1.14.1",
"syrupy>=4.8.0",
"torch>=2.2.2",
"torch==2.2.2; sys_platform == 'darwin' and platform_machine == 'x86_64'"
"torch==2.2.2; sys_platform == 'darwin' and platform_machine == 'x86_64'",
"click>=8.1.0"
]

[project.scripts]
stimulus-shuffle-csv = "stimulus.cli.shuffle_csv:run"
stimulus-transform-csv = "stimulus.cli.transform_csv:run"
stimulus-split-csv = "stimulus.cli.split_csv:run"
stimulus-check-model = "stimulus.cli.check_model:run"
stimulus-tuning = "stimulus.cli.tuning:run"
stimulus-split-yaml-split = "stimulus.cli.split_split:run"
stimulus-split-yaml-transform = "stimulus.cli.split_transforms:run"
stimulus = "stimulus.cli.main:cli"

[project.urls]
Homepage = "https://mathysgrapotte.github.io/stimulus-py"
Expand Down
159 changes: 36 additions & 123 deletions src/stimulus/cli/check_model.py
Original file line number Diff line number Diff line change
@@ -1,104 +1,51 @@
#!/usr/bin/env python3
"""CLI module for checking model configuration and running initial tests."""

import argparse
import logging

import ray
import yaml
from torch.utils.data import DataLoader

from stimulus.data import handlertorch, loaders
from stimulus.data import data_handlers
from stimulus.data.interface import data_config_parser
from stimulus.learner import raytune_learner
from stimulus.utils import model_file_interface, yaml_data, yaml_model_schema
from stimulus.utils import model_file_interface, yaml_model_schema

logger = logging.getLogger(__name__)


def get_args() -> argparse.Namespace:
"""Get the arguments when using from the commandline.
def load_data_config_from_path(data_path: str, data_config_path: str, split: int) -> data_handlers.TorchDataset:
"""Load the data config from a path.

Args:
data_config_path: Path to the data config file.

Returns:
Parsed command line arguments.
A tuple of the parsed configuration.
"""
parser = argparse.ArgumentParser(description="Launch check_model.")
parser.add_argument(
"-d",
"--data",
type=str,
required=True,
metavar="FILE",
help="Path to input csv file.",
)
parser.add_argument(
"-m",
"--model",
type=str,
required=True,
metavar="FILE",
help="Path to model file.",
)
parser.add_argument(
"-e",
"--data_config",
type=str,
required=True,
metavar="FILE",
help="Path to data config file.",
)
parser.add_argument(
"-c",
"--model_config",
type=str,
required=True,
metavar="FILE",
help="Path to yaml config training file.",
)
parser.add_argument(
"-w",
"--initial_weights",
type=str,
required=False,
nargs="?",
const=None,
default=None,
metavar="FILE",
help="The path to the initial weights (optional).",
)
with open(data_config_path) as file:
data_config_dict = yaml.safe_load(file)
data_config_obj = data_config_parser.SplitTransformDict(**data_config_dict)

parser.add_argument(
"-n",
"--num_samples",
type=int,
required=False,
nargs="?",
const=3,
default=3,
metavar="NUM_SAMPLES",
help="Number of samples for tuning. Overwrites tune.tune_params.num_samples in config.",
)
parser.add_argument(
"--ray_results_dirpath",
type=str,
required=False,
nargs="?",
const=None,
default=None,
metavar="DIR_PATH",
help="Location where ray_results output dir should be written. If None, uses ~/ray_results.",
)
parser.add_argument(
"--debug_mode",
action="store_true",
help="Activate debug mode for tuning. Default false, no debug.",
encoders, input_columns, label_columns, meta_columns = data_config_parser.parse_split_transform_config(
data_config_obj,
)

return parser.parse_args()
return data_handlers.TorchDataset(
loader=data_handlers.DatasetLoader(
encoders=encoders,
input_columns=input_columns,
label_columns=label_columns,
meta_columns=meta_columns,
csv_path=data_path,
split=split,
),
)


def main(
model_path: str,
def check_model(
data_path: str,
model_path: str,
data_config_path: str,
model_config_path: str,
initial_weights: str | None = None, # noqa: ARG001
Expand All @@ -119,26 +66,19 @@ def main(
ray_results_dirpath: Directory for ray results.
debug_mode: Whether to run in debug mode.
"""
with open(data_config_path) as file:
data_config = yaml.safe_load(file)
data_config = yaml_data.YamlSplitTransformDict(**data_config)

with open(model_config_path) as file:
model_config = yaml.safe_load(file)
model_config = yaml_model_schema.Model(**model_config)

encoder_loader = loaders.EncoderLoader()
encoder_loader.initialize_column_encoders_from_config(
column_config=data_config.columns,
)

train_dataset = load_data_config_from_path(data_path, data_config_path, split=0)
validation_dataset = load_data_config_from_path(data_path, data_config_path, split=1)
logger.info("Dataset loaded successfully.")

model_class = model_file_interface.import_class_from_file(model_path)

logger.info("Model class loaded successfully.")

ray_config_loader = yaml_model_schema.YamlRayConfigLoader(model=model_config)
with open(model_config_path) as file:
model_config_content = yaml.safe_load(file)
model_config = yaml_model_schema.Model(**model_config_content)

ray_config_loader = yaml_model_schema.RayConfigLoader(model=model_config)
ray_config_dict = ray_config_loader.get_config().model_dump()
ray_config_model = ray_config_loader.get_config()

Expand All @@ -155,13 +95,7 @@ def main(

logger.info("Model instance loaded successfully.")

torch_dataset = handlertorch.TorchDataset(
data_config=data_config,
csv_path=data_path,
encoder_loader=encoder_loader,
)

torch_dataloader = DataLoader(torch_dataset, batch_size=10, shuffle=True)
torch_dataloader = DataLoader(train_dataset, batch_size=10, shuffle=True)

logger.info("Torch dataloader loaded successfully.")

Expand All @@ -187,10 +121,9 @@ def main(

tuner = raytune_learner.TuneWrapper(
model_config=ray_config_model,
data_config=data_config,
model_class=model_class,
data_path=data_path,
encoder_loader=encoder_loader,
train_dataset=train_dataset,
validation_dataset=validation_dataset,
seed=42,
ray_results_dir=ray_results_dirpath,
debug=debug_mode,
Expand All @@ -202,23 +135,3 @@ def main(

logger.info("Tuning completed successfully.")
logger.info("Checks complete")


def run() -> None:
"""Run the model checking script."""
ray.init(address="auto", ignore_reinit_error=True)
args = get_args()
main(
data_path=args.data,
model_path=args.model,
data_config_path=args.data_config,
model_config_path=args.model_config,
initial_weights=args.initial_weights,
num_samples=args.num_samples,
ray_results_dirpath=args.ray_results_dirpath,
debug_mode=args.debug_mode,
)


if __name__ == "__main__":
run()
Loading
Loading