From 935538c0896f99cb6fba3be9ab60686d4841bb3c Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 8 May 2024 18:10:41 +0200 Subject: [PATCH 1/8] Some optimizations to TensorNet --- torchmdnet/models/tensornet.py | 219 ++++++++++++++++++++------------- 1 file changed, 134 insertions(+), 85 deletions(-) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index a2006c9e..07bde31e 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -173,7 +173,6 @@ def __init__( act_class, cutoff_lower, cutoff_upper, - trainable_rbf, max_z, dtype, ) @@ -220,6 +219,43 @@ def reset_parameters(self): self.linear.reset_parameters() self.out_norm.reset_parameters() + def _make_static( + self, num_nodes: int, edge_index: Tensor, edge_weight: Tensor, edge_vec: Tensor + ) -> Tuple[Tensor, Tensor, Tensor]: + # Distance module returns -1 for non-existing edges, to avoid having to resize the tensors when we want to ensure static shapes (for CUDA graphs) we make all non-existing edges pertain to a ghost atom + if self.static_shapes: + mask = (edge_index[0] < 0).unsqueeze(0).expand_as(edge_index) + # I trick the model into thinking that the masked edges pertain to the extra atom + # WARNING: This can hurt performance if max_num_pairs >> actual_num_pairs + edge_index = edge_index.masked_fill(mask, num_nodes) + edge_weight = edge_weight.masked_fill(mask[0], 0) + edge_vec = edge_vec.masked_fill( + mask[0].unsqueeze(-1).expand_as(edge_vec), 0 + ) + return edge_index, edge_weight, edge_vec + + def _compute_neighbors( + self, pos: Tensor, batch: Tensor, box: Optional[Tensor] + ) -> Tuple[Tensor, Tensor, Tensor]: + edge_index, edge_weight, edge_vec = self.distance(pos, batch, box) + # This assert convinces TorchScript that edge_vec is a Tensor and not an Optional[Tensor] + assert ( + edge_vec is not None + ), "Distance module did not return directional information" + edge_index, edge_weight, edge_vec = self._make_static( + pos.shape[0], edge_index, edge_weight, edge_vec + ) + return edge_index, edge_weight, edge_vec + + def output(self, X: Tensor) -> Tensor: + I, A, S = decompose_tensor(X) # shape: (n_atoms, hidden_channels, 3, 3) + x = torch.cat( + (tensor_norm(I), tensor_norm(A), tensor_norm(S)), dim=-1 + ) # shape: (n_atoms, 3*hidden_channels) + x = self.out_norm(x) # shape: (n_atoms, 3*hidden_channels) + x = self.act(self.linear((x))) # shape: (n_atoms, hidden_channels) + return x + def forward( self, z: Tensor, @@ -229,45 +265,27 @@ def forward( q: Optional[Tensor] = None, s: Optional[Tensor] = None, ) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]: - # Obtain graph, with distances and relative position vectors - edge_index, edge_weight, edge_vec = self.distance(pos, batch, box) - # This assert convinces TorchScript that edge_vec is a Tensor and not an Optional[Tensor] - assert ( - edge_vec is not None - ), "Distance module did not return directional information" - # Distance module returns -1 for non-existing edges, to avoid having to resize the tensors when we want to ensure static shapes (for CUDA graphs) we make all non-existing edges pertain to a ghost atom + if self.static_shapes: + z = torch.cat((z, torch.zeros(1, device=z.device, dtype=z.dtype)), dim=0) # Total charge q is a molecule-wise property. We transform it into an atom-wise property, with all atoms belonging to the same molecule being assigned the same charge q if q is None: q = torch.zeros_like(z, device=z.device, dtype=z.dtype) else: q = q[batch] - zp = z - if self.static_shapes: - mask = (edge_index[0] < 0).unsqueeze(0).expand_as(edge_index) - zp = torch.cat((z, torch.zeros(1, device=z.device, dtype=z.dtype)), dim=0) - q = torch.cat((q, torch.zeros(1, device=q.device, dtype=q.dtype)), dim=0) - # I trick the model into thinking that the masked edges pertain to the extra atom - # WARNING: This can hurt performance if max_num_pairs >> actual_num_pairs - edge_index = edge_index.masked_fill(mask, z.shape[0]) - edge_weight = edge_weight.masked_fill(mask[0], 0) - edge_vec = edge_vec.masked_fill( - mask[0].unsqueeze(-1).expand_as(edge_vec), 0 - ) - edge_attr = self.distance_expansion(edge_weight) - mask = edge_index[0] == edge_index[1] - # Normalizing edge vectors by their length can result in NaNs, breaking Autograd. - # I avoid dividing by zero by setting the weight of self edges and self loops to 1 - edge_vec = edge_vec / edge_weight.masked_fill(mask, 1).unsqueeze(1) - X = self.tensor_embedding(zp, edge_index, edge_weight, edge_vec, edge_attr) + edge_index, edge_weight, edge_vec = self._compute_neighbors(pos, batch, box) + edge_attr = self.distance_expansion(edge_weight) # shape: (n_edges, num_rbf) + X = self.tensor_embedding( + z, edge_index, edge_weight, edge_vec, edge_attr + ) # shape: (n_atoms, hidden_channels, 3, 3) for layer in self.layers: - X = layer(X, edge_index, edge_weight, edge_attr, q) - I, A, S = decompose_tensor(X) - x = torch.cat((tensor_norm(I), tensor_norm(A), tensor_norm(S)), dim=-1) - x = self.out_norm(x) - x = self.act(self.linear((x))) - # # Remove the extra atom + X = layer( + X, edge_index, edge_weight, edge_attr, q + ) # shape: (n_atoms, hidden_channels, 3, 3) + x = self.output(X) # shape: (n_atoms, hidden_channels) + # Remove the extra atom if self.static_shapes: x = x[:-1] + z = z[:-1] return x, None, z, pos, batch @@ -284,7 +302,6 @@ def __init__( activation, cutoff_lower, cutoff_upper, - trainable_rbf=False, max_z=128, dtype=torch.float32, ): @@ -326,7 +343,16 @@ def reset_parameters(self): linear.reset_parameters() self.init_norm.reset_parameters() - def _get_atomic_number_message(self, z: Tensor, edge_index: Tensor) -> Tensor: + def _normalize_edges( + self, edge_index: Tensor, edge_weight: Tensor, edge_vec: Tensor + ) -> Tensor: + mask = edge_index[0] == edge_index[1] + # Normalizing edge vectors by their length can result in NaNs, breaking Autograd. + # I avoid dividing by zero by setting the weight of self edges and self loops to 1 + edge_vec = edge_vec / edge_weight.masked_fill(mask, 1).unsqueeze(1) + return edge_vec + + def _compute_edge_atomic_features(self, z: Tensor, edge_index: Tensor) -> Tensor: Z = self.emb(z) Zij = self.emb2( Z.index_select(0, edge_index.t().reshape(-1)).view( @@ -335,48 +361,66 @@ def _get_atomic_number_message(self, z: Tensor, edge_index: Tensor) -> Tensor: )[..., None, None] return Zij - def _get_tensor_messages( - self, Zij: Tensor, edge_weight: Tensor, edge_vec_norm: Tensor, edge_attr: Tensor - ) -> Tuple[Tensor, Tensor, Tensor]: + def _compute_edge_features( + self, + z: Tensor, + edge_index, + edge_weight: Tensor, + edge_vec: Tensor, + edge_attr: Tensor, + ) -> Tensor: + edge_vec_norm = self._normalize_edges( + edge_index, edge_weight, edge_vec + ) # shape: (n_edges, 3) + Zij = self._compute_edge_atomic_features( + z, edge_index + ) # shape: (n_edges, hidden_channels) C = self.cutoff(edge_weight).reshape(-1, 1, 1, 1) * Zij - eye = torch.eye(3, 3, device=edge_vec_norm.device, dtype=edge_vec_norm.dtype)[ - None, None, ... - ] - Iij = self.distance_proj1(edge_attr)[..., None, None] * C * eye + Iij = self.distance_proj1(edge_attr) Aij = ( self.distance_proj2(edge_attr)[..., None, None] - * C * vector_to_skewtensor(edge_vec_norm)[..., None, :, :] ) Sij = ( self.distance_proj3(edge_attr)[..., None, None] - * C * vector_to_symtensor(edge_vec_norm)[..., None, :, :] ) - return Iij, Aij, Sij + features = Aij + Sij + features.diagonal(dim1=-2, dim2=-1).add_(Iij.unsqueeze(-1)) + return features * C + + def _message_passing( + self, + num_atoms: int, + edge_index: Tensor, + Xij: Tensor, + ) -> Tensor: + source = torch.zeros( + num_atoms, self.hidden_channels, 3, 3, device=Xij.device, dtype=Xij.dtype + ) + X = source.index_add(dim=0, index=edge_index[0], source=Xij) + return X def forward( self, z: Tensor, edge_index: Tensor, edge_weight: Tensor, - edge_vec_norm: Tensor, + edge_vec: Tensor, edge_attr: Tensor, ) -> Tensor: - Zij = self._get_atomic_number_message(z, edge_index) - Iij, Aij, Sij = self._get_tensor_messages( - Zij, edge_weight, edge_vec_norm, edge_attr - ) - source = torch.zeros( - z.shape[0], self.hidden_channels, 3, 3, device=z.device, dtype=Iij.dtype - ) - I = source.index_add(dim=0, index=edge_index[0], source=Iij) - A = source.index_add(dim=0, index=edge_index[0], source=Aij) - S = source.index_add(dim=0, index=edge_index[0], source=Sij) - norm = self.init_norm(tensor_norm(I + A + S)) + Xij = self._compute_edge_features( + z, edge_index, edge_weight, edge_vec, edge_attr + ) # shape: (n_edges, hidden_channels, 3, 3) + X = self._message_passing( + z.shape[0], edge_index, Xij + ) # shape: (n_atoms, hidden_channels, 3, 3) + + norm = self.init_norm(tensor_norm(X)) # shape: (n_atoms, hidden_channels) for linear_scalar in self.linears_scalar: norm = self.act(linear_scalar(norm)) norm = norm.reshape(-1, self.hidden_channels, 3) + I, A, S = decompose_tensor(X) # shape: (n_atoms, hidden_channels, 3, 3) I = ( self.linears_tensor[0](I.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * norm[..., 0, None, None] @@ -389,21 +433,29 @@ def forward( self.linears_tensor[2](S.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * norm[..., 2, None, None] ) - X = I + A + S + X = I + A + S # shape: (n_atoms, hidden_channels, 3, 3) return X -def tensor_message_passing( - edge_index: Tensor, factor: Tensor, tensor: Tensor, natoms: int -) -> Tensor: +def tensor_message_passing(edge_index: Tensor, tensor: Tensor, natoms: int) -> Tensor: """Message passing for tensors.""" - msg = factor * tensor.index_select(0, edge_index[1]) + msg = tensor.index_select(0, edge_index[1]) shape = (natoms, tensor.shape[1], tensor.shape[2], tensor.shape[3]) tensor_m = torch.zeros(*shape, device=tensor.device, dtype=tensor.dtype) tensor_m = tensor_m.index_add(0, edge_index[0], msg) return tensor_m +def compute_tensor_edge_features(X, edge_index, factor): + I, A, S = decompose_tensor(X) + msg = ( + factor[..., 0, None, None] * I.index_select(0, edge_index[1]) + + factor[..., 1, None, None] * A.index_select(0, edge_index[1]) + + factor[..., 2, None, None] * S.index_select(0, edge_index[1]) + ) + return msg + + class Interaction(nn.Module): """Interaction layer. @@ -450,6 +502,22 @@ def reset_parameters(self): for linear in self.linears_tensor: linear.reset_parameters() + def update_tensor_features(self, X, X_aggregated): + B = torch.matmul(X, X_aggregated) + if self.equivariance_invariance_group == "O(3)": + A = torch.matmul(X_aggregated, X) + elif self.equivariance_invariance_group == "SO(3)": + A = B + else: + raise ValueError("Unknown equivariance group") + Xnew = A + B + I, A, S = decompose_tensor(Xnew / (tensor_norm(Xnew) + 1)[..., None, None]) + I = self.linears_tensor[3](I.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + A = self.linears_tensor[4](A.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + S = self.linears_tensor[5](S.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + dX = I + A + S + return dX + def forward( self, X: Tensor, @@ -470,28 +538,9 @@ def forward( A = self.linears_tensor[1](A.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) S = self.linears_tensor[2](S.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) Y = I + A + S - Im = tensor_message_passing( - edge_index, edge_attr[..., 0, None, None], I, X.shape[0] - ) - Am = tensor_message_passing( - edge_index, edge_attr[..., 1, None, None], A, X.shape[0] - ) - Sm = tensor_message_passing( - edge_index, edge_attr[..., 2, None, None], S, X.shape[0] - ) - msg = Im + Am + Sm - if self.equivariance_invariance_group == "O(3)": - A = torch.matmul(msg, Y) - B = torch.matmul(Y, msg) - I, A, S = decompose_tensor((1 + 0.1 * q[..., None, None, None]) * (A + B)) - if self.equivariance_invariance_group == "SO(3)": - B = torch.matmul(Y, msg) - I, A, S = decompose_tensor(2 * B) - normp1 = (tensor_norm(I + A + S) + 1)[..., None, None] - I, A, S = I / normp1, A / normp1, S / normp1 - I = self.linears_tensor[3](I.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - A = self.linears_tensor[4](A.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - S = self.linears_tensor[5](S.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - dX = I + A + S - X = X + dX + (1 + 0.1 * q[..., None, None, None]) * torch.matrix_power(dX, 2) + Y_edges = compute_tensor_edge_features(Y, edge_index, edge_attr) + Ynew = tensor_message_passing(edge_index, Y_edges, X.shape[0]) + charge_factor = 1 + 0.1 * q[..., None, None, None] + dX = self.update_tensor_features(Y, Ynew) * charge_factor + X = X + dX + charge_factor * torch.matrix_power(dX, 2) return X From afb22e1ffd1a3d8593a6c9318549c16e8a1d8b69 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 8 May 2024 19:05:30 +0200 Subject: [PATCH 2/8] Moving some code around --- torchmdnet/models/tensornet.py | 138 +++++++++++++++++++-------------- 1 file changed, 81 insertions(+), 57 deletions(-) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index 07bde31e..66ec90a0 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -389,37 +389,21 @@ def _compute_edge_features( features.diagonal(dim1=-2, dim2=-1).add_(Iij.unsqueeze(-1)) return features * C - def _message_passing( - self, - num_atoms: int, - edge_index: Tensor, - Xij: Tensor, + def _aggregate_edge_features( + self, num_atoms: int, X: Tensor, edge_index: Tensor ) -> Tensor: - source = torch.zeros( - num_atoms, self.hidden_channels, 3, 3, device=Xij.device, dtype=Xij.dtype + Xij = torch.zeros( + num_atoms, + self.hidden_channels, + 3, + 3, + device=X.device, + dtype=X.dtype, ) - X = source.index_add(dim=0, index=edge_index[0], source=Xij) - return X + Xij = Xij.index_add(0, edge_index[0], source=X) + return Xij - def forward( - self, - z: Tensor, - edge_index: Tensor, - edge_weight: Tensor, - edge_vec: Tensor, - edge_attr: Tensor, - ) -> Tensor: - Xij = self._compute_edge_features( - z, edge_index, edge_weight, edge_vec, edge_attr - ) # shape: (n_edges, hidden_channels, 3, 3) - X = self._message_passing( - z.shape[0], edge_index, Xij - ) # shape: (n_atoms, hidden_channels, 3, 3) - - norm = self.init_norm(tensor_norm(X)) # shape: (n_atoms, hidden_channels) - for linear_scalar in self.linears_scalar: - norm = self.act(linear_scalar(norm)) - norm = norm.reshape(-1, self.hidden_channels, 3) + def _tensor_linear(self, X, norm): I, A, S = decompose_tensor(X) # shape: (n_atoms, hidden_channels, 3, 3) I = ( self.linears_tensor[0](I.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) @@ -436,14 +420,52 @@ def forward( X = I + A + S # shape: (n_atoms, hidden_channels, 3, 3) return X + def _norm_mlp(self, norm): + norm = self.init_norm(norm) + for linear_scalar in self.linears_scalar: + norm = self.act(linear_scalar(norm)) + norm = norm.reshape(-1, self.hidden_channels, 3) + return norm + + def forward( + self, + z: Tensor, + edge_index: Tensor, + edge_weight: Tensor, + edge_vec: Tensor, + edge_attr: Tensor, + ) -> Tensor: + Xij = self._compute_edge_features( + z, edge_index, edge_weight, edge_vec, edge_attr + ) # shape: (n_edges, hidden_channels, 3, 3) + X = self._aggregate_edge_features( + z.shape[0], Xij, edge_index + ) # shape: (n_atoms, hidden_channels, 3, 3) + norm = self._norm_mlp(tensor_norm(X)) # shape: (n_atoms, hidden_channels) + X = self._tensor_linear(X, norm) # shape: (n_atoms, hidden_channels, 3, 3) + return X -def tensor_message_passing(edge_index: Tensor, tensor: Tensor, natoms: int) -> Tensor: - """Message passing for tensors.""" - msg = tensor.index_select(0, edge_index[1]) - shape = (natoms, tensor.shape[1], tensor.shape[2], tensor.shape[3]) - tensor_m = torch.zeros(*shape, device=tensor.device, dtype=tensor.dtype) - tensor_m = tensor_m.index_add(0, edge_index[0], msg) - return tensor_m + +class TensorLinear(nn.Module): + + def __init__(self, hidden_channels): + super(TensorLinear, self).__init__() + self.linearI = nn.Linear(hidden_channels, hidden_channels, bias=False) + self.linearA = nn.Linear(hidden_channels, hidden_channels, bias=False) + self.linearS = nn.Linear(hidden_channels, hidden_channels, bias=False) + + def reset_parameters(self): + self.linearI.reset_parameters() + self.linearA.reset_parameters() + self.linearS.reset_parameters() + + def forward(self, X): + I, A, S = decompose_tensor(X) + I = self.linearI(I.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + A = self.linearA(A.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + S = self.linearS(S.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + dX = I + A + S + return dX def compute_tensor_edge_features(X, edge_index, factor): @@ -456,6 +478,19 @@ def compute_tensor_edge_features(X, edge_index, factor): return msg +def tensor_message_passing(n_atoms: int, edge_index: Tensor, tensor: Tensor) -> Tensor: + msg = tensor.index_select( + 0, edge_index[1] + ) # shape = (n_edges, hidden_channels, 3, 3) + tensor_m = torch.zeros( + (n_atoms, tensor.shape[1], tensor.shape[2], tensor.shape[3]), + device=tensor.device, + dtype=tensor.dtype, + ) + tensor_m = tensor_m.index_add(0, edge_index[0], msg) + return tensor_m # shape = (n_atoms, hidden_channels, 3, 3) + + class Interaction(nn.Module): """Interaction layer. @@ -487,11 +522,8 @@ def __init__( self.linears_scalar.append( nn.Linear(2 * hidden_channels, 3 * hidden_channels, bias=True, dtype=dtype) ) - self.linears_tensor = nn.ModuleList() - for _ in range(6): - self.linears_tensor.append( - nn.Linear(hidden_channels, hidden_channels, bias=False) - ) + self.tensor_linear_in = TensorLinear(hidden_channels) + self.tensor_linear_out = TensorLinear(hidden_channels) self.act = activation() self.equivariance_invariance_group = equivariance_invariance_group self.reset_parameters() @@ -499,10 +531,10 @@ def __init__( def reset_parameters(self): for linear in self.linears_scalar: linear.reset_parameters() - for linear in self.linears_tensor: - linear.reset_parameters() + self.tensor_linear_in.reset_parameters() + self.tensor_linear_out.reset_parameters() - def update_tensor_features(self, X, X_aggregated): + def _update_tensor_node_features(self, X, X_aggregated): B = torch.matmul(X, X_aggregated) if self.equivariance_invariance_group == "O(3)": A = torch.matmul(X_aggregated, X) @@ -511,12 +543,7 @@ def update_tensor_features(self, X, X_aggregated): else: raise ValueError("Unknown equivariance group") Xnew = A + B - I, A, S = decompose_tensor(Xnew / (tensor_norm(Xnew) + 1)[..., None, None]) - I = self.linears_tensor[3](I.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - A = self.linears_tensor[4](A.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - S = self.linears_tensor[5](S.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - dX = I + A + S - return dX + return Xnew def forward( self, @@ -533,14 +560,11 @@ def forward( edge_attr.shape[0], self.hidden_channels, 3 ) X = X / (tensor_norm(X) + 1)[..., None, None] - I, A, S = decompose_tensor(X) - I = self.linears_tensor[0](I.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - A = self.linears_tensor[1](A.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - S = self.linears_tensor[2](S.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - Y = I + A + S - Y_edges = compute_tensor_edge_features(Y, edge_index, edge_attr) - Ynew = tensor_message_passing(edge_index, Y_edges, X.shape[0]) + Y = self.tensor_linear_in(X) + Y_edges = compute_tensor_edge_features(X, edge_index, edge_attr) + Ynew = tensor_message_passing(X.shape[0], edge_index, Y_edges) charge_factor = 1 + 0.1 * q[..., None, None, None] - dX = self.update_tensor_features(Y, Ynew) * charge_factor + Xnew = self._update_tensor_node_features(Y, Ynew) * charge_factor + dX = self.tensor_linear_out(Xnew / (tensor_norm(Xnew) + 1)[..., None, None]) X = X + dX + charge_factor * torch.matrix_power(dX, 2) return X From 6ea18d180a84a303d5c06e509ac0014d7223726c Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Thu, 9 May 2024 11:42:16 +0200 Subject: [PATCH 3/8] More updates --- torchmdnet/models/tensornet.py | 140 ++++++++++++++++++--------------- 1 file changed, 76 insertions(+), 64 deletions(-) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index 66ec90a0..bad9203b 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -10,6 +10,7 @@ OptimizedDistance, rbf_class_mapping, act_class_mapping, + MLP, ) __all__ = ["TensorNet"] @@ -289,6 +290,43 @@ def forward( return x, None, z, pos, batch +class TensorLinear(nn.Module): + + def __init__(self, in_channels, out_channels, dtype=torch.float32): + super(TensorLinear, self).__init__() + self.linearI = nn.Linear(in_channels, out_channels, bias=False, dtype=dtype) + self.linearA = nn.Linear(in_channels, out_channels, bias=False, dtype=dtype) + self.linearS = nn.Linear(in_channels, out_channels, bias=False, dtype=dtype) + + def reset_parameters(self): + self.linearI.reset_parameters() + self.linearA.reset_parameters() + self.linearS.reset_parameters() + + def forward(self, X: Tensor, factor: Optional[Tensor] = None) -> Tensor: + if factor is None: + factor = ( + torch.ones(1, device=X.device, dtype=X.dtype) + .unsqueeze(-1) + .unsqueeze(-1) + ).expand(-1, -1, 3) + I, A, S = decompose_tensor(X) + I = ( + self.linearI(I.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + * factor[..., 0, None, None] + ) + A = ( + self.linearA(A.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + * factor[..., 1, None, None] + ) + S = ( + self.linearS(S.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + * factor[..., 2, None, None] + ) + dX = I + A + S + return dX + + class TensorEmbedding(nn.Module): """Tensor embedding layer. @@ -316,11 +354,7 @@ def __init__( self.emb = nn.Embedding(max_z, hidden_channels, dtype=dtype) self.emb2 = nn.Linear(2 * hidden_channels, hidden_channels, dtype=dtype) self.act = activation() - self.linears_tensor = nn.ModuleList() - for _ in range(3): - self.linears_tensor.append( - nn.Linear(hidden_channels, hidden_channels, bias=False) - ) + self.linear_tensor = TensorLinear(hidden_channels, hidden_channels) self.linears_scalar = nn.ModuleList() self.linears_scalar.append( nn.Linear(hidden_channels, 2 * hidden_channels, bias=True, dtype=dtype) @@ -337,8 +371,7 @@ def reset_parameters(self): self.distance_proj3.reset_parameters() self.emb.reset_parameters() self.emb2.reset_parameters() - for linear in self.linears_tensor: - linear.reset_parameters() + self.linear_tensor.reset_parameters() for linear in self.linears_scalar: linear.reset_parameters() self.init_norm.reset_parameters() @@ -361,7 +394,7 @@ def _compute_edge_atomic_features(self, z: Tensor, edge_index: Tensor) -> Tensor )[..., None, None] return Zij - def _compute_edge_features( + def _compute_edge_tensor_features( self, z: Tensor, edge_index, @@ -403,23 +436,6 @@ def _aggregate_edge_features( Xij = Xij.index_add(0, edge_index[0], source=X) return Xij - def _tensor_linear(self, X, norm): - I, A, S = decompose_tensor(X) # shape: (n_atoms, hidden_channels, 3, 3) - I = ( - self.linears_tensor[0](I.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - * norm[..., 0, None, None] - ) - A = ( - self.linears_tensor[1](A.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - * norm[..., 1, None, None] - ) - S = ( - self.linears_tensor[2](S.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - * norm[..., 2, None, None] - ) - X = I + A + S # shape: (n_atoms, hidden_channels, 3, 3) - return X - def _norm_mlp(self, norm): norm = self.init_norm(norm) for linear_scalar in self.linears_scalar: @@ -435,39 +451,17 @@ def forward( edge_vec: Tensor, edge_attr: Tensor, ) -> Tensor: - Xij = self._compute_edge_features( + Xij = self._compute_edge_tensor_features( z, edge_index, edge_weight, edge_vec, edge_attr ) # shape: (n_edges, hidden_channels, 3, 3) X = self._aggregate_edge_features( z.shape[0], Xij, edge_index ) # shape: (n_atoms, hidden_channels, 3, 3) norm = self._norm_mlp(tensor_norm(X)) # shape: (n_atoms, hidden_channels) - X = self._tensor_linear(X, norm) # shape: (n_atoms, hidden_channels, 3, 3) + X = self.linear_tensor(X, norm) # shape: (n_atoms, hidden_channels, 3, 3) return X -class TensorLinear(nn.Module): - - def __init__(self, hidden_channels): - super(TensorLinear, self).__init__() - self.linearI = nn.Linear(hidden_channels, hidden_channels, bias=False) - self.linearA = nn.Linear(hidden_channels, hidden_channels, bias=False) - self.linearS = nn.Linear(hidden_channels, hidden_channels, bias=False) - - def reset_parameters(self): - self.linearI.reset_parameters() - self.linearA.reset_parameters() - self.linearS.reset_parameters() - - def forward(self, X): - I, A, S = decompose_tensor(X) - I = self.linearI(I.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - A = self.linearA(A.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - S = self.linearS(S.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - dX = I + A + S - return dX - - def compute_tensor_edge_features(X, edge_index, factor): I, A, S = decompose_tensor(X) msg = ( @@ -522,8 +516,8 @@ def __init__( self.linears_scalar.append( nn.Linear(2 * hidden_channels, 3 * hidden_channels, bias=True, dtype=dtype) ) - self.tensor_linear_in = TensorLinear(hidden_channels) - self.tensor_linear_out = TensorLinear(hidden_channels) + self.tensor_linear_in = TensorLinear(hidden_channels, hidden_channels) + self.tensor_linear_out = TensorLinear(hidden_channels, hidden_channels) self.act = activation() self.equivariance_invariance_group = equivariance_invariance_group self.reset_parameters() @@ -535,6 +529,7 @@ def reset_parameters(self): self.tensor_linear_out.reset_parameters() def _update_tensor_node_features(self, X, X_aggregated): + X = self.tensor_linear_in(X) B = torch.matmul(X, X_aggregated) if self.equivariance_invariance_group == "O(3)": A = torch.matmul(X_aggregated, X) @@ -545,6 +540,15 @@ def _update_tensor_node_features(self, X, X_aggregated): Xnew = A + B return Xnew + def _compute_vector_node_features(self, edge_attr, edge_weight): + C = self.cutoff(edge_weight) + for linear_scalar in self.linears_scalar: + edge_attr = self.act(linear_scalar(edge_attr)) + edge_attr = (edge_attr * C.view(-1, 1)).reshape( + edge_attr.shape[0], self.hidden_channels, 3 + ) + return edge_attr + def forward( self, X: Tensor, @@ -553,18 +557,26 @@ def forward( edge_attr: Tensor, q: Tensor, ) -> Tensor: - C = self.cutoff(edge_weight) - for linear_scalar in self.linears_scalar: - edge_attr = self.act(linear_scalar(edge_attr)) - edge_attr = (edge_attr * C.view(-1, 1)).reshape( - edge_attr.shape[0], self.hidden_channels, 3 - ) - X = X / (tensor_norm(X) + 1)[..., None, None] - Y = self.tensor_linear_in(X) - Y_edges = compute_tensor_edge_features(X, edge_index, edge_attr) - Ynew = tensor_message_passing(X.shape[0], edge_index, Y_edges) + X = ( + X / (tensor_norm(X) + 1)[..., None, None] + ) # shape (n_atoms, hidden_channels, 3, 3) + node_features = self._compute_vector_node_features( + edge_attr, edge_weight + ) # shape (n_atoms, hidden_channels, 3) + Y_edges = compute_tensor_edge_features( + X, edge_index, node_features + ) # shape (n_edges, hidden_channels, 3, 3) + Y_aggregated = tensor_message_passing( + X.shape[0], edge_index, Y_edges + ) # shape (n_atoms, hidden_channels, 3, 3) + Xnew = self._update_tensor_node_features( + X, Y_aggregated + ) # shape (n_atoms, hidden_channels, 3, 3) + dX = self.tensor_linear_out( + Xnew / (tensor_norm(Xnew) + 1)[..., None, None] + ) # shape (n_atoms, hidden_channels, 3, 3) charge_factor = 1 + 0.1 * q[..., None, None, None] - Xnew = self._update_tensor_node_features(Y, Ynew) * charge_factor - dX = self.tensor_linear_out(Xnew / (tensor_norm(Xnew) + 1)[..., None, None]) - X = X + dX + charge_factor * torch.matrix_power(dX, 2) + X = ( + X + (dX + torch.matrix_power(dX, 2)) * charge_factor + ) # shape (n_atoms, hidden_channels, 3, 3) return X From ada1192774d17f0dbb3134ff1841688f9e2f6a43 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Fri, 10 May 2024 15:45:51 +0200 Subject: [PATCH 4/8] Updates to TensorEmbeeding MP --- torchmdnet/models/tensornet.py | 110 +++++++++++++++++++++------------ 1 file changed, 70 insertions(+), 40 deletions(-) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index bad9203b..6c95535c 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -11,35 +11,39 @@ rbf_class_mapping, act_class_mapping, MLP, + nvtx_annotate, + nvtx_range, ) __all__ = ["TensorNet"] -torch.set_float32_matmul_precision("high") +torch.set_float32_matmul_precision("medium") torch.backends.cuda.matmul.allow_tf32 = True +@nvtx_annotate("vector_to_skewtensor") def vector_to_skewtensor(vector): """Creates a skew-symmetric tensor from a vector.""" - batch_size = vector.size(0) + batch_size = vector.shape[:-1] zero = torch.zeros(batch_size, device=vector.device, dtype=vector.dtype) tensor = torch.stack( ( zero, - -vector[:, 2], - vector[:, 1], - vector[:, 2], + -vector[..., 2], + vector[..., 1], + vector[..., 2], zero, - -vector[:, 0], - -vector[:, 1], - vector[:, 0], + -vector[..., 0], + -vector[..., 1], + vector[..., 0], zero, ), - dim=1, + dim=-1, ) - tensor = tensor.view(-1, 3, 3) + tensor = tensor.view(*batch_size, 3, 3) return tensor.squeeze(0) +@nvtx_annotate("vector_to_symtensor") def vector_to_symtensor(vector): """Creates a symmetric traceless tensor from the outer product of a vector with itself.""" tensor = torch.matmul(vector.unsqueeze(-1), vector.unsqueeze(-2)) @@ -50,6 +54,7 @@ def vector_to_symtensor(vector): return S +@nvtx_annotate("decompose_tensor") def decompose_tensor(tensor): """Full tensor decomposition into irreducible components.""" I = (tensor.diagonal(offset=0, dim1=-1, dim2=-2)).mean(-1)[ @@ -60,6 +65,7 @@ def decompose_tensor(tensor): return I, A, S +@nvtx_annotate("tensor_norm") def tensor_norm(tensor): """Computes Frobenius norm.""" return (tensor**2).sum((-2, -1)) @@ -220,6 +226,7 @@ def reset_parameters(self): self.linear.reset_parameters() self.out_norm.reset_parameters() + @nvtx_annotate("make_static") def _make_static( self, num_nodes: int, edge_index: Tensor, edge_weight: Tensor, edge_vec: Tensor ) -> Tuple[Tensor, Tensor, Tensor]: @@ -235,6 +242,7 @@ def _make_static( ) return edge_index, edge_weight, edge_vec + @nvtx_annotate("compute_neighbors") def _compute_neighbors( self, pos: Tensor, batch: Tensor, box: Optional[Tensor] ) -> Tuple[Tensor, Tensor, Tensor]: @@ -248,6 +256,7 @@ def _compute_neighbors( ) return edge_index, edge_weight, edge_vec + @nvtx_annotate("output") def output(self, X: Tensor) -> Tensor: I, A, S = decompose_tensor(X) # shape: (n_atoms, hidden_channels, 3, 3) x = torch.cat( @@ -257,6 +266,7 @@ def output(self, X: Tensor) -> Tensor: x = self.act(self.linear((x))) # shape: (n_atoms, hidden_channels) return x + @nvtx_annotate("TensorNet") def forward( self, z: Tensor, @@ -303,6 +313,7 @@ def reset_parameters(self): self.linearA.reset_parameters() self.linearS.reset_parameters() + @nvtx_annotate("TensorLinear") def forward(self, X: Tensor, factor: Optional[Tensor] = None) -> Tensor: if factor is None: factor = ( @@ -363,6 +374,8 @@ def __init__( nn.Linear(2 * hidden_channels, 3 * hidden_channels, bias=True, dtype=dtype) ) self.init_norm = nn.LayerNorm(hidden_channels, dtype=dtype) + self.num_rbf = num_rbf + self.hidden_channels = hidden_channels self.reset_parameters() def reset_parameters(self): @@ -376,6 +389,7 @@ def reset_parameters(self): linear.reset_parameters() self.init_norm.reset_parameters() + @nvtx_annotate("normalize_edges") def _normalize_edges( self, edge_index: Tensor, edge_weight: Tensor, edge_vec: Tensor ) -> Tensor: @@ -385,16 +399,18 @@ def _normalize_edges( edge_vec = edge_vec / edge_weight.masked_fill(mask, 1).unsqueeze(1) return edge_vec + @nvtx_annotate("compute_edge_atomic_features") def _compute_edge_atomic_features(self, z: Tensor, edge_index: Tensor) -> Tensor: Z = self.emb(z) Zij = self.emb2( Z.index_select(0, edge_index.t().reshape(-1)).view( -1, self.hidden_channels * 2 ) - )[..., None, None] + ) return Zij - def _compute_edge_tensor_features( + @nvtx_annotate("compute_edge_tensor_features") + def _compute_node_tensor_features( self, z: Tensor, edge_index, @@ -405,37 +421,45 @@ def _compute_edge_tensor_features( edge_vec_norm = self._normalize_edges( edge_index, edge_weight, edge_vec ) # shape: (n_edges, 3) - Zij = self._compute_edge_atomic_features( + Zij = self.cutoff(edge_weight)[:, None] * self._compute_edge_atomic_features( z, edge_index ) # shape: (n_edges, hidden_channels) - C = self.cutoff(edge_weight).reshape(-1, 1, 1, 1) * Zij - Iij = self.distance_proj1(edge_attr) - Aij = ( - self.distance_proj2(edge_attr)[..., None, None] - * vector_to_skewtensor(edge_vec_norm)[..., None, :, :] - ) - Sij = ( + + A = ( + self.distance_proj2(edge_attr)[ + ..., None + ] # shape: (n_edges, hidden_channels, 1) + * Zij[..., None] # shape: (n_edges, hidden_channels, 1) + * edge_vec_norm[:, None, :] # shape: (n_edges, 1, 3) + ) # shape: (n_edges, hidden_channels, 3) + A = self._aggregate_edge_features( + z.shape[0], A, edge_index[0] + ) # shape: (n_atoms, hidden_channels, 3) + A = vector_to_skewtensor(A) # shape: (n_atoms, hidden_channels, 3, 3) + + S = ( self.distance_proj3(edge_attr)[..., None, None] + * Zij[..., None, None] * vector_to_symtensor(edge_vec_norm)[..., None, :, :] - ) - features = Aij + Sij - features.diagonal(dim1=-2, dim2=-1).add_(Iij.unsqueeze(-1)) - return features * C + ) # shape: (n_edges, hidden_channels, 3, 3) + S = self._aggregate_edge_features( + z.shape[0], S, edge_index[0] + ) # shape: (n_atoms, hidden_channels, 3, 3) + I = self.distance_proj1(edge_attr) * Zij + I = self._aggregate_edge_features(z.shape[0], I, edge_index[0]) + features = A + S + features.diagonal(dim1=-2, dim2=-1).add_(I.unsqueeze(-1)) + return features + @nvtx_annotate("aggregate_edge_features") def _aggregate_edge_features( - self, num_atoms: int, X: Tensor, edge_index: Tensor + self, num_atoms: int, T: Tensor, source_indices: Tensor ) -> Tensor: - Xij = torch.zeros( - num_atoms, - self.hidden_channels, - 3, - 3, - device=X.device, - dtype=X.dtype, - ) - Xij = Xij.index_add(0, edge_index[0], source=X) - return Xij + targetI = torch.zeros(num_atoms, *T.shape[1:], device=T.device, dtype=T.dtype) + I = targetI.index_add(dim=0, index=source_indices, source=T) + return I + @nvtx_annotate("norm_mlp") def _norm_mlp(self, norm): norm = self.init_norm(norm) for linear_scalar in self.linears_scalar: @@ -443,6 +467,7 @@ def _norm_mlp(self, norm): norm = norm.reshape(-1, self.hidden_channels, 3) return norm + @nvtx_annotate("TensorEmbedding") def forward( self, z: Tensor, @@ -451,17 +476,18 @@ def forward( edge_vec: Tensor, edge_attr: Tensor, ) -> Tensor: - Xij = self._compute_edge_tensor_features( + X = self._compute_node_tensor_features( z, edge_index, edge_weight, edge_vec, edge_attr - ) # shape: (n_edges, hidden_channels, 3, 3) - X = self._aggregate_edge_features( - z.shape[0], Xij, edge_index ) # shape: (n_atoms, hidden_channels, 3, 3) + # X = self._aggregate_edge_features( + # z.shape[0], Xij, edge_index + # ) # shape: (n_atoms, hidden_channels, 3, 3) norm = self._norm_mlp(tensor_norm(X)) # shape: (n_atoms, hidden_channels) X = self.linear_tensor(X, norm) # shape: (n_atoms, hidden_channels, 3, 3) return X +@nvtx_annotate("compute_tensor_edge_features") def compute_tensor_edge_features(X, edge_index, factor): I, A, S = decompose_tensor(X) msg = ( @@ -472,6 +498,7 @@ def compute_tensor_edge_features(X, edge_index, factor): return msg +@nvtx_annotate("tensor_message_passing") def tensor_message_passing(n_atoms: int, edge_index: Tensor, tensor: Tensor) -> Tensor: msg = tensor.index_select( 0, edge_index[1] @@ -528,6 +555,7 @@ def reset_parameters(self): self.tensor_linear_in.reset_parameters() self.tensor_linear_out.reset_parameters() + @nvtx_annotate("update_tensor_node_features") def _update_tensor_node_features(self, X, X_aggregated): X = self.tensor_linear_in(X) B = torch.matmul(X, X_aggregated) @@ -540,6 +568,7 @@ def _update_tensor_node_features(self, X, X_aggregated): Xnew = A + B return Xnew + @nvtx_annotate("compute_vector_node_features") def _compute_vector_node_features(self, edge_attr, edge_weight): C = self.cutoff(edge_weight) for linear_scalar in self.linears_scalar: @@ -549,6 +578,7 @@ def _compute_vector_node_features(self, edge_attr, edge_weight): ) return edge_attr + @nvtx_annotate("Interaction") def forward( self, X: Tensor, @@ -562,7 +592,7 @@ def forward( ) # shape (n_atoms, hidden_channels, 3, 3) node_features = self._compute_vector_node_features( edge_attr, edge_weight - ) # shape (n_atoms, hidden_channels, 3) + ) # shape (n_edges, hidden_channels, 3) Y_edges = compute_tensor_edge_features( X, edge_index, node_features ) # shape (n_edges, hidden_channels, 3, 3) From 5212c8eb5f54d6399cd33cc6b99c8888a30714bd Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Fri, 10 May 2024 15:54:33 +0200 Subject: [PATCH 5/8] Store I as a single number --- torchmdnet/models/tensornet.py | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index 6c95535c..aa56e13a 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -57,11 +57,10 @@ def vector_to_symtensor(vector): @nvtx_annotate("decompose_tensor") def decompose_tensor(tensor): """Full tensor decomposition into irreducible components.""" - I = (tensor.diagonal(offset=0, dim1=-1, dim2=-2)).mean(-1)[ - ..., None, None - ] * torch.eye(3, 3, device=tensor.device, dtype=tensor.dtype) + I = (tensor.diagonal(offset=0, dim1=-1, dim2=-2)).mean(-1) A = 0.5 * (tensor - tensor.transpose(-2, -1)) - S = 0.5 * (tensor + tensor.transpose(-2, -1)) - I + S = tensor - A + S.diagonal(offset=0, dim1=-1, dim2=-2).sub_(I.unsqueeze(-1)) return I, A, S @@ -260,7 +259,7 @@ def _compute_neighbors( def output(self, X: Tensor) -> Tensor: I, A, S = decompose_tensor(X) # shape: (n_atoms, hidden_channels, 3, 3) x = torch.cat( - (tensor_norm(I), tensor_norm(A), tensor_norm(S)), dim=-1 + (3 * I**2, tensor_norm(A), tensor_norm(S)), dim=-1 ) # shape: (n_atoms, 3*hidden_channels) x = self.out_norm(x) # shape: (n_atoms, 3*hidden_channels) x = self.act(self.linear((x))) # shape: (n_atoms, hidden_channels) @@ -322,10 +321,7 @@ def forward(self, X: Tensor, factor: Optional[Tensor] = None) -> Tensor: .unsqueeze(-1) ).expand(-1, -1, 3) I, A, S = decompose_tensor(X) - I = ( - self.linearI(I.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - * factor[..., 0, None, None] - ) + I = self.linearI(I) * factor[..., 0] A = ( self.linearA(A.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * factor[..., 1, None, None] @@ -334,7 +330,8 @@ def forward(self, X: Tensor, factor: Optional[Tensor] = None) -> Tensor: self.linearS(S.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * factor[..., 2, None, None] ) - dX = I + A + S + dX = A + S + dX.diagonal(dim1=-2, dim2=-1).add_(I.unsqueeze(-1)) return dX @@ -490,10 +487,11 @@ def forward( @nvtx_annotate("compute_tensor_edge_features") def compute_tensor_edge_features(X, edge_index, factor): I, A, S = decompose_tensor(X) - msg = ( - factor[..., 0, None, None] * I.index_select(0, edge_index[1]) - + factor[..., 1, None, None] * A.index_select(0, edge_index[1]) - + factor[..., 2, None, None] * S.index_select(0, edge_index[1]) + msg = factor[..., 1, None, None] * A.index_select(0, edge_index[1]) + factor[ + ..., 2, None, None + ] * S.index_select(0, edge_index[1]) + msg.diagonal(dim1=-2, dim2=-1).add_( + factor[..., 0, None] * I.index_select(0, edge_index[1]).unsqueeze(-1) ) return msg From c1bd2e9234e4c97bf5f299973601795184d948c0 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Fri, 10 May 2024 15:55:53 +0200 Subject: [PATCH 6/8] Small changes to I --- torchmdnet/models/tensornet.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index aa56e13a..2c250c07 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -47,19 +47,18 @@ def vector_to_skewtensor(vector): def vector_to_symtensor(vector): """Creates a symmetric traceless tensor from the outer product of a vector with itself.""" tensor = torch.matmul(vector.unsqueeze(-1), vector.unsqueeze(-2)) - I = (tensor.diagonal(offset=0, dim1=-1, dim2=-2)).mean(-1)[ - ..., None, None - ] * torch.eye(3, 3, device=tensor.device, dtype=tensor.dtype) - S = 0.5 * (tensor + tensor.transpose(-2, -1)) - I + S = 0.5 * (tensor + tensor.transpose(-2, -1)) + I = (tensor.diagonal(offset=0, dim1=-1, dim2=-2)).mean(-1) + S.diagonal(offset=0, dim1=-1, dim2=-2).sub_(I.unsqueeze(-1)) return S @nvtx_annotate("decompose_tensor") def decompose_tensor(tensor): """Full tensor decomposition into irreducible components.""" - I = (tensor.diagonal(offset=0, dim1=-1, dim2=-2)).mean(-1) A = 0.5 * (tensor - tensor.transpose(-2, -1)) S = tensor - A + I = (tensor.diagonal(offset=0, dim1=-1, dim2=-2)).mean(-1) S.diagonal(offset=0, dim1=-1, dim2=-2).sub_(I.unsqueeze(-1)) return I, A, S From 09cec641355d0b0a0b9719988f1b4bdef1c41790 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Fri, 10 May 2024 16:01:20 +0200 Subject: [PATCH 7/8] Add NVTX decorators --- examples/TensorNet-QM9.yaml | 2 ++ examples/TensorNet-rMD17.yaml | 2 ++ torchmdnet/models/output_modules.py | 1 + torchmdnet/models/tensornet.py | 18 ++++++++--- torchmdnet/models/utils.py | 50 +++++++++++++++++++++++++++-- 5 files changed, 66 insertions(+), 7 deletions(-) diff --git a/examples/TensorNet-QM9.yaml b/examples/TensorNet-QM9.yaml index 6ab98a2c..4000b53f 100644 --- a/examples/TensorNet-QM9.yaml +++ b/examples/TensorNet-QM9.yaml @@ -57,3 +57,5 @@ weight_decay: 0.0 box_vecs: null charge: false spin: false +static_shapes: True +check_errors: False diff --git a/examples/TensorNet-rMD17.yaml b/examples/TensorNet-rMD17.yaml index 737e4c95..bc14f73e 100644 --- a/examples/TensorNet-rMD17.yaml +++ b/examples/TensorNet-rMD17.yaml @@ -57,3 +57,5 @@ weight_decay: 0.0 box_vecs: null charge: false spin: false +static_shapes: True +check_errors: False diff --git a/torchmdnet/models/output_modules.py b/torchmdnet/models/output_modules.py index bf408aa3..8d913ec9 100644 --- a/torchmdnet/models/output_modules.py +++ b/torchmdnet/models/output_modules.py @@ -52,6 +52,7 @@ def reduce(self, x, batch): self.dim_size ) ) + # self.dim_size = 1 return scatter(x, batch, dim=0, dim_size=self.dim_size, reduce=self.reduce_op) def post_reduce(self, x): diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index 2c250c07..6a85e619 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -47,9 +47,15 @@ def vector_to_skewtensor(vector): def vector_to_symtensor(vector): """Creates a symmetric traceless tensor from the outer product of a vector with itself.""" tensor = torch.matmul(vector.unsqueeze(-1), vector.unsqueeze(-2)) + S = tensor_to_symtensor(tensor) + return S + + +@nvtx_annotate("tensor_to_symtensor") +def tensor_to_symtensor(tensor): S = 0.5 * (tensor + tensor.transpose(-2, -1)) - I = (tensor.diagonal(offset=0, dim1=-1, dim2=-2)).mean(-1) - S.diagonal(offset=0, dim1=-1, dim2=-2).sub_(I.unsqueeze(-1)) + I = (tensor.diagonal(dim1=-1, dim2=-2)).mean(-1) + S.diagonal(dim1=-1, dim2=-2).sub_(I.unsqueeze(-1)) return S @@ -58,8 +64,8 @@ def decompose_tensor(tensor): """Full tensor decomposition into irreducible components.""" A = 0.5 * (tensor - tensor.transpose(-2, -1)) S = tensor - A - I = (tensor.diagonal(offset=0, dim1=-1, dim2=-2)).mean(-1) - S.diagonal(offset=0, dim1=-1, dim2=-2).sub_(I.unsqueeze(-1)) + I = (tensor.diagonal(dim1=-1, dim2=-2)).mean(-1) + S.diagonal(dim1=-1, dim2=-2).sub_(I.unsqueeze(-1)) return I, A, S @@ -433,14 +439,16 @@ def _compute_node_tensor_features( ) # shape: (n_atoms, hidden_channels, 3) A = vector_to_skewtensor(A) # shape: (n_atoms, hidden_channels, 3, 3) + tensor = torch.matmul(edge_vec_norm.unsqueeze(-1), edge_vec_norm.unsqueeze(-2)) S = ( self.distance_proj3(edge_attr)[..., None, None] * Zij[..., None, None] - * vector_to_symtensor(edge_vec_norm)[..., None, :, :] + * tensor[..., None, :, :] ) # shape: (n_edges, hidden_channels, 3, 3) S = self._aggregate_edge_features( z.shape[0], S, edge_index[0] ) # shape: (n_atoms, hidden_channels, 3, 3) + S = tensor_to_symtensor(S) # shape: (n_atoms, hidden_channels, 3, 3) I = self.distance_proj1(edge_attr) * Zij I = self._aggregate_edge_features(z.shape[0], I, edge_index[0]) features = A + S diff --git a/torchmdnet/models/utils.py b/torchmdnet/models/utils.py index a0d3e403..297f4e1e 100644 --- a/torchmdnet/models/utils.py +++ b/torchmdnet/models/utils.py @@ -209,7 +209,7 @@ def __init__( self.use_periodic = True if self.box is None: self.use_periodic = False - self.box = torch.empty((0, 0)) + self.box = torch.empty((0, 0), device="cpu", dtype=torch.float32) if self.strategy == "cell": # Default the box to 3 times the cutoff, really inefficient for the cell list lbox = cutoff_upper * 3.0 @@ -255,9 +255,10 @@ def forward( use_periodic = self.use_periodic if not use_periodic: use_periodic = box is not None + self.box = self.box.to(pos.device) box = self.box if box is None else box assert box is not None, "Box must be provided" - box = box.to(pos.dtype) + # box = box.to(pos.dtype) max_pairs: int = self.max_num_pairs if self.max_num_pairs < 0: max_pairs = -self.max_num_pairs * pos.shape[0] @@ -618,3 +619,48 @@ def scatter( } dtype_mapping = {16: torch.float16, 32: torch.float, 64: torch.float64} + + +# Can be globally disabled by setting the global variable ENABLE_NVTX to False +class nvtx_range: + def __init__(self, name, force_enabled=False): + self.name = name + self.force_enabled = force_enabled + + def __enter__(self): + if self.force_enabled or ENABLE_NVTX: + torch.cuda.synchronize() + torch.cuda.nvtx.range_push(self.name) + + def __exit__(self, type, value, traceback): + if self.force_enabled or ENABLE_NVTX: + torch.cuda.synchronize() + torch.cuda.nvtx.range_pop() + + +ENABLE_NVTX = False + + +def tmdnet_push_range(name: str, force_enabled: bool = False): + if force_enabled or ENABLE_NVTX: + torch.cuda.synchronize() + torch.cuda.nvtx.range_push(name) + + +def tmdnet_pop_range(force_enabled: bool = False): + if force_enabled or ENABLE_NVTX: + torch.cuda.synchronize() + torch.cuda.nvtx.range_pop() + + +def nvtx_annotate(tag: Optional[str] = None): + def Inner(foo): + def wrapper(*args, **kwargs): + if not ENABLE_NVTX: + return foo(*args, **kwargs) + with nvtx_range(foo.__name__ if tag is None else tag): + return foo(*args, **kwargs) + + return wrapper + + return Inner From 063ad645d4a62c89879f582f989aabe0af47ed2c Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Fri, 10 May 2024 17:24:19 +0200 Subject: [PATCH 8/8] Remove more work from the edges in TensorEmbedding --- torchmdnet/models/tensornet.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index 6a85e619..deae9161 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -426,7 +426,6 @@ def _compute_node_tensor_features( Zij = self.cutoff(edge_weight)[:, None] * self._compute_edge_atomic_features( z, edge_index ) # shape: (n_edges, hidden_channels) - A = ( self.distance_proj2(edge_attr)[ ..., None @@ -438,19 +437,21 @@ def _compute_node_tensor_features( z.shape[0], A, edge_index[0] ) # shape: (n_atoms, hidden_channels, 3) A = vector_to_skewtensor(A) # shape: (n_atoms, hidden_channels, 3, 3) - - tensor = torch.matmul(edge_vec_norm.unsqueeze(-1), edge_vec_norm.unsqueeze(-2)) - S = ( + I = self.distance_proj1(edge_attr) * Zij + I = self._aggregate_edge_features(z.shape[0], I, edge_index[0]) + # Outer product of edge vectors + tensor = torch.matmul( + edge_vec_norm.unsqueeze(-1), edge_vec_norm.unsqueeze(-2) + ) # shape: (n_edges, 3, 3) + tensor = ( self.distance_proj3(edge_attr)[..., None, None] * Zij[..., None, None] * tensor[..., None, :, :] ) # shape: (n_edges, hidden_channels, 3, 3) - S = self._aggregate_edge_features( - z.shape[0], S, edge_index[0] + tensor = self._aggregate_edge_features( + z.shape[0], tensor, edge_index[0] ) # shape: (n_atoms, hidden_channels, 3, 3) - S = tensor_to_symtensor(S) # shape: (n_atoms, hidden_channels, 3, 3) - I = self.distance_proj1(edge_attr) * Zij - I = self._aggregate_edge_features(z.shape[0], I, edge_index[0]) + S = tensor_to_symtensor(tensor) # shape: (n_atoms, hidden_channels, 3, 3) features = A + S features.diagonal(dim1=-2, dim2=-1).add_(I.unsqueeze(-1)) return features @@ -483,9 +484,6 @@ def forward( X = self._compute_node_tensor_features( z, edge_index, edge_weight, edge_vec, edge_attr ) # shape: (n_atoms, hidden_channels, 3, 3) - # X = self._aggregate_edge_features( - # z.shape[0], Xij, edge_index - # ) # shape: (n_atoms, hidden_channels, 3, 3) norm = self._norm_mlp(tensor_norm(X)) # shape: (n_atoms, hidden_channels) X = self.linear_tensor(X, norm) # shape: (n_atoms, hidden_channels, 3, 3) return X