Skip to content

Commit e4215ff

Browse files
committed
Update rnnode.py
1 parent 98c1806 commit e4215ff

File tree

1 file changed

+8
-25
lines changed

1 file changed

+8
-25
lines changed

doc/src/week46/programs/rnnode.py

Lines changed: 8 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,16 @@
55
import matplotlib.pyplot as plt
66

77

8-
# ============================
9-
# 1. Physics model (from rk4.py)
10-
# ============================
8+
# Newton's equation for harmonic oscillations with external force
119

12-
# Global parameters (same idea as in rk4.py)
10+
# Global parameters
1311
gamma = 0.2 # damping
1412
Omegatilde = 0.5 # driving frequency
1513
Ftilde = 1.0 # driving amplitude
1614

1715
def spring_force(v, x, t):
1816
"""
19-
SpringForce from rk4.py:
17+
SpringForce:
2018
note: divided by mass => returns acceleration
2119
a = -2*gamma*v - x + Ftilde*cos(Omegatilde * t)
2220
"""
@@ -25,7 +23,6 @@ def spring_force(v, x, t):
2523

2624
def rk4_trajectory(DeltaT=0.001, tfinal=20.0, x0=1.0, v0=0.0):
2725
"""
28-
Reimplementation of RK4 integrator from rk4.py.
2926
Returns t, x, v arrays.
3027
"""
3128
n = int(np.ceil(tfinal / DeltaT))
@@ -68,9 +65,7 @@ def rk4_trajectory(DeltaT=0.001, tfinal=20.0, x0=1.0, v0=0.0):
6865
return t, x, v
6966

7067

71-
# =====================================
72-
# 2. Sequence generation for RNN training
73-
# =====================================
68+
# Sequence generation for RNN training
7469

7570
def create_sequences(x, seq_len):
7671
"""
@@ -113,9 +108,7 @@ def __getitem__(self, idx):
113108
return self.inputs[idx], self.targets[idx]
114109

115110

116-
# ==============================
117-
# 3. RNN model (LSTM-based)
118-
# ==============================
111+
# RNN model (LSTM-based in this example)
119112

120113
class RNNPredictor(nn.Module):
121114
def __init__(self, input_size=1, hidden_size=32, num_layers=1, output_size=1):
@@ -133,9 +126,7 @@ def forward(self, x):
133126
return out
134127

135128

136-
# ==============================
137-
# 4. Training loop
138-
# ==============================
129+
# Training part
139130

140131
def train_model(
141132
seq_len=50,
@@ -184,9 +175,7 @@ def train_model(
184175
return model, dataset
185176

186177

187-
# ==============================
188-
# 5. Evaluation / visualization
189-
# ==============================
178+
# Evaluation / visualization
190179

191180
def evaluate_and_plot(model, dataset, seq_len=50, device=None):
192181
if device is None:
@@ -197,13 +186,10 @@ def evaluate_and_plot(model, dataset, seq_len=50, device=None):
197186
# Take a single sequence from the dataset
198187
x_seq, y_seq = dataset[0] # shapes: (seq_len, 1), (seq_len, 1)
199188
x_input = x_seq.unsqueeze(0).to(device) # (1, seq_len, 1)
200-
201189
# Model prediction (next-step for whole sequence)
202190
y_pred = model(x_input).cpu().numpy().squeeze(-1).squeeze(0) # (seq_len,)
203-
204191
# True target
205192
y_true = y_seq.numpy().squeeze(-1) # (seq_len,)
206-
207193
# Plot comparison
208194
plt.figure()
209195
plt.plot(y_true, label="True x(t+Δt)", linestyle="-")
@@ -215,10 +201,7 @@ def evaluate_and_plot(model, dataset, seq_len=50, device=None):
215201
plt.tight_layout()
216202
plt.show()
217203

218-
219-
# ==============================
220-
# 6. Main
221-
# ==============================
204+
# This is the main part of the code where we define the network
222205

223206
if __name__ == "__main__":
224207
# Hyperparameters can be tweaked as you like

0 commit comments

Comments
 (0)