Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
from opentelemetry.sdk.resources import OTELResourceDetector, get_aggregated_resources
from safetensors.torch import load_file, save_file, save_model
from safetensors.torch import load_model, save_file, save_model

from .api import Session
from .dataset import IterableStreamDataSet
Expand Down Expand Up @@ -274,7 +274,9 @@ def sleep_until_epoch_ms(target_ms: int) -> None:
rel_path = parameters.get("path") if parameters else latest.get("path")
if isinstance(rel_path, str):
path = os.path.join(work_dir, rel_path)
model.load_state_dict(merge_models(previous_model_path, path))
# As https://github.com/huggingface/safetensors/blob/806426784adb43631e9a1102d4621126bb589347/bindings/python/py_src/safetensors/torch.py#L228C33-L228C48
# it should be fine to use `strict=False` here.
model.load_state_dict(merge_models(previous_model_path, path), strict=False)
save_model(model, previous_model_path)

# Once we updated the model, we no longer need the parameter file.
Expand Down Expand Up @@ -381,7 +383,7 @@ def sleep_until_epoch_ms(target_ms: int) -> None:
rel_path = incomming.get("path")
if isinstance(rel_path, str):
path = os.path.join(work_dir, rel_path)
model.load_state_dict(load_file(path))
load_model(model, path)
os.remove(previous_model_path)
shutil.copy(path, previous_model_path)
os.remove(path)
Expand Down