Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions apps/autointerp/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
- [repo status](#repo-status)
- [what this is](#what-this-is)
- [simple non-docker setup](#simple-non-docker-setup)
- [Simple Usage](#simple-usage)
- [some docker commands for reference](#some-docker-commands-for-reference)
- [Testing, Linting, and Formatting](#testing-linting-and-formatting)

## repo status

Expand Down Expand Up @@ -33,6 +35,31 @@ as much as possible we try to use classes/types from the `packages/python/neuron
poetry run uvicorn server:app --host 0.0.0.0 --port 5003 --workers 1 --reload
```

## Simple Usage

Generate an explanation given 2 activations

```
curl -X POST "http://localhost:5003/v1/explain/default" \
-H "Content-Type: application/json" \
-d '{
"activations": [
{
"tokens": ["The", " cat", " sat", " on", " the", " mat"],
"values": [0.0, 0.8, 0.0, 0.0, 0.0, 0.0]
},
{
"tokens": ["I", " like", " felines"],
"values": [0, 0, 0.9]
}
],
"openrouter_key": "YOUR_OPENROUTER_KEY",
"model": "openai/gpt-4o-mini"
}'
```

See other endpoints under `/schemas/openapi/autointerp/paths`.

## some docker commands for reference

build the image from root directory
Expand Down
63 changes: 63 additions & 0 deletions apps/autointerp/neuronpedia_autointerp/logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import os
from logging.config import dictConfig


def initialize_logging():
home_dir = os.environ.get("HOME_DIR", ".")
log_directory = os.path.join(home_dir, "logs")
os.makedirs(log_directory, exist_ok=True)

dictConfig(
{
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"default": {
"format": "[%(asctime)s] %(levelname)s [%(name)s:%(lineno)s] %(message)s",
"datefmt": "%Y-%m-%d %H:%M:%S",
}
},
"handlers": {
"console": {
"class": "logging.StreamHandler",
"stream": "ext://sys.stdout",
"formatter": "default",
"level": "DEBUG",
},
"file": {
"class": "logging.handlers.RotatingFileHandler",
"filename": f"{log_directory}/server.log",
"maxBytes": 10485760, # 10MB
"backupCount": 5,
"formatter": "default",
"level": "DEBUG",
},
},
"loggers": {
"neuronpedia_autointerp": { # Main package logger
"level": "DEBUG",
"handlers": ["console", "file"],
"propagate": False,
},
"neuronpedia_autointerp.routes": { # All routes logger
"level": "DEBUG",
"handlers": ["console", "file"],
"propagate": False,
},
"uvicorn": { # Add uvicorn logger
"level": "INFO",
"handlers": ["console", "file"],
"propagate": False,
},
"fastapi": { # Add fastapi logger
"level": "INFO",
"handlers": ["console", "file"],
"propagate": False,
},
},
"root": { # Root logger
"level": "INFO",
"handlers": ["console", "file"],
},
}
)
20 changes: 10 additions & 10 deletions apps/autointerp/server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# ruff: noqa: T201

import logging
import os
from collections.abc import Awaitable, Callable
from contextlib import asynccontextmanager
Expand All @@ -21,6 +20,7 @@
)
from sentence_transformers import SentenceTransformer

from neuronpedia_autointerp.logging import initialize_logging
from neuronpedia_autointerp.routes.explain.default import explain_default
from neuronpedia_autointerp.routes.score.embedding import generate_score_embedding
from neuronpedia_autointerp.routes.score.fuzz_detection import (
Expand All @@ -35,9 +35,12 @@
load_dotenv()
SECRET = os.getenv("SECRET")

initialize_logging()
logger = logging.getLogger(__name__)

# only initialize sentry if we have a dsn
if os.getenv("SENTRY_DSN"):
print("initializing sentry")
logger.info("initializing sentry")
sentry_sdk.init(
dsn=os.getenv("SENTRY_DSN"),
# Set traces_sample_rate to 1.0 to capture 100%
Expand All @@ -55,43 +58,40 @@


def initialize_globals():
print("initializing globals")
logger.info("initializing globals")
global model
if torch.cuda.is_available():
model = SentenceTransformer(
"dunzhang/stella_en_400M_v5",
trust_remote_code=True, # type: ignore[call-arg]
).cuda()
print("initialized embedding model")
logger.info("initialized embedding model")
else:
print("no cuda available, not initializing embedding model")
logger.info("no cuda available, not initializing embedding model")


@router.post("/explain/default")
async def explanation_endpoint(request: ExplainDefaultPostRequest):
print("Explain Default Called")
return await explain_default(request)


@router.post("/score/embedding")
async def score_embedding_endpoint(request: ScoreEmbeddingPostRequest):
if model is None:
raise HTTPException(status_code=500, detail="Model not initialized")
print("Score Embedding Called")
return await generate_score_embedding(request, model)


@router.post("/score/fuzz-detection")
async def score_fuzz_detection_endpoint(request: ScoreFuzzDetectionPostRequest):
print("Score Fuzz Detection Called")
return await generate_score_fuzz_detection(request)


@asynccontextmanager
async def lifespan(app: FastAPI):
initialize_globals()
yield
print("shutting down")
logger.info("shutting down")


app = FastAPI(lifespan=lifespan)
Expand Down
16 changes: 8 additions & 8 deletions apps/inference/neuronpedia_inference/sae_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,23 +190,23 @@ def print_sae_status(self):
"""
Print a nicely formatted status of loadable and loaded SAEs.
"""
print("\nSAE Status:")
print("===========")
logger.info("\nSAE Status:")
logger.info("===========")

print("\nLoadable SAEs:")
logger.info("\nLoadable SAEs:")
for sae_set, sae_ids in self.sae_set_to_saes.items():
if sae_set == self.NEURONS_SOURCESET:
continue
print(f" {sae_set}:")
logger.info(f" {sae_set}:")
for sae_id in sae_ids:
status = "Loaded" if sae_id in self.loaded_saes else "Not Loaded"
print(f" - {sae_id}: {status}")
logger.info(f" - {sae_id}: {status}")

print("\nCurrently Loaded SAEs:")
logger.info("\nCurrently Loaded SAEs:")
for i, sae_id in enumerate(self.loaded_saes, 1):
print(f" {i}. {sae_id}")
logger.info(f" {i}. {sae_id}")

print(f"\nTotal Loaded: {len(self.loaded_saes)} / {self.max_loaded_saes}")
logger.info(f"\nTotal Loaded: {len(self.loaded_saes)} / {self.max_loaded_saes}")

# Utility methods
def get_sae_type(self, sae_id: str) -> str:
Expand Down