Skip to content

Commit e30e2b0

Browse files
committed
fix synthon for USPTO50k
1 parent 489677c commit e30e2b0

File tree

2 files changed

+3
-206
lines changed

2 files changed

+3
-206
lines changed

diff.txt

-202
This file was deleted.

torchdrug/datasets/uspto50k.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -173,10 +173,9 @@ def _get_synthon(self, reactant, product):
173173

174174
if len(edge_added) > 0:
175175
if len(edge_added) == 1: # add a single edge
176-
edge = edge_added[0]
177-
reverse_edge = edge.flip(0)
176+
reverse_edge = edge_added.flip(1)
178177
any = -torch.ones(2, 1, dtype=torch.long)
179-
pattern = torch.cat([edge, reverse_edge])
178+
pattern = torch.cat([edge_added, reverse_edge])
180179
pattern = torch.cat([pattern, any], dim=-1)
181180
index, num_match = product.match(pattern)
182181
edge_mask = torch.ones(product.num_edge, dtype=torch.bool)
@@ -186,7 +185,7 @@ def _get_synthon(self, reactant, product):
186185
_synthons = product.connected_components()[0]
187186
assert len(_synthons) >= len(_reactants) # because a few samples contain multiple products
188187

189-
h, t = edge
188+
h, t = edge_added[0]
190189
reaction_center = torch.tensor([product.atom_map[h], product.atom_map[t]])
191190
with _reactants.graph():
192191
_reactants.reaction_center = reaction_center.expand(len(_reactants), -1)

0 commit comments

Comments
 (0)