Skip to content

Commit

Permalink
Always compute axis limits for TeX
Browse files Browse the repository at this point in the history
  • Loading branch information
JasonGross committed Nov 6, 2024
1 parent 06b1c8d commit c381db9
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 77 deletions.
154 changes: 101 additions & 53 deletions gbmi/exp_max_of_n/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,91 @@ def compute_irrelevant(
}


@torch.no_grad()
def compute_basic_interpretation_axis_limits(
model: HookedTransformer,
*,
include_uncentered: bool = False,
include_equals_OV: bool = False,
includes_eos: Optional[bool] = None,
plot_with: Literal["plotly", "matplotlib"] = "plotly",
) -> Tuple[dict, dict[str, float]]:
cached_data = {}
axis_limits = {
"OV_zmin": np.inf,
"OV_zmax": -np.inf,
"QK_zmin": np.inf,
"QK_zmax": -np.inf,
"OVCentered_zmin": np.inf,
"OVCentered_zmax": -np.inf,
"QKWithAttnScale_zmin": np.inf,
"QKWithAttnScale_zmax": -np.inf,
}
if includes_eos is None:
includes_eos = model.cfg.d_vocab != model.cfg.d_vocab_out
title_kind = "html" if plot_with == "plotly" else "latex"
for attn_scale, with_attn_scale in (("", False), ("WithAttnScale", True)):
QK = compute_QK(
model, includes_eos=includes_eos, with_attn_scale=with_attn_scale
)
axis_limits[f"QK{attn_scale}_zmin"] = np.min(
[axis_limits[f"QK{attn_scale}_zmin"], QK["data"].min()]
)
axis_limits[f"QK{attn_scale}_zmax"] = np.max(
[axis_limits[f"QK{attn_scale}_zmax"], QK["data"].max()]
)
cached_data[("QK", attn_scale)] = QK

if include_uncentered:
OV = compute_OV(model, centered=False, includes_eos=includes_eos)
axis_limits["OV_zmin"] = np.min([axis_limits["OV_zmin"], OV["data"].min()])
axis_limits["OV_zmax"] = np.max([axis_limits["OV_zmax"], OV["data"].max()])
cached_data[("OV", False)] = OV

OV = compute_OV(model, centered=True, includes_eos=includes_eos)
axis_limits["OVCentered_zmin"] = np.min(
[axis_limits["OVCentered_zmin"], OV["data"].min()]
)
axis_limits["OVCentered_zmax"] = np.max(
[axis_limits["OVCentered_zmax"], OV["data"].max()]
)
cached_data[("OV", True)] = OV

for attn_scale, with_attn_scale in (("", False), ("WithAttnScale", True)):
pos_QK = compute_QK_by_position(
model, includes_eos=includes_eos, with_attn_scale=with_attn_scale
)
cached_data[("pos_QK", attn_scale)] = pos_QK
if includes_eos:
axis_limits[f"QK{attn_scale}_zmin"] = np.min(
[axis_limits[f"QK{attn_scale}_zmin"], pos_QK["data"]["QK"].min()]
)
axis_limits[f"QK{attn_scale}_zmax"] = np.max(
[axis_limits[f"QK{attn_scale}_zmax"], pos_QK["data"]["QK"].max()]
)
else:
axis_limits[f"QK{attn_scale}_zmin"] = np.min(
[axis_limits[f"QK{attn_scale}_zmin"], pos_QK["data"]["QK"].min()]
)
axis_limits[f"QK{attn_scale}_zmax"] = np.max(
[axis_limits[f"QK{attn_scale}_zmax"], pos_QK["data"]["QK"].max()]
)

irrelevant = compute_irrelevant(
model,
include_equals_OV=include_equals_OV,
includes_eos=includes_eos,
title_kind=title_kind,
)
cached_data["irrelevant"] = irrelevant
for key, data in irrelevant["data"].items():
if len(data.shape) == 2:
axis_limits["OV_zmin"] = np.min([axis_limits["OV_zmin"], data.min()])
axis_limits["OV_zmax"] = np.max([axis_limits["OV_zmax"], data.max()])

return cached_data, axis_limits


@torch.no_grad()
def display_basic_interpretation(
model: HookedTransformer,
Expand All @@ -485,34 +570,26 @@ def display_basic_interpretation(
plot_with: Literal["plotly", "matplotlib"] = "plotly",
renderer: Optional[str] = None,
show: bool = True,
cached_data: Optional[dict] = None,
axis_limits: Optional[dict[str, float]] = None,
) -> Tuple[dict[str, Union[go.Figure, matplotlib.figure.Figure]], dict[str, float]]:
if cached_data is None:
cached_data, axis_limits = compute_basic_interpretation_axis_limits(
model,
include_uncentered=include_uncentered,
include_equals_OV=include_equals_OV,
includes_eos=includes_eos,
plot_with=plot_with,
)
QK_cmap = colorscale_to_cmap(QK_colorscale)
QK_SVD_cmap = colorscale_to_cmap(QK_SVD_colorscale)
OV_cmap = colorscale_to_cmap(OV_colorscale)
if includes_eos is None:
includes_eos = model.cfg.d_vocab != model.cfg.d_vocab_out
result = {}
axis_limits = {
"OV_zmin": np.inf,
"OV_zmax": -np.inf,
"QK_zmin": np.inf,
"QK_zmax": -np.inf,
"OVCentered_zmin": np.inf,
"OVCentered_zmax": -np.inf,
"QKWithAttnScale_zmin": np.inf,
"QKWithAttnScale_zmax": -np.inf,
}
title_kind = "html" if plot_with == "plotly" else "latex"
for attn_scale, with_attn_scale in (("", False), ("WithAttnScale", True)):
QK = compute_QK(
model, includes_eos=includes_eos, with_attn_scale=with_attn_scale
)
axis_limits[f"QK{attn_scale}_zmin"] = np.min(
[axis_limits[f"QK{attn_scale}_zmin"], QK["data"].min()]
)
axis_limits[f"QK{attn_scale}_zmax"] = np.max(
[axis_limits[f"QK{attn_scale}_zmax"], QK["data"].max()]
)
title_kind = "html" if plot_with == "plotly" else "latex"
QK = cached_data[("QK", attn_scale)]
if includes_eos:
match plot_with:
case "plotly":
Expand Down Expand Up @@ -567,9 +644,7 @@ def display_basic_interpretation(
result[f"EQKE{attn_scale}"] = fig_qk

if include_uncentered:
OV = compute_OV(model, centered=False, includes_eos=includes_eos)
axis_limits["OV_zmin"] = np.min([axis_limits["OV_zmin"], OV["data"].min()])
axis_limits["OV_zmax"] = np.max([axis_limits["OV_zmax"], OV["data"].max()])
OV = cached_data[("OV", False)]
fig_ov = imshow(
OV["data"],
title=OV["title"][title_kind],
Expand All @@ -585,13 +660,7 @@ def display_basic_interpretation(
show=show,
)
result["EVOU"] = fig_ov
OV = compute_OV(model, centered=True, includes_eos=includes_eos)
axis_limits["OVCentered_zmin"] = np.min(
[axis_limits["OVCentered_zmin"], OV["data"].min()]
)
axis_limits["OVCentered_zmax"] = np.max(
[axis_limits["OVCentered_zmax"], OV["data"].max()]
)
OV = cached_data[("OV", True)]
fig_ov = imshow(
OV["data"],
title=OV["title"][title_kind],
Expand All @@ -609,16 +678,8 @@ def display_basic_interpretation(
result["EVOU-centered"] = fig_ov

for attn_scale, with_attn_scale in (("", False), ("WithAttnScale", True)):
pos_QK = compute_QK_by_position(
model, includes_eos=includes_eos, with_attn_scale=with_attn_scale
)
pos_QK = cached_data[("pos_QK", attn_scale)]
if includes_eos:
axis_limits[f"QK{attn_scale}_zmin"] = np.min(
[axis_limits[f"QK{attn_scale}_zmin"], pos_QK["data"]["QK"].min()]
)
axis_limits[f"QK{attn_scale}_zmax"] = np.max(
[axis_limits[f"QK{attn_scale}_zmax"], pos_QK["data"]["QK"].max()]
)
fig_qk = px.scatter(
pos_QK["data"],
title=pos_QK["title"][title_kind],
Expand All @@ -631,12 +692,6 @@ def display_basic_interpretation(
if show:
fig_qk.show(renderer=renderer)
else:
axis_limits[f"QK{attn_scale}_zmin"] = np.min(
[axis_limits[f"QK{attn_scale}_zmin"], pos_QK["data"]["QK"].min()]
)
axis_limits[f"QK{attn_scale}_zmax"] = np.max(
[axis_limits[f"QK{attn_scale}_zmax"], pos_QK["data"]["QK"].max()]
)
fig_qk = imshow(
pos_QK["data"]["QK"],
title=pos_QK["title"][title_kind],
Expand All @@ -653,16 +708,9 @@ def display_basic_interpretation(
)
result[f"EQKP{attn_scale}"] = fig_qk

irrelevant = compute_irrelevant(
model,
include_equals_OV=include_equals_OV,
includes_eos=includes_eos,
title_kind=title_kind,
)
irrelevant = cached_data["irrelevant"]
for key, data in irrelevant["data"].items():
if len(data.shape) == 2:
axis_limits["OV_zmin"] = np.min([axis_limits["OV_zmin"], data.min()])
axis_limits["OV_zmax"] = np.max([axis_limits["OV_zmax"], data.max()])
fig = imshow(
data,
title=key,
Expand Down
91 changes: 67 additions & 24 deletions notebooks_jason/max_of_K_all_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,7 @@ def optimize_pngs(errs: list[Exception] = []):
EVOU_max_minus_diag_logit_diff,
attention_difference_over_gap,
display_basic_interpretation,
compute_basic_interpretation_axis_limits,
display_EQKE_SVD_analysis,
hist_attention_difference_over_gap,
hist_EVOU_max_minus_diag_logit_diff,
Expand Down Expand Up @@ -1912,13 +1913,67 @@ def handle_compute_EQKE_SVD_analysis(

# %% [markdown]
# # Plots
all_axis_limits = defaultdict(dict)
all_cached_data = {}
with tqdm(
runtime_models.items(), desc="compute_basic_interpretation_axis_limits"
) as pbar:
for seed, (_runtime, model) in pbar:
pbar.set_postfix(dict(seed=seed))
all_cached_data[seed], axis_limits = compute_basic_interpretation_axis_limits(
model,
include_uncentered=True,
plot_with=PLOT_WITH,
)
for k, v in axis_limits.items():
all_axis_limits[k][seed] = v

axis_limits = {}
for k, v in all_axis_limits.items():
if k.endswith("min"):
axis_limits[k] = np.min(list(v.values()))
elif k.endswith("max"):
axis_limits[k] = np.max(list(v.values()))
else:
raise ValueError(f"Unknown axis limit key: {k}")

for k in axis_limits.keys():
k_no_min_max = (
k.replace("zmin", "").replace("zmax", "").replace("min", "").replace("max", "")
)
latex_key = "".join(
[
kpart if kpart[:1] == kpart[:1].capitalize() else kpart.capitalize()
for kpart in k_no_min_max.replace("-", "_").split("_")
]
)
k_min = k.replace("max", "min")
k_max = k.replace("min", "max")
assert k_min in axis_limits, f"Missing {k_min}"
assert k_max in axis_limits, f"Missing {k_max}"
assert k_min == k or k_max == k, f"Unknown key: {k}"
assert k_min != k_max, f"Same key: {k}"
if "centered" not in k.lower():
v_max = np.max([np.abs(axis_limits[k_min]), np.abs(axis_limits[k_max])])
axis_limits[k_min] = -v_max
axis_limits[k_max] = v_max
assert "OV" in k or "QK" in k, f"Unknown key: {k}"

for k, v in axis_limits.items():
k = "".join(
[
kpart if kpart[0] == kpart[0].capitalize() else kpart.capitalize()
for kpart in k.replace("-", "_").split("_")
]
)
latex_values[f"AxisLimits{k}Float"] = v

# %%
if (DISPLAY_PLOTS or SAVE_PLOTS) and INDIVIDUAL_PLOTS:
all_axis_limits = defaultdict(dict)
with tqdm(runtime_models.items(), desc="display_basic_interpretation") as pbar:
for seed, (_runtime, model) in pbar:
pbar.set_postfix(dict(seed=seed))
figs, axis_limits = display_basic_interpretation(
figs, _ = display_basic_interpretation(
model,
include_uncentered=True,
OV_colorscale=default_OV_colorscale,
Expand All @@ -1928,9 +1983,8 @@ def handle_compute_EQKE_SVD_analysis(
plot_with=PLOT_WITH,
renderer=RENDERER,
show=DISPLAY_PLOTS,
cached_data=all_cached_data[seed],
)
for k, v in axis_limits.items():
all_axis_limits[k][seed] = v
for attn_scale in ("", "WithAttnScale"):
for fig in (
figs[f"EQKE{attn_scale}"],
Expand Down Expand Up @@ -1970,15 +2024,6 @@ def handle_compute_EQKE_SVD_analysis(
if unused_keys:
print(f"Unused keys: {unused_keys}")

axis_limits = {}
for k, v in all_axis_limits.items():
if k.endswith("min"):
axis_limits[k] = np.min(list(v.values()))
elif k.endswith("max"):
axis_limits[k] = np.max(list(v.values()))
else:
raise ValueError(f"Unknown axis limit key: {k}")

seen = set()
for k in axis_limits.keys():
k_no_min_max = (
Expand All @@ -2001,8 +2046,14 @@ def handle_compute_EQKE_SVD_analysis(
assert k_min != k_max, f"Same key: {k}"
if "centered" not in k.lower():
v_max = np.max([np.abs(axis_limits[k_min]), np.abs(axis_limits[k_max])])
axis_limits[k_min] = -v_max
axis_limits[k_max] = v_max
if axis_limits[k_min] != -v_max:
print(
f"Warning: {axis_limits[k_min]} == axix_limits[{k_min}] != -v_max == {-v_max}"
)
if axis_limits[k_max] != v_max:
print(
f"Warning: {axis_limits[k_max]} == axix_limits[{k_max}] != v_max == {v_max}"
)

assert "OV" in k or "QK" in k, f"Unknown key: {k}"
if k_no_min_max in seen:
Expand All @@ -2024,15 +2075,6 @@ def handle_compute_EQKE_SVD_analysis(
latex_figures[f"Colorbar-{latex_key}-Vertical"] = figV
latex_figures[f"Colorbar-{latex_key}-Horizontal"] = figH

for k, v in axis_limits.items():
k = "".join(
[
kpart if kpart[0] == kpart[0].capitalize() else kpart.capitalize()
for kpart in k.replace("-", "_").split("_")
]
)
latex_values[f"AxisLimits{k}Float"] = v

with tqdm(
runtime_models.items(), desc="display_basic_interpretation (uniform limits)"
) as pbar:
Expand All @@ -2045,6 +2087,7 @@ def handle_compute_EQKE_SVD_analysis(
QK_colorscale=default_QK_colorscale,
QK_SVD_colorscale=default_QK_SVD_colorscale,
tok_dtick=10,
cached_data=all_cached_data[seed],
**axis_limits,
plot_with=PLOT_WITH,
renderer=RENDERER,
Expand Down

0 comments on commit c381db9

Please sign in to comment.