Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 5dd98ea

Browse files
committedMar 17, 2025·
test olmo better
1 parent 524316e commit 5dd98ea

File tree

2 files changed

+31
-24
lines changed

2 files changed

+31
-24
lines changed
 

‎src/levanter/models/olmo.py

+10-14
Original file line numberDiff line numberDiff line change
@@ -265,28 +265,24 @@ def init(config: Olmo2Config, *, key) -> "Olmo2Attention":
265265
use_bias = config.attention_bias
266266
Embed = config.Embed
267267
QHeadsPerGroup = hax.Axis("q_heads_per_group", config.num_heads // config.num_kv_heads)
268+
HeadSize = config.HeadSize
268269

269270
k_q, k_k, k_v, k_o, k_q_norm, k_k_norm = jrandom.split(key, 6)
270271
q_proj = hnn.Linear.init(
271-
In=Embed, Out=(config.KVHeads, QHeadsPerGroup, config.HeadSize), key=k_q, use_bias=use_bias, out_first=True
272-
)
273-
k_proj = hnn.Linear.init(
274-
In=Embed, Out=(config.KVHeads, config.HeadSize), key=k_k, use_bias=use_bias, out_first=True
275-
)
276-
v_proj = hnn.Linear.init(
277-
In=Embed, Out=(config.KVHeads, config.HeadSize), key=k_v, use_bias=use_bias, out_first=True
278-
)
279-
o_proj = hnn.Linear.init(
280-
In=(config.Heads, config.HeadSize), Out=Embed, key=k_o, use_bias=use_bias, out_first=True
272+
In=Embed, Out=(config.KVHeads, QHeadsPerGroup, HeadSize), key=k_q, use_bias=use_bias, out_first=True
281273
)
274+
k_proj = hnn.Linear.init(In=Embed, Out=(config.KVHeads, HeadSize), key=k_k, use_bias=use_bias, out_first=True)
275+
v_proj = hnn.Linear.init(In=Embed, Out=(config.KVHeads, HeadSize), key=k_v, use_bias=use_bias, out_first=True)
276+
o_proj = hnn.Linear.init(In=(config.Heads, HeadSize), Out=Embed, key=k_o, use_bias=use_bias, out_first=True)
282277

283-
# OLMo2 uses normalization over the entire hidden dimension for q and k
278+
# Fix this line - k_norm should be the size of kv_heads * head_size (if that's what HF is doing)
284279
q_norm = Olmo2RMSNorm.init(
285280
config.Embed, eps=config.layer_norm_epsilon, use_weight=config.use_layer_norm_weight
286281
)
287-
k_norm = Olmo2RMSNorm.init(
288-
config.Embed, eps=config.layer_norm_epsilon, use_weight=config.use_layer_norm_weight
289-
)
282+
283+
# Define a specific KVHeadSize axis for k_norm if needed
284+
KVEmbedSize = Axis("kv_embed", config.KVHeads.size * HeadSize.size)
285+
k_norm = Olmo2RMSNorm.init(KVEmbedSize, eps=config.layer_norm_epsilon, use_weight=config.use_layer_norm_weight)
290286

291287
return Olmo2Attention(config, q_proj, k_proj, v_proj, o_proj, q_norm, k_norm)
292288

‎tests/test_olmo.py

+21-10
Original file line numberDiff line numberDiff line change
@@ -214,20 +214,31 @@ def test_olmo2_roundtrip(scan_layers, num_kv_heads):
214214
torch_out = torch_model(input_torch)
215215
torch_out = torch_out.logits[0].detach().cpu().numpy()
216216

217-
# Add this before the roundtrip test fails
218-
print("HF model params:", {k: v.shape for k, v in torch_model.state_dict().items() if "layers.0" in k})
219-
print(
220-
"Levanter model expected:",
221-
{
222-
"layers.0.mlp.gate_proj.weight": (config.hidden_dim, config.intermediate_dim),
223-
"layers.0.mlp.up_proj.weight": (config.hidden_dim, config.intermediate_dim),
224-
"layers.0.mlp.down_proj.weight": (config.intermediate_dim, config.hidden_dim),
225-
},
226-
)
227217
with tempfile.TemporaryDirectory() as tmpdir:
228218
# Save HF model
229219
torch_model.save_pretrained(f"{tmpdir}/torch_model")
230220

221+
# Add this before the converter.load_pretrained line
222+
torch_state_dict = torch.load(f"{tmpdir}/torch_model/pytorch_model.bin")
223+
print("\nDetailed HF Shapes:")
224+
for k, v in torch_state_dict.items():
225+
if "layers.0" in k:
226+
print(f"{k}: {v.shape}")
227+
228+
# Create a template model to inspect
229+
template_model = Olmo2LMHeadModel.init(Vocab=Vocab, config=config, key=random.PRNGKey(0))
230+
print("\nLevanter Model Parameter Structure:")
231+
for layer_idx in range(config.num_layers):
232+
print(f"Layer {layer_idx}:")
233+
234+
# Print the attention module params shapes
235+
attn = template_model.transformer.layers.blocks[layer_idx].self_attn
236+
print(f" q_proj: {attn.q_proj.weight.array.shape}")
237+
print(f" k_proj: {attn.k_proj.weight.array.shape}")
238+
print(f" v_proj: {attn.v_proj.weight.array.shape}")
239+
print(f" o_proj: {attn.o_proj.weight.array.shape}")
240+
print(f" q_norm: {attn.q_norm.weight.array.shape if attn.q_norm.weight is not None else None}")
241+
print(f" k_norm: {attn.k_norm.weight.array.shape if attn.k_norm.weight is not None else None}")
231242
# Load into our model
232243
model = converter.load_pretrained(
233244
Olmo2LMHeadModel, ref=f"{tmpdir}/torch_model", resize_vocab_to_match_tokenizer=False

0 commit comments

Comments
 (0)
Please sign in to comment.