Skip to content

Commit 8be0806

Browse files
Ryan McKennaFlax Authors
Ryan McKenna
authored and
Flax Authors
committed
Add in_kv_features argument to nnx.MultiHeadAttention, addressing #4756.
PiperOrigin-RevId: 760691901
1 parent a5eebe5 commit 8be0806

File tree

2 files changed

+36
-5
lines changed

2 files changed

+36
-5
lines changed

flax/nnx/nn/attention.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,8 @@ class MultiHeadAttention(Module):
279279
should be divisible by the number of heads.
280280
in_features: int or tuple with number of input features.
281281
qkv_features: dimension of the key, query, and value.
282-
out_features: dimension of the last projection
282+
out_features: dimension of the last projection.
283+
in_kv_features: number of input features for computing key and value.
283284
dtype: the dtype of the computation (default: infer from inputs and params)
284285
param_dtype: the dtype passed to parameter initializers (default: float32)
285286
broadcast_dropout: bool: use a broadcasted dropout along batch dims.
@@ -326,6 +327,7 @@ def __init__(
326327
in_features: int,
327328
qkv_features: int | None = None,
328329
out_features: int | None = None,
330+
in_kv_features: int | None = None,
329331
*,
330332
dtype: Dtype | None = None,
331333
param_dtype: Dtype = jnp.float32,
@@ -357,6 +359,9 @@ def __init__(
357359
self.out_features = (
358360
out_features if out_features is not None else in_features
359361
)
362+
self.in_kv_features = (
363+
in_kv_features if in_kv_features is not None else in_features
364+
)
360365
self.dtype = dtype
361366
self.param_dtype = param_dtype
362367
self.broadcast_dropout = broadcast_dropout
@@ -386,7 +391,6 @@ def __init__(
386391

387392
linear_general = functools.partial(
388393
LinearGeneral,
389-
in_features=self.in_features,
390394
out_features=(self.num_heads, self.head_dim),
391395
dtype=self.dtype,
392396
param_dtype=self.param_dtype,
@@ -399,9 +403,9 @@ def __init__(
399403
)
400404
# project inputs_q to multi-headed q/k/v
401405
# dimensions are then [batch..., length, n_heads, n_features_per_head]
402-
self.query = linear_general(rngs=rngs)
403-
self.key = linear_general(rngs=rngs)
404-
self.value = linear_general(rngs=rngs)
406+
self.query = linear_general(self.in_features, rngs=rngs)
407+
self.key = linear_general(self.in_kv_features, rngs=rngs)
408+
self.value = linear_general(self.in_kv_features, rngs=rngs)
405409

406410
self.query_ln: LayerNorm | None
407411
self.key_ln: LayerNorm | None

tests/nnx/nn/attention_test.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,5 +199,32 @@ def test_nnx_attention_equivalence(
199199
np.testing.assert_array_equal(out, out_nnx)
200200

201201

202+
class TestKVFeatures(parameterized.TestCase):
203+
204+
def test_varying_num_features(self):
205+
key = jax.random.key(42)
206+
rngs = nnx.Rngs(42)
207+
208+
num_heads = 2
209+
in_features = 3
210+
in_kv_features = 4
211+
qkv_features = 6
212+
out_features = 6
213+
214+
x = jax.numpy.ones((1, in_features))
215+
y = jax.random.normal(key, (1, in_kv_features))
216+
layer = nnx.MultiHeadAttention(
217+
num_heads=num_heads,
218+
in_features=in_features,
219+
qkv_features=qkv_features,
220+
out_features=out_features,
221+
in_kv_features=in_kv_features,
222+
rngs=rngs,
223+
decode=False
224+
)
225+
226+
self.assertIsNotNone(layer(x, y))
227+
228+
202229
if __name__ == '__main__':
203230
absltest.main()

0 commit comments

Comments
 (0)