diff --git a/executors/accelerate/src/hypha/accelerate_executor/training.py b/executors/accelerate/src/hypha/accelerate_executor/training.py index 335ec5e9..00b2861c 100644 --- a/executors/accelerate/src/hypha/accelerate_executor/training.py +++ b/executors/accelerate/src/hypha/accelerate_executor/training.py @@ -21,7 +21,7 @@ from opentelemetry.sdk.metrics import MeterProvider from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader from opentelemetry.sdk.resources import OTELResourceDetector, get_aggregated_resources -from safetensors.torch import load_file, save_file, save_model +from safetensors.torch import load_model, save_file, save_model from .api import Session from .dataset import IterableStreamDataSet @@ -274,7 +274,9 @@ def sleep_until_epoch_ms(target_ms: int) -> None: rel_path = parameters.get("path") if parameters else latest.get("path") if isinstance(rel_path, str): path = os.path.join(work_dir, rel_path) - model.load_state_dict(merge_models(previous_model_path, path)) + # As https://github.com/huggingface/safetensors/blob/806426784adb43631e9a1102d4621126bb589347/bindings/python/py_src/safetensors/torch.py#L228C33-L228C48 + # it should be fine to use `strict=False` here. + model.load_state_dict(merge_models(previous_model_path, path), strict=False) save_model(model, previous_model_path) # Once we updated the model, we no longer need the parameter file. @@ -381,7 +383,7 @@ def sleep_until_epoch_ms(target_ms: int) -> None: rel_path = incomming.get("path") if isinstance(rel_path, str): path = os.path.join(work_dir, rel_path) - model.load_state_dict(load_file(path)) + load_model(model, path) os.remove(previous_model_path) shutil.copy(path, previous_model_path) os.remove(path)