From 50d5e4175740c586e9372ffccfcfccc3beac41c1 Mon Sep 17 00:00:00 2001 From: leoniewgnr <42536262+leoniewgnr@users.noreply.github.com> Date: Thu, 21 Sep 2023 12:39:11 -0700 Subject: [PATCH] [bug] Fix saved model size (#1425) * initial try * tidy up * fixed * reversed test changes * reversed test changes * removed minimal version --- neuralprophet/utils.py | 49 +++++++++++++++++++++++++++++------------- 1 file changed, 34 insertions(+), 15 deletions(-) diff --git a/neuralprophet/utils.py b/neuralprophet/utils.py index 2c198b929..fd209cb56 100644 --- a/neuralprophet/utils.py +++ b/neuralprophet/utils.py @@ -23,31 +23,50 @@ def save(forecaster, path: str): - """save a fitted np model to a disk file. + """Save a fitted Neural Prophet model to disk. - Parameters - ---------- + Parameters: forecaster : np.forecaster.NeuralProphet input forecaster that is fitted path : str path and filename to be saved. filename could be any but suggested to have extension .np. - Examples - -------- + After you fitted a model, you may save the model to save_test_model.np >>> from neuralprophet import save >>> save(forecaster, "test_save_model.np") """ - # Remove the Lightning trainer since it does not serialise correcly with torch.save - attrs_to_remove = ["trainer"] + # List of attributes to remove + attrs_to_remove_forecaster = ["trainer"] + attrs_to_remove_model = ["_trainer"] + + # Store removed attributes temporarily removed_attrs = {} - for attr in attrs_to_remove: - removed_attrs[attr] = getattr(forecaster, attr) - setattr(forecaster, attr, None) - torch.save(forecaster, path) - - # Restore the Lightning trainer - for attr in attrs_to_remove: - setattr(forecaster, attr, removed_attrs[attr]) + + # Remove specified attributes from forecaster + for attr in attrs_to_remove_forecaster: + if hasattr(forecaster, attr): + removed_attrs[attr] = getattr(forecaster, attr) + setattr(forecaster, attr, None) + + # Remove specified attributes from forecaster.model + for attr in attrs_to_remove_model: + if hasattr(forecaster.model, attr): + removed_attrs[attr] = getattr(forecaster.model, attr) + setattr(forecaster.model, attr, None) + + # Perform the save operation + try: + torch.save(forecaster, path) + except Exception as e: + print(f"An error occurred while saving the model: {e}") + raise + finally: + # Restore the removed attributes + for attr, value in removed_attrs.items(): + if hasattr(forecaster, attr): + setattr(forecaster, attr, value) + elif hasattr(forecaster.model, attr): + setattr(forecaster.model, attr, value) def load(path: str):