Skip to content

Add in_kv_features argument to nnx.MultiHeadAttention, addressing https://github.com/google/flax/issues/4756. #4757

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions flax/nnx/nn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,8 @@ class MultiHeadAttention(Module):
should be divisible by the number of heads.
in_features: int or tuple with number of input features.
qkv_features: dimension of the key, query, and value.
out_features: dimension of the last projection
out_features: dimension of the last projection.
in_kv_features: number of input features for computing key and value.
dtype: the dtype of the computation (default: infer from inputs and params)
param_dtype: the dtype passed to parameter initializers (default: float32)
broadcast_dropout: bool: use a broadcasted dropout along batch dims.
Expand Down Expand Up @@ -326,6 +327,7 @@ def __init__(
in_features: int,
qkv_features: int | None = None,
out_features: int | None = None,
in_kv_features: int | None = None,
*,
dtype: Dtype | None = None,
param_dtype: Dtype = jnp.float32,
Expand Down Expand Up @@ -357,6 +359,9 @@ def __init__(
self.out_features = (
out_features if out_features is not None else in_features
)
self.in_kv_features = (
in_kv_features if in_kv_features is not None else in_features
)
self.dtype = dtype
self.param_dtype = param_dtype
self.broadcast_dropout = broadcast_dropout
Expand Down Expand Up @@ -386,7 +391,6 @@ def __init__(

linear_general = functools.partial(
LinearGeneral,
in_features=self.in_features,
out_features=(self.num_heads, self.head_dim),
dtype=self.dtype,
param_dtype=self.param_dtype,
Expand All @@ -399,9 +403,9 @@ def __init__(
)
# project inputs_q to multi-headed q/k/v
# dimensions are then [batch..., length, n_heads, n_features_per_head]
self.query = linear_general(rngs=rngs)
self.key = linear_general(rngs=rngs)
self.value = linear_general(rngs=rngs)
self.query = linear_general(self.in_features, rngs=rngs)
self.key = linear_general(self.in_kv_features, rngs=rngs)
self.value = linear_general(self.in_kv_features, rngs=rngs)

self.query_ln: LayerNorm | None
self.key_ln: LayerNorm | None
Expand Down
27 changes: 27 additions & 0 deletions tests/nnx/nn/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,5 +199,32 @@ def test_nnx_attention_equivalence(
np.testing.assert_array_equal(out, out_nnx)


class TestKVFeatures(parameterized.TestCase):

def test_varying_num_features(self):
key = jax.random.key(42)
rngs = nnx.Rngs(42)

num_heads = 2
in_features = 3
in_kv_features = 4
qkv_features = 6
out_features = 6

x = jax.numpy.ones((1, in_features))
y = jax.random.normal(key, (1, in_kv_features))
layer = nnx.MultiHeadAttention(
num_heads=num_heads,
in_features=in_features,
qkv_features=qkv_features,
out_features=out_features,
in_kv_features=in_kv_features,
rngs=rngs,
decode=False
)

self.assertIsNotNone(layer(x, y))


if __name__ == '__main__':
absltest.main()
Loading