You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/source/models.rst
+75
Original file line number
Diff line number
Diff line change
@@ -4,13 +4,87 @@ Neural Network Potentials
4
4
=========================
5
5
6
6
7
+
.. figure:: img/tmdnet_model.*
8
+
:align:center
9
+
10
+
Schematic representation of the TorchMD-Net model.
11
+
12
+
13
+
In TorchMD-Net a model, abstracted by the :py:mod:`torchmdnet.models.model.TorchMD_Net` class, is composed of three main components:
14
+
15
+
1. **A representation model** that takes atomic numbers and positions (and optionally other per-sample or per-atom properties that the particular model might make use of) as input and outputs a series of per-atom features. The representation model is responsible for encoding the local environment of each atom.
16
+
17
+
2. **An output model** that takes the per-atom features and outputs a single per-batch label (i.e. total energy).
18
+
19
+
3. Optionally, :ref:`priors <Priors>` can be used to add additional constraints to the output model. For instance, a prior can be used to add a reference energy to the output of the model, or add a :ref:`Coulomb potential <:py:mod:torchmdnet.priors.Coulomb>` to the output of the model.
20
+
21
+
The resulting model can also be used to compute the negative gradient of the output with respect to the input positions (i.e. forces) via backpropagation with `autograd <https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html>`_. This is done by setting the :code:`derivative` flag to :code:`True` when creating the model.
22
+
23
+
.. hint:: Given the large amount of configuration options available, one typically does not instantiate :py:mod:`torchmdnet.models.model.TorchMD_Net` directly, but uses the :py:mod:`torchmdnet.models.model.create_model` function.
24
+
25
+
.. hint:: It is possible to use the :py:mod:`torchmdnet.models.model.TorchMD_Net` class directly instead of using the :py:mod:`torchmdnet.models.model.create_model` function. This can be useful if, for instance, you want to make use of the default parameters of the representation model and output model.
26
+
27
+
.. code:: python
28
+
29
+
from torchmdnet.models.model import TorchMD_Net
30
+
from torchmdnet.models.tensornet import TensorNet
31
+
from torchmdnet.models.output_modules import Scalar
32
+
33
+
model = TorchMD_Net(
34
+
representation_model=TensorNet(),
35
+
output_model=Scalar(hidden_channels=32),
36
+
)
37
+
38
+
7
39
Training a model
8
40
----------------
9
41
10
42
The typical workflow to obtain a neural network potential in TorchMD-Net starts with :ref:`training <training>` one of the `Available Models`_. During this process you will get a checkpoint file that can be used to load the model for inference.
11
43
44
+
Custom training loops
45
+
~~~~~~~~~~~~~~~~~~~~~
46
+
47
+
If you want to use a custom training loop, you can use the :py:mod:`torchmdnet.models.model.TorchMD_Net` class directly. This class is a wrapper around the representation model, output model and priors that takes care of putting the pieces together.
48
+
In order to do this you need to follow these steps:
49
+
50
+
1. Create a new instance of :py:mod:`torchmdnet.models.model.TorchMD_Net` with the representation model, output model and priors you want to use. We provide the :py:mod:`torchmdnet.models.model.create_model` function to help you with this. Check :ref:`torchmd-train` for a list and description of the parameters.
51
+
52
+
2. Use one of the available :ref:`Datasets` to prepare your data, or create a custom one. You may then use a `pytorch DataLoader <https://pytorch.org/tutorials/beginner/basics/data_tutorial.html#preparing-your-data-for-training-with-dataloaders>`_ to iterate over the data.
53
+
54
+
3. Train the model using a custom training loop. See for instance the `PyTorch tutorial on training loops <https://pytorch.org/tutorials/beginner/introyt/trainingyt.html#optimizer>`_.
12
55
56
+
This is a minimal example of a custom training loop:
y, neg_dy = model(args["z"], args["pos"], args["batch"])
80
+
# A simple loss function that uses the energy and forces to train the model
81
+
loss = criterion(y, args["y"]) + criterion(neg_dy, args["neg_dy"])
82
+
loss.backward()
83
+
optimizer.step()
84
+
85
+
86
+
87
+
14
88
Loading a model for inference
15
89
-----------------------------
16
90
@@ -37,6 +111,7 @@ Once you have trained a model you should have a checkpoint that you can load for
37
111
38
112
.. note:: When periodic boundary conditions are required, modules typically offer the possibility of providing the box vectors at construction and/or as an argument to the forward pass. Check the documentation of the class you are using to see if this is the case.
0 commit comments