-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathfit2_qm9.py
44 lines (40 loc) · 982 Bytes
/
fit2_qm9.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from train.fitter import fit_qm9
# normal
TAG = '0920'
fit_qm9(use_cuda=True,
limit=-1,
use_tqdm=False,
force_save=False,
model_save_path='net/{}.pt'.format(TAG),
tag=TAG,
)
# no hamiltonian kernel
TAG = '0920-noham'
fit_qm9(use_cuda=True,
limit=-1,
use_tqdm=False,
force_save=False,
model_save_path='net/{}.pt'.format(TAG),
special_config={'HGN_LAYERS': 0},
tag=TAG,
)
# no LSTM when producing p0 & q0
TAG = '0920-nolstm'
fit_qm9(use_cuda=True,
limit=-1,
use_tqdm=False,
force_save=False,
model_save_path='net/{}.pt'.format(TAG),
special_config={'LSTM': False},
tag=TAG,
)
# no ADJ3 loss
TAG = '0920-noadj'
fit_qm9(use_cuda=True,
limit=-1,
use_tqdm=False,
force_save=False,
model_save_path='net/{}.pt'.format(TAG),
special_config={'GAMMA_A': 0.000},
tag=TAG,
)