|
| 1 | +import os |
| 2 | +import numpy as np |
| 3 | +import time |
| 4 | + |
| 5 | +import torch |
| 6 | +from torch.nn import MSELoss, L1Loss |
| 7 | +from torch.utils.data import DataLoader, WeightedRandomSampler |
| 8 | +from torch.optim.lr_scheduler import ReduceLROnPlateau |
| 9 | + |
| 10 | +import schnetpack as spk |
| 11 | +import schnetpack.atomistic as atm |
| 12 | +import schnetpack.representation as rep |
| 13 | +from schnetpack.nn.cutoff import CosineCutoff |
| 14 | +from schnetpack.data.loader import _collate_aseatoms |
| 15 | +from schnetpack.environment import SimpleEnvironmentProvider |
| 16 | + |
| 17 | +from torchmdnet.nnp.schnet_dataset import SchNetDataset |
| 18 | +from torchmdnet.nnp.utils import LoadFromFile, LogWriter |
| 19 | +from torchmdnet.nnp.utils import save_argparse |
| 20 | +from torchmdnet.nnp.utils import train_val_test_split, set_batch_size |
| 21 | +from torchmdnet.nnp.npdataset import NpysDataset, NpysDataset2 |
| 22 | +from torchmdnet.nnp.model import make_schnet_model |
| 23 | + |
| 24 | +import argparse |
| 25 | + |
| 26 | +import pytorch_lightning as pl |
| 27 | +from pytorch_lightning.callbacks import LearningRateMonitor |
| 28 | + |
| 29 | + |
| 30 | +def get_args(): |
| 31 | + # fmt: off |
| 32 | + parser = argparse.ArgumentParser(description='Training') |
| 33 | + parser.add_argument('--conf','-c', type=open, action=LoadFromFile)#keep first |
| 34 | + parser.add_argument('--lr', default=1e-4, type=float, help='learning rate') |
| 35 | + parser.add_argument('--batch-size', default=32,type=int, help='batch size') |
| 36 | + parser.add_argument('--num-epochs', default=300,type=int, help='number of epochs') |
| 37 | + parser.add_argument('--order', default=None, help='Npy file with order on which to split idx_train,idx_val,idx_test') |
| 38 | + parser.add_argument('--coords', default='coords.npy', help='Data source') |
| 39 | + parser.add_argument('--forces', default='forces.npy', help='Data source') |
| 40 | + parser.add_argument('--embeddings', default='embeddings.npy', help='Data source') |
| 41 | + parser.add_argument('--weights', default=None, help='Data source') |
| 42 | + parser.add_argument('--splits', default=None, help='Npz with splits idx_train,idx_val,idx_test') |
| 43 | + parser.add_argument('--gpus', default=0, help='Number of GPUs. Use CUDA_VISIBLE_DEVICES=1,2 to decide gpu') |
| 44 | + parser.add_argument('--num-nodes', type=int, default=1, help='Number of nodes') |
| 45 | + parser.add_argument('--log-dir', '-l', default='/tmp/net', help='log file') |
| 46 | + parser.add_argument('--label', default='forces', help='Label') |
| 47 | + parser.add_argument('--derivative', default='forces', help='Label') |
| 48 | + parser.add_argument('--eval-interval',type=int,default=2,help='eval interval, one eval per n updates (default: 2)') |
| 49 | + parser.add_argument('--save-interval',type=int,default=10,help='save interval, one save per n updates (default: 10)') |
| 50 | + parser.add_argument('--seed',type=int,default=1,help='random seed (default: 1)') |
| 51 | + parser.add_argument('--load-model',default=None,help='Restart training using a model checkpoint') |
| 52 | + parser.add_argument('--progress',action='store_true', default=False,help='Progress bar during batching') |
| 53 | + parser.add_argument('--val-ratio',type=float, default=0.05,help='Percentual of validation set') |
| 54 | + parser.add_argument('--test-ratio',type=float, default=0,help='Percentual of test set') |
| 55 | + parser.add_argument('--num-workers',type=int, default=0,help='Number of workers for data prefetch') |
| 56 | + parser.add_argument('--num-filters',type=int, default=128,help='Number of filter in model') |
| 57 | + parser.add_argument('--num-gaussians',type=int, default=50,help='Number of Gaussians in model') |
| 58 | + parser.add_argument('--num-interactions',type=int, default=2,help='Number of interactions in model') |
| 59 | + parser.add_argument('--max-z',type=int, default=100,help='Max atomic number in model') |
| 60 | + parser.add_argument('--cutoff',type=float, default=9,help='Cutoff in model') |
| 61 | + parser.add_argument('--lr-patience',type=int,default=10,help='Patience for lr-schedule. Patience per eval-interval of validation') |
| 62 | + parser.add_argument('--lr-min',type=float, default=1e-6,help='Minimum learning rate before early stop') |
| 63 | + parser.add_argument('--lr-factor',type=float, default=0.8,help='Minimum learning rate before early stop') |
| 64 | + parser.add_argument('--distributed-backend', default='ddp', help='Distributed backend: dp, ddp, ddp2') |
| 65 | + # fmt: on |
| 66 | + args = parser.parse_args() |
| 67 | + |
| 68 | + if args.val_ratio == 0: |
| 69 | + args.eval_interval = 0 |
| 70 | + |
| 71 | + save_argparse(args, os.path.join(args.log_dir, "input.yaml"), exclude=["conf"]) |
| 72 | + |
| 73 | + return args |
| 74 | + |
| 75 | + |
| 76 | +def make_splits( |
| 77 | + dataset_len, val_ratio, test_ratio, seed, filename=None, splits=None, order=None |
| 78 | +): |
| 79 | + if splits is not None: |
| 80 | + splits = np.load(splits) |
| 81 | + idx_train = splits["idx_train"] |
| 82 | + idx_val = splits["idx_val"] |
| 83 | + idx_test = splits["idx_test"] |
| 84 | + else: |
| 85 | + idx_train, idx_val, idx_test = train_val_test_split( |
| 86 | + dataset_len, val_ratio, test_ratio, seed, order |
| 87 | + ) |
| 88 | + |
| 89 | + if filename is not None: |
| 90 | + np.savez(filename, idx_train=idx_train, idx_val=idx_val, idx_test=idx_test) |
| 91 | + |
| 92 | + return idx_train, idx_val, idx_test |
| 93 | + |
| 94 | + |
| 95 | +class LNNP(pl.LightningModule): |
| 96 | + def __init__(self, hparams): |
| 97 | + super(LNNP, self).__init__() |
| 98 | + self.hparams = hparams |
| 99 | + if self.hparams.load_model: |
| 100 | + raise NotImplementedError # TODO |
| 101 | + else: |
| 102 | + self.model = make_schnet_model(self.hparams) |
| 103 | + # save linear fit model with random parameters |
| 104 | + self.loss_fn = MSELoss() |
| 105 | + self.test_fn = L1Loss() |
| 106 | + |
| 107 | + def prepare_data(self): |
| 108 | + print("Preparing data...", flush=True) |
| 109 | + self.dataset = NpysDataset2( |
| 110 | + self.hparams.coords, self.hparams.forces, self.hparams.embeddings |
| 111 | + ) |
| 112 | + self.dataset = SchNetDataset( |
| 113 | + self.dataset, |
| 114 | + environment_provider=SimpleEnvironmentProvider(), |
| 115 | + label=["forces"], |
| 116 | + ) |
| 117 | + self.idx_train, self.idx_val, self.idx_test = make_splits( |
| 118 | + len(self.dataset), |
| 119 | + self.hparams.val_ratio, |
| 120 | + self.hparams.test_ratio, |
| 121 | + self.hparams.seed, |
| 122 | + os.path.join(self.hparams.log_dir, f"splits.npz"), |
| 123 | + self.hparams.splits, |
| 124 | + ) |
| 125 | + self.train_dataset = torch.utils.data.Subset(self.dataset, self.idx_train) |
| 126 | + self.val_dataset = torch.utils.data.Subset(self.dataset, self.idx_val) |
| 127 | + self.test_dataset = torch.utils.data.Subset(self.dataset, self.idx_test) |
| 128 | + print( |
| 129 | + "train {}, val {}, test {}".format( |
| 130 | + len(self.train_dataset), len(self.val_dataset), len(self.test_dataset) |
| 131 | + ) |
| 132 | + ) |
| 133 | + |
| 134 | + if self.hparams.weights is not None: |
| 135 | + self.weights = torch.from_numpy(np.load(self.hparams.weights)) |
| 136 | + else: |
| 137 | + self.weights = torch.ones(len(self.dataset)) |
| 138 | + |
| 139 | + def forward(self, x): |
| 140 | + return self.model(x) |
| 141 | + |
| 142 | + def train_dataloader(self): |
| 143 | + train_loader = DataLoader( |
| 144 | + self.train_dataset, |
| 145 | + sampler=WeightedRandomSampler( |
| 146 | + self.weights[self.idx_train], len(self.train_dataset) |
| 147 | + ), |
| 148 | + batch_size=set_batch_size(self.hparams.batch_size, len(self.train_dataset)), |
| 149 | + shuffle=False, |
| 150 | + collate_fn=_collate_aseatoms, |
| 151 | + num_workers=self.hparams.num_workers, |
| 152 | + pin_memory=True, |
| 153 | + ) |
| 154 | + return train_loader |
| 155 | + |
| 156 | + def training_step(self, batch, batch_idx): |
| 157 | + prediction = self(batch) |
| 158 | + loss = self.loss_fn(prediction[self.hparams.label], batch[self.hparams.label]) |
| 159 | + self.log('train_loss', loss, on_step=True, on_epoch=False, prog_bar=True, logger=True) |
| 160 | + return loss |
| 161 | + |
| 162 | + def val_dataloader(self): |
| 163 | + val_loader = None |
| 164 | + if len(self.val_dataset) > 0: |
| 165 | + # val_loader = DataLoader(self.val_dataset, sampler=WeightedRandomSampler(self.weights[self.idx_val], len(self.val_dataset)), |
| 166 | + val_loader = DataLoader( |
| 167 | + self.val_dataset, |
| 168 | + batch_size=set_batch_size( |
| 169 | + self.hparams.batch_size, len(self.val_dataset) |
| 170 | + ), |
| 171 | + collate_fn=_collate_aseatoms, |
| 172 | + num_workers=self.hparams.num_workers, |
| 173 | + pin_memory=True, |
| 174 | + ) |
| 175 | + return val_loader |
| 176 | + |
| 177 | + def validation_step(self, batch, batch_idx): |
| 178 | + torch.set_grad_enabled(True) |
| 179 | + prediction = self(batch) |
| 180 | + torch.set_grad_enabled(False) |
| 181 | + loss = self.loss_fn(prediction[self.hparams.label], batch[self.hparams.label]) |
| 182 | + return loss |
| 183 | + |
| 184 | + def validation_epoch_end(self, validation_step_outputs): |
| 185 | + avg_loss = torch.stack(validation_step_outputs).mean() |
| 186 | + self.log("val_loss", avg_loss) |
| 187 | + |
| 188 | + def test_dataloader(self): |
| 189 | + test_loader = None |
| 190 | + if len(self.test_dataset) > 0: |
| 191 | + # test_loader = DataLoader(self.test_dataset, sampler=WeightedRandomSampler(self.weights[self.idx_test], len(self.test_dataset)), |
| 192 | + test_loader = DataLoader( |
| 193 | + self.test_dataset, |
| 194 | + batch_size=set_batch_size( |
| 195 | + self.hparams.batch_size, len(self.test_dataset) |
| 196 | + ), |
| 197 | + collate_fn=_collate_aseatoms, |
| 198 | + num_workers=self.hparams.num_workers, |
| 199 | + pin_memory=True, |
| 200 | + ) |
| 201 | + return test_loader |
| 202 | + |
| 203 | + def test_step(self, batch, batch_idx): |
| 204 | + torch.set_grad_enabled(True) |
| 205 | + prediction = self(batch) |
| 206 | + torch.set_grad_enabled(False) |
| 207 | + loss = self.test_fn(prediction[self.hparams.label], batch[self.hparams.label]) |
| 208 | + return loss |
| 209 | + |
| 210 | + def test_epoch_end(self, test_step_outputs): |
| 211 | + avg_loss = torch.stack(test_step_outputs).mean() |
| 212 | + self.log("test_loss", avg_loss) |
| 213 | + |
| 214 | + def configure_optimizers(self): |
| 215 | + # optimizer = torch.optim.SGD(self.model.parameters(), lr=self.hparams.lr, momentum=0.9) |
| 216 | + optimizer = torch.optim.Adam(self.model.parameters(), lr=self.hparams.lr) |
| 217 | + scheduler = ReduceLROnPlateau( |
| 218 | + optimizer, |
| 219 | + "min", |
| 220 | + factor=self.hparams.lr_factor, |
| 221 | + patience=self.hparams.lr_patience, |
| 222 | + min_lr=self.hparams.lr_min |
| 223 | + ) |
| 224 | + lr_scheduler = {'scheduler':scheduler, |
| 225 | + 'monitor':'val_loss', |
| 226 | + 'interval': 'epoch', |
| 227 | + 'frequency': 1, |
| 228 | + } |
| 229 | + return [optimizer], [lr_scheduler] |
| 230 | + |
| 231 | +def main(): |
| 232 | + from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint |
| 233 | + |
| 234 | + args = get_args() |
| 235 | + torch.manual_seed(args.seed) |
| 236 | + torch.cuda.manual_seed_all(args.seed) |
| 237 | + |
| 238 | + model = LNNP(args) |
| 239 | + checkpoint_callback = ModelCheckpoint( |
| 240 | + filepath=args.log_dir, |
| 241 | + monitor="val_loss", |
| 242 | + save_top_k=8, |
| 243 | + period=args.eval_interval, |
| 244 | + ) |
| 245 | + lr_monitor = LearningRateMonitor(logging_interval='epoch') |
| 246 | + tb_logger = pl.loggers.TensorBoardLogger(args.log_dir) |
| 247 | + trainer = pl.Trainer( |
| 248 | + gpus=args.gpus, |
| 249 | + max_epochs=args.num_epochs, |
| 250 | + distributed_backend=args.distributed_backend, |
| 251 | + num_nodes=args.num_nodes, |
| 252 | + default_root_dir=args.log_dir, |
| 253 | + auto_lr_find=False, |
| 254 | + resume_from_checkpoint=args.load_model, |
| 255 | + checkpoint_callback=checkpoint_callback, |
| 256 | + callbacks=[lr_monitor], |
| 257 | + logger=tb_logger, |
| 258 | + reload_dataloaders_every_epoch=False |
| 259 | + ) |
| 260 | + |
| 261 | + trainer.fit(model) |
| 262 | + |
| 263 | + # run test set after completing the fit |
| 264 | + trainer.test() |
| 265 | + |
| 266 | + # logs = LogWriter(args.log_dir,keys=('epoch','train_loss','val_loss','test_mae','lr')) |
| 267 | + |
| 268 | + |
| 269 | +# logs.write_row({'epoch':epoch,'train_loss':train_loss,'val_loss':val_loss, |
| 270 | +# 'test_mae':test_mae, 'lr':optimizer.param_groups[0]['lr']}) |
| 271 | +# progress.set_postfix({'Loss': train_loss, 'lr':optimizer.param_groups[0]['lr']}) |
| 272 | + |
| 273 | +# if optimizer.param_groups[0]['lr'] < args.lr_min: |
| 274 | +# print("Early stop reached") |
| 275 | +# break |
| 276 | + |
| 277 | + |
| 278 | +if __name__ == "__main__": |
| 279 | + main() |
0 commit comments