-
Notifications
You must be signed in to change notification settings - Fork 760
Added is_causal mask argument to flax.nnx.dot_product_attention #5093
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
vfdev-5
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR @ibbyml !
|
Pushed the updated parameterized tests and removed the unnecessary formatting. The tests now cover self-attention with and without a padding mask as well as cross-attention with and without a padding mask. Happy to adjust anything else. |
Thanks! For parameterized tests, the idea is to write all 3 test cases as a parameterized single one. I do not think we need to parameterize on B, T, S etc. |
|
@ibbyml thanks for the updates! Please squash all commits into 1 otherwise CI will fail for num_commits >= 5 |
b06c6fc to
0c33425
Compare
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
|
@ibbyml let's keep only your updates |
Sorry about that. Accidentally pulled everything. New push should be squashed correctly. |
0c33425 to
9bab148
Compare
vfdev-5
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks @ibbyml !
What does this PR do?
is_causalarg toflax.nnx.dot_product_attentionanddot_product_attention_weights.is_causalthrough tojax.nn.dot_product_attentionfast path when possible.is_causalwith input masks with thecombine_maskshelperChecklist