Skip to content

Commit 2429c94

Browse files
authored
Add custom training loop to the docs (#298)
Add a more thourough description of the TorchMD_Net model
1 parent 8ca7f60 commit 2429c94

File tree

3 files changed

+79
-0
lines changed

3 files changed

+79
-0
lines changed

docs/source/img/tmdnet_model.pdf

69.4 KB
Binary file not shown.

docs/source/img/tmdnet_model.svg

+4
Loading

docs/source/models.rst

+75
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,87 @@ Neural Network Potentials
44
=========================
55

66

7+
.. figure:: img/tmdnet_model.*
8+
:align: center
9+
10+
Schematic representation of the TorchMD-Net model.
11+
12+
13+
In TorchMD-Net a model, abstracted by the :py:mod:`torchmdnet.models.model.TorchMD_Net` class, is composed of three main components:
14+
15+
1. **A representation model** that takes atomic numbers and positions (and optionally other per-sample or per-atom properties that the particular model might make use of) as input and outputs a series of per-atom features. The representation model is responsible for encoding the local environment of each atom.
16+
17+
2. **An output model** that takes the per-atom features and outputs a single per-batch label (i.e. total energy).
18+
19+
3. Optionally, :ref:`priors <Priors>` can be used to add additional constraints to the output model. For instance, a prior can be used to add a reference energy to the output of the model, or add a :ref:`Coulomb potential <:py:mod:torchmdnet.priors.Coulomb>` to the output of the model.
20+
21+
The resulting model can also be used to compute the negative gradient of the output with respect to the input positions (i.e. forces) via backpropagation with `autograd <https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html>`_. This is done by setting the :code:`derivative` flag to :code:`True` when creating the model.
22+
23+
.. hint:: Given the large amount of configuration options available, one typically does not instantiate :py:mod:`torchmdnet.models.model.TorchMD_Net` directly, but uses the :py:mod:`torchmdnet.models.model.create_model` function.
24+
25+
.. hint:: It is possible to use the :py:mod:`torchmdnet.models.model.TorchMD_Net` class directly instead of using the :py:mod:`torchmdnet.models.model.create_model` function. This can be useful if, for instance, you want to make use of the default parameters of the representation model and output model.
26+
27+
.. code:: python
28+
29+
from torchmdnet.models.model import TorchMD_Net
30+
from torchmdnet.models.tensornet import TensorNet
31+
from torchmdnet.models.output_modules import Scalar
32+
33+
model = TorchMD_Net(
34+
representation_model=TensorNet(),
35+
output_model=Scalar(hidden_channels=32),
36+
)
37+
38+
739
Training a model
840
----------------
941

1042
The typical workflow to obtain a neural network potential in TorchMD-Net starts with :ref:`training <training>` one of the `Available Models`_. During this process you will get a checkpoint file that can be used to load the model for inference.
1143

44+
Custom training loops
45+
~~~~~~~~~~~~~~~~~~~~~
46+
47+
If you want to use a custom training loop, you can use the :py:mod:`torchmdnet.models.model.TorchMD_Net` class directly. This class is a wrapper around the representation model, output model and priors that takes care of putting the pieces together.
48+
In order to do this you need to follow these steps:
49+
50+
1. Create a new instance of :py:mod:`torchmdnet.models.model.TorchMD_Net` with the representation model, output model and priors you want to use. We provide the :py:mod:`torchmdnet.models.model.create_model` function to help you with this. Check :ref:`torchmd-train` for a list and description of the parameters.
51+
52+
2. Use one of the available :ref:`Datasets` to prepare your data, or create a custom one. You may then use a `pytorch DataLoader <https://pytorch.org/tutorials/beginner/basics/data_tutorial.html#preparing-your-data-for-training-with-dataloaders>`_ to iterate over the data.
53+
54+
3. Train the model using a custom training loop. See for instance the `PyTorch tutorial on training loops <https://pytorch.org/tutorials/beginner/introyt/trainingyt.html#optimizer>`_.
1255

56+
This is a minimal example of a custom training loop:
1357

58+
.. code:: python
59+
60+
import torch
61+
from torchmdnet.models.model import create_model
62+
from torchmdnet.datasets import MD17
63+
from torch_geometric.loader import DataLoader
64+
import yaml
65+
66+
args = yaml.load(open("TensorNet-rMD17.yaml"), Loader=yaml.FullLoader)
67+
model = create_model(args, prior_model=None)
68+
dataset = MD17(root="~/data", molecules="revised_aspirin")
69+
if torch.cuda.is_available():
70+
model = model.to("cuda")
71+
dataset = dataset.to("cuda")
72+
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
73+
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
74+
criterion = torch.nn.MSELoss()
75+
for epoch in range(10):
76+
for batch in dataloader:
77+
args = batch.to_dict()
78+
optimizer.zero_grad()
79+
y, neg_dy = model(args["z"], args["pos"], args["batch"])
80+
# A simple loss function that uses the energy and forces to train the model
81+
loss = criterion(y, args["y"]) + criterion(neg_dy, args["neg_dy"])
82+
loss.backward()
83+
optimizer.step()
84+
85+
86+
87+
1488
Loading a model for inference
1589
-----------------------------
1690

@@ -37,6 +111,7 @@ Once you have trained a model you should have a checkpoint that you can load for
37111

38112
.. note:: When periodic boundary conditions are required, modules typically offer the possibility of providing the box vectors at construction and/or as an argument to the forward pass. Check the documentation of the class you are using to see if this is the case.
39113

114+
40115
.. _delta-learning:
41116
Training on relative energies
42117
-----------------------------

0 commit comments

Comments
 (0)