Skip to content

Commit

Permalink
go all-in on tensor typing for this project
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 13, 2024
1 parent 06f8690 commit d36f675
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions alphafold3_pytorch/alphafold3.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,10 @@ def __init__(
@typecheck
def forward(
self,
x: Tensor,
x: Float['... n d'],
**kwargs
):
) -> Float['... n d']:

x = self.norm(x)
return self.fn(x, **kwargs)

Expand All @@ -117,9 +118,10 @@ def __init__(
@typecheck
def forward(
self,
x: Tensor,
x: Float['... n d'],
cond: Tensor
):
) -> Float['... n d']:

normed = self.norm(x)
normed_cond = self.norm_cond(cond)

Expand Down Expand Up @@ -157,11 +159,11 @@ def __init__(
@typecheck
def forward(
self,
x: Tensor,
x: Float['... n d'],
*,
cond: Tensor,
**kwargs
):
) -> Float['... n d']:
x = self.adaptive_norm(x, cond = cond)

out = self.fn(x, **kwargs)
Expand Down Expand Up @@ -214,7 +216,8 @@ def forward(
self,
x: Float['b n n d'],
mask: Float['b n n'] | None = None
):
) -> Float['b n n d']:

if exists(mask):
mask = rearrange(mask, '... -> ... 1')

Expand Down

0 comments on commit d36f675

Please sign in to comment.