Skip to content

Commit

Permalink
Alternative implementations for attention in comments for clarity
Browse files Browse the repository at this point in the history
  • Loading branch information
wouterkool committed Nov 20, 2020
1 parent c1d77bd commit b8c9b8a
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions nets/graph_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,15 @@ def forward(self, q, h=None, mask=None):
self.W_out.view(-1, self.embed_dim)
).view(batch_size, n_query, self.embed_dim)

# Alternative:
# headst = heads.transpose(0, 1) # swap the dimensions for batch and heads to align it for the matmul
# # proj_h = torch.einsum('bhni,hij->bhnj', headst, self.W_out)
# projected_heads = torch.matmul(headst, self.W_out)
# out = torch.sum(projected_heads, dim=1) # sum across heads

# Or:
# out = torch.einsum('hbni,hij->bnj', heads, self.W_out)

return out


Expand Down

0 comments on commit b8c9b8a

Please sign in to comment.