Skip to content

Conversation

@ibbyml
Copy link

@ibbyml ibbyml commented Nov 17, 2025

What does this PR do?

  • Adds an is_causal arg to flax.nnx.dot_product_attention and dot_product_attention_weights.
  • Forwards is_causal through to jax.nn.dot_product_attention fast path when possible.
  • Implements a manual causal masking logic that:
    • Supports both self-attention and cross-attention
    • Composes is_causal with input masks with the combine_masks helper
  • Adds 3 attention tests to ensure compatibility and correctness.

Checklist

  • This change is discussed in this discussion.
  • The documentation and docstrings adhere to the documentation guidelines.
  • This change includes necessary high-coverage tests. (No quality testing = no merge!)

Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @ibbyml !

@ibbyml
Copy link
Author

ibbyml commented Nov 17, 2025

Pushed the updated parameterized tests and removed the unnecessary formatting. The tests now cover self-attention with and without a padding mask as well as cross-attention with and without a padding mask. Happy to adjust anything else.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Nov 17, 2025

Pushed the updated parameterized tests and removed the unnecessary formatting. The tests now cover self-attention with and without a padding mask as well as cross-attention with and without a padding mask. Happy to adjust anything else.

Thanks! For parameterized tests, the idea is to write all 3 test cases as a parameterized single one. I do not think we need to parameterize on B, T, S etc.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Nov 18, 2025

@ibbyml thanks for the updates! Please squash all commits into 1 otherwise CI will fail for num_commits >= 5

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Nov 18, 2025

@ibbyml let's keep only your updates

@ibbyml
Copy link
Author

ibbyml commented Nov 18, 2025

@ibbyml let's keep only your updates

Sorry about that. Accidentally pulled everything. New push should be squashed correctly.

Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @ibbyml !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants