@@ -275,14 +275,23 @@ def init(config: Olmo2Config, *, key) -> "Olmo2Attention":
275
275
v_proj = hnn .Linear .init (In = Embed , Out = (config .KVHeads , HeadSize ), key = k_v , use_bias = use_bias , out_first = True )
276
276
o_proj = hnn .Linear .init (In = (config .Heads , HeadSize ), Out = Embed , key = k_o , use_bias = use_bias , out_first = True )
277
277
278
- # Fix this line - k_norm should be the size of kv_heads * head_size (if that's what HF is doing)
278
+ # For q_norm, normalization is over the entire hidden dimension
279
279
q_norm = Olmo2RMSNorm .init (
280
280
config .Embed , eps = config .layer_norm_epsilon , use_weight = config .use_layer_norm_weight
281
281
)
282
282
283
- # Define a specific KVHeadSize axis for k_norm if needed
284
- KVEmbedSize = Axis ("kv_embed" , config .KVHeads .size * HeadSize .size )
285
- k_norm = Olmo2RMSNorm .init (KVEmbedSize , eps = config .layer_norm_epsilon , use_weight = config .use_layer_norm_weight )
283
+ # For k_norm, we need to be careful with the axis
284
+ if config .num_kv_heads == config .num_heads :
285
+ # If num_kv_heads equals num_heads, use the same axis as q_norm
286
+ k_norm = Olmo2RMSNorm .init (
287
+ config .Embed , eps = config .layer_norm_epsilon , use_weight = config .use_layer_norm_weight
288
+ )
289
+ else :
290
+ # If using grouped query attention, the k_norm needs a smaller size
291
+ k_norm_axis = Axis ("embed" , config .num_kv_heads * HeadSize .size )
292
+ k_norm = Olmo2RMSNorm .init (
293
+ k_norm_axis , eps = config .layer_norm_epsilon , use_weight = config .use_layer_norm_weight
294
+ )
286
295
287
296
return Olmo2Attention (config , q_proj , k_proj , v_proj , o_proj , q_norm , k_norm )
288
297
@@ -305,9 +314,25 @@ def __call__(self, x: NamedArray, mask: Optional[NamedArray | AttentionMask], *,
305
314
norm_x = self .q_norm (x )
306
315
q = self .q_proj (norm_x , key = key_q )
307
316
308
- norm_x = self .k_norm (x )
309
- k = self .k_proj (norm_x , key = key_k )
317
+ # For k_norm, we need special handling if num_kv_heads != num_heads
318
+ if self .config .num_kv_heads == self .config .num_heads :
319
+ # Same normalization as q
320
+ norm_x = self .k_norm (x )
321
+ else :
322
+ # We need to project x to the right dimensionality for k_norm
323
+ # This is specific to OLMo2's implementation
324
+ HeadSize = self .config .HeadSize
325
+
326
+ # Project to the smaller embedding used for KV heads
327
+ # We can use a simple reshape or slice operation since the weight tensor
328
+ # is just a subset of the full embedding
329
+ norm_x_for_k = x .array [:, : self .config .num_kv_heads * HeadSize .size ]
330
+ norm_x_for_k = hax .named (
331
+ norm_x_for_k , (x .axes [0 ], Axis ("embed" , self .config .num_kv_heads * HeadSize .size ))
332
+ )
333
+ norm_x = self .k_norm (norm_x_for_k )
310
334
335
+ k = self .k_proj (norm_x , key = key_k )
311
336
v = self .v_proj (x , key = key_v )
312
337
313
338
# Reshape for attention
0 commit comments