Skip to content

Commit

Permalink
remove laser attention
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 28, 2024
1 parent 119fcb9 commit c5d1f7b
Show file tree
Hide file tree
Showing 3 changed files with 1 addition and 42 deletions.
9 changes: 0 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -515,15 +515,6 @@ docker run -v .:/data --gpus all -it af3
}
```

```bibtex
@inproceedings{Duvvuri2024LASERAW,
title = {LASER: Attention with Exponential Transformation},
author = {Sai Surya Duvvuri and Inderjit S. Dhillon},
year = {2024},
url = {https://api.semanticscholar.org/CorpusID:273849947}
}
```

```bibtex
@article{Zhu2024HyperConnections,
title = {Hyper-Connections},
Expand Down
32 changes: 0 additions & 32 deletions alphafold3_pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,6 @@ def __init__(
query_bias = True,
window_size = None,
num_memory_kv: int = 0,
laser = False,
laser_softclamp_value = 15.,
enable_attn_softclamp = False,
attn_softclamp_value = 50.,
softmax_full_precision = False,
Expand All @@ -211,8 +209,6 @@ def __init__(
dropout = dropout,
window_size = window_size,
enable_attn_softclamp = enable_attn_softclamp,
laser = laser,
laser_softclamp_value = laser_softclamp_value,
attn_softclamp_value = attn_softclamp_value,
softmax_full_precision = softmax_full_precision
)
Expand Down Expand Up @@ -322,8 +318,6 @@ class Attend(Module):
def __init__(
self,
dropout = 0.,
laser = False,
laser_softclamp_value = 15.,
window_size = None,
scale: float | None = None,
enable_attn_softclamp = False,
Expand Down Expand Up @@ -352,11 +346,6 @@ def __init__(

self.attn_dropout = nn.Dropout(dropout)

# laser attention

self.laser = laser
self.laser_softclamp_value = laser_softclamp_value

# softclamp attention logits
# being adopted by a number of recent llms (gemma, grok)

Expand Down Expand Up @@ -477,20 +466,10 @@ def local_attn(

attn = sim.softmax(dim = -1)

# maybe laser

if self.laser:
v = softclamp(v, self.laser_softclamp_value)

# aggregate

out = einsum(attn, v, "... i j, ... j d -> ... i d")

# maybe laser

if self.laser:
out = log(out)

# un-window the output

out = rearrange(out, "b h n w d -> b h (n w) d")
Expand Down Expand Up @@ -586,19 +565,8 @@ def forward(

attn = self.attn_dropout(attn)

# maybe laser

if self.laser:
v_max = v.amax(dim = -2, keepdim = True)
v = (v - v_max).exp()

# aggregate values

out = einsum(attn, v, "b h i j, b h j d -> b h i d")

# maybe laser

if self.laser:
out = log(out) + v_max

return out
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "alphafold3-pytorch"
version = "0.7.4"
version = "0.7.5"
description = "Alphafold 3 - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" },
Expand Down

0 comments on commit c5d1f7b

Please sign in to comment.