Skip to content

Commit

Permalink
fix: mistaken logical code in auto_save_model_if_necessary;
Browse files Browse the repository at this point in the history
  • Loading branch information
WenjieDu committed Apr 25, 2023
1 parent 922bbfb commit c7b6e26
Show file tree
Hide file tree
Showing 12 changed files with 30 additions and 22 deletions.
18 changes: 12 additions & 6 deletions pypots/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,25 +195,31 @@ def save_model(
f'Failed to save the model to "{saving_path}" because of the below error! \n{e}'
)

def auto_save_model_if_necessary(self, saving_name: str = None):
def auto_save_model_if_necessary(
self,
training_finished: bool = True,
saving_name: str = None,
):
"""Automatically save the current model into a file if in need.
Parameters
----------
training_finished : bool, default = False,
Whether the training is already finished when invoke this function.
The saving_strategy "better" only works when training_finished is False.
The saving_strategy "best" only works when training_finished is True.
saving_name : str, default = None,
The file name of the saved model.
"""
if self.saving_path is not None and self.auto_save_model:
name = self.__class__.__name__ if saving_name is None else saving_name
if self.saving_strategy == "best":
if not training_finished and self.saving_strategy == "better":
self.save_model(self.saving_path, name)
else: # self.saving_strategy == "better"
elif training_finished and self.saving_strategy == "best":
self.save_model(self.saving_path, name)

logger.info(
f"Successfully saved the model to {os.path.join(self.saving_path, name)}"
)
else:
return

Expand Down
3 changes: 2 additions & 1 deletion pypots/classification/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,8 @@ def _train_model(
self.patience = self.original_patience
# save the model if necessary
self.auto_save_model_if_necessary(
saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}"
training_finished=False,
saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}",
)
else:
self.patience -= 1
Expand Down
2 changes: 1 addition & 1 deletion pypots/classification/brits.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.

# Step 3: save the model if necessary
self.auto_save_model_if_necessary()
self.auto_save_model_if_necessary(training_finished=True)

def classify(self, X: Union[dict, str], file_type: str = "h5py"):
"""Classify the input data with the trained model.
Expand Down
2 changes: 1 addition & 1 deletion pypots/classification/grud.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.

# Step 3: save the model if necessary
self.auto_save_model_if_necessary()
self.auto_save_model_if_necessary(training_finished=True)

def classify(self, X: Union[dict, str], file_type: str = "h5py") -> np.ndarray:
"""Classify the input data with the trained model.
Expand Down
2 changes: 1 addition & 1 deletion pypots/classification/raindrop.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,7 +826,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.

# Step 3: save the model if necessary
self.auto_save_model_if_necessary()
self.auto_save_model_if_necessary(training_finished=True)

def classify(self, X: Union[dict, str], file_type: str = "h5py") -> np.ndarray:
"""Classify the input data with the trained model.
Expand Down
5 changes: 3 additions & 2 deletions pypots/clustering/crli.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,8 @@ def _train_model(
self.patience = self.original_patience
# save the model if necessary
self.auto_save_model_if_necessary(
saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}"
training_finished=False,
saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}",
)
else:
self.patience -= 1
Expand Down Expand Up @@ -596,7 +597,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.

# Step 3: save the model if necessary
self.auto_save_model_if_necessary()
self.auto_save_model_if_necessary(training_finished=True)

def cluster(
self,
Expand Down
5 changes: 3 additions & 2 deletions pypots/clustering/vader.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,8 @@ def _train_model(
self.patience = self.original_patience
# save the model if necessary
self.auto_save_model_if_necessary(
saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}"
training_finished=False,
saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}",
)
else:
self.patience -= 1
Expand Down Expand Up @@ -683,7 +684,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.

# Step 3: save the model if necessary
self.auto_save_model_if_necessary()
self.auto_save_model_if_necessary(training_finished=True)

def cluster(self, X: Union[dict, str], file_type: str = "h5py") -> np.ndarray:
"""Cluster the input with the trained model.
Expand Down
3 changes: 2 additions & 1 deletion pypots/imputation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,8 @@ def _train_model(
self.patience = self.original_patience
# save the model if necessary
self.auto_save_model_if_necessary(
saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}"
training_finished=False,
saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}",
)
else:
self.patience -= 1
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/brits.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.

# Step 3: save the model if necessary
self.auto_save_model_if_necessary()
self.auto_save_model_if_necessary(training_finished=True)

def impute(
self,
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/saits.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.

# Step 3: save the model if necessary
self.auto_save_model_if_necessary()
self.auto_save_model_if_necessary(training_finished=True)

def impute(
self,
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.

# Step 3: save the model if necessary
self.auto_save_model_if_necessary()
self.auto_save_model_if_necessary(training_finished=True)

def impute(self, X: Union[dict, str], file_type: str = "h5py") -> np.ndarray:
"""Impute missing values in the given data with the trained model.
Expand Down
6 changes: 2 additions & 4 deletions pypots/utils/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ def create_dir_if_not_exist(path: str, is_dir: bool = True) -> None:
"""
path = extract_parent_dir(path) if not is_dir else path
if os.path.exists(path):
logger.info(f'The given directory "{path}" exists.')
else:
if not os.path.exists(path):
os.makedirs(path, exist_ok=True)
logger.info(f'Successfully created "{path}".')
logger.info(f'Successfully created the given path "{path}".')

0 comments on commit c7b6e26

Please sign in to comment.