Skip to content

Commit

Permalink
Factor out plot_matrices_from_model
Browse files Browse the repository at this point in the history
  • Loading branch information
JasonGross committed Dec 27, 2024
1 parent dc8c213 commit edbda5a
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions gbmi/training_tools/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -1088,17 +1088,15 @@ def plot_matrices(
return figs

@torch.no_grad()
def log_matrices(
def plot_matrices_from_model(
self,
logger: Run,
model: HookedTransformer,
*,
unsafe: bool = False,
default_heatmap_kwargs: dict[str, Any] = {},
**kwargs,
):
self.assert_model_supported(model, unsafe=unsafe)
figs = self.plot_matrices(
return self.plot_matrices(
model.W_E,
model.W_pos,
model.W_U,
Expand All @@ -1114,6 +1112,20 @@ def log_matrices(
attention_dir=model.cfg.attention_dir,
default_heatmap_kwargs=default_heatmap_kwargs,
)

@torch.no_grad()
def log_matrices(
self,
logger: Run,
model: HookedTransformer,
*,
unsafe: bool = False,
default_heatmap_kwargs: dict[str, Any] = {},
**kwargs,
):
figs = self.plot_matrices_from_model(
model, unsafe=unsafe, default_heatmap_kwargs=default_heatmap_kwargs
)
logger.log(
{encode_4_byte_unicode(k): v for k, v in figs.items()},
commit=False,
Expand Down

0 comments on commit edbda5a

Please sign in to comment.