Skip to content

Commit

Permalink
interesting enough, they keep a bias for the query projection within …
Browse files Browse the repository at this point in the history
…attention. this is not usual
  • Loading branch information
lucidrains committed May 13, 2024
1 parent e74b28d commit 07c8a28
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions alphafold3_pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def __init__(
dim_head = 64,
heads = 8,
dropout = 0.,
gate_output = False,
gate_output = True,
query_bias = True,
flash = True,
efficient_attn_config: Config = Config(True, True, True)
):
Expand All @@ -66,7 +67,7 @@ def __init__(
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
self.merge_heads = Rearrange('b h n d -> b n (h d)')

self.to_q = nn.Linear(dim, dim_inner, bias = False)
self.to_q = nn.Linear(dim, dim_inner, bias = query_bias)
self.to_kv = nn.Linear(dim, dim_inner * 2, bias = False)
self.to_out = nn.Linear(dim_inner, dim, bias = False)

Expand Down

0 comments on commit 07c8a28

Please sign in to comment.