@@ -279,7 +279,8 @@ class MultiHeadAttention(Module):
279
279
should be divisible by the number of heads.
280
280
in_features: int or tuple with number of input features.
281
281
qkv_features: dimension of the key, query, and value.
282
- out_features: dimension of the last projection
282
+ out_features: dimension of the last projection.
283
+ in_kv_features: number of input features for computing key and value.
283
284
dtype: the dtype of the computation (default: infer from inputs and params)
284
285
param_dtype: the dtype passed to parameter initializers (default: float32)
285
286
broadcast_dropout: bool: use a broadcasted dropout along batch dims.
@@ -326,6 +327,7 @@ def __init__(
326
327
in_features : int ,
327
328
qkv_features : int | None = None ,
328
329
out_features : int | None = None ,
330
+ in_kv_features : int | None = None ,
329
331
* ,
330
332
dtype : Dtype | None = None ,
331
333
param_dtype : Dtype = jnp .float32 ,
@@ -357,6 +359,9 @@ def __init__(
357
359
self .out_features = (
358
360
out_features if out_features is not None else in_features
359
361
)
362
+ self .in_kv_features = (
363
+ in_kv_features if in_kv_features is not None else in_features
364
+ )
360
365
self .dtype = dtype
361
366
self .param_dtype = param_dtype
362
367
self .broadcast_dropout = broadcast_dropout
@@ -386,7 +391,6 @@ def __init__(
386
391
387
392
linear_general = functools .partial (
388
393
LinearGeneral ,
389
- in_features = self .in_features ,
390
394
out_features = (self .num_heads , self .head_dim ),
391
395
dtype = self .dtype ,
392
396
param_dtype = self .param_dtype ,
@@ -399,9 +403,9 @@ def __init__(
399
403
)
400
404
# project inputs_q to multi-headed q/k/v
401
405
# dimensions are then [batch..., length, n_heads, n_features_per_head]
402
- self .query = linear_general (rngs = rngs )
403
- self .key = linear_general (rngs = rngs )
404
- self .value = linear_general (rngs = rngs )
406
+ self .query = linear_general (self . in_features , rngs = rngs )
407
+ self .key = linear_general (self . in_kv_features , rngs = rngs )
408
+ self .value = linear_general (self . in_kv_features , rngs = rngs )
405
409
406
410
self .query_ln : LayerNorm | None
407
411
self .key_ln : LayerNorm | None
0 commit comments