Skip to content

Commit

Permalink
More customizability
Browse files Browse the repository at this point in the history
  • Loading branch information
JasonGross committed Dec 20, 2024
1 parent 898b9c3 commit 8d96f31
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion gbmi/training_tools/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def plot_tensors(
plot_1D_kind: Literal["line", "scatter"] = "line",
title="Subplots of Matrices",
groups: Optional[Collection[Collection[str]]] = None,
default_heatmap_kwargs: dict[str, Any] = {},
**kwargs,
) -> go.Figure:
# Calculate grid size based on the number of matrices
Expand Down Expand Up @@ -149,7 +150,11 @@ def plot_tensors(
elif len(matrix.shape) == 2:
# 2D data - heatmap
fig.add_trace(
go.Heatmap(z=matrix, name=name, **zmax_zmin_args.get(name, {})),
go.Heatmap(
z=matrix,
name=name,
**zmax_zmin_args.get(name, default_heatmap_kwargs),
),
row=row,
col=col,
)
Expand Down Expand Up @@ -992,6 +997,7 @@ def log_matrices(
model: HookedTransformer,
*,
unsafe: bool = False,
default_heatmap_kwargs: dict[str, Any] = {},
**kwargs,
):
matrices = dict(self.matrices_to_log(model, unsafe=unsafe))
Expand Down Expand Up @@ -1021,6 +1027,7 @@ def log_matrices(
if self.group_colorbars
else None
),
default_heatmap_kwargs=default_heatmap_kwargs,
)
}
if matrices
Expand Down

0 comments on commit 8d96f31

Please sign in to comment.