From 9bab1485d93b7c81ba1597ea584ebcb03690ff56 Mon Sep 17 00:00:00 2001 From: Ibbyml <238398173+ibbyml@users.noreply.github.com> Date: Tue, 18 Nov 2025 18:03:45 -0500 Subject: [PATCH] Add is_causal mask argument and tests --- flax/nnx/nn/attention.py | 28 +++++++++++-- tests/nnx/nn/attention_test.py | 73 ++++++++++++++++++++++++++++++++++ 2 files changed, 98 insertions(+), 3 deletions(-) diff --git a/flax/nnx/nn/attention.py b/flax/nnx/nn/attention.py index 163c96a3b..a8186c973 100644 --- a/flax/nnx/nn/attention.py +++ b/flax/nnx/nn/attention.py @@ -62,6 +62,7 @@ def dot_product_attention_weights( precision: PrecisionLike = None, module: Module | None = None, promote_dtype: PromoteDtypeFn = dtypes.promote_dtype, + is_causal: bool = False, ): """Computes dot-product attention weights given query and key. @@ -94,6 +95,11 @@ def dot_product_attention_weights( promote_dtype: function to promote the dtype of the arrays to the desired dtype. The function should accept a tuple of ``(query, key)`` and a ``dtype`` keyword argument, and return a tuple of arrays with the promoted dtype. + is_causal: If true, causal attention will be applied. Note, some + implementations like xla will generate a mask tensor and apply it to + the logits to mask out the non-causal parts of the attention matrix, + but other implementations like cudnn will avoid computing the + non-causal regions, providing speedups. Returns: Output of shape `[batch..., num_heads, q_length, kv_length]`. @@ -118,9 +124,17 @@ def dot_product_attention_weights( if bias is not None: attn_weights = attn_weights + bias # apply attention mask - if mask is not None: + if mask is not None or is_causal: big_neg = jnp.finfo(dtype).min - attn_weights = jnp.where(mask, attn_weights, big_neg) + masks = [m for m in [mask] if m is not None] + if is_causal: + T, S = attn_weights.shape[-2:] + causal_mask = jnp.tril(jnp.ones((T, S), dtype=dtype)) + target_shape = mask.shape if mask is not None else attn_weights.shape + masks.append(jnp.broadcast_to(causal_mask, target_shape)) + combined_mask = combine_masks(*masks, dtype=dtype) + assert combined_mask is not None + attn_weights = jnp.where(combined_mask, attn_weights, big_neg) # normalize the attention weights attn_weights = jax.nn.softmax(attn_weights).astype(dtype) @@ -157,6 +171,7 @@ def dot_product_attention( precision: PrecisionLike = None, module: Module | None = None, promote_dtype: PromoteDtypeFn = dtypes.promote_dtype, + is_causal: bool = False, ): """Computes dot-product attention given query, key, and value. @@ -198,6 +213,11 @@ def dot_product_attention( dtype. The function should accept a tuple of ``(query, key, value)`` and a ``dtype`` keyword argument, and return a tuple of arrays with the promoted dtype. + is_causal: If true, causal attention will be applied. Note, some + implementations like xla will generate a mask tensor and apply it to + the logits to mask out the non-causal parts of the attention matrix, + but other implementations like cudnn will avoid computing the + non-causal regions, providing speedups. Returns: Output of shape `[batch..., q_length, num_heads, v_depth_per_head]`. @@ -224,7 +244,7 @@ def reshape_4d(x): reshape_4d, (query, key, value, bias, mask)) if mask is not None: mask = mask.astype(jnp.bool) - out = jax.nn.dot_product_attention(query, key, value, bias, mask) + out = jax.nn.dot_product_attention(query, key, value, bias, mask, is_causal=is_causal) if len(query_shape) > 4: out = jnp.reshape(out, query_shape) return out @@ -242,6 +262,8 @@ def reshape_4d(x): dtype, precision, module, + promote_dtype, + is_causal, ) # return weighted sum over values for each query position diff --git a/tests/nnx/nn/attention_test.py b/tests/nnx/nn/attention_test.py index c8a9d55a7..4e99f94d9 100644 --- a/tests/nnx/nn/attention_test.py +++ b/tests/nnx/nn/attention_test.py @@ -17,6 +17,7 @@ from flax import linen from flax import nnx +from flax.nnx.nn.attention import combine_masks from flax.typing import Dtype, PrecisionLike import numpy as np @@ -132,6 +133,78 @@ def test_keep_rngs(self, keep_rngs): else: nnx.split(module, nnx.Param) + @parameterized.product(use_padding=[True, False], is_cross_attention=[True, False]) + def test_causal_mask_equivalence( + self, + use_padding: bool, + is_cross_attention: bool + ): + batch_size = 1 + num_heads = 2 + q_len = 2 + kv_len = 4 if is_cross_attention else q_len + head_dim = 4 + + q = jax.random.normal( + key=jax.random.key(0), + shape=(batch_size, 1, q_len, num_heads, head_dim) + ) + k = jax.random.normal( + key=jax.random.key(1), + shape=(batch_size, 1, kv_len, num_heads, head_dim) + ) + v = jax.random.normal( + key=jax.random.key(2), + shape=(batch_size, 1, kv_len, num_heads, head_dim) + ) + + causal_mask = jnp.tril(jnp.ones( + shape=(q_len, kv_len), + dtype=jnp.bool_ + ) + ) + causal_mask = jnp.broadcast_to( + array=causal_mask, + shape=(batch_size, 1, num_heads, q_len, kv_len) + ) + + padding_mask = None + + if use_padding: + padding_mask = jnp.ones( + shape=(batch_size, 1, 1, q_len, kv_len), + dtype=jnp.bool_, + ) + padding_mask = padding_mask.at[..., -2:].set(False) + + manual_mask = combine_masks(padding_mask, causal_mask, dtype=q.dtype) + + # Jax.nn path with precombined mask and is_causal = False + attn_jax = nnx.dot_product_attention( + query=q, + key=k, + value=v, + mask=manual_mask, + is_causal=False, + deterministic=True, + module=None, + ) + + class DummyModule(nnx.Module): + pass + + # nnx path with padding mask and is_causal = True (internally combines them) + attn_manual = nnx.dot_product_attention( + query=q, + key=k, + value=v, + mask=padding_mask, + is_causal=True, + deterministic=True, + module=DummyModule(), + ) + + np.testing.assert_allclose(attn_jax, attn_manual, atol=1e-6) # TODO: add all possible constructor argument values to parameterized.product class TestLinenConsistency(parameterized.TestCase):