Skip to content

Commit 262ad2d

Browse files
committed
Training schnet with lighting
1 parent c6b0a4c commit 262ad2d

15 files changed

+825
-2
lines changed

LICENSE

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
MIT License
22

3-
Copyright (c) 2020 TorchMD
3+
Copyright (c) 2020 compscience.org
44

55
Permission is hereby granted, free of charge, to any person obtaining a copy
66
of this software and associated documentation files (the "Software"), to deal

MANIFEST.in

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
include requirements.txt README.md LICENSE

Makefile

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
ifndef version
2+
$(error version variable is not set. Call with `make release version=XXX`)
3+
endif
4+
5+
release:
6+
git checkout master
7+
git fetch
8+
git pull
9+
git tag -a $(version) -m "$(version) release"
10+
git push --tags origin $(version)

README.md

+20-1
Original file line numberDiff line numberDiff line change
@@ -1 +1,20 @@
1-
# torchmd-net
1+
# torchmd-cg
2+
3+
## Installation
4+
5+
```
6+
conda create -n torchmd_cg
7+
conda activate torchmd_cg
8+
conda install pytorch torchvision cudatoolkit=10.1 -c pytorch
9+
conda install pyyaml ipython scikit-learn tqdm
10+
pip install pytorch-lightning
11+
pip install torchmd-cg
12+
```
13+
14+
## Usage
15+
16+
```
17+
pip install moleculekit
18+
conda install seaborn pandas jupyter
19+
```
20+
Check the jupyter notebook in the `tutorial` folder

requirements.txt

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
torchmd
2+
schnetpack

scripts/light_train.py

+279
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
import os
2+
import numpy as np
3+
import time
4+
5+
import torch
6+
from torch.nn import MSELoss, L1Loss
7+
from torch.utils.data import DataLoader, WeightedRandomSampler
8+
from torch.optim.lr_scheduler import ReduceLROnPlateau
9+
10+
import schnetpack as spk
11+
import schnetpack.atomistic as atm
12+
import schnetpack.representation as rep
13+
from schnetpack.nn.cutoff import CosineCutoff
14+
from schnetpack.data.loader import _collate_aseatoms
15+
from schnetpack.environment import SimpleEnvironmentProvider
16+
17+
from torchmdnet.nnp.schnet_dataset import SchNetDataset
18+
from torchmdnet.nnp.utils import LoadFromFile, LogWriter
19+
from torchmdnet.nnp.utils import save_argparse
20+
from torchmdnet.nnp.utils import train_val_test_split, set_batch_size
21+
from torchmdnet.nnp.npdataset import NpysDataset, NpysDataset2
22+
from torchmdnet.nnp.model import make_schnet_model
23+
24+
import argparse
25+
26+
import pytorch_lightning as pl
27+
from pytorch_lightning.callbacks import LearningRateMonitor
28+
29+
30+
def get_args():
31+
# fmt: off
32+
parser = argparse.ArgumentParser(description='Training')
33+
parser.add_argument('--conf','-c', type=open, action=LoadFromFile)#keep first
34+
parser.add_argument('--lr', default=1e-4, type=float, help='learning rate')
35+
parser.add_argument('--batch-size', default=32,type=int, help='batch size')
36+
parser.add_argument('--num-epochs', default=300,type=int, help='number of epochs')
37+
parser.add_argument('--order', default=None, help='Npy file with order on which to split idx_train,idx_val,idx_test')
38+
parser.add_argument('--coords', default='coords.npy', help='Data source')
39+
parser.add_argument('--forces', default='forces.npy', help='Data source')
40+
parser.add_argument('--embeddings', default='embeddings.npy', help='Data source')
41+
parser.add_argument('--weights', default=None, help='Data source')
42+
parser.add_argument('--splits', default=None, help='Npz with splits idx_train,idx_val,idx_test')
43+
parser.add_argument('--gpus', default=0, help='Number of GPUs. Use CUDA_VISIBLE_DEVICES=1,2 to decide gpu')
44+
parser.add_argument('--num-nodes', type=int, default=1, help='Number of nodes')
45+
parser.add_argument('--log-dir', '-l', default='/tmp/net', help='log file')
46+
parser.add_argument('--label', default='forces', help='Label')
47+
parser.add_argument('--derivative', default='forces', help='Label')
48+
parser.add_argument('--eval-interval',type=int,default=2,help='eval interval, one eval per n updates (default: 2)')
49+
parser.add_argument('--save-interval',type=int,default=10,help='save interval, one save per n updates (default: 10)')
50+
parser.add_argument('--seed',type=int,default=1,help='random seed (default: 1)')
51+
parser.add_argument('--load-model',default=None,help='Restart training using a model checkpoint')
52+
parser.add_argument('--progress',action='store_true', default=False,help='Progress bar during batching')
53+
parser.add_argument('--val-ratio',type=float, default=0.05,help='Percentual of validation set')
54+
parser.add_argument('--test-ratio',type=float, default=0,help='Percentual of test set')
55+
parser.add_argument('--num-workers',type=int, default=0,help='Number of workers for data prefetch')
56+
parser.add_argument('--num-filters',type=int, default=128,help='Number of filter in model')
57+
parser.add_argument('--num-gaussians',type=int, default=50,help='Number of Gaussians in model')
58+
parser.add_argument('--num-interactions',type=int, default=2,help='Number of interactions in model')
59+
parser.add_argument('--max-z',type=int, default=100,help='Max atomic number in model')
60+
parser.add_argument('--cutoff',type=float, default=9,help='Cutoff in model')
61+
parser.add_argument('--lr-patience',type=int,default=10,help='Patience for lr-schedule. Patience per eval-interval of validation')
62+
parser.add_argument('--lr-min',type=float, default=1e-6,help='Minimum learning rate before early stop')
63+
parser.add_argument('--lr-factor',type=float, default=0.8,help='Minimum learning rate before early stop')
64+
parser.add_argument('--distributed-backend', default='ddp', help='Distributed backend: dp, ddp, ddp2')
65+
# fmt: on
66+
args = parser.parse_args()
67+
68+
if args.val_ratio == 0:
69+
args.eval_interval = 0
70+
71+
save_argparse(args, os.path.join(args.log_dir, "input.yaml"), exclude=["conf"])
72+
73+
return args
74+
75+
76+
def make_splits(
77+
dataset_len, val_ratio, test_ratio, seed, filename=None, splits=None, order=None
78+
):
79+
if splits is not None:
80+
splits = np.load(splits)
81+
idx_train = splits["idx_train"]
82+
idx_val = splits["idx_val"]
83+
idx_test = splits["idx_test"]
84+
else:
85+
idx_train, idx_val, idx_test = train_val_test_split(
86+
dataset_len, val_ratio, test_ratio, seed, order
87+
)
88+
89+
if filename is not None:
90+
np.savez(filename, idx_train=idx_train, idx_val=idx_val, idx_test=idx_test)
91+
92+
return idx_train, idx_val, idx_test
93+
94+
95+
class LNNP(pl.LightningModule):
96+
def __init__(self, hparams):
97+
super(LNNP, self).__init__()
98+
self.hparams = hparams
99+
if self.hparams.load_model:
100+
raise NotImplementedError # TODO
101+
else:
102+
self.model = make_schnet_model(self.hparams)
103+
# save linear fit model with random parameters
104+
self.loss_fn = MSELoss()
105+
self.test_fn = L1Loss()
106+
107+
def prepare_data(self):
108+
print("Preparing data...", flush=True)
109+
self.dataset = NpysDataset2(
110+
self.hparams.coords, self.hparams.forces, self.hparams.embeddings
111+
)
112+
self.dataset = SchNetDataset(
113+
self.dataset,
114+
environment_provider=SimpleEnvironmentProvider(),
115+
label=["forces"],
116+
)
117+
self.idx_train, self.idx_val, self.idx_test = make_splits(
118+
len(self.dataset),
119+
self.hparams.val_ratio,
120+
self.hparams.test_ratio,
121+
self.hparams.seed,
122+
os.path.join(self.hparams.log_dir, f"splits.npz"),
123+
self.hparams.splits,
124+
)
125+
self.train_dataset = torch.utils.data.Subset(self.dataset, self.idx_train)
126+
self.val_dataset = torch.utils.data.Subset(self.dataset, self.idx_val)
127+
self.test_dataset = torch.utils.data.Subset(self.dataset, self.idx_test)
128+
print(
129+
"train {}, val {}, test {}".format(
130+
len(self.train_dataset), len(self.val_dataset), len(self.test_dataset)
131+
)
132+
)
133+
134+
if self.hparams.weights is not None:
135+
self.weights = torch.from_numpy(np.load(self.hparams.weights))
136+
else:
137+
self.weights = torch.ones(len(self.dataset))
138+
139+
def forward(self, x):
140+
return self.model(x)
141+
142+
def train_dataloader(self):
143+
train_loader = DataLoader(
144+
self.train_dataset,
145+
sampler=WeightedRandomSampler(
146+
self.weights[self.idx_train], len(self.train_dataset)
147+
),
148+
batch_size=set_batch_size(self.hparams.batch_size, len(self.train_dataset)),
149+
shuffle=False,
150+
collate_fn=_collate_aseatoms,
151+
num_workers=self.hparams.num_workers,
152+
pin_memory=True,
153+
)
154+
return train_loader
155+
156+
def training_step(self, batch, batch_idx):
157+
prediction = self(batch)
158+
loss = self.loss_fn(prediction[self.hparams.label], batch[self.hparams.label])
159+
self.log('train_loss', loss, on_step=True, on_epoch=False, prog_bar=True, logger=True)
160+
return loss
161+
162+
def val_dataloader(self):
163+
val_loader = None
164+
if len(self.val_dataset) > 0:
165+
# val_loader = DataLoader(self.val_dataset, sampler=WeightedRandomSampler(self.weights[self.idx_val], len(self.val_dataset)),
166+
val_loader = DataLoader(
167+
self.val_dataset,
168+
batch_size=set_batch_size(
169+
self.hparams.batch_size, len(self.val_dataset)
170+
),
171+
collate_fn=_collate_aseatoms,
172+
num_workers=self.hparams.num_workers,
173+
pin_memory=True,
174+
)
175+
return val_loader
176+
177+
def validation_step(self, batch, batch_idx):
178+
torch.set_grad_enabled(True)
179+
prediction = self(batch)
180+
torch.set_grad_enabled(False)
181+
loss = self.loss_fn(prediction[self.hparams.label], batch[self.hparams.label])
182+
return loss
183+
184+
def validation_epoch_end(self, validation_step_outputs):
185+
avg_loss = torch.stack(validation_step_outputs).mean()
186+
self.log("val_loss", avg_loss)
187+
188+
def test_dataloader(self):
189+
test_loader = None
190+
if len(self.test_dataset) > 0:
191+
# test_loader = DataLoader(self.test_dataset, sampler=WeightedRandomSampler(self.weights[self.idx_test], len(self.test_dataset)),
192+
test_loader = DataLoader(
193+
self.test_dataset,
194+
batch_size=set_batch_size(
195+
self.hparams.batch_size, len(self.test_dataset)
196+
),
197+
collate_fn=_collate_aseatoms,
198+
num_workers=self.hparams.num_workers,
199+
pin_memory=True,
200+
)
201+
return test_loader
202+
203+
def test_step(self, batch, batch_idx):
204+
torch.set_grad_enabled(True)
205+
prediction = self(batch)
206+
torch.set_grad_enabled(False)
207+
loss = self.test_fn(prediction[self.hparams.label], batch[self.hparams.label])
208+
return loss
209+
210+
def test_epoch_end(self, test_step_outputs):
211+
avg_loss = torch.stack(test_step_outputs).mean()
212+
self.log("test_loss", avg_loss)
213+
214+
def configure_optimizers(self):
215+
# optimizer = torch.optim.SGD(self.model.parameters(), lr=self.hparams.lr, momentum=0.9)
216+
optimizer = torch.optim.Adam(self.model.parameters(), lr=self.hparams.lr)
217+
scheduler = ReduceLROnPlateau(
218+
optimizer,
219+
"min",
220+
factor=self.hparams.lr_factor,
221+
patience=self.hparams.lr_patience,
222+
min_lr=self.hparams.lr_min
223+
)
224+
lr_scheduler = {'scheduler':scheduler,
225+
'monitor':'val_loss',
226+
'interval': 'epoch',
227+
'frequency': 1,
228+
}
229+
return [optimizer], [lr_scheduler]
230+
231+
def main():
232+
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
233+
234+
args = get_args()
235+
torch.manual_seed(args.seed)
236+
torch.cuda.manual_seed_all(args.seed)
237+
238+
model = LNNP(args)
239+
checkpoint_callback = ModelCheckpoint(
240+
filepath=args.log_dir,
241+
monitor="val_loss",
242+
save_top_k=8,
243+
period=args.eval_interval,
244+
)
245+
lr_monitor = LearningRateMonitor(logging_interval='epoch')
246+
tb_logger = pl.loggers.TensorBoardLogger(args.log_dir)
247+
trainer = pl.Trainer(
248+
gpus=args.gpus,
249+
max_epochs=args.num_epochs,
250+
distributed_backend=args.distributed_backend,
251+
num_nodes=args.num_nodes,
252+
default_root_dir=args.log_dir,
253+
auto_lr_find=False,
254+
resume_from_checkpoint=args.load_model,
255+
checkpoint_callback=checkpoint_callback,
256+
callbacks=[lr_monitor],
257+
logger=tb_logger,
258+
reload_dataloaders_every_epoch=False
259+
)
260+
261+
trainer.fit(model)
262+
263+
# run test set after completing the fit
264+
trainer.test()
265+
266+
# logs = LogWriter(args.log_dir,keys=('epoch','train_loss','val_loss','test_mae','lr'))
267+
268+
269+
# logs.write_row({'epoch':epoch,'train_loss':train_loss,'val_loss':val_loss,
270+
# 'test_mae':test_mae, 'lr':optimizer.param_groups[0]['lr']})
271+
# progress.set_postfix({'Loss': train_loss, 'lr':optimizer.param_groups[0]['lr']})
272+
273+
# if optimizer.param_groups[0]['lr'] < args.lr_min:
274+
# print("Early stop reached")
275+
# break
276+
277+
278+
if __name__ == "__main__":
279+
main()

setup.py

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import setuptools
2+
import subprocess
3+
import os
4+
5+
try:
6+
version = (
7+
subprocess.check_output(["git", "describe", "--abbrev=0", "--tags"])
8+
.strip()
9+
.decode("utf-8")
10+
)
11+
except Exception as e:
12+
print("Could not get version tag. Defaulting to version 0")
13+
version = "0"
14+
15+
with open("requirements.txt") as f:
16+
requirements = f.read().splitlines()
17+
18+
if __name__ == "__main__":
19+
with open("README.md", "r") as fh:
20+
long_description = fh.read()
21+
22+
setuptools.setup(
23+
name="torchmdnet",
24+
version=version,
25+
author="Acellera",
26+
author_email="[email protected]",
27+
description="TorchMD-net. Training Schnet with pytorch lighthing",
28+
long_description=long_description,
29+
long_description_content_type="text/markdown",
30+
url="https://github.com/torchmd/torchmd-net/",
31+
classifiers=[
32+
"Programming Language :: Python :: 3",
33+
"Operating System :: POSIX :: Linux",
34+
"License :: OSI Approved :: MIT License",
35+
],
36+
packages=setuptools.find_packages(include=["torchmdnet*"], exclude=[]),
37+
# package_data={"torchmdnet": ["config.ini", "logging.ini"],},
38+
install_requires=requirements,
39+
)

torchmdnet/__init__.py

Whitespace-only changes.

torchmdnet/nnp/__init__.py

Whitespace-only changes.

torchmdnet/nnp/calculators/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)