Skip to content

Commit 98c1806

Browse files
committed
Create rnnode.py
1 parent 4a2795c commit 98c1806

File tree

1 file changed

+243
-0
lines changed

1 file changed

+243
-0
lines changed

doc/src/week46/programs/rnnode.py

Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
import numpy as np
2+
import torch
3+
import torch.nn as nn
4+
from torch.utils.data import Dataset, DataLoader
5+
import matplotlib.pyplot as plt
6+
7+
8+
# ============================
9+
# 1. Physics model (from rk4.py)
10+
# ============================
11+
12+
# Global parameters (same idea as in rk4.py)
13+
gamma = 0.2 # damping
14+
Omegatilde = 0.5 # driving frequency
15+
Ftilde = 1.0 # driving amplitude
16+
17+
def spring_force(v, x, t):
18+
"""
19+
SpringForce from rk4.py:
20+
note: divided by mass => returns acceleration
21+
a = -2*gamma*v - x + Ftilde*cos(Omegatilde * t)
22+
"""
23+
return -2.0 * gamma * v - x + Ftilde * np.cos(Omegatilde * t)
24+
25+
26+
def rk4_trajectory(DeltaT=0.001, tfinal=20.0, x0=1.0, v0=0.0):
27+
"""
28+
Reimplementation of RK4 integrator from rk4.py.
29+
Returns t, x, v arrays.
30+
"""
31+
n = int(np.ceil(tfinal / DeltaT))
32+
33+
t = np.zeros(n, dtype=np.float32)
34+
x = np.zeros(n, dtype=np.float32)
35+
v = np.zeros(n, dtype=np.float32)
36+
37+
x[0] = x0
38+
v[0] = v0
39+
40+
for i in range(n - 1):
41+
# k1
42+
k1x = DeltaT * v[i]
43+
k1v = DeltaT * spring_force(v[i], x[i], t[i])
44+
45+
# k2
46+
vv = v[i] + 0.5 * k1v
47+
xx = x[i] + 0.5 * k1x
48+
k2x = DeltaT * vv
49+
k2v = DeltaT * spring_force(vv, xx, t[i] + 0.5 * DeltaT)
50+
51+
# k3
52+
vv = v[i] + 0.5 * k2v
53+
xx = x[i] + 0.5 * k2x
54+
k3x = DeltaT * vv
55+
k3v = DeltaT * spring_force(vv, xx, t[i] + 0.5 * DeltaT)
56+
57+
# k4
58+
vv = v[i] + k3v
59+
xx = x[i] + k3x
60+
k4x = DeltaT * vv
61+
k4v = DeltaT * spring_force(vv, xx, t[i] + DeltaT)
62+
63+
# Update
64+
x[i + 1] = x[i] + (k1x + 2.0 * k2x + 2.0 * k3x + k4x) / 6.0
65+
v[i + 1] = v[i] + (k1v + 2.0 * k2v + 2.0 * k3v + k4v) / 6.0
66+
t[i + 1] = t[i] + DeltaT
67+
68+
return t, x, v
69+
70+
71+
# =====================================
72+
# 2. Sequence generation for RNN training
73+
# =====================================
74+
75+
def create_sequences(x, seq_len):
76+
"""
77+
Given a 1D array x (e.g., position as a function of time),
78+
create input/target sequences for next-step prediction.
79+
80+
Inputs: [x_i, x_{i+1}, ..., x_{i+seq_len-1}]
81+
Targets: [x_{i+1}, ..., x_{i+seq_len}]
82+
"""
83+
xs = []
84+
ys = []
85+
for i in range(len(x) - seq_len):
86+
seq_x = x[i : i + seq_len]
87+
seq_y = x[i + 1 : i + seq_len + 1] # shifted by one step
88+
xs.append(seq_x)
89+
ys.append(seq_y)
90+
91+
xs = np.array(xs, dtype=np.float32) # shape: (num_samples, seq_len)
92+
ys = np.array(ys, dtype=np.float32) # shape: (num_samples, seq_len)
93+
# Add feature dimension (1 feature: the position x)
94+
xs = np.expand_dims(xs, axis=-1) # (num_samples, seq_len, 1)
95+
ys = np.expand_dims(ys, axis=-1) # (num_samples, seq_len, 1)
96+
return xs, ys
97+
98+
99+
class OscillatorDataset(Dataset):
100+
def __init__(self, seq_len=50, DeltaT=0.001, tfinal=20.0, x0=1.0, v0=0.0):
101+
t, x, v = rk4_trajectory(DeltaT=DeltaT, tfinal=tfinal, x0=x0, v0=v0)
102+
self.t = t
103+
self.x = x
104+
self.v = v
105+
xs, ys = create_sequences(x, seq_len=seq_len)
106+
self.inputs = torch.from_numpy(xs) # (N, seq_len, 1)
107+
self.targets = torch.from_numpy(ys) # (N, seq_len, 1)
108+
109+
def __len__(self):
110+
return self.inputs.shape[0]
111+
112+
def __getitem__(self, idx):
113+
return self.inputs[idx], self.targets[idx]
114+
115+
116+
# ==============================
117+
# 3. RNN model (LSTM-based)
118+
# ==============================
119+
120+
class RNNPredictor(nn.Module):
121+
def __init__(self, input_size=1, hidden_size=32, num_layers=1, output_size=1):
122+
super().__init__()
123+
self.lstm = nn.LSTM(input_size=input_size,
124+
hidden_size=hidden_size,
125+
num_layers=num_layers,
126+
batch_first=True)
127+
self.fc = nn.Linear(hidden_size, output_size)
128+
129+
def forward(self, x):
130+
# x: (batch, seq_len, input_size)
131+
out, _ = self.lstm(x) # out: (batch, seq_len, hidden_size)
132+
out = self.fc(out) # (batch, seq_len, output_size)
133+
return out
134+
135+
136+
# ==============================
137+
# 4. Training loop
138+
# ==============================
139+
140+
def train_model(
141+
seq_len=50,
142+
DeltaT=0.001,
143+
tfinal=20.0,
144+
batch_size=64,
145+
num_epochs=10,
146+
hidden_size=64,
147+
lr=1e-3,
148+
device=None,
149+
):
150+
if device is None:
151+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
152+
print(f"Using device: {device}")
153+
154+
# Dataset & DataLoader
155+
dataset = OscillatorDataset(seq_len=seq_len, DeltaT=DeltaT, tfinal=tfinal)
156+
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
157+
158+
# Model, loss, optimizer
159+
model = RNNPredictor(input_size=1, hidden_size=hidden_size, output_size=1)
160+
model.to(device)
161+
162+
criterion = nn.MSELoss()
163+
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
164+
165+
# Training loop
166+
model.train()
167+
for epoch in range(num_epochs):
168+
epoch_loss = 0.0
169+
for batch_x, batch_y in train_loader:
170+
batch_x = batch_x.to(device)
171+
batch_y = batch_y.to(device)
172+
173+
optimizer.zero_grad()
174+
outputs = model(batch_x)
175+
loss = criterion(outputs, batch_y)
176+
loss.backward()
177+
optimizer.step()
178+
179+
epoch_loss += loss.item() * batch_x.size(0)
180+
181+
epoch_loss /= len(train_loader.dataset)
182+
print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.6f}")
183+
184+
return model, dataset
185+
186+
187+
# ==============================
188+
# 5. Evaluation / visualization
189+
# ==============================
190+
191+
def evaluate_and_plot(model, dataset, seq_len=50, device=None):
192+
if device is None:
193+
device = next(model.parameters()).device
194+
195+
model.eval()
196+
with torch.no_grad():
197+
# Take a single sequence from the dataset
198+
x_seq, y_seq = dataset[0] # shapes: (seq_len, 1), (seq_len, 1)
199+
x_input = x_seq.unsqueeze(0).to(device) # (1, seq_len, 1)
200+
201+
# Model prediction (next-step for whole sequence)
202+
y_pred = model(x_input).cpu().numpy().squeeze(-1).squeeze(0) # (seq_len,)
203+
204+
# True target
205+
y_true = y_seq.numpy().squeeze(-1) # (seq_len,)
206+
207+
# Plot comparison
208+
plt.figure()
209+
plt.plot(y_true, label="True x(t+Δt)", linestyle="-")
210+
plt.plot(y_pred, label="Predicted x(t+Δt)", linestyle="--")
211+
plt.xlabel("Time step in sequence")
212+
plt.ylabel("Position")
213+
plt.legend()
214+
plt.title("RNN next-step prediction on oscillator trajectory")
215+
plt.tight_layout()
216+
plt.show()
217+
218+
219+
# ==============================
220+
# 6. Main
221+
# ==============================
222+
223+
if __name__ == "__main__":
224+
# Hyperparameters can be tweaked as you like
225+
seq_len = 50
226+
DeltaT = 0.001
227+
tfinal = 20.0
228+
num_epochs = 10
229+
batch_size = 64
230+
hidden_size = 64
231+
lr = 1e-3
232+
233+
model, dataset = train_model(
234+
seq_len=seq_len,
235+
DeltaT=DeltaT,
236+
tfinal=tfinal,
237+
batch_size=batch_size,
238+
num_epochs=num_epochs,
239+
hidden_size=hidden_size,
240+
lr=lr,
241+
)
242+
243+
evaluate_and_plot(model, dataset, seq_len=seq_len)

0 commit comments

Comments
 (0)