Skip to content

Commit f227eb0

Browse files
committed
fix olmo
1 parent ab464d3 commit f227eb0

File tree

1 file changed

+31
-6
lines changed

1 file changed

+31
-6
lines changed

src/levanter/models/olmo.py

+31-6
Original file line numberDiff line numberDiff line change
@@ -275,14 +275,23 @@ def init(config: Olmo2Config, *, key) -> "Olmo2Attention":
275275
v_proj = hnn.Linear.init(In=Embed, Out=(config.KVHeads, HeadSize), key=k_v, use_bias=use_bias, out_first=True)
276276
o_proj = hnn.Linear.init(In=(config.Heads, HeadSize), Out=Embed, key=k_o, use_bias=use_bias, out_first=True)
277277

278-
# Fix this line - k_norm should be the size of kv_heads * head_size (if that's what HF is doing)
278+
# For q_norm, normalization is over the entire hidden dimension
279279
q_norm = Olmo2RMSNorm.init(
280280
config.Embed, eps=config.layer_norm_epsilon, use_weight=config.use_layer_norm_weight
281281
)
282282

283-
# Define a specific KVHeadSize axis for k_norm if needed
284-
KVEmbedSize = Axis("kv_embed", config.KVHeads.size * HeadSize.size)
285-
k_norm = Olmo2RMSNorm.init(KVEmbedSize, eps=config.layer_norm_epsilon, use_weight=config.use_layer_norm_weight)
283+
# For k_norm, we need to be careful with the axis
284+
if config.num_kv_heads == config.num_heads:
285+
# If num_kv_heads equals num_heads, use the same axis as q_norm
286+
k_norm = Olmo2RMSNorm.init(
287+
config.Embed, eps=config.layer_norm_epsilon, use_weight=config.use_layer_norm_weight
288+
)
289+
else:
290+
# If using grouped query attention, the k_norm needs a smaller size
291+
k_norm_axis = Axis("embed", config.num_kv_heads * HeadSize.size)
292+
k_norm = Olmo2RMSNorm.init(
293+
k_norm_axis, eps=config.layer_norm_epsilon, use_weight=config.use_layer_norm_weight
294+
)
286295

287296
return Olmo2Attention(config, q_proj, k_proj, v_proj, o_proj, q_norm, k_norm)
288297

@@ -305,9 +314,25 @@ def __call__(self, x: NamedArray, mask: Optional[NamedArray | AttentionMask], *,
305314
norm_x = self.q_norm(x)
306315
q = self.q_proj(norm_x, key=key_q)
307316

308-
norm_x = self.k_norm(x)
309-
k = self.k_proj(norm_x, key=key_k)
317+
# For k_norm, we need special handling if num_kv_heads != num_heads
318+
if self.config.num_kv_heads == self.config.num_heads:
319+
# Same normalization as q
320+
norm_x = self.k_norm(x)
321+
else:
322+
# We need to project x to the right dimensionality for k_norm
323+
# This is specific to OLMo2's implementation
324+
HeadSize = self.config.HeadSize
325+
326+
# Project to the smaller embedding used for KV heads
327+
# We can use a simple reshape or slice operation since the weight tensor
328+
# is just a subset of the full embedding
329+
norm_x_for_k = x.array[:, : self.config.num_kv_heads * HeadSize.size]
330+
norm_x_for_k = hax.named(
331+
norm_x_for_k, (x.axes[0], Axis("embed", self.config.num_kv_heads * HeadSize.size))
332+
)
333+
norm_x = self.k_norm(norm_x_for_k)
310334

335+
k = self.k_proj(norm_x, key=key_k)
311336
v = self.v_proj(x, key=key_v)
312337

313338
# Reshape for attention

0 commit comments

Comments
 (0)