Skip to content

Commit

Permalink
Lstm saving (#8)
Browse files Browse the repository at this point in the history
* Add saving lstm model to not retrain in the future.
  • Loading branch information
Jad-yehya authored Jul 22, 2024
1 parent 0203f1f commit 125f312
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
1 change: 1 addition & 0 deletions solvers/AR.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def mean_overlaping_pred(

return averaged_predictions


def run(self, _):

self.model.to(self.device)
Expand Down
11 changes: 6 additions & 5 deletions solvers/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ def forward(self, x):
class Solver(BaseSolver):
name = "LSTM"

install_cmd = "pip"
requirements = ["torch", "tqdm"]
install_cmd = "conda"
requirements = ["pip:torch", "tqdm"]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Expand All @@ -56,7 +56,7 @@ class Solver(BaseSolver):
"n_epochs": [50],
"lr": [1e-5],
"window": [True],
"window_size": [128], # window_size = seq_len
"window_size": [256], # window_size = seq_len
"stride": [1],
"percentile": [97],
"encoder_layers": [32],
Expand All @@ -75,8 +75,6 @@ def set_objective(self, X_train, y_test, X_test):
self.n_features = X_train.shape[1]
self.seq_len = self.window_size

print("Simulated data shape: ", X_train.shape, X_test.shape)

self.model = LSTM_Autoencoder(
self.seq_len,
self.n_features,
Expand Down Expand Up @@ -147,6 +145,9 @@ def run(self, _):

ti.set_postfix(train_loss=f"{train_loss:.5f}")

# Saving the model
torch.save(self.model.state_dict(), "model.pth")

self.model.eval()
raw_reconstruction = []
for x in self.test_loader:
Expand Down

0 comments on commit 125f312

Please sign in to comment.