Skip to content

Commit 58f298f

Browse files
committed
Allow appending extra values to embedding vector
1 parent 166b7db commit 58f298f

File tree

7 files changed

+97
-8
lines changed

7 files changed

+97
-8
lines changed

tests/test_model.py

+11
Original file line numberDiff line numberDiff line change
@@ -227,3 +227,14 @@ def test_gradients(model_name):
227227
torch.autograd.gradcheck(
228228
model, (z, pos, batch), eps=1e-4, atol=1e-3, rtol=1e-2, nondet_tol=1e-3
229229
)
230+
231+
232+
@mark.parametrize("model_name", models.__all_models__)
233+
@mark.parametrize("use_batch", [True, False])
234+
def test_extra_embedding(model_name, use_batch):
235+
z, pos, batch = create_example_batch()
236+
args = load_example_args(model_name, prior_model=None)
237+
args["extra_embedding"] = ["atomic", "global"]
238+
model = create_model(args)
239+
batch = batch if use_batch else None
240+
model(z, pos, batch=batch, extra_args={'atomic':torch.rand(6), 'global':torch.rand(2)})

torchmdnet/models/model.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,12 @@ def create_model(args, prior_model=None, mean=None, std=None):
3838
args["static_shapes"] = False
3939
if "vector_cutoff" not in args:
4040
args["vector_cutoff"] = False
41+
if "extra_embedding" not in args:
42+
extra_embedding = None
43+
elif isinstance(args["extra_embedding"], str):
44+
extra_embedding = [args["extra_embedding"]]
45+
else:
46+
extra_embedding = args["extra_embedding"]
4147

4248
shared_args = dict(
4349
hidden_channels=args["embedding_dimension"],
@@ -57,6 +63,7 @@ def create_model(args, prior_model=None, mean=None, std=None):
5763
else None
5864
),
5965
dtype=dtype,
66+
extra_embedding=extra_embedding
6067
)
6168

6269
# representation network
@@ -370,7 +377,7 @@ def forward(
370377
If this is omitted, periodic boundary conditions are not applied.
371378
q (Tensor, optional): Atomic charges in the molecule. Shape: (N,).
372379
s (Tensor, optional): Atomic spins in the molecule. Shape: (N,).
373-
extra_args (Dict[str, Tensor], optional): Extra arguments to pass to the prior model.
380+
extra_args (Dict[str, Tensor], optional): Extra arguments to pass to the model.
374381
375382
Returns:
376383
Tuple[Tensor, Optional[Tensor]]: The output of the model and the derivative of the output with respect to the positions if derivative is True, None otherwise.
@@ -380,9 +387,19 @@ def forward(
380387

381388
if self.derivative:
382389
pos.requires_grad_(True)
390+
if self.representation_model.extra_embedding is None:
391+
extra_embedding_args = None
392+
else:
393+
extra = []
394+
for arg in self.representation_model.extra_embedding:
395+
t = extra_args[arg]
396+
if t.shape != z.shape:
397+
t = t[batch]
398+
extra.append(t)
399+
extra_embedding_args = tuple(extra)
383400
# run the potentially wrapped representation model
384401
x, v, z, pos, batch = self.representation_model(
385-
z, pos, batch, box=box, q=q, s=s
402+
z, pos, batch, box=box, q=q, s=s, extra_embedding_args=extra_embedding_args
386403
)
387404
# apply the output network
388405
x = self.output_model.pre_reduce(x, v, z, pos, batch)

torchmdnet/models/tensornet.py

+21-3
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,9 @@ class TensorNet(nn.Module):
120120
(default: :obj:`True`)
121121
check_errors (bool, optional): Whether to check for errors in the distance module.
122122
(default: :obj:`True`)
123+
extra_embedding (tuple, optional): the names of extra fields to append to the embedding
124+
vector for each atom
125+
(default: :obj:`None`)
123126
"""
124127

125128
def __init__(
@@ -139,6 +142,7 @@ def __init__(
139142
check_errors=True,
140143
dtype=torch.float32,
141144
box_vecs=None,
145+
extra_embedding=None
142146
):
143147
super(TensorNet, self).__init__()
144148

@@ -163,6 +167,7 @@ def __init__(
163167
self.activation = activation
164168
self.cutoff_lower = cutoff_lower
165169
self.cutoff_upper = cutoff_upper
170+
self.extra_embedding = extra_embedding
166171
act_class = act_class_mapping[activation]
167172
self.distance_expansion = rbf_class_mapping[rbf_type](
168173
cutoff_lower, cutoff_upper, num_rbf, trainable_rbf
@@ -176,6 +181,7 @@ def __init__(
176181
trainable_rbf,
177182
max_z,
178183
dtype,
184+
extra_embedding
179185
)
180186

181187
self.layers = nn.ModuleList()
@@ -228,6 +234,7 @@ def forward(
228234
box: Optional[Tensor] = None,
229235
q: Optional[Tensor] = None,
230236
s: Optional[Tensor] = None,
237+
extra_embedding_args: [Optional[Tuple[Tensor]]] = None
231238
) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]:
232239
# Obtain graph, with distances and relative position vectors
233240
edge_index, edge_weight, edge_vec = self.distance(pos, batch, box)
@@ -258,7 +265,7 @@ def forward(
258265
# Normalizing edge vectors by their length can result in NaNs, breaking Autograd.
259266
# I avoid dividing by zero by setting the weight of self edges and self loops to 1
260267
edge_vec = edge_vec / edge_weight.masked_fill(mask, 1).unsqueeze(1)
261-
X = self.tensor_embedding(zp, edge_index, edge_weight, edge_vec, edge_attr)
268+
X = self.tensor_embedding(zp, edge_index, edge_weight, edge_vec, edge_attr, extra_embedding_args)
262269
for layer in self.layers:
263270
X = layer(X, edge_index, edge_weight, edge_attr, q)
264271
I, A, S = decompose_tensor(X)
@@ -287,6 +294,7 @@ def __init__(
287294
trainable_rbf=False,
288295
max_z=128,
289296
dtype=torch.float32,
297+
extra_embedding=None
290298
):
291299
super(TensorEmbedding, self).__init__()
292300

@@ -297,6 +305,10 @@ def __init__(
297305
self.cutoff = CosineCutoff(cutoff_lower, cutoff_upper)
298306
self.max_z = max_z
299307
self.emb = nn.Embedding(max_z, hidden_channels, dtype=dtype)
308+
if extra_embedding is not None:
309+
self.reshape_embedding = nn.Linear(hidden_channels+len(extra_embedding), hidden_channels, dtype=dtype)
310+
else:
311+
self.reshape_embedding = None
300312
self.emb2 = nn.Linear(2 * hidden_channels, hidden_channels, dtype=dtype)
301313
self.act = activation()
302314
self.linears_tensor = nn.ModuleList()
@@ -319,15 +331,20 @@ def reset_parameters(self):
319331
self.distance_proj2.reset_parameters()
320332
self.distance_proj3.reset_parameters()
321333
self.emb.reset_parameters()
334+
if self.reshape_embedding is not None:
335+
self.reshape_embedding.reset_parameters()
322336
self.emb2.reset_parameters()
323337
for linear in self.linears_tensor:
324338
linear.reset_parameters()
325339
for linear in self.linears_scalar:
326340
linear.reset_parameters()
327341
self.init_norm.reset_parameters()
328342

329-
def _get_atomic_number_message(self, z: Tensor, edge_index: Tensor) -> Tensor:
343+
def _get_atomic_number_message(self, z: Tensor, edge_index: Tensor, extra_embedding_args: Optional[Tuple[Tensor]]) -> Tensor:
330344
Z = self.emb(z)
345+
if self.reshape_embedding is not None:
346+
Z = torch.cat((Z,)+tuple(t.unsqueeze(1) for t in extra_embedding_args), dim=1)
347+
Z = self.reshape_embedding(Z)
331348
Zij = self.emb2(
332349
Z.index_select(0, edge_index.t().reshape(-1)).view(
333350
-1, self.hidden_channels * 2
@@ -362,8 +379,9 @@ def forward(
362379
edge_weight: Tensor,
363380
edge_vec_norm: Tensor,
364381
edge_attr: Tensor,
382+
extra_embedding_args: Optional[Tuple[Tensor]]
365383
) -> Tensor:
366-
Zij = self._get_atomic_number_message(z, edge_index)
384+
Zij = self._get_atomic_number_message(z, edge_index, extra_embedding_args)
367385
Iij, Aij, Sij = self._get_tensor_messages(
368386
Zij, edge_weight, edge_vec_norm, edge_attr
369387
)

torchmdnet/models/torchmd_et.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,9 @@ class TorchMD_ET(nn.Module):
7979
(default: :obj:`False`)
8080
check_errors (bool, optional): Whether to check for errors in the distance module.
8181
(default: :obj:`True`)
82-
82+
extra_embedding (tuple, optional): the names of extra fields to append to the embedding
83+
vector for each atom
84+
(default: :obj:`None`)
8385
"""
8486

8587
def __init__(
@@ -102,6 +104,7 @@ def __init__(
102104
box_vecs=None,
103105
vector_cutoff=False,
104106
dtype=torch.float32,
107+
extra_embedding=None
105108
):
106109
super(TorchMD_ET, self).__init__()
107110

@@ -133,10 +136,15 @@ def __init__(
133136
self.cutoff_upper = cutoff_upper
134137
self.max_z = max_z
135138
self.dtype = dtype
139+
self.extra_embedding = extra_embedding
136140

137141
act_class = act_class_mapping[activation]
138142

139143
self.embedding = nn.Embedding(self.max_z, hidden_channels, dtype=dtype)
144+
if extra_embedding is not None:
145+
self.reshape_embedding = nn.Linear(hidden_channels+len(extra_embedding), hidden_channels, dtype=dtype)
146+
else:
147+
self.reshape_embedding = None
140148

141149
self.distance = OptimizedDistance(
142150
cutoff_lower,
@@ -181,6 +189,8 @@ def __init__(
181189

182190
def reset_parameters(self):
183191
self.embedding.reset_parameters()
192+
if self.reshape_embedding is not None:
193+
self.reshape_embedding.reset_parameters()
184194
self.distance_expansion.reset_parameters()
185195
if self.neighbor_embedding is not None:
186196
self.neighbor_embedding.reset_parameters()
@@ -196,8 +206,12 @@ def forward(
196206
box: Optional[Tensor] = None,
197207
q: Optional[Tensor] = None,
198208
s: Optional[Tensor] = None,
209+
extra_embedding_args: [Optional[Tuple[Tensor]]] = None
199210
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
200211
x = self.embedding(z)
212+
if self.reshape_embedding is not None:
213+
x = torch.cat((x,)+tuple(t.unsqueeze(1) for t in extra_embedding_args), dim=1)
214+
x = self.reshape_embedding(x)
201215

202216
edge_index, edge_weight, edge_vec = self.distance(pos, batch, box)
203217
# This assert must be here to convince TorchScript that edge_vec is not None

torchmdnet/models/torchmd_gn.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,9 @@ class TorchMD_GN(nn.Module):
8686
(default: :obj:`None`)
8787
check_errors (bool, optional): Whether to check for errors in the distance module.
8888
(default: :obj:`True`)
89-
89+
extra_embedding (tuple, optional): the names of extra fields to append to the embedding
90+
vector for each atom
91+
(default: :obj:`None`)
9092
"""
9193

9294
def __init__(
@@ -107,6 +109,7 @@ def __init__(
107109
aggr="add",
108110
dtype=torch.float32,
109111
box_vecs=None,
112+
extra_embedding=None
110113
):
111114
super(TorchMD_GN, self).__init__()
112115

@@ -136,10 +139,15 @@ def __init__(
136139
self.cutoff_upper = cutoff_upper
137140
self.max_z = max_z
138141
self.aggr = aggr
142+
self.extra_embedding = extra_embedding
139143

140144
act_class = act_class_mapping[activation]
141145

142146
self.embedding = nn.Embedding(self.max_z, hidden_channels, dtype=dtype)
147+
if extra_embedding is not None:
148+
self.reshape_embedding = nn.Linear(hidden_channels+len(extra_embedding), hidden_channels, dtype=dtype)
149+
else:
150+
self.reshape_embedding = None
143151

144152
self.distance = OptimizedDistance(
145153
cutoff_lower,
@@ -184,6 +192,8 @@ def __init__(
184192

185193
def reset_parameters(self):
186194
self.embedding.reset_parameters()
195+
if self.reshape_embedding is not None:
196+
self.reshape_embedding.reset_parameters()
187197
self.distance_expansion.reset_parameters()
188198
if self.neighbor_embedding is not None:
189199
self.neighbor_embedding.reset_parameters()
@@ -198,8 +208,12 @@ def forward(
198208
box: Optional[Tensor] = None,
199209
s: Optional[Tensor] = None,
200210
q: Optional[Tensor] = None,
211+
extra_embedding_args: [Optional[Tuple[Tensor]]] = None
201212
) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]:
202213
x = self.embedding(z)
214+
if self.reshape_embedding is not None:
215+
x = torch.cat((x,)+tuple(t.unsqueeze(1) for t in extra_embedding_args), dim=1)
216+
x = self.reshape_embedding(x)
203217

204218
edge_index, edge_weight, _ = self.distance(pos, batch, box)
205219
edge_attr = self.distance_expansion(edge_weight)

torchmdnet/models/torchmd_t.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,9 @@ class TorchMD_T(nn.Module):
7676
(default: :obj:`None`)
7777
check_errors (bool, optional): Whether to check for errors in the distance module.
7878
(default: :obj:`True`)
79-
79+
extra_embedding (tuple, optional): the names of extra fields to append to the embedding
80+
vector for each atom
81+
(default: :obj:`None`)
8082
"""
8183

8284
def __init__(
@@ -98,6 +100,7 @@ def __init__(
98100
max_num_neighbors=32,
99101
dtype=torch.float,
100102
box_vecs=None,
103+
extra_embedding=None
101104
):
102105
super(TorchMD_T, self).__init__()
103106

@@ -124,11 +127,16 @@ def __init__(
124127
self.cutoff_lower = cutoff_lower
125128
self.cutoff_upper = cutoff_upper
126129
self.max_z = max_z
130+
self.extra_embedding = extra_embedding
127131

128132
act_class = act_class_mapping[activation]
129133
attn_act_class = act_class_mapping[attn_activation]
130134

131135
self.embedding = nn.Embedding(self.max_z, hidden_channels, dtype=dtype)
136+
if extra_embedding is not None:
137+
self.reshape_embedding = nn.Linear(hidden_channels+len(extra_embedding), hidden_channels, dtype=dtype)
138+
else:
139+
self.reshape_embedding = None
132140

133141
self.distance = OptimizedDistance(
134142
cutoff_lower,
@@ -177,6 +185,8 @@ def __init__(
177185

178186
def reset_parameters(self):
179187
self.embedding.reset_parameters()
188+
if self.reshape_embedding is not None:
189+
self.reshape_embedding.reset_parameters()
180190
self.distance_expansion.reset_parameters()
181191
if self.neighbor_embedding is not None:
182192
self.neighbor_embedding.reset_parameters()
@@ -192,8 +202,12 @@ def forward(
192202
box: Optional[Tensor] = None,
193203
s: Optional[Tensor] = None,
194204
q: Optional[Tensor] = None,
205+
extra_embedding_args: [Optional[Tuple[Tensor]]] = None
195206
) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]:
196207
x = self.embedding(z)
208+
if self.reshape_embedding is not None:
209+
x = torch.cat((x,)+tuple(t.unsqueeze(1) for t in extra_embedding_args), dim=1)
210+
x = self.reshape_embedding(x)
197211

198212
edge_index, edge_weight, _ = self.distance(pos, batch, box)
199213
edge_attr = self.distance_expansion(edge_weight)

torchmdnet/scripts/train.py

+1
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def get_argparse():
8080
parser.add_argument('--charge', type=bool, default=False, help='Model needs a total charge. Set this to True if your dataset contains charges and you want them passed down to the model.')
8181
parser.add_argument('--spin', type=bool, default=False, help='Model needs a spin state. Set this to True if your dataset contains spin states and you want them passed down to the model.')
8282
parser.add_argument('--embedding-dimension', type=int, default=256, help='Embedding dimension')
83+
parser.add_argument('--extra-embedding', type=str, default=None, help='Extra fields of the dataset to pass to the model and append to the embedding vector.', action="extend", nargs="*")
8384
parser.add_argument('--num-layers', type=int, default=6, help='Number of interaction layers in the model')
8485
parser.add_argument('--num-rbf', type=int, default=64, help='Number of radial basis functions in model')
8586
parser.add_argument('--activation', type=str, default='silu', choices=list(act_class_mapping.keys()), help='Activation function')

0 commit comments

Comments
 (0)