Skip to content

Commit 0e4cda1

Browse files
committed
parameter recovery checked
1 parent fa56123 commit 0e4cda1

3 files changed

Lines changed: 25 additions & 9 deletions

File tree

spice/resources/rnn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -605,6 +605,8 @@ def print_spice_model(self, participant_id: int = 0, ensemble_idx: int = 0) -> N
605605
equation_str = module + "[t+1] = "
606606
for index_term, term in enumerate(self.sindy_library_names[module]):
607607
coeff_value = self.sindy_coefficients[module][participant_id, ensemble_idx, index_term].item()
608+
if term == module:
609+
coeff_value += 1
608610
if np.abs(coeff_value) > 1e-3:
609611
if equation_str[-3:] != " = ":
610612
equation_str += "+ "

spice/resources/spice_training.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch.nn as nn
55
from torch.utils.data import DataLoader, RandomSampler
66
from typing import Tuple
7+
from collections import defaultdict
78

89
from .rnn import BaseRNN
910
from .spice_utils import SpiceDataset
@@ -130,13 +131,23 @@ def batch_train(
130131
)
131132

132133
# small l2-regularization on logits to keep the absolute values in the smalles possible range (only diff between values is necessary)
133-
loss_step += 0.001 * ys_pred.abs().sum(dim=-1).mean()
134+
# loss_step += 0.01 * torch.pow(ys_pred, 2).mean()
135+
136+
# l2 reg on module outputs
137+
for module in model.submodules_rnn:
138+
input_size_module = model.submodules_rnn[module].linear_in.in_features+1
139+
# loss_step += 0.01 * torch.pow(model.submodules_rnn[module](torch.rand((1, 100, input_size_module))), 2).mean()
140+
loss_step += 0.01 * torch.abs(model.submodules_rnn[module](torch.ones((1, 100, input_size_module)))).mean()
141+
142+
# l2 reg on state values
143+
# for state in model.state:
144+
# loss_step += 0.01 * torch.pow(model.state[state], 2).mean()
134145

135146
# Add SINDy regularization loss
136147
if sindy_weight > 0 and model.sindy_loss != 0:
137148
loss_step = loss_step + sindy_weight * model.sindy_loss
138149

139-
loss_batch += loss_step
150+
loss_batch += loss_step.item()
140151
iterations += 1
141152

142153
if torch.is_grad_enabled():
@@ -146,8 +157,8 @@ def batch_train(
146157
loss_step.backward()
147158
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
148159
optimizer.step()
149-
150-
return model, optimizer, loss_batch.item()/iterations
160+
161+
return model, optimizer, loss_batch/iterations
151162

152163

153164
def fit_sindy_second_stage(
@@ -377,7 +388,7 @@ def fit_model(
377388
dataloader_test = DataLoader(dataset_test, batch_size=len(dataset_test))
378389

379390
# set up learning rate scheduler
380-
warmup_steps = 0
391+
warmup_steps = 500
381392
warmup_steps = warmup_steps if epochs > warmup_steps else 1 #int(epochs * 0.125/16)
382393
if scheduler and optimizer is not None:
383394
# Define the LambdaLR scheduler for warm-up
@@ -475,14 +486,17 @@ def warmup_lr_lambda(current_step):
475486
print("\n"+"="*80)
476487
print(f"SPICE model before {n_calls_to_train_model} epochs:")
477488
print("="*80)
478-
model.print_spice_model(ensemble_idx=4)
489+
model.print_spice_model(ensemble_idx=0)
479490

480491
model.thresholding(threshold=sindy_threshold, base_threshold=0.1, n_terms_cutoff=1)
481492

493+
# TODO: Try optimizer reset for stability
494+
# optimizer.state = defaultdict(dict)
495+
482496
print("\n"+"="*80)
483497
print(f"SPICE model after {n_calls_to_train_model} epochs:")
484498
print("="*80)
485-
model.print_spice_model(ensemble_idx=4)
499+
model.print_spice_model(ensemble_idx=0)
486500

487501
# check for convergence
488502
dloss = last_loss - loss_test if dataset_test is not None else last_loss - loss_train

tutorials/0_data_preparation.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -536,7 +536,7 @@
536536
],
537537
"metadata": {
538538
"kernelspec": {
539-
"display_name": ".venv",
539+
"display_name": "spice",
540540
"language": "python",
541541
"name": "python3"
542542
},
@@ -550,7 +550,7 @@
550550
"name": "python",
551551
"nbconvert_exporter": "python",
552552
"pygments_lexer": "ipython3",
553-
"version": "3.11.5"
553+
"version": "3.11.13"
554554
}
555555
},
556556
"nbformat": 4,

0 commit comments

Comments
 (0)