Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions apps/inference/neuronpedia_inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ def get_sae_lens_ids_from_neuronpedia_id(
(df_exploded["model"] == model_id)
& (df_exploded["neuronpedia_id"].str.endswith(f"/{neuronpedia_id}"))
]

assert (
tmp_df.shape[0] == 1
), f"Found {tmp_df.shape[0]} entries when searching for {model_id}/{neuronpedia_id}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def steering_hook(activations: torch.Tensor, hook: Any) -> torch.Tensor: # noqa
editing_hooks = [
(
(
sae_manager.get_sae_hook(feature.source)
sae_manager.get_decoder_hook(feature.source)
if isinstance(feature, NPSteerFeature)
else feature.hook
),
Expand Down Expand Up @@ -248,7 +248,7 @@ def steering_hook(activations: torch.Tensor, hook: Any) -> torch.Tensor: # noqa
editing_hooks = [
(
(
sae_manager.get_sae_hook(feature.source)
sae_manager.get_decoder_hook(feature.source)
if isinstance(feature, NPSteerFeature)
else feature.hook
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def steering_hook(activations: torch.Tensor, hook: Any) -> torch.Tensor: # noqa
editing_hooks = [
(
(
sae_manager.get_sae_hook(feature.source)
sae_manager.get_decoder_hook(feature.source)
if isinstance(feature, NPSteerFeature)
else feature.hook
),
Expand Down Expand Up @@ -308,7 +308,7 @@ def steering_hook(activations: torch.Tensor, hook: Any) -> torch.Tensor: # noqa
editing_hooks = [
(
(
sae_manager.get_sae_hook(feature.source)
sae_manager.get_decoder_hook(feature.source)
if isinstance(feature, NPSteerFeature)
else feature.hook
),
Expand Down
11 changes: 8 additions & 3 deletions apps/inference/neuronpedia_inference/sae_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def load_sae(self, model_id: str, sae_id: str) -> None:
df_exploded=get_saelens_neuronpedia_directory_df(),
)

loaded_sae, hook_name = SaeLensSAE.load(
loaded_sae, hook_in, hook_out = SaeLensSAE.load(
release=sae_lens_release,
sae_id=sae_lens_id,
device=self.device,
Expand All @@ -118,7 +118,8 @@ def load_sae(self, model_id: str, sae_id: str) -> None:

self.sae_data[sae_id] = {
"sae": loaded_sae,
"hook": hook_name,
"hook": hook_in,
"hook_out": hook_out,
"neuronpedia_id": loaded_sae.cfg.neuronpedia_id,
"type": SAE_TYPE.SAELENS,
# TODO: this should be in SAELens
Expand All @@ -129,7 +130,7 @@ def load_sae(self, model_id: str, sae_id: str) -> None:
or DFA_ENABLED_NP_ID_SEGMENT_ALT in loaded_sae.cfg.neuronpedia_id
)
),
"transcoder": False, # You might want to set this based on some condition
"transcoder": hook_out is not None
}

self.loaded_saes[sae_id] = None # We're using OrderedDict as an OrderedSet
Expand Down Expand Up @@ -261,6 +262,10 @@ def get_sae_type(self, sae_id: str) -> str:

def get_sae_hook(self, sae_id: str) -> str:
return self.sae_data.get(sae_id, {}).get("hook")

def get_decoder_hook(self, sae_id):
data = self.sae_data.get(sae_id, {})
return data.get("hook_out") or data.get("hook")

def is_dfa_enabled(self, sae_id: str) -> bool:
return self.sae_data.get(sae_id, {}).get("dfa_enabled", False)
Expand Down
83 changes: 69 additions & 14 deletions apps/inference/neuronpedia_inference/saes/saelens.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,79 @@
import torch
from sae_lens.sae import SAE
"""SAE/Transcoder loader wrapper for Neuronpedia Searcher.

from neuronpedia_inference.saes.base import BaseSAE
This module previously supported only the vanilla SAE objects exposed by
`sae_lens.sae.SAE`. We now extend the functionality to transparently load
three different artifact classes coming from the sae-lens code-base:

* SAE (classic auto-encoder)
* Transcoder
* SkipTranscoder

The heavy lifting is delegated to `load_artifact_from_pretrained`, a new helper
published upstream that inspects the YAML metadata of a given release/sae_id
and automatically returns an instance of the correct class.

For Neuronpedia Inference we treat each artifact uniformly – the caller only
needs the instantiated object and the hook names. For classic SAEs the single
hook `cfg.hook_name` is sufficient. Transcoders additionally come with
`cfg.hook_name_out`, the location where the decoder output should be steered.

DTYPE_MAP = {
"float16": torch.float16,
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}
The `load` method therefore now returns **three** values:

(artifact, hook_name_in, hook_name_out)

`hook_name_out` is `None` for plain SAEs so users can branch on a simple
truthiness check to detect Transcoder-like artifacts.
"""

from neuronpedia_inference.saes.base import BaseSAE
from sae_lens.toolkit.pretrained_sae_loaders import ( # type: ignore
load_artifact_from_pretrained,
)
from sae_lens.config import DTYPE_MAP # type: ignore


class SaeLensSAE(BaseSAE):
@staticmethod
def load(release: str, sae_id: str, device: str, dtype: str) -> tuple["SAE", str]:
loaded_sae, _, _ = SAE.from_pretrained(
def load(release: str, sae_id: str, device: str, dtype: str):
"""Load an artifact (SAE / Transcoder / SkipTranscoder).

Args:
release: The named release on the HF hub (e.g. "sae_lens")
sae_id: The specific SAE/Transcoder identifier inside *release*.
device: Torch device string, forwarded to the loader.
dtype: One of {"float16", "float32", "bfloat16"} – we convert
the loaded weights to this dtype after loading.

Returns:
artifact: The initialised model instance (type depends on
YAML `type` field).
hook_name_in: Where to read encoder activations from.
hook_name_out: Where to *write* decoder deltas to when steering.
`None` for classic SAEs.
"""

artifact, _cfg_dict, _sparsity = load_artifact_from_pretrained(
release=release,
sae_id=sae_id,
device=device,
)
loaded_sae.to(device, dtype=DTYPE_MAP[dtype])
loaded_sae.fold_W_dec_norm()
loaded_sae.eval()
return loaded_sae, loaded_sae.cfg.hook_name

# Ensure correct dtype & eval mode
artifact.to(device, dtype=DTYPE_MAP[dtype])

# Some classes (SAE, Transcoder, SkipTranscoder) expose this helper –
# if it does not exist we silently ignore the attribute.
if hasattr(artifact, "fold_W_dec_norm"):
try:
artifact.fold_W_dec_norm()
except Exception:
# Folding is a convenience optimization, not critical – we do
# not want loading to fail if it is not implemented.
pass

artifact.eval()

hook_name_in = artifact.cfg.hook_name
hook_name_out = getattr(artifact.cfg, "hook_name_out", None) or None

return artifact, hook_name_in, hook_name_out
10 changes: 7 additions & 3 deletions apps/inference/neuronpedia_inference/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,13 @@ async def initialize(
def load_model_and_sae():
# Validate inputs
df = get_saelens_neuronpedia_directory_df()

models = df["model"].unique()
sae_sets = df["neuronpedia_set"].unique()
if args.model_id not in models:
logger.error(
f"Error: Invalid model_id '{args.model_id}'. Use --list_models to see available options."
f"Error: Invalid model_id '{args.model_id}'. "
"Use --list_models to see available options."
)
exit(1)
# iterate through sae_sets and split them by spaces
Expand All @@ -126,7 +128,8 @@ def load_model_and_sae():
invalid_sae_sets = set(args_sae_sets) - set(sae_sets)
if invalid_sae_sets:
logger.error(
f"Error: Invalid SAE set(s): {', '.join(invalid_sae_sets)}. Use --list_models to see available options."
f"Error: Invalid SAE set(s): {', '.join(invalid_sae_sets)}. "
"Use --list_models to see available options."
)
exit(1)

Expand Down Expand Up @@ -203,7 +206,8 @@ def load_model_and_sae():
config.set_steer_special_token_ids(special_token_ids) # type: ignore

logger.info(
f"Loaded {config.CUSTOM_HF_MODEL_ID if config.CUSTOM_HF_MODEL_ID else config.OVERRIDE_MODEL_ID} on {args.device}"
f"Loaded {config.CUSTOM_HF_MODEL_ID or config.OVERRIDE_MODEL_ID} "
f"on {args.device}"
)
checkCudaError()

Expand Down
Loading