diff --git a/neuralprophet/utils.py b/neuralprophet/utils.py index 2c198b929..fd209cb56 100644 --- a/neuralprophet/utils.py +++ b/neuralprophet/utils.py @@ -23,31 +23,50 @@ def save(forecaster, path: str): - """save a fitted np model to a disk file. + """Save a fitted Neural Prophet model to disk. - Parameters - ---------- + Parameters: forecaster : np.forecaster.NeuralProphet input forecaster that is fitted path : str path and filename to be saved. filename could be any but suggested to have extension .np. - Examples - -------- + After you fitted a model, you may save the model to save_test_model.np >>> from neuralprophet import save >>> save(forecaster, "test_save_model.np") """ - # Remove the Lightning trainer since it does not serialise correcly with torch.save - attrs_to_remove = ["trainer"] + # List of attributes to remove + attrs_to_remove_forecaster = ["trainer"] + attrs_to_remove_model = ["_trainer"] + + # Store removed attributes temporarily removed_attrs = {} - for attr in attrs_to_remove: - removed_attrs[attr] = getattr(forecaster, attr) - setattr(forecaster, attr, None) - torch.save(forecaster, path) - - # Restore the Lightning trainer - for attr in attrs_to_remove: - setattr(forecaster, attr, removed_attrs[attr]) + + # Remove specified attributes from forecaster + for attr in attrs_to_remove_forecaster: + if hasattr(forecaster, attr): + removed_attrs[attr] = getattr(forecaster, attr) + setattr(forecaster, attr, None) + + # Remove specified attributes from forecaster.model + for attr in attrs_to_remove_model: + if hasattr(forecaster.model, attr): + removed_attrs[attr] = getattr(forecaster.model, attr) + setattr(forecaster.model, attr, None) + + # Perform the save operation + try: + torch.save(forecaster, path) + except Exception as e: + print(f"An error occurred while saving the model: {e}") + raise + finally: + # Restore the removed attributes + for attr, value in removed_attrs.items(): + if hasattr(forecaster, attr): + setattr(forecaster, attr, value) + elif hasattr(forecaster.model, attr): + setattr(forecaster.model, attr, value) def load(path: str):