Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 2 additions & 19 deletions apps/inference/neuronpedia_inference/endpoints/activation/all.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from neuronpedia_inference.shared import (
Model,
calculate_per_source_dfa,
get_activations_by_index,
get_layer_num_from_sae_id,
safe_cast,
with_request_lock,
Expand Down Expand Up @@ -173,7 +174,7 @@ def _process_sources(
hook_name = sae_manager.get_sae_hook(selected_source)
sae_type = sae_manager.get_sae_type(selected_source)

activations_by_index = self._get_activations_by_index(
activations_by_index = get_activations_by_index(
sae_type, selected_source, cache, hook_name
)

Expand All @@ -195,24 +196,6 @@ def _process_sources(

return source_activations

def _get_activations_by_index(
self,
sae_type: str,
selected_source: str,
cache: ActivationCache,
hook_name: str,
) -> torch.Tensor:
"""Get activations by index for a specific layer and SAE type."""
if sae_type == "neurons":
mlp_activation_data = cache[hook_name].to(Config.get_instance().device)
return torch.transpose(mlp_activation_data[0], 0, 1)

activation_data = cache[hook_name].to(Config.get_instance().device)
feature_activation_data = (
SAEManager.get_instance().get_sae(selected_source).encode(activation_data)
)
return torch.transpose(feature_activation_data.squeeze(0), 0, 1)

def _process_source_activations(
self,
activations_by_index: torch.Tensor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@
from neuronpedia_inference_client.models.activation_topk_by_token_post_request import (
ActivationTopkByTokenPostRequest,
)
from transformer_lens import ActivationCache

from neuronpedia_inference.config import Config
from neuronpedia_inference.sae_manager import SAEManager
from neuronpedia_inference.shared import Model, with_request_lock
from neuronpedia_inference.shared import (
Model,
get_activations_by_index,
with_request_lock,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -110,21 +113,3 @@ async def activation_topk_by_token(
results=results,
tokens=str_tokens, # type: ignore
)


# Keep the get_activations_by_index function from the original code
def get_activations_by_index(
sae_type: str,
selected_layer: str,
cache: ActivationCache | dict[str, torch.Tensor],
hook_name: str,
) -> torch.Tensor:
if sae_type == "neurons":
mlp_activation_data = cache[hook_name].to(Config.get_instance().device)
return torch.transpose(mlp_activation_data[0], 0, 1)

activation_data = cache[hook_name].to(Config.get_instance().device)
feature_activation_data = (
SAEManager.get_instance().get_sae(selected_layer).encode(activation_data)
)
return torch.transpose(feature_activation_data.squeeze(0), 0, 1)
22 changes: 21 additions & 1 deletion apps/inference/neuronpedia_inference/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@

import einops
import torch
from transformer_lens import HookedTransformer
from transformer_lens import ActivationCache, HookedTransformer

from neuronpedia_inference.config import Config
from neuronpedia_inference.sae_manager import SAEManager

request_lock = asyncio.Lock()

Expand Down Expand Up @@ -135,3 +138,20 @@ def calculate_per_source_dfa(

def get_layer_num_from_sae_id(sae_id: str) -> int:
return int(sae_id.split("-")[0]) if not sae_id.isdigit() else int(sae_id)


def get_activations_by_index(
sae_type: str,
selected_source: str,
cache: ActivationCache,
hook_name: str,
) -> torch.Tensor:
"""Get activations by index for a specific layer and SAE type."""
if sae_type == "neurons":
mlp_activation_data = cache[hook_name].to(Config.get_instance().device)
return torch.transpose(mlp_activation_data[0], 0, 1)
activation_data = cache[hook_name].to(Config.get_instance().device)
feature_activation_data = (
SAEManager.get_instance().get_sae(selected_source).encode(activation_data)
)
return torch.transpose(feature_activation_data.squeeze(0), 0, 1)
15 changes: 15 additions & 0 deletions apps/inference/tests/unit/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,15 @@ def processor() -> ActivationProcessor:
return ActivationProcessor()


@patch("neuronpedia_inference.endpoints.activation.all.get_activations_by_index")
@patch("neuronpedia_inference.endpoints.activation.all.Model")
@patch("neuronpedia_inference.endpoints.activation.all.SAEManager")
@patch("neuronpedia_inference.endpoints.activation.all.Config")
def test_process_activations_basic(
mock_config_class: MagicMock,
mock_sae_manager_class: MagicMock,
mock_model_class: MagicMock,
mock_get_activations: MagicMock,
processor: ActivationProcessor,
sample_request: ActivationAllPostRequest,
mock_model: Mock,
Expand All @@ -117,6 +119,7 @@ def test_process_activations_basic(
mock_model_class.get_instance.return_value = mock_model
mock_sae_manager_class.get_instance.return_value = mock_sae_manager
mock_config_class.get_instance.return_value = mock_config
mock_get_activations.return_value = torch.randn(128, 5)
# Execute
result = processor.process_activations(sample_request)
# Verify result structure
Expand All @@ -132,13 +135,15 @@ def test_process_activations_basic(
mock_model.run_with_cache.assert_called_once()


@patch("neuronpedia_inference.endpoints.activation.all.get_activations_by_index")
@patch("neuronpedia_inference.endpoints.activation.all.Model")
@patch("neuronpedia_inference.endpoints.activation.all.SAEManager")
@patch("neuronpedia_inference.endpoints.activation.all.Config")
def test_process_activations_with_sort_by_token_indexes(
mock_config_class: MagicMock,
mock_sae_manager_class: MagicMock,
mock_model_class: MagicMock,
mock_get_activations: MagicMock,
processor: ActivationProcessor,
sample_request: ActivationAllPostRequest,
mock_model: Mock,
Expand All @@ -150,6 +155,7 @@ def test_process_activations_with_sort_by_token_indexes(
mock_model_class.get_instance.return_value = mock_model
mock_sae_manager_class.get_instance.return_value = mock_sae_manager
mock_config_class.get_instance.return_value = mock_config
mock_get_activations.return_value = torch.randn(128, 5)
# Modify request to include sort_by_token_indexes
sample_request.sort_by_token_indexes = [1, 2, 3]
# Execute
Expand All @@ -159,13 +165,15 @@ def test_process_activations_with_sort_by_token_indexes(
assert len(result.tokens) == 5 # Should match mock token length


@patch("neuronpedia_inference.endpoints.activation.all.get_activations_by_index")
@patch("neuronpedia_inference.endpoints.activation.all.Model")
@patch("neuronpedia_inference.endpoints.activation.all.SAEManager")
@patch("neuronpedia_inference.endpoints.activation.all.Config")
def test_process_activations_with_feature_filter(
mock_config_class: MagicMock,
mock_sae_manager_class: MagicMock,
mock_model_class: MagicMock,
mock_get_activations: MagicMock,
processor: ActivationProcessor,
sample_request: ActivationAllPostRequest,
mock_model: Mock,
Expand All @@ -177,6 +185,7 @@ def test_process_activations_with_feature_filter(
mock_model_class.get_instance.return_value = mock_model
mock_sae_manager_class.get_instance.return_value = mock_sae_manager
mock_config_class.get_instance.return_value = mock_config
mock_get_activations.return_value = torch.randn(128, 5)
# Modify request for single layer with feature filter
sample_request.selected_sources = ["0-test_set"]
sample_request.feature_filter = [0, 1, 5, 10]
Expand Down Expand Up @@ -212,13 +221,15 @@ def test_process_activations_with_neurons_sae_type(
assert isinstance(result, ActivationAllPost200Response)


@patch("neuronpedia_inference.endpoints.activation.all.get_activations_by_index")
@patch("neuronpedia_inference.endpoints.activation.all.Model")
@patch("neuronpedia_inference.endpoints.activation.all.SAEManager")
@patch("neuronpedia_inference.endpoints.activation.all.Config")
def test_process_activations_with_dfa_enabled(
mock_config_class: MagicMock,
mock_sae_manager_class: MagicMock,
mock_model_class: MagicMock,
mock_get_activations: MagicMock,
processor: ActivationProcessor,
sample_request: ActivationAllPostRequest,
mock_model: Mock,
Expand All @@ -230,6 +241,7 @@ def test_process_activations_with_dfa_enabled(
mock_model_class.get_instance.return_value = mock_model
mock_sae_manager_class.get_instance.return_value = mock_sae_manager
mock_config_class.get_instance.return_value = mock_config
mock_get_activations.return_value = torch.randn(128, 5)
# Enable DFA
mock_sae_manager.is_dfa_enabled.return_value = True
# Mock calculate_per_source_dfa function
Expand Down Expand Up @@ -270,13 +282,15 @@ def test_process_activations_invalid_token_index(
processor.process_activations(sample_request)


@patch("neuronpedia_inference.endpoints.activation.all.get_activations_by_index")
@patch("neuronpedia_inference.endpoints.activation.all.Model")
@patch("neuronpedia_inference.endpoints.activation.all.SAEManager")
@patch("neuronpedia_inference.endpoints.activation.all.Config")
def test_process_activations_ignore_bos(
mock_config_class: MagicMock,
mock_sae_manager_class: MagicMock,
mock_model_class: MagicMock,
mock_get_activations: MagicMock,
processor: ActivationProcessor,
sample_request: ActivationAllPostRequest,
mock_model: Mock,
Expand All @@ -288,6 +302,7 @@ def test_process_activations_ignore_bos(
mock_model_class.get_instance.return_value = mock_model
mock_sae_manager_class.get_instance.return_value = mock_sae_manager
mock_config_class.get_instance.return_value = mock_config
mock_get_activations.return_value = torch.randn(128, 5)
sample_request.ignore_bos = True
# Execute
result = processor.process_activations(sample_request)
Expand Down