Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions openfold3/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import random
from pathlib import Path

import biotite.setup_ccd
import numpy as np
import pytest
import torch
from biotite.structure import AtomArray
from torch.random import fork_rng

from openfold3.core.data.primitives.structure.component import BiotiteCCDWrapper
from openfold3.setup_openfold import setup_biotite_ccd
Expand Down Expand Up @@ -76,3 +81,27 @@ def ensure_biotite_ccd(request):
def biotite_ccd_wrapper():
"""Cache CCD wrapper fixture for tests that need it."""
return BiotiteCCDWrapper()


@pytest.fixture(scope="module")
def original_datadir(request: pytest.FixtureRequest) -> Path:
"""Redirect pytest-regressions snapshot storage to test_data/snapshots/."""
return Path(__file__).parent / "test_data" / "snapshots" / Path(request.path).stem
Comment on lines +86 to +89
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what allows us to place the snapshots along other test_data, sibling to the 'cassettes'

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Does it make sense to update the snapshot paths for the templates test you added in https://github.com/aqlaboratory/openfold-3/tree/main/openfold3/tests/test_data/cassettes/test_rscb

Can be a different PR if it is cumbersome to update here.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are different snapshots – we have two types

  • the VCR snapshots for web requests
  • ndarrays_regression snapshots for numerics

I'd be weary of mixing those up – they mean a different thing and solve a different problem



@pytest.fixture()
def seeded_rng():
"""Isolate all RNG state (torch, numpy, python) for the duration of a test.

Uses torch.random.fork_rng() to save/restore torch (+CUDA) state, and
manually saves/restores numpy and python random state.
"""
py_state = random.getstate()
np_state = np.random.get_state()
with fork_rng():
torch.manual_seed(123)
random.seed(123)
np.random.seed(123)
yield
random.setstate(py_state)
np.random.set_state(np_state)
Binary file not shown.
Binary file not shown.
Binary file not shown.
65 changes: 43 additions & 22 deletions openfold3/tests/test_triangular_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,38 +12,59 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import pytest
import torch

from openfold3.core.model.layers.triangular_attention import TriangleAttention
from openfold3.tests.config import consts


class TestTriangularAttention(unittest.TestCase):
def test_shape(self):
c_z = consts.c_z
c = 12
no_heads = 4
starting = True
# starting=True -> "starting node" variant: rows attend to rows,
# biased by z[i, k]. False would transpose internally for the
# "ending node" variant (columns attend to columns).
@pytest.mark.parametrize("starting", [True, False])
def test_shape(starting, seeded_rng, ndarrays_regression):
# c_z: pair representation channel dim (128 in production)
c_z = consts.c_z
# c: attention hidden dim (production uses 32; smaller here for speed)
c = 12
no_heads = 4

tan = TriangleAttention(
c_z,
c,
no_heads,
starting=starting,
)
tan = TriangleAttention(
c_z,
c,
no_heads,
starting=starting,
)
# AlphaFold initializes the output projection to zero (so residual blocks
# start as identity). Reinitialize all params so the test exercises the
# actual computation and produces non-trivial output.
for p in tan.parameters():
torch.nn.init.normal_(p, std=0.01)
tan.eval()

batch_size = consts.batch_size
n_res = consts.n_res
batch_size = consts.batch_size
n_res = consts.n_res

x = torch.rand((batch_size, n_res, n_res, c_z))
shape_before = x.shape
# Pair representation: [batch, N_residues, N_residues, C_z]
x = torch.rand((batch_size, n_res, n_res, c_z))
shape_before = x.shape
# chunk_size=None -> no memory-saving chunking, full attention in one pass
with torch.no_grad():
x = tan(x, chunk_size=None)
shape_after = x.shape
shape_after = x.shape

self.assertTrue(shape_before == shape_after)
# Shape must be preserved for the residual addition z = z + tri_att(z)
assert shape_before == shape_after

# Guard against trivial all-zero output (e.g. from zero-initialized weights)
assert x.abs().max().item() > 0, (
"Output is all zeros — snapshot would be meaningless"
)

if __name__ == "__main__":
unittest.main()
# Snapshot regression: output must be numerically identical across runs.
# Regenerate with: pytest --force-regen
ndarrays_regression.check(
{"output": x.cpu().numpy()},
default_tolerance=dict(atol=1e-6, rtol=1e-5),
)
76 changes: 49 additions & 27 deletions openfold3/tests/test_triangular_multiplicative_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import re
import unittest

import torch

Expand All @@ -23,34 +22,57 @@
)
from openfold3.tests.config import consts

# Updates pair representation z[i,j] by projecting to two gated vectors (a, b),
# contracting along a shared dimension (outgoing vs incoming), then projecting
# back. "Outgoing" contracts over the starting node, "Incoming" over the ending
# node. Shape-preserving: [*, N, N, C_z] -> [*, N, N, C_z].

class TestTriangularMultiplicativeUpdate(unittest.TestCase):
def test_shape(self):
c_z = consts.c_z
c = 11

if re.fullmatch("^model_[1-5]_multimer_v3$", consts.model_preset):
tm = FusedTriangleMultiplicationOutgoing(
c_z,
c,
)
else:
tm = TriangleMultiplicationOutgoing(
c_z,
c,
)

n_res = consts.c_z
batch_size = consts.batch_size

x = torch.rand((batch_size, n_res, n_res, c_z))
mask = torch.randint(0, 2, size=(batch_size, n_res, n_res))
shape_before = x.shape

def _make_module(c_z, c):
"""Pick fused vs non-fused variant based on model preset."""
# Multimer v3 uses a fused variant (single projection split into a, b)
# vs separate projections for each
if re.fullmatch("^model_[1-5]_multimer_v3$", consts.model_preset):
return FusedTriangleMultiplicationOutgoing(c_z, c)
return TriangleMultiplicationOutgoing(c_z, c)


def test_shape(seeded_rng, ndarrays_regression):
# c_z: pair representation channel dim (128 in production)
c_z = consts.c_z
# c: hidden projection dim (production uses ~128; smaller here for speed)
c = 11

tm = _make_module(c_z, c)
# Reinitialize all params to non-trivial values (some layers may be
# zero-initialized by default for residual identity at init)
for p in tm.parameters():
torch.nn.init.normal_(p, std=0.01)
tm.eval()

n_res = consts.n_res
batch_size = consts.batch_size

# Pair representation: [batch, N_residues, N_residues, C_z]
x = torch.rand((batch_size, n_res, n_res, c_z))
# Binary mask: which residue pairs are valid
mask = torch.randint(0, 2, size=(batch_size, n_res, n_res))
shape_before = x.shape
with torch.no_grad():
x = tm(x, mask)
shape_after = x.shape
shape_after = x.shape

self.assertTrue(shape_before == shape_after)
# Shape must be preserved for the residual addition z = z + tri_mul(z)
assert shape_before == shape_after

# Guard against trivial all-zero output (e.g. from zero-initialized weights)
assert x.abs().max().item() > 0, (
"Output is all zeros — snapshot would be meaningless"
)

if __name__ == "__main__":
unittest.main()
# Snapshot regression: output must be numerically identical across runs.
# Regenerate with: pytest --force-regen
ndarrays_regression.check(
{"output": x.cpu().numpy()},
default_tolerance=dict(atol=1e-6, rtol=1e-5),
)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ test = [
"pytest-xdist",
"pytest-cov",
"pytest-benchmark",
"pytest-regressions",
"debugpy",
"pytest-recording",
]
Expand Down
Loading