Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
16 changes: 11 additions & 5 deletions analyze_times_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@
print("(Row = same first operand, Col = same second operand)")

# Create a table showing for each problem: row/col bias
print(f"\n{'Problem':<10} {'Correct':<8} {'Rank':<6} {'Same Row':<10} {'Same Col':<10} {'Bias':<12}")
print(
f"\n{'Problem':<10} {'Correct':<8} {'Rank':<6} {'Same Row':<10} {'Same Col':<10} {'Bias':<12}"
)
print("-" * 60)

row_bias_count = 0
Expand Down Expand Up @@ -124,7 +126,11 @@
print("\n\n5. HARDEST MULTIPLICATIONS (correct not in top-3)")
print("-" * 40)

hard = [r for r in data if r["neighborhood"]["correct_rank"] is None or r["neighborhood"]["correct_rank"] > 3]
hard = [
r
for r in data
if r["neighborhood"]["correct_rank"] is None or r["neighborhood"]["correct_rank"] > 3
]
hard.sort(key=lambda x: x["neighborhood"]["correct_rank"] or 999)

for r in hard:
Expand Down Expand Up @@ -166,7 +172,7 @@
for bucket in sorted(diff_counts.keys()):
count = diff_counts[bucket]
bar = "#" * (count // 3)
print(f" {bucket:>3}-{bucket+4:<3}: {count:>3} {bar}")
print(f" {bucket:>3}-{bucket + 4:<3}: {count:>3} {bar}")

# 7. Asymmetry analysis
print("\n\n7. ASYMMETRY ANALYSIS (a*b vs b*a)")
Expand All @@ -175,7 +181,7 @@
print("\nDoes the model treat a*b differently from b*a?")

for a in range(2, 9):
for b in range(a+1, 10):
for b in range(a + 1, 10):
# Find both
ab = next((r for r in data if r["a"] == a and r["b"] == b), None)
ba = next((r for r in data if r["a"] == b and r["b"] == a), None)
Expand All @@ -189,5 +195,5 @@
if abs(rank_ab - rank_ba) > 2 or abs(prob_ab - prob_ba) > 0.1:
print(f" {a}*{b}: rank={rank_ab}, prob={prob_ab:.3f}")
print(f" {b}*{a}: rank={rank_ba}, prob={prob_ba:.3f}")
print(f" Δrank={rank_ab-rank_ba:+d}, Δprob={prob_ab-prob_ba:+.3f}")
print(f" Δrank={rank_ab - rank_ba:+d}, Δprob={prob_ab - prob_ba:+.3f}")
print()
53 changes: 30 additions & 23 deletions attention_head_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import json
from collections import defaultdict
from dataclasses import dataclass
from typing import Any

import mlx.core as mx
import mlx.nn as nn
Expand All @@ -22,6 +21,7 @@
@dataclass
class HeadAblationResult:
"""Result of ablating a single attention head."""

layer: int
head: int
original_answer: str
Expand Down Expand Up @@ -149,7 +149,11 @@ def forward_with_head_ablation(
out = lyr(h, mask=mask)
except TypeError:
out = lyr(h)
h = out.hidden_states if hasattr(out, "hidden_states") else (out[0] if isinstance(out, tuple) else out)
h = (
out.hidden_states
if hasattr(out, "hidden_states")
else (out[0] if isinstance(out, tuple) else out)
)
continue

# Reshape for multi-head attention
Expand All @@ -158,7 +162,7 @@ def forward_with_head_ablation(
v = v.reshape(B, L, num_heads, head_dim).transpose(0, 2, 1, 3)

# Compute attention scores
scale_factor = head_dim ** -0.5
scale_factor = head_dim**-0.5
scores = (q @ k.transpose(0, 1, 3, 2)) * scale_factor

# Apply causal mask
Expand All @@ -176,11 +180,9 @@ def forward_with_head_ablation(

# Instead, create a mask
head_mask = mx.ones((num_heads,))
head_mask = mx.concatenate([
head_mask[:ablate_head],
mx.zeros((1,)),
head_mask[ablate_head + 1:]
])
head_mask = mx.concatenate(
[head_mask[:ablate_head], mx.zeros((1,)), head_mask[ablate_head + 1 :]]
)
head_mask = head_mask.reshape(1, num_heads, 1, 1)
attn_output = attn_output * head_mask

Expand Down Expand Up @@ -218,7 +220,11 @@ def forward_with_head_ablation(
out = lyr(h, mask=mask)
except TypeError:
out = lyr(h)
h = out.hidden_states if hasattr(out, "hidden_states") else (out[0] if isinstance(out, tuple) else out)
h = (
out.hidden_states
if hasattr(out, "hidden_states")
else (out[0] if isinstance(out, tuple) else out)
)

# Final prediction
if norm is not None:
Expand Down Expand Up @@ -265,15 +271,17 @@ def forward_with_head_ablation(
if impact > 0.1: # Significant impact
layer_impacts.append((head, impact, ablated_token, ablated_prob))

query_results.append(HeadAblationResult(
layer=layer,
head=head,
original_answer=baseline_token,
original_prob=baseline_prob,
ablated_answer=ablated_token,
ablated_prob=ablated_prob,
impact=impact,
))
query_results.append(
HeadAblationResult(
layer=layer,
head=head,
original_answer=baseline_token,
original_prob=baseline_prob,
ablated_answer=ablated_token,
ablated_prob=ablated_prob,
impact=impact,
)
)

if layer_impacts:
print(f"high-impact heads: {[(h, f'{i:.2f}') for h, i, _, _ in layer_impacts[:5]]}")
Expand All @@ -295,16 +303,15 @@ def forward_with_head_ablation(
head_importance[(r.layer, r.head)].append(r.impact)

# Sort by average impact
sorted_heads = sorted(
head_importance.items(),
key=lambda x: -np.mean(x[1])
)[:20]
sorted_heads = sorted(head_importance.items(), key=lambda x: -np.mean(x[1]))[:20]

print("\nTop 20 most important heads (across all test queries):")
for (layer, head), impacts in sorted_heads:
avg_impact = np.mean(impacts)
count = len(impacts)
print(f" Layer {layer:2d}, Head {head:2d}: avg_impact={avg_impact:.3f}, affects {count}/{len(test_queries)} queries")
print(
f" Layer {layer:2d}, Head {head:2d}: avg_impact={avg_impact:.3f}, affects {count}/{len(test_queries)} queries"
)

return results

Expand Down
Loading
Loading