@@ -265,28 +265,24 @@ def init(config: Olmo2Config, *, key) -> "Olmo2Attention":
265
265
use_bias = config .attention_bias
266
266
Embed = config .Embed
267
267
QHeadsPerGroup = hax .Axis ("q_heads_per_group" , config .num_heads // config .num_kv_heads )
268
+ HeadSize = config .HeadSize
268
269
269
270
k_q , k_k , k_v , k_o , k_q_norm , k_k_norm = jrandom .split (key , 6 )
270
271
q_proj = hnn .Linear .init (
271
- In = Embed , Out = (config .KVHeads , QHeadsPerGroup , config .HeadSize ), key = k_q , use_bias = use_bias , out_first = True
272
- )
273
- k_proj = hnn .Linear .init (
274
- In = Embed , Out = (config .KVHeads , config .HeadSize ), key = k_k , use_bias = use_bias , out_first = True
275
- )
276
- v_proj = hnn .Linear .init (
277
- In = Embed , Out = (config .KVHeads , config .HeadSize ), key = k_v , use_bias = use_bias , out_first = True
278
- )
279
- o_proj = hnn .Linear .init (
280
- In = (config .Heads , config .HeadSize ), Out = Embed , key = k_o , use_bias = use_bias , out_first = True
272
+ In = Embed , Out = (config .KVHeads , QHeadsPerGroup , HeadSize ), key = k_q , use_bias = use_bias , out_first = True
281
273
)
274
+ k_proj = hnn .Linear .init (In = Embed , Out = (config .KVHeads , HeadSize ), key = k_k , use_bias = use_bias , out_first = True )
275
+ v_proj = hnn .Linear .init (In = Embed , Out = (config .KVHeads , HeadSize ), key = k_v , use_bias = use_bias , out_first = True )
276
+ o_proj = hnn .Linear .init (In = (config .Heads , HeadSize ), Out = Embed , key = k_o , use_bias = use_bias , out_first = True )
282
277
283
- # OLMo2 uses normalization over the entire hidden dimension for q and k
278
+ # Fix this line - k_norm should be the size of kv_heads * head_size (if that's what HF is doing)
284
279
q_norm = Olmo2RMSNorm .init (
285
280
config .Embed , eps = config .layer_norm_epsilon , use_weight = config .use_layer_norm_weight
286
281
)
287
- k_norm = Olmo2RMSNorm .init (
288
- config .Embed , eps = config .layer_norm_epsilon , use_weight = config .use_layer_norm_weight
289
- )
282
+
283
+ # Define a specific KVHeadSize axis for k_norm if needed
284
+ KVEmbedSize = Axis ("kv_embed" , config .KVHeads .size * HeadSize .size )
285
+ k_norm = Olmo2RMSNorm .init (KVEmbedSize , eps = config .layer_norm_epsilon , use_weight = config .use_layer_norm_weight )
290
286
291
287
return Olmo2Attention (config , q_proj , k_proj , v_proj , o_proj , q_norm , k_norm )
292
288
0 commit comments