Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 12 additions & 72 deletions apps/inference/neuronpedia_inference/endpoints/activation/all.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import re
from typing import Any

import einops
import torch
from fastapi import APIRouter
from fastapi.responses import JSONResponse
Expand All @@ -21,6 +20,9 @@
from neuronpedia_inference.sae_manager import SAEManager
from neuronpedia_inference.shared import (
Model,
calculate_per_source_dfa,
get_layer_num_from_sae_id,
safe_cast,
with_request_lock,
)

Expand Down Expand Up @@ -95,24 +97,6 @@ async def activation_all(
)


def _get_safe_dtype(dtype: torch.dtype) -> torch.dtype:
"""
Convert float16 to float32, leave other dtypes unchanged.
"""
return torch.float32 if dtype == torch.float16 else dtype


def _safe_cast(tensor: torch.Tensor, target_dtype: torch.dtype) -> torch.Tensor:
"""
Safely cast a tensor to the target dtype, creating a copy if needed.
Convert float16 to float32, leave other dtypes unchanged.
"""
safe_dtype = _get_safe_dtype(tensor.dtype)
if safe_dtype != tensor.dtype or safe_dtype != target_dtype:
return tensor.to(target_dtype)
return tensor


class ActivationProcessor:
@torch.no_grad()
def process_activations(
Expand Down Expand Up @@ -355,61 +339,17 @@ def _calculate_dfa_values(
v = cache["v", layer_num]
attn_weights = cache["pattern", layer_num]

# Determine the safe dtype for operations
v_dtype = _get_safe_dtype(v.dtype)
attn_weights_dtype = _get_safe_dtype(attn_weights.dtype)
encoder_dtype = _get_safe_dtype(encoder.W_enc.dtype)

# Use the highest precision dtype
op_dtype = max(
v_dtype,
attn_weights_dtype,
encoder_dtype,
key=lambda x: x.itemsize,
result = calculate_per_source_dfa(
model=model,
encoder=encoder,
v=v,
attn_weights=attn_weights,
feature_index=idx,
max_value_index=max_value_index,
)

# Check if the model uses GQA
use_gqa = (
hasattr(model.cfg, "n_key_value_heads")
and model.cfg.n_key_value_heads is not None
and model.cfg.n_key_value_heads < model.cfg.n_heads
)

if use_gqa:
n_query_heads = attn_weights.shape[1]
n_kv_heads = v.shape[2]
expansion_factor = n_query_heads // n_kv_heads
v = v.repeat_interleave(expansion_factor, dim=2)

# Cast tensors to operation dtype
v = _safe_cast(v, op_dtype)
attn_weights = _safe_cast(attn_weights, op_dtype)

v_cat = einops.rearrange(
v, "batch src_pos n_heads d_head -> batch src_pos (n_heads d_head)"
)

attn_weights_bcast = einops.repeat(
attn_weights,
"batch n_heads dest_pos src_pos -> batch dest_pos src_pos (n_heads d_head)",
d_head=model.cfg.d_head,
)

decomposed_z_cat = attn_weights_bcast * v_cat.unsqueeze(1)

# Cast encoder weights to operation dtype
W_enc = _safe_cast(encoder.W_enc[:, idx], op_dtype)

per_src_pos_dfa = einops.einsum(
decomposed_z_cat,
W_enc,
"batch dest_pos src_pos d_model, d_model -> batch dest_pos src_pos",
)

result = per_src_pos_dfa[torch.arange(1), torch.tensor([max_value_index]), :]

# Cast the result back to the original dtype of v
return _safe_cast(result, v.dtype)
return safe_cast(result, v.dtype)

def _calculate_table_counts(
self,
Expand All @@ -435,7 +375,7 @@ def _calculate_table_counts(
def _get_layer_num(sae_id: str) -> int:
"""Get layer number from SAE ID."""
try:
return int(sae_id.split("-")[0]) if not sae_id.isdigit() else int(sae_id)
return get_layer_num_from_sae_id(sae_id)

except ValueError:
if "blocks" in sae_id:
Expand Down
80 changes: 14 additions & 66 deletions apps/inference/neuronpedia_inference/endpoints/activation/single.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging
from typing import Any

import einops
import torch
from fastapi import APIRouter
from fastapi.responses import JSONResponse
Expand All @@ -18,7 +17,12 @@

from neuronpedia_inference.config import Config
from neuronpedia_inference.sae_manager import SAEManager
from neuronpedia_inference.shared import Model, with_request_lock
from neuronpedia_inference.shared import (
Model,
calculate_per_source_dfa,
get_layer_num_from_sae_id,
with_request_lock,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -127,28 +131,6 @@ async def activation_single(
return ActivationSinglePost200Response(activation=result, tokens=str_tokens)


def _get_safe_dtype(dtype: torch.dtype) -> torch.dtype:
"""
Convert float16 to float32, leave other dtypes unchanged.
"""
return torch.float32 if dtype == torch.float16 else dtype


def _safe_cast(tensor: torch.Tensor, target_dtype: torch.dtype) -> torch.Tensor:
"""
Safely cast a tensor to the target dtype, creating a copy if needed.
Convert float16 to float32, leave other dtypes unchanged.
"""
safe_dtype = _get_safe_dtype(tensor.dtype)
if safe_dtype != tensor.dtype or safe_dtype != target_dtype:
return tensor.to(target_dtype)
return tensor


def get_layer_num_from_sae_id(sae_id: str) -> int:
return int(sae_id.split("-")[0]) if not sae_id.isdigit() else int(sae_id)


def process_activations(
model: HookedTransformer, layer: str, index: int, tokens: torch.Tensor
) -> ActivationSinglePost200ResponseActivation:
Expand Down Expand Up @@ -249,51 +231,17 @@ def calculate_dfa(
v = cache["v", layer_num] # [batch, src_pos, n_heads, d_head]
attn_weights = cache["pattern", layer_num] # [batch, n_heads, dest_pos, src_pos]

# Determine the safe dtype for operations
v_dtype = _get_safe_dtype(v.dtype)
attn_weights_dtype = _get_safe_dtype(attn_weights.dtype)
sae_dtype = _get_safe_dtype(sae.W_enc.dtype)

# Use the highest precision dtype
op_dtype = max(v_dtype, attn_weights_dtype, sae_dtype, key=lambda x: x.itemsize)

# Check if the model uses GQA
use_gqa = (
hasattr(model.cfg, "n_key_value_heads")
and model.cfg.n_key_value_heads is not None
and model.cfg.n_key_value_heads < model.cfg.n_heads
per_src_dfa = calculate_per_source_dfa(
model=model,
encoder=sae,
v=v,
attn_weights=attn_weights,
feature_index=index,
max_value_index=max_value_index,
)

if use_gqa:
n_query_heads = attn_weights.shape[1]
n_kv_heads = v.shape[2]
expansion_factor = n_query_heads // n_kv_heads
v = v.repeat_interleave(expansion_factor, dim=2)

# Cast tensors to operation dtype
v = _safe_cast(v, op_dtype)
attn_weights = _safe_cast(attn_weights, op_dtype)

v_cat = einops.rearrange(
v, "batch src_pos n_heads d_head -> batch src_pos (n_heads d_head)"
)
attn_weights_bcast = einops.repeat(
attn_weights,
"batch n_heads dest_pos src_pos -> batch dest_pos src_pos (n_heads d_head)",
d_head=model.cfg.d_head,
)
decomposed_z_cat = attn_weights_bcast * v_cat.unsqueeze(1)

# Cast SAE weights to operation dtype
W_enc = _safe_cast(sae.W_enc[:, index], op_dtype)

per_src_pos_dfa = einops.einsum(
decomposed_z_cat,
W_enc,
"batch dest_pos src_pos d_model, d_model -> batch dest_pos src_pos",
)
per_src_dfa = per_src_pos_dfa[torch.arange(1), torch.tensor([max_value_index]), :]
dfa_values = per_src_dfa[0].tolist()

return {
"dfa_values": dfa_values,
"dfa_target_index": max_value_index,
Expand Down
95 changes: 95 additions & 0 deletions apps/inference/neuronpedia_inference/shared.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import asyncio
from functools import wraps
from typing import Any

import einops
import torch
from transformer_lens import HookedTransformer

Expand Down Expand Up @@ -40,3 +42,96 @@ def set_instance(cls, model: HookedTransformer) -> None:
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}


def _get_safe_dtype(dtype: torch.dtype) -> torch.dtype:
"""
Convert float16 to float32, leave other dtypes unchanged.
"""
return torch.float32 if dtype == torch.float16 else dtype


def safe_cast(tensor: torch.Tensor, target_dtype: torch.dtype) -> torch.Tensor:
"""
Safely cast a tensor to the target dtype, creating a copy if needed.
Convert float16 to float32, leave other dtypes unchanged.
"""
safe_dtype = _get_safe_dtype(tensor.dtype)
if safe_dtype != tensor.dtype or safe_dtype != target_dtype:
return tensor.to(target_dtype)
return tensor


def calculate_per_source_dfa(
model: HookedTransformer,
encoder: Any,
v: torch.Tensor,
attn_weights: torch.Tensor,
feature_index: int,
max_value_index: int,
) -> torch.Tensor:
"""
Core DFA calculation logic that can be shared between different endpoints.

Args:
model: The transformer model
encoder: The SAE or encoder object with W_enc attribute
v: The value tensor from cache
attn_weights: The attention pattern tensor from cache
feature_index: Index of the feature in the encoder
max_value_index: Index where maximum activation occurs

Returns:
Tensor with the raw DFA values (before formatting or casting)
"""
# Determine the safe dtype for operations
v_dtype = _get_safe_dtype(v.dtype)
attn_weights_dtype = _get_safe_dtype(attn_weights.dtype)
encoder_dtype = _get_safe_dtype(encoder.W_enc.dtype)

# Use the highest precision dtype
op_dtype = max(v_dtype, attn_weights_dtype, encoder_dtype, key=lambda x: x.itemsize)

# Check if the model uses GQA
use_gqa = (
hasattr(model.cfg, "n_key_value_heads")
and model.cfg.n_key_value_heads is not None
and model.cfg.n_key_value_heads < model.cfg.n_heads
)

if use_gqa:
n_query_heads = attn_weights.shape[1]
n_kv_heads = v.shape[2]
expansion_factor = n_query_heads // n_kv_heads
v = v.repeat_interleave(expansion_factor, dim=2)

# Cast tensors to operation dtype
v = safe_cast(v, op_dtype)
attn_weights = safe_cast(attn_weights, op_dtype)

v_cat = einops.rearrange(
v, "batch src_pos n_heads d_head -> batch src_pos (n_heads d_head)"
)

attn_weights_bcast = einops.repeat(
attn_weights,
"batch n_heads dest_pos src_pos -> batch dest_pos src_pos (n_heads d_head)",
d_head=model.cfg.d_head,
)

decomposed_z_cat = attn_weights_bcast * v_cat.unsqueeze(1)

# Cast encoder weights to operation dtype
W_enc = safe_cast(encoder.W_enc[:, feature_index], op_dtype)

per_src_pos_dfa = einops.einsum(
decomposed_z_cat,
W_enc,
"batch dest_pos src_pos d_model, d_model -> batch dest_pos src_pos",
)

return per_src_pos_dfa[torch.arange(1), torch.tensor([max_value_index]), :]


def get_layer_num_from_sae_id(sae_id: str) -> int:
return int(sae_id.split("-")[0]) if not sae_id.isdigit() else int(sae_id)
Loading