diff --git a/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py b/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py index ba0d84fe..ee8765f0 100644 --- a/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py +++ b/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py @@ -94,6 +94,10 @@ def train( "ignore", message=".*does not have many workers which may be a bottleneck.*", ) + warnings.filterwarnings( + "ignore", + message="Starting from v1\\.9\\.0, `tensorboardX` has been removed.*", + ) self.model.fit( output_train=output, cell_line_input=cell_line_input, diff --git a/drevalpy/models/SimpleNeuralNetwork/utils.py b/drevalpy/models/SimpleNeuralNetwork/utils.py index 9209ed7a..0a6c5801 100644 --- a/drevalpy/models/SimpleNeuralNetwork/utils.py +++ b/drevalpy/models/SimpleNeuralNetwork/utils.py @@ -1,5 +1,6 @@ """Utility functions for the simple neural network models.""" +import os import secrets from typing import Any @@ -229,11 +230,14 @@ def fit( monitor = "train_loss" if (val_loader is None) else "val_loss" early_stop_callback = EarlyStopping(monitor=monitor, mode="min", patience=patience) - name = "version-" + "".join( - [secrets.choice("0123456789abcdef") for i in range(20)] - ) # preventing conflicts of filenames + + unique_subfolder = os.path.join(model_checkpoint_dir, "run_" + secrets.token_hex(8)) + os.makedirs(unique_subfolder, exist_ok=True) + + # prevent conflicts + name = "version-" + "".join([secrets.choice("0123456789abcdef") for _ in range(10)]) self.checkpoint_callback = pl.callbacks.ModelCheckpoint( - dirpath=model_checkpoint_dir, + dirpath=unique_subfolder, monitor=monitor, mode="min", save_top_k=1, @@ -262,7 +266,7 @@ def fit( # load best model if self.checkpoint_callback.best_model_path is not None: - checkpoint = torch.load(self.checkpoint_callback.best_model_path) # noqa: S614 + checkpoint = torch.load(self.checkpoint_callback.best_model_path, weights_only=True) # noqa: S614 self.load_state_dict(checkpoint["state_dict"]) else: print("checkpoint_callback: No best model found, using the last model.")