Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ def train(
"ignore",
message=".*does not have many workers which may be a bottleneck.*",
)
warnings.filterwarnings(
"ignore",
message="Starting from v1\\.9\\.0, `tensorboardX` has been removed.*",
)
self.model.fit(
output_train=output,
cell_line_input=cell_line_input,
Expand Down
14 changes: 9 additions & 5 deletions drevalpy/models/SimpleNeuralNetwork/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Utility functions for the simple neural network models."""

import os
import secrets
from typing import Any

Expand Down Expand Up @@ -229,11 +230,14 @@ def fit(
monitor = "train_loss" if (val_loader is None) else "val_loss"

early_stop_callback = EarlyStopping(monitor=monitor, mode="min", patience=patience)
name = "version-" + "".join(
[secrets.choice("0123456789abcdef") for i in range(20)]
) # preventing conflicts of filenames

unique_subfolder = os.path.join(model_checkpoint_dir, "run_" + secrets.token_hex(8))
os.makedirs(unique_subfolder, exist_ok=True)

# prevent conflicts
name = "version-" + "".join([secrets.choice("0123456789abcdef") for _ in range(10)])
self.checkpoint_callback = pl.callbacks.ModelCheckpoint(
dirpath=model_checkpoint_dir,
dirpath=unique_subfolder,
monitor=monitor,
mode="min",
save_top_k=1,
Expand Down Expand Up @@ -262,7 +266,7 @@ def fit(

# load best model
if self.checkpoint_callback.best_model_path is not None:
checkpoint = torch.load(self.checkpoint_callback.best_model_path) # noqa: S614
checkpoint = torch.load(self.checkpoint_callback.best_model_path, weights_only=True) # noqa: S614
self.load_state_dict(checkpoint["state_dict"])
else:
print("checkpoint_callback: No best model found, using the last model.")
Expand Down
Loading