diff --git a/apps/inference/neuronpedia_inference/endpoints/activation/all.py b/apps/inference/neuronpedia_inference/endpoints/activation/all.py index 7a1224e38..1cc82e12b 100644 --- a/apps/inference/neuronpedia_inference/endpoints/activation/all.py +++ b/apps/inference/neuronpedia_inference/endpoints/activation/all.py @@ -21,6 +21,7 @@ from neuronpedia_inference.shared import ( Model, calculate_per_source_dfa, + get_activations_by_index, get_layer_num_from_sae_id, safe_cast, with_request_lock, @@ -173,7 +174,7 @@ def _process_sources( hook_name = sae_manager.get_sae_hook(selected_source) sae_type = sae_manager.get_sae_type(selected_source) - activations_by_index = self._get_activations_by_index( + activations_by_index = get_activations_by_index( sae_type, selected_source, cache, hook_name ) @@ -195,24 +196,6 @@ def _process_sources( return source_activations - def _get_activations_by_index( - self, - sae_type: str, - selected_source: str, - cache: ActivationCache, - hook_name: str, - ) -> torch.Tensor: - """Get activations by index for a specific layer and SAE type.""" - if sae_type == "neurons": - mlp_activation_data = cache[hook_name].to(Config.get_instance().device) - return torch.transpose(mlp_activation_data[0], 0, 1) - - activation_data = cache[hook_name].to(Config.get_instance().device) - feature_activation_data = ( - SAEManager.get_instance().get_sae(selected_source).encode(activation_data) - ) - return torch.transpose(feature_activation_data.squeeze(0), 0, 1) - def _process_source_activations( self, activations_by_index: torch.Tensor, diff --git a/apps/inference/neuronpedia_inference/endpoints/activation/topk_by_token.py b/apps/inference/neuronpedia_inference/endpoints/activation/topk_by_token.py index 482645ec6..250cc90c7 100644 --- a/apps/inference/neuronpedia_inference/endpoints/activation/topk_by_token.py +++ b/apps/inference/neuronpedia_inference/endpoints/activation/topk_by_token.py @@ -15,11 +15,14 @@ from neuronpedia_inference_client.models.activation_topk_by_token_post_request import ( ActivationTopkByTokenPostRequest, ) -from transformer_lens import ActivationCache from neuronpedia_inference.config import Config from neuronpedia_inference.sae_manager import SAEManager -from neuronpedia_inference.shared import Model, with_request_lock +from neuronpedia_inference.shared import ( + Model, + get_activations_by_index, + with_request_lock, +) logger = logging.getLogger(__name__) @@ -110,21 +113,3 @@ async def activation_topk_by_token( results=results, tokens=str_tokens, # type: ignore ) - - -# Keep the get_activations_by_index function from the original code -def get_activations_by_index( - sae_type: str, - selected_layer: str, - cache: ActivationCache | dict[str, torch.Tensor], - hook_name: str, -) -> torch.Tensor: - if sae_type == "neurons": - mlp_activation_data = cache[hook_name].to(Config.get_instance().device) - return torch.transpose(mlp_activation_data[0], 0, 1) - - activation_data = cache[hook_name].to(Config.get_instance().device) - feature_activation_data = ( - SAEManager.get_instance().get_sae(selected_layer).encode(activation_data) - ) - return torch.transpose(feature_activation_data.squeeze(0), 0, 1) diff --git a/apps/inference/neuronpedia_inference/shared.py b/apps/inference/neuronpedia_inference/shared.py index 0e925796a..550cc8f2c 100644 --- a/apps/inference/neuronpedia_inference/shared.py +++ b/apps/inference/neuronpedia_inference/shared.py @@ -4,7 +4,10 @@ import einops import torch -from transformer_lens import HookedTransformer +from transformer_lens import ActivationCache, HookedTransformer + +from neuronpedia_inference.config import Config +from neuronpedia_inference.sae_manager import SAEManager request_lock = asyncio.Lock() @@ -135,3 +138,20 @@ def calculate_per_source_dfa( def get_layer_num_from_sae_id(sae_id: str) -> int: return int(sae_id.split("-")[0]) if not sae_id.isdigit() else int(sae_id) + + +def get_activations_by_index( + sae_type: str, + selected_source: str, + cache: ActivationCache, + hook_name: str, +) -> torch.Tensor: + """Get activations by index for a specific layer and SAE type.""" + if sae_type == "neurons": + mlp_activation_data = cache[hook_name].to(Config.get_instance().device) + return torch.transpose(mlp_activation_data[0], 0, 1) + activation_data = cache[hook_name].to(Config.get_instance().device) + feature_activation_data = ( + SAEManager.get_instance().get_sae(selected_source).encode(activation_data) + ) + return torch.transpose(feature_activation_data.squeeze(0), 0, 1) diff --git a/apps/inference/tests/unit/test_all.py b/apps/inference/tests/unit/test_all.py index f25bacd0a..ddb14d4d9 100644 --- a/apps/inference/tests/unit/test_all.py +++ b/apps/inference/tests/unit/test_all.py @@ -99,6 +99,7 @@ def processor() -> ActivationProcessor: return ActivationProcessor() +@patch("neuronpedia_inference.endpoints.activation.all.get_activations_by_index") @patch("neuronpedia_inference.endpoints.activation.all.Model") @patch("neuronpedia_inference.endpoints.activation.all.SAEManager") @patch("neuronpedia_inference.endpoints.activation.all.Config") @@ -106,6 +107,7 @@ def test_process_activations_basic( mock_config_class: MagicMock, mock_sae_manager_class: MagicMock, mock_model_class: MagicMock, + mock_get_activations: MagicMock, processor: ActivationProcessor, sample_request: ActivationAllPostRequest, mock_model: Mock, @@ -117,6 +119,7 @@ def test_process_activations_basic( mock_model_class.get_instance.return_value = mock_model mock_sae_manager_class.get_instance.return_value = mock_sae_manager mock_config_class.get_instance.return_value = mock_config + mock_get_activations.return_value = torch.randn(128, 5) # Execute result = processor.process_activations(sample_request) # Verify result structure @@ -132,6 +135,7 @@ def test_process_activations_basic( mock_model.run_with_cache.assert_called_once() +@patch("neuronpedia_inference.endpoints.activation.all.get_activations_by_index") @patch("neuronpedia_inference.endpoints.activation.all.Model") @patch("neuronpedia_inference.endpoints.activation.all.SAEManager") @patch("neuronpedia_inference.endpoints.activation.all.Config") @@ -139,6 +143,7 @@ def test_process_activations_with_sort_by_token_indexes( mock_config_class: MagicMock, mock_sae_manager_class: MagicMock, mock_model_class: MagicMock, + mock_get_activations: MagicMock, processor: ActivationProcessor, sample_request: ActivationAllPostRequest, mock_model: Mock, @@ -150,6 +155,7 @@ def test_process_activations_with_sort_by_token_indexes( mock_model_class.get_instance.return_value = mock_model mock_sae_manager_class.get_instance.return_value = mock_sae_manager mock_config_class.get_instance.return_value = mock_config + mock_get_activations.return_value = torch.randn(128, 5) # Modify request to include sort_by_token_indexes sample_request.sort_by_token_indexes = [1, 2, 3] # Execute @@ -159,6 +165,7 @@ def test_process_activations_with_sort_by_token_indexes( assert len(result.tokens) == 5 # Should match mock token length +@patch("neuronpedia_inference.endpoints.activation.all.get_activations_by_index") @patch("neuronpedia_inference.endpoints.activation.all.Model") @patch("neuronpedia_inference.endpoints.activation.all.SAEManager") @patch("neuronpedia_inference.endpoints.activation.all.Config") @@ -166,6 +173,7 @@ def test_process_activations_with_feature_filter( mock_config_class: MagicMock, mock_sae_manager_class: MagicMock, mock_model_class: MagicMock, + mock_get_activations: MagicMock, processor: ActivationProcessor, sample_request: ActivationAllPostRequest, mock_model: Mock, @@ -177,6 +185,7 @@ def test_process_activations_with_feature_filter( mock_model_class.get_instance.return_value = mock_model mock_sae_manager_class.get_instance.return_value = mock_sae_manager mock_config_class.get_instance.return_value = mock_config + mock_get_activations.return_value = torch.randn(128, 5) # Modify request for single layer with feature filter sample_request.selected_sources = ["0-test_set"] sample_request.feature_filter = [0, 1, 5, 10] @@ -212,6 +221,7 @@ def test_process_activations_with_neurons_sae_type( assert isinstance(result, ActivationAllPost200Response) +@patch("neuronpedia_inference.endpoints.activation.all.get_activations_by_index") @patch("neuronpedia_inference.endpoints.activation.all.Model") @patch("neuronpedia_inference.endpoints.activation.all.SAEManager") @patch("neuronpedia_inference.endpoints.activation.all.Config") @@ -219,6 +229,7 @@ def test_process_activations_with_dfa_enabled( mock_config_class: MagicMock, mock_sae_manager_class: MagicMock, mock_model_class: MagicMock, + mock_get_activations: MagicMock, processor: ActivationProcessor, sample_request: ActivationAllPostRequest, mock_model: Mock, @@ -230,6 +241,7 @@ def test_process_activations_with_dfa_enabled( mock_model_class.get_instance.return_value = mock_model mock_sae_manager_class.get_instance.return_value = mock_sae_manager mock_config_class.get_instance.return_value = mock_config + mock_get_activations.return_value = torch.randn(128, 5) # Enable DFA mock_sae_manager.is_dfa_enabled.return_value = True # Mock calculate_per_source_dfa function @@ -270,6 +282,7 @@ def test_process_activations_invalid_token_index( processor.process_activations(sample_request) +@patch("neuronpedia_inference.endpoints.activation.all.get_activations_by_index") @patch("neuronpedia_inference.endpoints.activation.all.Model") @patch("neuronpedia_inference.endpoints.activation.all.SAEManager") @patch("neuronpedia_inference.endpoints.activation.all.Config") @@ -277,6 +290,7 @@ def test_process_activations_ignore_bos( mock_config_class: MagicMock, mock_sae_manager_class: MagicMock, mock_model_class: MagicMock, + mock_get_activations: MagicMock, processor: ActivationProcessor, sample_request: ActivationAllPostRequest, mock_model: Mock, @@ -288,6 +302,7 @@ def test_process_activations_ignore_bos( mock_model_class.get_instance.return_value = mock_model mock_sae_manager_class.get_instance.return_value = mock_sae_manager mock_config_class.get_instance.return_value = mock_config + mock_get_activations.return_value = torch.randn(128, 5) sample_request.ignore_bos = True # Execute result = processor.process_activations(sample_request)