-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
324 lines (262 loc) · 14.2 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
"""
This is the main file for training any model on the chosen experiment.
"""
import torch
import numpy as np
import time
import copy
import matplotlib.pyplot as plt
import importlib
from _Utilities.Adam import Adam
from _Utilities.utilities import count_params, percentage_difference
################################################################
# configs
################################################################
configs = {
'model': 'fno_dse', # Model to train - fno, ffno, ufno, geo_fno, geo_ffno, geo_ufno, fno_dse, ffno_dse, ufno_dse
'experiment': 'Airfoil', # Burgers, Elasticity, Airfoil, ShearLayer, Humidity
'device': torch.device('cuda'), # Define device for training & inference - GPU/CPU
### Data specific parameters
# 'datapath': '_Data/Elasticity/', # Path to data
# 'datapath': '~/data/elasticity', # Path to data
# 'num_train': 1000, # Number of training samples
# 'num_test': 20, # Number of test samples
# 'batch_size': 20, # Batch size
# 'epochs': 501, # Number of epochs
'test_epochs': 1, # How often we print test error during training
### Training specific parameters
# 'learning_rate': 0.005, # Learning rate
# 'scheduler_step': 10, # Scheduler step size
# 'scheduler_gamma': 0.97, # Scheduler gamma
# 'weight_decay': 1e-4, # Weight decay
# 'iphi_loss_reg': 0.0, # Regularization parameter for IPHI loss term for the diffeomorphism models
# 'loss_fn': 'L1', # Loss function to use - L1, L2
### Saving and loading models
'save_model': True, # Whether to save the model or not
'load_model': False, # Whether to load a pretrained model or not, need to specify the model_path then.
'model_path': '_Models/model.pt', # Path to model file if loading model
### Model specific parameters
# 'modes1': 12, # Number of x-modes to use in the Fourier layer
# 'modes2': 12, # Number of y-modes to use in the Fourier layer
# 'width': 32, # Number of channels in the convolutional layers
# 'in_channels': 2, # Number of channels in input linear layer
# 'out_channels': 1, # Number of channels in output linear layer
# 'is_mesh': True, # Is it a mesh?
# 's1': 40, # Number of x-points on latent space GeoFNO grid
# 's2': 40, # Number of y-points on latent space GeoFNO grid
### Specifically for Burgers
# 'data_dist': 'cubic_from_conexp', # Data distribution to use - uniform, conexp, cubic_from_conexp
### Specifically for Airfoil
# 'data_small_domain': True, # Whether to use a small domain or not for specifically the Airfoil experiment
### Specifically for Shear Layer
# 'center_1': 256, # Center of top interface
# 'center_2': 768, # Center of bottom interface
# 'uniform': 100, # Number of points uniform along interface
# 'growth': 1.0, # Growth factor, how quickly do points become sparse
### Specifically for Humidity
# 'center_lat': 180, # Lattitude center of the nonuniform sampling region
# 'center_lon': 140, # Longitude center of the nonuniform sampling region
}
def train (configs):
"""
Main training function that will load the configs and find the correct model and experiment to run the training for. The function will return all relevant metrics that are computed during training validation.
Returns:
training_times (list): The time in seconds it took to train the model for each epoch.
train_loss_hist (list): The training loss for each epoch.
test_loss_hist (list): The test loss for each epoch.
relative_error_hist (list): The average (over test dataset) relative error for each epoch.
relative_median_error_hist (list): The median (over test dataset) relative error for each epoch.
"""
device = configs['device']
### Load Model
try:
### Regular Grids
if configs['model'].lower() == 'fno':
Model = importlib.import_module(configs['experiment']+'.architectures').FNO
elif configs['model'].lower() == 'ffno':
Model = importlib.import_module(configs['experiment']+'.architectures').FFNO
elif configs['model'].lower() == 'ufno':
Model = importlib.import_module(configs['experiment']+'.architectures').UFNO
### Irregular Meshes
elif configs['model'].lower() == 'geo_fno':
Model = importlib.import_module(configs['experiment']+'.architectures').Geo_FNO
elif configs['model'].lower() == 'geo_ffno':
Model = importlib.import_module(configs['experiment']+'.architectures').Geo_FFNO
elif configs['model'].lower() == 'geo_ufno':
Model = importlib.import_module(configs['experiment']+'.architectures').Geo_UFNO
### Structured Matrix Method
elif configs['model'].lower() == 'fno_dse':
Model = importlib.import_module(configs['experiment']+'.architectures').FNO_dse
elif configs['model'].lower() == 'ffno_dse':
Model = importlib.import_module(configs['experiment']+'.architectures').FFNO_dse
elif configs['model'].lower() == 'ufno_dse':
Model = importlib.import_module(configs['experiment']+'.architectures').UFNO_dse
else:
raise ValueError('Model not recognized.')
# Replace default configs with experiment specific configs if not specified.
for key in Model.configs:
configs.setdefault(key, Model.configs[key])
except Exception as error:
print(error)
raise ValueError('Model not compatible with experiment.')
### Load Dataset
try:
getDataloaders = importlib.import_module(configs['experiment']+'.dataset').getDataloaders
except Exception as error:
print(error)
raise ValueError('Experiment not recognized.')
# NOTE: Some configs will have been overwritten by the defaults in Model.configs, hence we print here!
print(configs)
##############
# data loaders
##############
start_time = time.time()
print(f'Loading and processing data.')
train_loader, test_loader = getDataloaders(configs)
### TODO TEMPORARY, unlikely good idea to put point dataset into dictionary
configs['point_data'] = None if not hasattr(train_loader, "point_data") else train_loader.point_data
configs['denormalizer'] = None if not hasattr(train_loader, "denormalizer") else train_loader.denormalizer
print(f'Processing finished in {time.time()-start_time:.2f}s.')
#######
# model
#######
# initialize model
if configs['load_model']:
model = torch.load(configs['model_path']).to(device)
else:
model = Model(configs).to(device)
# TODO: Trainable parameters will now include IPHI parameters as well.
print(f"Number of trainable parameters: {count_params(model)}")
optimizer = Adam(model.parameters(), lr=configs['learning_rate'], weight_decay=configs['weight_decay'])
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=configs['scheduler_step'], gamma=configs['scheduler_gamma'])
# Define the loss function
if configs['loss_fn'] == 'L1':
loss_fn = torch.nn.L1Loss()
elif configs['loss_fn'] == 'L2':
loss_fn = torch.nn.MSELoss('mean')
else:
raise ValueError('Loss function not recognized.')
##########
# training
##########
training_times = []
train_loss_hist = []
test_loss_hist = []
relative_error_hist = []
relative_median_error_hist = []
for epoch in range(configs['epochs']):
start_train = time.time()
train_loss = 0
model.train()
for inputs, targets in train_loader:
batch_size = targets.shape[0]
if isinstance(inputs, list):
inputs = [x.to(device) for x in inputs]
else:
inputs = inputs.to(device)
targets = targets.to(device)
predictions = model(inputs)
# Compute loss separate for complex numbers
if predictions.is_complex():
loss = (
loss_fn(predictions.real.view(batch_size, -1), targets.real.view(batch_size, -1))
+ loss_fn(predictions.imag.view(batch_size, -1), targets.imag.view(batch_size, -1))
)
else:
loss = loss_fn(predictions.view(batch_size, -1), targets.view(batch_size, -1))
# For diffeomorphisms, additional loss term:
if hasattr(model, "model_iphi") and ('iphi_loss_reg' in configs) and (configs['iphi_loss_reg'] > 0):
samples_x = torch.rand(batch_size, targets.shape[1], 2).cuda() * 3 - 1 # TODO Hardcoded values, check if applies to all
samples_xi = model.model_iphi(samples_x)
loss += configs['iphi_loss_reg'] * loss_fn(samples_xi, samples_x)
train_loss += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss /= configs['num_train']
stop_train = time.time()
training_times.append(stop_train-start_train)
scheduler.step()
### Only test every test_epochs epochs.
if epoch % configs['test_epochs'] > 0:
continue
##########
# testing
##########
with torch.no_grad():
test_loss = 0
relative_error = 0
median_error = torch.zeros(configs['num_test'])
model.eval()
for idx, (inputs, targets) in enumerate(test_loader):
batch_size = targets.shape[0]
if isinstance(inputs, list):
inputs = [x.to(device) for x in inputs]
else:
inputs = inputs.to(device)
targets = targets.to(device)
predictions = model(inputs)
# For different growth factors and sparse data in Humidity, we only consider a small domain.
if configs['experiment'] == 'Humidity' and configs['growth'] != 1.0:
l, r, b, t = test_loader.crop
predictions = predictions[:, b:t, l:r,:]
targets = targets[:, b:t, l:r,:]
# Compute loss separate for complex numbers
if predictions.is_complex():
loss = (
loss_fn(predictions.real.reshape(batch_size, -1), targets.real.reshape(batch_size, -1))
+ loss_fn(predictions.imag.reshape(batch_size, -1), targets.imag.reshape(batch_size, -1))
)
# For metrics we only consider REAL parts
predictions = predictions.real
targets = targets.real
else:
loss = loss_fn(predictions.reshape(batch_size, -1), targets.reshape(batch_size, -1))
test_loss += loss.item()
relative_error += percentage_difference(targets.reshape(batch_size, -1), predictions.reshape(batch_size, -1))
median_error[idx] = percentage_difference(targets.reshape(batch_size, -1), predictions.reshape(batch_size, -1))
test_loss /= configs['num_test']
relative_error /= configs['num_test']
relative_error_hist.append(relative_error)
relative_median_error_hist.append(torch.median(median_error))
print(f"Epoch [{epoch:03d}/{configs['epochs']-1}] in {stop_train-start_train:.2f}s with LR {scheduler.get_last_lr()[0]:.2e}: \tTrain loss {train_loss:.4e} \t- Test loss {test_loss:.4e} \t- Test Error {relative_error:.2f}% \t- Median Test Error {torch.median(median_error).item():.4f}%")
train_loss_hist.append(train_loss)
test_loss_hist.append(test_loss)
if relative_error == min(relative_error_hist):
best_model = copy.deepcopy(model)
lowest_error = min(relative_error_hist)
# Plot losses
plt.figure(figsize=(8,6))
plt.plot([np.log(x) for x in train_loss_hist], label='Train loss')
plt.plot([np.log(x) for x in test_loss_hist], label='Test loss')
plt.ticklabel_format(axis="y", style="sci", scilimits=(0,0))
plt.xlabel('Epoch')
plt.ylabel('Log Loss')
plt.legend()
plt.grid()
plt.savefig(f"_Models/loss_history.png")
plt.close()
# Save Model
if configs['save_model']:
# Remove unnecessary data from models
best_model.point_data = None
best_model.denormalizer = None
print(f"Experiment: {configs['experiment']} \t- Model: {configs['model']} \t- Error: {lowest_error:.4f}%")
torch.save(best_model, f"_Models/{configs['experiment']}_{configs['model']}.pt")
return training_times, train_loss_hist, test_loss_hist, relative_error_hist, relative_median_error_hist
if __name__=='__main__':
### Set random seed for reproducibility
seed = 333
torch.manual_seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
## Run training for single sample
train(configs)
# Run training for multiple models: TODO figure out bug with copying parameters
# models = ['geo_fno', 'geo_ffno', 'geo_ufno', 'fno_dse', 'ffno_dse', 'ufno_dse']
# models = ['ffno_dse']
# for model in models:
# new_configs = configs
# new_configs['model'] = model
# train(new_configs)