@@ -62,6 +62,7 @@ def dot_product_attention_weights(
6262 precision : PrecisionLike = None ,
6363 module : Module | None = None ,
6464 promote_dtype : PromoteDtypeFn = dtypes .promote_dtype ,
65+ is_causal : bool = False ,
6566):
6667 """Computes dot-product attention weights given query and key.
6768
@@ -94,6 +95,11 @@ def dot_product_attention_weights(
9495 promote_dtype: function to promote the dtype of the arrays to the desired
9596 dtype. The function should accept a tuple of ``(query, key)`` and a ``dtype``
9697 keyword argument, and return a tuple of arrays with the promoted dtype.
98+ is_causal: If true, causal attention will be applied. Note, some
99+ implementations like xla will generate a mask tensor and apply it to
100+ the logits to mask out the non-causal parts of the attention matrix,
101+ but other implementations like cudnn will avoid computing the
102+ non-causal regions, providing speedups.
97103
98104 Returns:
99105 Output of shape `[batch..., num_heads, q_length, kv_length]`.
@@ -118,9 +124,17 @@ def dot_product_attention_weights(
118124 if bias is not None :
119125 attn_weights = attn_weights + bias
120126 # apply attention mask
121- if mask is not None :
127+ if mask is not None or is_causal :
122128 big_neg = jnp .finfo (dtype ).min
123- attn_weights = jnp .where (mask , attn_weights , big_neg )
129+ masks = [m for m in [mask ] if m is not None ]
130+ if is_causal :
131+ T , S = attn_weights .shape [- 2 :]
132+ causal_mask = jnp .tril (jnp .ones ((T , S ), dtype = dtype ))
133+ target_shape = mask .shape if mask is not None else attn_weights .shape
134+ masks .append (jnp .broadcast_to (causal_mask , target_shape ))
135+ combined_mask = combine_masks (* masks , dtype = dtype )
136+ assert combined_mask is not None
137+ attn_weights = jnp .where (combined_mask , attn_weights , big_neg )
124138
125139 # normalize the attention weights
126140 attn_weights = jax .nn .softmax (attn_weights ).astype (dtype )
@@ -157,6 +171,7 @@ def dot_product_attention(
157171 precision : PrecisionLike = None ,
158172 module : Module | None = None ,
159173 promote_dtype : PromoteDtypeFn = dtypes .promote_dtype ,
174+ is_causal : bool = False ,
160175):
161176 """Computes dot-product attention given query, key, and value.
162177
@@ -198,6 +213,11 @@ def dot_product_attention(
198213 dtype. The function should accept a tuple of ``(query, key, value)`` and a
199214 ``dtype`` keyword argument, and return a tuple of arrays with the promoted
200215 dtype.
216+ is_causal: If true, causal attention will be applied. Note, some
217+ implementations like xla will generate a mask tensor and apply it to
218+ the logits to mask out the non-causal parts of the attention matrix,
219+ but other implementations like cudnn will avoid computing the
220+ non-causal regions, providing speedups.
201221
202222 Returns:
203223 Output of shape `[batch..., q_length, num_heads, v_depth_per_head]`.
@@ -224,7 +244,7 @@ def reshape_4d(x):
224244 reshape_4d , (query , key , value , bias , mask ))
225245 if mask is not None :
226246 mask = mask .astype (jnp .bool )
227- out = jax .nn .dot_product_attention (query , key , value , bias , mask )
247+ out = jax .nn .dot_product_attention (query , key , value , bias , mask , is_causal = is_causal )
228248 if len (query_shape ) > 4 :
229249 out = jnp .reshape (out , query_shape )
230250 return out
@@ -242,6 +262,8 @@ def reshape_4d(x):
242262 dtype ,
243263 precision ,
244264 module ,
265+ promote_dtype ,
266+ is_causal ,
245267 )
246268
247269 # return weighted sum over values for each query position
0 commit comments