Skip to content

Commit 166b7db

Browse files
authored
Use switching function for Coulomb prior (#287)
* Use switching function for Coulomb prior * Updated documentation
1 parent 5a4fca7 commit 166b7db

File tree

4 files changed

+30
-17
lines changed

4 files changed

+30
-17
lines changed

docs/source/priors.rst

+6-5
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,12 @@ It is possible to configure more than one prior in this way:
3131

3232
.. code:: yaml
3333
34-
prior_model:
35-
Atomref: {} # No additional arguments
36-
Coulomb:
37-
alpha: 1
38-
max_num_neighbors: 10
34+
prior_model:
35+
Atomref: {} # No additional arguments
36+
Coulomb:
37+
lower_switch_distance: 4
38+
upper_switch_distance: 8
39+
max_num_neighbors: 128
3940
4041
4142

tests/test_priors.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -88,19 +88,25 @@ def test_coulomb(dtype):
8888
types = torch.tensor([0, 1, 2, 1], dtype=torch.long) # Atom types
8989
distance_scale = 1e-9 # Convert nm to meters
9090
energy_scale = 1000.0/6.02214076e23 # Convert kJ/mol to Joules
91-
alpha = 1.8
91+
lower_switch_distance = 0.9
92+
upper_switch_distance = 1.3
9293

9394
# Use the Coulomb class to compute the energy.
9495

95-
coulomb = Coulomb(alpha, 5, distance_scale=distance_scale, energy_scale=energy_scale)
96+
coulomb = Coulomb(lower_switch_distance, upper_switch_distance, 5, distance_scale=distance_scale, energy_scale=energy_scale)
9697
energy = coulomb.post_reduce(torch.zeros((1,)), types, pos, torch.zeros_like(types), extra_args={'partial_charges':charge})[0]
9798

9899
# Compare to the expected value.
99100

100101
def compute_interaction(pos1, pos2, z1, z2):
101102
delta = pos1-pos2
102103
r = torch.sqrt(torch.dot(delta, delta))
103-
return torch.erf(alpha*r)*138.935*z1*z2/r
104+
if r < lower_switch_distance:
105+
return 0
106+
energy = 138.935*z1*z2/r
107+
if r < upper_switch_distance:
108+
energy *= 0.5-0.5*torch.cos(torch.pi*(r-lower_switch_distance)/(upper_switch_distance-lower_switch_distance))
109+
return energy
104110

105111
expected = 0
106112
for i in range(len(pos)):

torchmdnet/priors/coulomb.py

+14-8
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,15 @@
88
from typing import Optional, Dict
99

1010
class Coulomb(BasePrior):
11-
"""This class implements a Coulomb potential, scaled by :math:`\\textrm{erf}(\\textrm{alpha}*r)` to reduce its
11+
"""This class implements a Coulomb potential, scaled by a cosine switching function to reduce its
1212
effect at short distances.
1313
1414
Parameters
1515
----------
16-
alpha : float
17-
Scaling factor for the error function.
16+
lower_switch_distance : float
17+
distance below which the interaction strength is zero.
18+
upper_switch_distance : float
19+
distance above which the interaction has full strength
1820
max_num_neighbors : int
1921
Maximum number of neighbors per atom allowed.
2022
distance_scale : float, optional
@@ -31,20 +33,22 @@ class Coulomb(BasePrior):
3133
The Dataset used with this class must include a `partial_charges` field for each sample, and provide
3234
`distance_scale` and `energy_scale` attributes if they are not explicitly passed as arguments.
3335
"""
34-
def __init__(self, alpha, max_num_neighbors, distance_scale=None, energy_scale=None, box_vecs=None, dataset=None):
36+
def __init__(self, lower_switch_distance, upper_switch_distance, max_num_neighbors, distance_scale=None, energy_scale=None, box_vecs=None, dataset=None):
3537
super(Coulomb, self).__init__()
3638
if distance_scale is None:
3739
distance_scale = dataset.distance_scale
3840
if energy_scale is None:
3941
energy_scale = dataset.energy_scale
4042
self.distance = OptimizedDistance(0, torch.inf, max_num_pairs=-max_num_neighbors)
41-
self.alpha = alpha
43+
self.lower_switch_distance = lower_switch_distance
44+
self.upper_switch_distance = upper_switch_distance
4245
self.max_num_neighbors = max_num_neighbors
4346
self.distance_scale = float(distance_scale)
4447
self.energy_scale = float(energy_scale)
4548
self.initial_box = box_vecs
4649
def get_init_args(self):
47-
return {'alpha': self.alpha,
50+
return {'lower_switch_distance': self.lower_switch_distance,
51+
'upper_switch_distance': self.upper_switch_distance,
4852
'max_num_neighbors': self.max_num_neighbors,
4953
'distance_scale': self.distance_scale,
5054
'energy_scale': self.energy_scale,
@@ -78,14 +82,16 @@ def post_reduce(self, y, z, pos, batch, box: Optional[torch.Tensor] = None, extr
7882
"""
7983
# Convert to nm and calculate distance.
8084
x = 1e9*self.distance_scale*pos
81-
alpha = self.alpha/(1e9*self.distance_scale)
8285
box = box if box is not None else self.initial_box
8386
edge_index, distance, _ = self.distance(x, batch, box=box)
8487

8588
# Compute the energy, converting to the dataset's units. Multiply by 0.5 because every atom pair
8689
# appears twice.
8790
q = extra_args['partial_charges'][edge_index]
88-
energy = torch.erf(alpha*distance)*q[0]*q[1]/distance
91+
lower = torch.tensor(self.lower_switch_distance)
92+
upper = torch.tensor(self.upper_switch_distance)
93+
phase = (torch.max(lower, torch.min(upper, distance))-lower)/(upper-lower)
94+
energy = (0.5-0.5*torch.cos(torch.pi*phase))*q[0]*q[1]/distance
8995
energy = 0.5*(2.30707e-28/self.energy_scale/self.distance_scale)*scatter(energy, batch[edge_index[0]], dim=0, reduce="sum")
9096
energy = energy.reshape(y.shape)
9197
return y + energy

torchmdnet/scripts/train.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def get_argparse():
7474
# model architecture
7575
parser.add_argument('--model', type=str, default='graph-network', choices=models.__all_models__, help='Which model to train')
7676
parser.add_argument('--output-model', type=str, default='Scalar', choices=output_modules.__all__, help='The type of output model')
77-
parser.add_argument('--prior-model', type=str, default=None, help='Which prior model to use. It can be a string, a dict if you want to add arguments for it or a dicts to add more than one prior. e.g. {"Atomref": {"max_z":100}, "Coulomb":{"max_num_neighs"=100, "alpha"=1}', action="extend", nargs="*")
77+
parser.add_argument('--prior-model', type=str, default=None, help='Which prior model to use. It can be a string, a dict if you want to add arguments for it or a dicts to add more than one prior. e.g. {"Atomref": {"max_z":100}, "Coulomb":{"max_num_neighs"=100, "lower_switch_distance"=4, "upper_switch_distance"=8}', action="extend", nargs="*")
7878

7979
# architectural args
8080
parser.add_argument('--charge', type=bool, default=False, help='Model needs a total charge. Set this to True if your dataset contains charges and you want them passed down to the model.')

0 commit comments

Comments
 (0)