-
Notifications
You must be signed in to change notification settings - Fork 93
test: triangular attention snapshots #128
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 13 commits
08bc6a6
ab0b7f1
0b90ee2
70b252e
4652a08
9d46b7a
4afb40f
391dbac
081e12c
de728e2
d906f0c
bffa142
0f02bf4
fb63c92
69631ad
b90c8bc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,38 +12,64 @@ | |
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import unittest | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| from openfold3.core.model.layers.triangular_attention import TriangleAttention | ||
| from openfold3.tests.config import consts | ||
|
|
||
|
|
||
| class TestTriangularAttention(unittest.TestCase): | ||
| def test_shape(self): | ||
| c_z = consts.c_z | ||
| c = 12 | ||
| no_heads = 4 | ||
| starting = True | ||
| # starting=True -> "starting node" variant: rows attend to rows, | ||
| # biased by z[i, k]. False would transpose internally for the | ||
| # "ending node" variant (columns attend to columns). | ||
| @pytest.mark.parametrize("starting", [True, False]) | ||
| def test_shape(starting, ndarrays_regression): | ||
| # NOTE: seeding may need further work — torch.manual_seed controls both | ||
| # the random input and the module's weight init. If init changes upstream, | ||
| # regenerate snapshots with: pytest --force-regen | ||
| torch.manual_seed(123) | ||
|
||
|
|
||
| # c_z: pair representation channel dim (128 in production) | ||
| c_z = consts.c_z | ||
| # c: attention hidden dim (production uses 32; smaller here for speed) | ||
| c = 12 | ||
| no_heads = 4 | ||
|
|
||
| tan = TriangleAttention( | ||
| c_z, | ||
| c, | ||
| no_heads, | ||
| starting=starting, | ||
| ) | ||
| tan = TriangleAttention( | ||
| c_z, | ||
| c, | ||
| no_heads, | ||
| starting=starting, | ||
| ) | ||
| # AlphaFold initializes the output projection to zero (so residual blocks | ||
| # start as identity). Reinitialize all params so the test exercises the | ||
| # actual computation and produces non-trivial output. | ||
| for p in tan.parameters(): | ||
| torch.nn.init.normal_(p, std=0.01) | ||
| tan.eval() | ||
|
|
||
| batch_size = consts.batch_size | ||
| n_res = consts.n_res | ||
| batch_size = consts.batch_size | ||
| n_res = consts.n_res | ||
|
|
||
| x = torch.rand((batch_size, n_res, n_res, c_z)) | ||
| shape_before = x.shape | ||
| # Pair representation: [batch, N_residues, N_residues, C_z] | ||
| x = torch.rand((batch_size, n_res, n_res, c_z)) | ||
| shape_before = x.shape | ||
| # chunk_size=None -> no memory-saving chunking, full attention in one pass | ||
| with torch.no_grad(): | ||
| x = tan(x, chunk_size=None) | ||
| shape_after = x.shape | ||
| shape_after = x.shape | ||
|
|
||
| self.assertTrue(shape_before == shape_after) | ||
| # Shape must be preserved for the residual addition z = z + tri_att(z) | ||
| assert shape_before == shape_after | ||
|
|
||
| # Guard against trivial all-zero output (e.g. from zero-initialized weights) | ||
| assert x.abs().max().item() > 0, ( | ||
| "Output is all zeros — snapshot would be meaningless" | ||
| ) | ||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() | ||
| # Snapshot regression: output must be numerically identical across runs. | ||
| # Regenerate with: pytest --force-regen | ||
| ndarrays_regression.check( | ||
| {"output": x.cpu().numpy()}, | ||
| default_tolerance=dict(atol=1e-6, rtol=1e-5), | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is what allows us to place the snapshots along other test_data, sibling to the 'cassettes'
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Does it make sense to update the snapshot paths for the templates test you added in
https://github.com/aqlaboratory/openfold-3/tree/main/openfold3/tests/test_data/cassettes/test_rscbCan be a different PR if it is cumbersome to update here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are different snapshots – we have two types
I'd be weary of mixing those up – they mean a different thing and solve a different problem