44import torch .nn as nn
55from torch .utils .data import DataLoader , RandomSampler
66from typing import Tuple
7+ from collections import defaultdict
78
89from .rnn import BaseRNN
910from .spice_utils import SpiceDataset
@@ -130,13 +131,23 @@ def batch_train(
130131 )
131132
132133 # small l2-regularization on logits to keep the absolute values in the smalles possible range (only diff between values is necessary)
133- loss_step += 0.001 * ys_pred .abs ().sum (dim = - 1 ).mean ()
134+ # loss_step += 0.01 * torch.pow(ys_pred, 2).mean()
135+
136+ # l2 reg on module outputs
137+ for module in model .submodules_rnn :
138+ input_size_module = model .submodules_rnn [module ].linear_in .in_features + 1
139+ # loss_step += 0.01 * torch.pow(model.submodules_rnn[module](torch.rand((1, 100, input_size_module))), 2).mean()
140+ loss_step += 0.01 * torch .abs (model .submodules_rnn [module ](torch .ones ((1 , 100 , input_size_module )))).mean ()
141+
142+ # l2 reg on state values
143+ # for state in model.state:
144+ # loss_step += 0.01 * torch.pow(model.state[state], 2).mean()
134145
135146 # Add SINDy regularization loss
136147 if sindy_weight > 0 and model .sindy_loss != 0 :
137148 loss_step = loss_step + sindy_weight * model .sindy_loss
138149
139- loss_batch += loss_step
150+ loss_batch += loss_step . item ()
140151 iterations += 1
141152
142153 if torch .is_grad_enabled ():
@@ -146,8 +157,8 @@ def batch_train(
146157 loss_step .backward ()
147158 torch .nn .utils .clip_grad_norm_ (model .parameters (), max_norm = 1.0 )
148159 optimizer .step ()
149-
150- return model , optimizer , loss_batch . item () / iterations
160+
161+ return model , optimizer , loss_batch / iterations
151162
152163
153164def fit_sindy_second_stage (
@@ -377,7 +388,7 @@ def fit_model(
377388 dataloader_test = DataLoader (dataset_test , batch_size = len (dataset_test ))
378389
379390 # set up learning rate scheduler
380- warmup_steps = 0
391+ warmup_steps = 500
381392 warmup_steps = warmup_steps if epochs > warmup_steps else 1 #int(epochs * 0.125/16)
382393 if scheduler and optimizer is not None :
383394 # Define the LambdaLR scheduler for warm-up
@@ -475,14 +486,17 @@ def warmup_lr_lambda(current_step):
475486 print ("\n " + "=" * 80 )
476487 print (f"SPICE model before { n_calls_to_train_model } epochs:" )
477488 print ("=" * 80 )
478- model .print_spice_model (ensemble_idx = 4 )
489+ model .print_spice_model (ensemble_idx = 0 )
479490
480491 model .thresholding (threshold = sindy_threshold , base_threshold = 0.1 , n_terms_cutoff = 1 )
481492
493+ # TODO: Try optimizer reset for stability
494+ # optimizer.state = defaultdict(dict)
495+
482496 print ("\n " + "=" * 80 )
483497 print (f"SPICE model after { n_calls_to_train_model } epochs:" )
484498 print ("=" * 80 )
485- model .print_spice_model (ensemble_idx = 4 )
499+ model .print_spice_model (ensemble_idx = 0 )
486500
487501 # check for convergence
488502 dloss = last_loss - loss_test if dataset_test is not None else last_loss - loss_train
0 commit comments