55import matplotlib .pyplot as plt
66
77
8- # ============================
9- # 1. Physics model (from rk4.py)
10- # ============================
8+ # Newton's equation for harmonic oscillations with external force
119
12- # Global parameters (same idea as in rk4.py)
10+ # Global parameters
1311gamma = 0.2 # damping
1412Omegatilde = 0.5 # driving frequency
1513Ftilde = 1.0 # driving amplitude
1614
1715def spring_force (v , x , t ):
1816 """
19- SpringForce from rk4.py :
17+ SpringForce:
2018 note: divided by mass => returns acceleration
2119 a = -2*gamma*v - x + Ftilde*cos(Omegatilde * t)
2220 """
@@ -25,7 +23,6 @@ def spring_force(v, x, t):
2523
2624def rk4_trajectory (DeltaT = 0.001 , tfinal = 20.0 , x0 = 1.0 , v0 = 0.0 ):
2725 """
28- Reimplementation of RK4 integrator from rk4.py.
2926 Returns t, x, v arrays.
3027 """
3128 n = int (np .ceil (tfinal / DeltaT ))
@@ -68,9 +65,7 @@ def rk4_trajectory(DeltaT=0.001, tfinal=20.0, x0=1.0, v0=0.0):
6865 return t , x , v
6966
7067
71- # =====================================
72- # 2. Sequence generation for RNN training
73- # =====================================
68+ # Sequence generation for RNN training
7469
7570def create_sequences (x , seq_len ):
7671 """
@@ -113,9 +108,7 @@ def __getitem__(self, idx):
113108 return self .inputs [idx ], self .targets [idx ]
114109
115110
116- # ==============================
117- # 3. RNN model (LSTM-based)
118- # ==============================
111+ # RNN model (LSTM-based in this example)
119112
120113class RNNPredictor (nn .Module ):
121114 def __init__ (self , input_size = 1 , hidden_size = 32 , num_layers = 1 , output_size = 1 ):
@@ -133,9 +126,7 @@ def forward(self, x):
133126 return out
134127
135128
136- # ==============================
137- # 4. Training loop
138- # ==============================
129+ # Training part
139130
140131def train_model (
141132 seq_len = 50 ,
@@ -184,9 +175,7 @@ def train_model(
184175 return model , dataset
185176
186177
187- # ==============================
188- # 5. Evaluation / visualization
189- # ==============================
178+ # Evaluation / visualization
190179
191180def evaluate_and_plot (model , dataset , seq_len = 50 , device = None ):
192181 if device is None :
@@ -197,13 +186,10 @@ def evaluate_and_plot(model, dataset, seq_len=50, device=None):
197186 # Take a single sequence from the dataset
198187 x_seq , y_seq = dataset [0 ] # shapes: (seq_len, 1), (seq_len, 1)
199188 x_input = x_seq .unsqueeze (0 ).to (device ) # (1, seq_len, 1)
200-
201189 # Model prediction (next-step for whole sequence)
202190 y_pred = model (x_input ).cpu ().numpy ().squeeze (- 1 ).squeeze (0 ) # (seq_len,)
203-
204191 # True target
205192 y_true = y_seq .numpy ().squeeze (- 1 ) # (seq_len,)
206-
207193 # Plot comparison
208194 plt .figure ()
209195 plt .plot (y_true , label = "True x(t+Δt)" , linestyle = "-" )
@@ -215,10 +201,7 @@ def evaluate_and_plot(model, dataset, seq_len=50, device=None):
215201 plt .tight_layout ()
216202 plt .show ()
217203
218-
219- # ==============================
220- # 6. Main
221- # ==============================
204+ # This is the main part of the code where we define the network
222205
223206if __name__ == "__main__" :
224207 # Hyperparameters can be tweaked as you like
0 commit comments