Skip to content

Commit

Permalink
make sure runtime typecheck is turned on during pytest
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 14, 2024
1 parent f35ac78 commit 8d4cfbf
Showing 1 changed file with 27 additions and 1 deletion.
28 changes: 27 additions & 1 deletion tests/test_readme.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import os
os.environ['TYPECHECK'] = 'True'

import torch
import pytest

from alphafold3_pytorch import (
PairformerStack
PairformerStack,
MSAModule
)

def test_pairformer():
Expand All @@ -24,3 +28,25 @@ def test_pairformer():

assert single.shape == single_out.shape
assert pairwise.shape == pairwise_out.shape

def test_msa_module():

single = torch.randn(2, 16, 512)
msa = torch.randn(2, 7, 16, 64)
pairwise = torch.randn(2, 16, 16, 128)
mask = torch.randint(0, 2, (2, 16)).bool()

msa_module = MSAModule(
dim_single = 512,
dim_pairwise = 128,
dim_msa = 64
)

pairwise_out = msa_module(
single_repr = single,
msa = msa,
pairwise_repr = pairwise,
mask = mask
)

assert pairwise.shape == pairwise_out.shape

0 comments on commit 8d4cfbf

Please sign in to comment.