Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove yaml from classes #119

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
82 commits
Select commit Hold shift + click to select a range
d5fb8e8
REPLACE: src/stimulus/typing/__init__ -> changed YamlGlobalConfig to …
Feb 19, 2025
89b5c26
REPLACE: src/stimulus/utils/yaml_data.py -> changed YamlGlobalConfig …
Feb 19, 2025
c6a38e7
REPLACE: tests/typing/test_typing.py -> changed YamlGlobalConfig to G…
Feb 19, 2025
5ba13f0
REPLACE: src/stimulus/typing/__init__.py -> Changed YamlColumnsEncode…
Feb 19, 2025
1863ced
REPLACE: src/stimulus/utils/yaml_data.py -> Changed YamlColumnsEncode…
Feb 19, 2025
e9c40eb
REPLACE: tests/typing/test_typing.py -> Changed YamlColumnsEncoder to…
Feb 19, 2025
f55b7f6
REPLACE: tests/typing/test_typing.py -> Changed YamlColumns to Columns
Feb 19, 2025
11062b4
REPLACE: tests/data/loaders.py -> Changed YamlColumns to Columns
Feb 19, 2025
c370891
REPLACE: tests/typing/__init__.py -> Changed YamlColumns to Columns
Feb 19, 2025
61e5787
REPLACE: src/stimulus/utils/yaml_data.py -> Changed YamlColumns to Co…
Feb 19, 2025
3c45094
REPLACE: src/stimulus/utils/yaml_data.py -> Changed YamlTransformColu…
Feb 19, 2025
ec66cdf
REPLACE: src/stimulus/typing/__init__.py -> Changed YamlTransformColu…
Feb 19, 2025
4746659
REPLACE: tests/data/test_data_handlers.py -> Changed YamlTransformCol…
Feb 19, 2025
1f99221
REPLACE: tests/typing/test_typing.py -> Changed YamlTransformColumnsT…
Feb 19, 2025
3ffe710
REPLACE: src/stimulus/utils/yaml_data.py -> Changed YamlTransformColu…
Feb 19, 2025
5d0cfa2
REPLACE: src/stimulus/typing/__init__.py -> Changed YamlTransformColu…
Feb 19, 2025
edede13
REPLACE: tests/data/test_data_handlers.py -> Changed YamlTransformCol…
Feb 19, 2025
53fe62e
REPLACE: tests/typing/test_typing.py -> Changed YamlTransformColumns …
Feb 19, 2025
fd8f92a
REPLACE: src/stimulus/utils/yaml_data.py -> Changed YamlTransform to …
Feb 19, 2025
93bfe2e
REPLACE: src/stimulus/data/loaders.py -> Changed YamlTransform to Tra…
Feb 19, 2025
0f7da0d
REPLACE: src/stimulus/typing/__init__.py -> Changed YamlTransform to …
Feb 19, 2025
b9f1b3a
REPLACE: tests/data/test_data_handlers.py -> Changed YamlTransform to…
Feb 19, 2025
ac3aa00
REPLACE: tests/typing/test_typing.py -> Changed YamlTransform to Tran…
Feb 19, 2025
b1cc3aa
REPLACE: tests/utils/test_data_yaml.py -> Changed YamlTransform to Tr…
Feb 19, 2025
1c7e996
REPLACE: src/stimulus/cli/check_model.py -> Changed YamlSplit* to Split*
Feb 19, 2025
940bb59
REPLACE: src/stimulus/cli/split_csv.py -> Changed YamlSplit* to Split*
Feb 19, 2025
9e4f368
REPLACE: src/stimulus/cli/split_split.py -> Changed YamlSplit* to Split*
Feb 19, 2025
a4a9e6b
REPLACE: src/stimulus/cli/split_transforms.py -> Changed YamlSplit* t…
Feb 19, 2025
d9baee1
REPLACE: src/stimulus/cli/transform_csv.py -> Changed YamlSplit* to S…
Feb 19, 2025
7a99a54
REPLACE: src/stimulus/cli/tuning.py -> Changed YamlSplit* to Split*
Feb 19, 2025
5bdeaa2
REPLACE: src/stimulus/data/data_handlers.py -> Changed YamlSplit* to …
Feb 19, 2025
843ed2d
REPLACE: src/stimulus/data/handlertorch.py -> Changed YamlSplit* to S…
Feb 19, 2025
ac93ac7
REPLACE: src/stimulus/data/loaders.py -> Changed YamlSplit* to Split*
Feb 19, 2025
ac3b64d
REPLACE: src/stimulus/data/raytune_learner.py -> Changed YamlSplit* t…
Feb 19, 2025
cb49240
REPLACE: src/stimulus/typing/__init.py -> Changed YamlSplit* to Split*
Feb 19, 2025
6aff6b5
REPLACE: src/stimulus/utils/yaml_data.py -> Changed YamlSplit* to Split*
Feb 19, 2025
1f33072
REPLACE: tests/cli/test_check_model.py -> Changed YamlSplit* to Split*
Feb 19, 2025
d661d30
REPLACE: tests/cli/test_shuffle_csv.py -> Changed YamlSplit* to Split*
Feb 19, 2025
a4f0595
REPLACE: tests/cli/test_tuning.py -> Changed YamlSplit* to Split*
Feb 19, 2025
01e7102
REPLACE: tests/data/test_data_handlers.py -> Changed YamlSplit* to Sp…
Feb 19, 2025
60d1467
REPLACE: tests/data/test_handlertorch.py -> Changed YamlSplit* to Split*
Feb 19, 2025
f12638a
REPLACE: tests/learner/test_raytune_learner.py -> Changed YamlSplit* …
Feb 19, 2025
8d7030d
REPLACE: tests/typing/test_typing.py -> Changed YamlSplit* to Split*
Feb 19, 2025
721b7fa
REPLACE: tests/utils/test_data_yaml.py -> Changed YamlSplit* to Split*
Feb 19, 2025
116a3f5
REPLACE: src/stimulus/utils/yaml_data.py -> Changed YamlConfigDict to…
Feb 19, 2025
d75b82f
REPLACE: src/stimulus/cli/split_split.py -> Changed YamlConfigDict to…
Feb 19, 2025
b13bd9f
REPLACE: src/stimulus/typing/__init__.py -> Changed YamlConfigDict to…
Feb 19, 2025
11978da
REPLACE: tests/data/test_data_handlers.py -> Changed YamlConfigDict t…
Feb 19, 2025
6f3c4e9
REPLACE: tests/data/test_experiment.py -> Changed YamlConfigDict to C…
Feb 19, 2025
b22b124
REPLACE: tests/typing/test_typing.py -> Changed YamlConfigDict to Con…
Feb 19, 2025
fc40247
REPLACE: tests/utils/test_data_yaml.py -> Changed YamlConfigDict to C…
Feb 19, 2025
306750d
REPLACE: src/stimulus/utils/yaml_data.py -> Changed YamlSchema to Schema
Feb 19, 2025
4a5a1e3
REPLACE: src/stimulus/typing/__init__.py -> Changed YamlSchema to Schema
Feb 19, 2025
ff2c7f2
REPLACE: tests/typing/test_typing.py -> Changed YamlSchema to Schema
Feb 19, 2025
fff0bb0
CHANGE: src/stimulus/yaml_data.py -> Changed some variables to remove…
Feb 19, 2025
982e83f
CHANGE: src/stimulus/yaml_data.py -> Changed some comments to remove
Feb 19, 2025
df5d4db
CHANGE: src/stimulus/yaml_data.py -> Changed function check_yaml_sche…
Feb 19, 2025
36978bb
CHANGE: src/stimulus/yaml_data.py -> Removed Yaml in classes in comments
Feb 19, 2025
e809194
CHANGE: src/stimulus/yaml_data.py -> Changed left variables to remove…
Feb 19, 2025
3027d37
FORMATTING: src/stimulus/cli/check_model.py
Feb 19, 2025
5418c00
FORMATTING: src/stimulus/cli/split_csv.py
Feb 19, 2025
7dea223
FORMATTING: src/stimulus/cli/split_transforms.py
Feb 19, 2025
129bd04
FORMATTING: src/stimulus/cli/transform_csv.py
Feb 19, 2025
9f583c1
FORMATTING: src/stimulus/cli/tuning.py
Feb 19, 2025
d52aae5
FORMATTING: src/stimulus/data/data_handlers.py
Feb 19, 2025
1d096fe
FORMATTING: src/stimulus/data/loaders.py
Feb 19, 2025
5c263a0
FORMATTING: src/stimulus/learner/raytune_learner.py
Feb 19, 2025
f377f07
FORMATTING: src/stimulus/typing/__init__.py
Feb 19, 2025
c5dbac0
FORMATTING: src/stimulus/utils/yaml_data.py
Feb 19, 2025
34191e2
FORMATTING: tests/cli/test_check_model.py
Feb 19, 2025
7f973d4
FORMATTING: tests/cli/test_shuffle_csv.py
Feb 19, 2025
8aafd46
FORMATTING: tests/cli/test_split_split.py
Feb 19, 2025
f97bc24
FORMATTING: tests/cli/test_split_transforms.py
Feb 19, 2025
e9c9561
FORMATTING: tests/cli/test_tuning.py
Feb 19, 2025
b80560a
FORMATTING: tests/data/test_data_handlers.py
Feb 19, 2025
40cb045
FORMATTING: tests/data/test_experiment.py
Feb 19, 2025
d1991ef
FORMATTING: tests/data/transform/test_data_transformers.py
Feb 19, 2025
5306445
FORMATTING: tests/learmes/test_raytune_learner.py
Feb 19, 2025
d7d5656
FORMATTING: tests/utils/test_data_yaml.py
Feb 19, 2025
26f8b3a
FIX: references to old yaml functions
Feb 19, 2025
abc4157
FIX: multiple files -> fix merge conflict
Feb 19, 2025
216298d
FORMATTING: formatting some files
Feb 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/stimulus/cli/check_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def main(
"""
with open(data_config_path) as file:
data_config = yaml.safe_load(file)
data_config = yaml_data.YamlSplitTransformDict(**data_config)
data_config = yaml_data.SplitTransformDict(**data_config)

with open(model_config_path) as file:
model_config = yaml.safe_load(file)
Expand Down
4 changes: 2 additions & 2 deletions src/stimulus/cli/shuffle_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import yaml

from stimulus.data.data_handlers import DatasetProcessor
from stimulus.utils.yaml_data import YamlSplitTransformDict
from stimulus.utils.yaml_data import SplitTransformDict


def get_args() -> argparse.Namespace:
Expand Down Expand Up @@ -56,7 +56,7 @@ def main(data_csv: str, config_yaml: str, out_path: str) -> None:
"""
# read the yaml file
with open(config_yaml) as f:
data_config: YamlSplitTransformDict = YamlSplitTransformDict(
data_config: SplitTransformDict = SplitTransformDict(
**yaml.safe_load(f),
)
# create a DatasetProcessor object from the config and the csv
Expand Down
4 changes: 2 additions & 2 deletions src/stimulus/cli/split_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from stimulus.data.data_handlers import DatasetProcessor, SplitManager
from stimulus.data.loaders import SplitLoader
from stimulus.utils.yaml_data import YamlSplitConfigDict
from stimulus.utils.yaml_data import SplitConfigDict


def get_args() -> argparse.Namespace:
Expand Down Expand Up @@ -70,7 +70,7 @@ def main(
# create a split manager from the config
split_config = processor.dataset_manager.config.split
with open(config_yaml) as f:
yaml_config = YamlSplitConfigDict(**yaml.safe_load(f))
yaml_config = SplitConfigDict(**yaml.safe_load(f))
split_loader = SplitLoader(seed=yaml_config.global_params.seed)
split_loader.initialize_splitter_from_config(split_config)
split_manager = SplitManager(split_loader)
Expand Down
12 changes: 6 additions & 6 deletions src/stimulus/cli/split_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
import yaml

from stimulus.utils.yaml_data import (
YamlConfigDict,
YamlSplitConfigDict,
check_yaml_schema,
ConfigDict,
SplitConfigDict,
check_schema,
dump_yaml_list_into_files,
generate_split_configs,
)
Expand Down Expand Up @@ -64,13 +64,13 @@ def main(config_yaml: str, out_dir_path: str) -> None:
with open(config_yaml) as conf_file:
yaml_config = yaml.safe_load(conf_file)

yaml_config_dict: YamlConfigDict = YamlConfigDict(**yaml_config)
yaml_config_dict: ConfigDict = ConfigDict(**yaml_config)
# check if the yaml schema is correct
# FIXME: isn't it redundant to check and already class with pydantic ?
check_yaml_schema(yaml_config_dict)
check_schema(yaml_config_dict)

# generate the yaml files per split
split_configs: list[YamlSplitConfigDict] = generate_split_configs(yaml_config_dict)
split_configs: list[SplitConfigDict] = generate_split_configs(yaml_config_dict)

# dump all the YAML configs into files
dump_yaml_list_into_files(split_configs, out_dir_path, "test_split")
Expand Down
8 changes: 4 additions & 4 deletions src/stimulus/cli/split_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
import yaml

from stimulus.utils.yaml_data import (
YamlSplitConfigDict,
YamlSplitTransformDict,
SplitConfigDict,
SplitTransformDict,
dump_yaml_list_into_files,
generate_split_transform_configs,
)
Expand Down Expand Up @@ -60,10 +60,10 @@ def main(config_yaml: str, out_dir_path: str) -> None:
with open(config_yaml) as conf_file:
yaml_config = yaml.safe_load(conf_file)

yaml_config_dict: YamlSplitConfigDict = YamlSplitConfigDict(**yaml_config)
yaml_config_dict: SplitConfigDict = SplitConfigDict(**yaml_config)

# Generate the yaml files for each transform
split_transform_configs: list[YamlSplitTransformDict] = generate_split_transform_configs(yaml_config_dict)
split_transform_configs: list[SplitTransformDict] = generate_split_transform_configs(yaml_config_dict)

# Dump all the YAML configs into files
dump_yaml_list_into_files(split_transform_configs, out_dir_path, "test_transforms")
Expand Down
4 changes: 2 additions & 2 deletions src/stimulus/cli/transform_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from stimulus.data.data_handlers import DatasetProcessor, TransformManager
from stimulus.data.loaders import TransformLoader
from stimulus.utils.yaml_data import YamlSplitConfigDict
from stimulus.utils.yaml_data import SplitConfigDict


def get_args() -> argparse.Namespace:
Expand Down Expand Up @@ -55,7 +55,7 @@ def main(data_csv: str, config_yaml: str, out_path: str) -> None:
# initialize the transform manager
transform_config = processor.dataset_manager.config.transforms
with open(config_yaml) as f:
yaml_config = YamlSplitConfigDict(**yaml.safe_load(f))
yaml_config = SplitConfigDict(**yaml.safe_load(f))
transform_loader = TransformLoader(seed=yaml_config.global_params.seed)
transform_loader.initialize_column_data_transformers_from_config(transform_config)
transform_manager = TransformManager(transform_loader)
Expand Down
10 changes: 5 additions & 5 deletions src/stimulus/cli/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def get_args() -> argparse.Namespace:
def main(
model_path: str,
data_path: str,
data_config: yaml_data.YamlSplitTransformDict,
data_config: yaml_data.SplitTransformDict,
model_config_path: str,
initial_weights: str | None = None, # noqa: ARG001
ray_results_dirpath: str | None = None,
Expand All @@ -167,7 +167,7 @@ def main(
Args:
data_path: Path to input data file.
model_path: Path to model file.
data_config: A YamlSplitTransformObject
data_config: A SplitTransformObject
model_config_path: Path to model config file.
initial_weights: Optional path to initial weights.
ray_results_dirpath: Directory for ray results.
Expand Down Expand Up @@ -239,10 +239,10 @@ def run() -> None:
"""Run the model checking script."""
ray.init(address="auto", ignore_reinit_error=True)
args = get_args()
# Try to convert the configuration file to a YamlSplitTransformDict
config_dict: yaml_data.YamlSplitTransformDict
# Try to convert the configuration file to a SplitTransformDict
config_dict: yaml_data.SplitTransformDict
with open(args.data_config) as f:
config_dict = yaml_data.YamlSplitTransformDict(**yaml.safe_load(f))
config_dict = yaml_data.SplitTransformDict(**yaml.safe_load(f))
main(
data_path=args.data,
model_path=args.model,
Expand Down
22 changes: 13 additions & 9 deletions src/stimulus/data/data_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ class DatasetManager:

def __init__(
self,
config_dict: yaml_data.YamlSplitConfigDict,
config_dict: yaml_data.SplitConfigDict,
) -> None:
"""Initialize the DatasetManager."""
# self.config = self._load_config(config_path)
self.config: yaml_data.YamlSplitTransformDict = config_dict
self.config: yaml_data.SplitTransformDict = config_dict
self.column_categories = self.categorize_columns_by_type()

def categorize_columns_by_type(self) -> dict:
Expand Down Expand Up @@ -95,7 +95,7 @@ def categorize_columns_by_type(self) -> dict:
return {"input": input_columns, "label": label_columns, "meta": meta_columns}

# TODO: Remove or change this function as the config is now preloaded
def _load_config(self, config_path: str) -> yaml_data.YamlSplitConfigDict:
def _load_config(self, config_path: str) -> yaml_data.SplitConfigDict:
"""Loads and parses a YAML configuration file.

Args:
Expand All @@ -112,8 +112,8 @@ def _load_config(self, config_path: str) -> yaml_data.YamlSplitConfigDict:
"""
with open(config_path) as file:
# FIXME: cette fonction est appellé pour test_shuffle_csv et test_tune
return yaml_data.YamlSplitConfigDict(**yaml.safe_load(file))
return yaml_data.YamlSplitTransformDict(**yaml.safe_load(file))
return yaml_data.SplitConfigDict(**yaml.safe_load(file))
return yaml_data.SplitTransformDict(**yaml.safe_load(file))

def get_split_columns(self) -> list[str]:
"""Get the columns that are used for splitting."""
Expand Down Expand Up @@ -282,13 +282,13 @@ class DatasetHandler:

def __init__(
self,
data_config: yaml_data.YamlSplitTransformDict,
data_config: yaml_data.SplitTransformDict,
csv_path: str,
) -> None:
"""Initialize the DatasetHandler with required config.

Args:
data_config (yaml_data.YamlSplitTransformDict): A YamlSplitTransformDict object holding the config.
data_config (yaml_data.SplitTransformDict): A SplitTransformDict object holding the config.
csv_path (str): Path to the CSV data file.
"""
self.dataset_manager = DatasetManager(data_config)
Expand Down Expand Up @@ -343,7 +343,11 @@ def save(self, path: str) -> None:
class DatasetProcessor(DatasetHandler):
"""Class for loading dataset, applying transformations and splitting."""

def __init__(self, data_config: yaml_data.YamlSplitTransformDict, csv_path: str) -> None:
def __init__(
self,
data_config: yaml_data.SplitTransformDict,
csv_path: str,
) -> None:
"""Initialize the DatasetProcessor."""
super().__init__(data_config, csv_path)

Expand Down Expand Up @@ -416,7 +420,7 @@ class DatasetLoader(DatasetHandler):

def __init__(
self,
data_config: yaml_data.YamlSplitTransformDict,
data_config: yaml_data.SplitTransformDict,
csv_path: str,
encoder_loader: loaders.EncoderLoader,
split: Union[int, None] = None,
Expand Down
6 changes: 3 additions & 3 deletions src/stimulus/data/handlertorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,23 @@
from torch.utils.data import Dataset

from stimulus.data import data_handlers, loaders
from stimulus.utils.yaml_data import YamlSplitTransformDict
from stimulus.utils.yaml_data import SplitTransformDict


class TorchDataset(Dataset):
"""Class for creating a torch dataset."""

def __init__(
self,
data_config: YamlSplitTransformDict,
data_config: SplitTransformDict,
csv_path: str,
encoder_loader: loaders.EncoderLoader,
split: Optional[int] = None,
) -> None:
"""Initialize the TorchDataset.

Args:
data_config: A YamlSplitTransformDict holding the configuration.
data_config: A SplitTransformDict holding the configuration.
csv_path: Path to the CSV data file
encoder_loader: Encoder loader instance
split: Optional tuple containing split information
Expand Down
12 changes: 6 additions & 6 deletions src/stimulus/data/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ def __init__(self, seed: Optional[float] = None) -> None:

def initialize_column_encoders_from_config(
self,
column_config: yaml_data.YamlColumns,
column_config: yaml_data.Columns,
) -> None:
"""Build the loader from a config dictionary.

Args:
column_config (yaml_data.YamlColumns): Configuration dictionary containing field names (column_name) and their encoder specifications.
column_config (yaml_data.Columns): Configuration dictionary containing field names (column_name) and their encoder specifications.
"""
for field in column_config:
encoder = self.get_encoder(field.encoder[0].name, field.encoder[0].params)
Expand Down Expand Up @@ -176,12 +176,12 @@ def set_data_transformer_as_attribute(

def initialize_column_data_transformers_from_config(
self,
transform_config: yaml_data.YamlTransform,
transform_config: yaml_data.Transform,
) -> None:
"""Build the loader from a config dictionary.

Args:
transform_config (yaml_data.YamlTransform): Configuration dictionary containing transforms configurations.
transform_config (yaml_data.Transform): Configuration dictionary containing transforms configurations.

Example:
Given a YAML config like:
Expand Down Expand Up @@ -282,12 +282,12 @@ def set_splitter_as_attribute(self, splitter: Any) -> None:

def initialize_splitter_from_config(
self,
split_config: yaml_data.YamlSplitConfigDict,
split_config: yaml_data.SplitConfigDict,
) -> None:
"""Build the loader from a config dictionary.

Args:
split_config (yaml_data.YamlSplitConfigDict): Configuration dictionary containing split configurations.
split_config (yaml_data.SplitConfigDict): Configuration dictionary containing split configurations.
"""
splitter = self.get_splitter(split_config.split_method, split_config.params)
self.set_splitter_as_attribute(splitter)
23 changes: 16 additions & 7 deletions src/stimulus/learner/raytune_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from stimulus.data.loaders import EncoderLoader
from stimulus.learner.predict import PredictWrapper
from stimulus.utils.generic_utils import set_general_seeds
from stimulus.utils.yaml_data import YamlSplitTransformDict
from stimulus.utils.yaml_data import SplitTransformDict
from stimulus.utils.yaml_model_schema import RayTuneModel


Expand All @@ -36,7 +36,7 @@ class TuneWrapper:
def __init__(
self,
model_config: RayTuneModel,
data_config: YamlSplitTransformDict,
data_config: SplitTransformDict,
model_class: nn.Module,
data_path: str,
encoder_loader: EncoderLoader,
Expand Down Expand Up @@ -119,7 +119,7 @@ def __init__(

def tuner_initialization(
self,
data_config: YamlSplitTransformDict,
data_config: SplitTransformDict,
data_path: str,
encoder_loader: EncoderLoader,
*,
Expand Down Expand Up @@ -318,7 +318,8 @@ def objective(self) -> dict[str, float]:
**{"train_" + metric: value for metric, value in predict_train.compute_metrics(metrics).items()},
}

def export_model(self, export_dir: str | None = None) -> None: # type: ignore[override]
# type: ignore[override]
def export_model(self, export_dir: str | None = None) -> None:
"""Export model to safetensors format."""
if export_dir is None:
return
Expand All @@ -329,11 +330,19 @@ def load_checkpoint(self, checkpoint: dict[Any, Any] | None) -> None:
if checkpoint is None:
return
checkpoint_dir = checkpoint["checkpoint_dir"]
self.model = safe_load_model(self.model, os.path.join(checkpoint_dir, "model.safetensors"))
self.optimizer.load_state_dict(torch.load(os.path.join(checkpoint_dir, "optimizer.pt")))
self.model = safe_load_model(
self.model,
os.path.join(checkpoint_dir, "model.safetensors"),
)
self.optimizer.load_state_dict(
torch.load(os.path.join(checkpoint_dir, "optimizer.pt")),
)

def save_checkpoint(self, checkpoint_dir: str) -> dict[Any, Any]:
"""Save model and optimizer state to checkpoint."""
safe_save_model(self.model, os.path.join(checkpoint_dir, "model.safetensors"))
torch.save(self.optimizer.state_dict(), os.path.join(checkpoint_dir, "optimizer.pt"))
torch.save(
self.optimizer.state_dict(),
os.path.join(checkpoint_dir, "optimizer.pt"),
)
return {"checkpoint_dir": checkpoint_dir}
Loading
Loading