You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Is your feature request related to a problem? Please describe.
If we don't require torch version < 2.5, the build test will fail, particularly test/module/test_attention.py::test_all_masked:
Before this PR, fully masked rows in the attn_mask passed to nn.functional.scaled_dot_product_attention would yield NANs in the output, after this PR, fully masked rows yield 0s.
I think later maybe we would like to change the test function to assert all 0s.
But I am not sure if we have special usage for torch.nan in output (like if we change to 0s for fully masked rows, it will break the pipeline) @chenghaoliu89@cuthalionn Any comments/suggestions?
The text was updated successfully, but these errors were encountered:
Is your feature request related to a problem? Please describe.
If we don't require torch version < 2.5, the build test will fail, particularly test/module/test_attention.py::test_all_masked:
uni2ts/test/module/test_attention.py
Line 85 in 2ba614d
Describe the solution you'd like
The reason might be that torch 2.5 updates the function
nn.functional.scaled_dot_product_attention
(see here https://github.com/pytorch/pytorch/releases/tag/v2.5.0)It mentions:
I think later maybe we would like to change the test function to assert all 0s.
But I am not sure if we have special usage for torch.nan in output (like if we change to 0s for fully masked rows, it will break the pipeline)
@chenghaoliu89 @cuthalionn Any comments/suggestions?
The text was updated successfully, but these errors were encountered: