diff --git a/README.md b/README.md index a1ab9116..049b4748 100644 --- a/README.md +++ b/README.md @@ -515,15 +515,6 @@ docker run -v .:/data --gpus all -it af3 } ``` -```bibtex -@inproceedings{Duvvuri2024LASERAW, - title = {LASER: Attention with Exponential Transformation}, - author = {Sai Surya Duvvuri and Inderjit S. Dhillon}, - year = {2024}, - url = {https://api.semanticscholar.org/CorpusID:273849947} -} -``` - ```bibtex @article{Zhu2024HyperConnections, title = {Hyper-Connections}, diff --git a/alphafold3_pytorch/attention.py b/alphafold3_pytorch/attention.py index 84612b9e..c8f2f0ce 100644 --- a/alphafold3_pytorch/attention.py +++ b/alphafold3_pytorch/attention.py @@ -184,8 +184,6 @@ def __init__( query_bias = True, window_size = None, num_memory_kv: int = 0, - laser = False, - laser_softclamp_value = 15., enable_attn_softclamp = False, attn_softclamp_value = 50., softmax_full_precision = False, @@ -211,8 +209,6 @@ def __init__( dropout = dropout, window_size = window_size, enable_attn_softclamp = enable_attn_softclamp, - laser = laser, - laser_softclamp_value = laser_softclamp_value, attn_softclamp_value = attn_softclamp_value, softmax_full_precision = softmax_full_precision ) @@ -322,8 +318,6 @@ class Attend(Module): def __init__( self, dropout = 0., - laser = False, - laser_softclamp_value = 15., window_size = None, scale: float | None = None, enable_attn_softclamp = False, @@ -352,11 +346,6 @@ def __init__( self.attn_dropout = nn.Dropout(dropout) - # laser attention - - self.laser = laser - self.laser_softclamp_value = laser_softclamp_value - # softclamp attention logits # being adopted by a number of recent llms (gemma, grok) @@ -477,20 +466,10 @@ def local_attn( attn = sim.softmax(dim = -1) - # maybe laser - - if self.laser: - v = softclamp(v, self.laser_softclamp_value) - # aggregate out = einsum(attn, v, "... i j, ... j d -> ... i d") - # maybe laser - - if self.laser: - out = log(out) - # un-window the output out = rearrange(out, "b h n w d -> b h (n w) d") @@ -586,19 +565,8 @@ def forward( attn = self.attn_dropout(attn) - # maybe laser - - if self.laser: - v_max = v.amax(dim = -2, keepdim = True) - v = (v - v_max).exp() - # aggregate values out = einsum(attn, v, "b h i j, b h j d -> b h i d") - # maybe laser - - if self.laser: - out = log(out) + v_max - return out diff --git a/pyproject.toml b/pyproject.toml index 0c96c8c6..83870d28 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "alphafold3-pytorch" -version = "0.7.4" +version = "0.7.5" description = "Alphafold 3 - Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" },