Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions apps/autointerp/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
- [repo status](#repo-status)
- [what this is](#what-this-is)
- [simple non-docker setup](#simple-non-docker-setup)
- [Simple Usage](#simple-usage)
- [some docker commands for reference](#some-docker-commands-for-reference)
- [Testing, Linting, and Formatting](#testing-linting-and-formatting)

## repo status

Expand Down Expand Up @@ -33,6 +35,31 @@ as much as possible we try to use classes/types from the `packages/python/neuron
poetry run uvicorn server:app --host 0.0.0.0 --port 5003 --workers 1 --reload
```

## Simple Usage

Generate an explanation given 2 activations

```
curl -X POST "http://localhost:5003/v1/explain/default" \
-H "Content-Type: application/json" \
-d '{
"activations": [
{
"tokens": ["The", " cat", " sat", " on", " the", " mat"],
"values": [0.0, 0.8, 0.0, 0.0, 0.0, 0.0]
},
{
"tokens": ["I", " like", " felines"],
"values": [0, 0, 0.9]
}
],
"openrouter_key": "YOUR_OPENROUTER_KEY",
"model": "openai/gpt-4o-mini"
}'
```

See other endpoints under `/schemas/openapi/autointerp/paths`.

## some docker commands for reference

build the image from root directory
Expand Down
63 changes: 63 additions & 0 deletions apps/autointerp/neuronpedia_autointerp/logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import os
from logging.config import dictConfig


def initialize_logging():
home_dir = os.environ.get("HOME_DIR", ".")
log_directory = os.path.join(home_dir, "logs")
os.makedirs(log_directory, exist_ok=True)

dictConfig(
{
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"default": {
"format": "[%(asctime)s] %(levelname)s [%(name)s:%(lineno)s] %(message)s",
"datefmt": "%Y-%m-%d %H:%M:%S",
}
},
"handlers": {
"console": {
"class": "logging.StreamHandler",
"stream": "ext://sys.stdout",
"formatter": "default",
"level": "DEBUG",
},
"file": {
"class": "logging.handlers.RotatingFileHandler",
"filename": f"{log_directory}/server.log",
"maxBytes": 10485760, # 10MB
"backupCount": 5,
"formatter": "default",
"level": "DEBUG",
},
},
"loggers": {
"neuronpedia_autointerp": { # Main package logger
"level": "DEBUG",
"handlers": ["console", "file"],
"propagate": False,
},
"neuronpedia_autointerp.routes": { # All routes logger
"level": "DEBUG",
"handlers": ["console", "file"],
"propagate": False,
},
"uvicorn": { # Add uvicorn logger
"level": "INFO",
"handlers": ["console", "file"],
"propagate": False,
},
"fastapi": { # Add fastapi logger
"level": "INFO",
"handlers": ["console", "file"],
"propagate": False,
},
},
"root": { # Root logger
"level": "INFO",
"handlers": ["console", "file"],
},
}
)
31 changes: 17 additions & 14 deletions apps/autointerp/server.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# ruff: noqa: T201

import logging
import os
from collections.abc import Awaitable, Callable
from contextlib import asynccontextmanager

import sentry_sdk
import torch
Expand All @@ -20,6 +20,7 @@
)
from sentence_transformers import SentenceTransformer

from neuronpedia_autointerp.logging import initialize_logging
from neuronpedia_autointerp.routes.explain.default import explain_default
from neuronpedia_autointerp.routes.score.embedding import generate_score_embedding
from neuronpedia_autointerp.routes.score.fuzz_detection import (
Expand All @@ -34,9 +35,12 @@
load_dotenv()
SECRET = os.getenv("SECRET")

initialize_logging()
logger = logging.getLogger(__name__)

# only initialize sentry if we have a dsn
if os.getenv("SENTRY_DSN"):
print("initializing sentry")
logger.info("initializing sentry")
sentry_sdk.init(
dsn=os.getenv("SENTRY_DSN"),
# Set traces_sample_rate to 1.0 to capture 100%
Expand All @@ -54,45 +58,44 @@


def initialize_globals():
print("initializing globals")
logger.info("initializing globals")
global model
if torch.cuda.is_available():
model = SentenceTransformer(
"dunzhang/stella_en_400M_v5",
trust_remote_code=True, # type: ignore[call-arg]
).cuda()
print("initialized embedding model")
logger.info("initialized embedding model")
else:
print("no cuda available, not initializing embedding model")
logger.info("no cuda available, not initializing embedding model")


@router.post("/explain/default")
async def explanation_endpoint(request: ExplainDefaultPostRequest):
print("Explain Default Called")
return await explain_default(request)


@router.post("/score/embedding")
async def score_embedding_endpoint(request: ScoreEmbeddingPostRequest):
if model is None:
raise HTTPException(status_code=500, detail="Model not initialized")
print("Score Embedding Called")
return await generate_score_embedding(request, model)


@router.post("/score/fuzz-detection")
async def score_fuzz_detection_endpoint(request: ScoreFuzzDetectionPostRequest):
print("Score Fuzz Detection Called")
return await generate_score_fuzz_detection(request)


app = FastAPI()
app.include_router(router)
@asynccontextmanager
async def lifespan(app: FastAPI):
initialize_globals()
yield
logger.info("shutting down")


@app.on_event("startup") # type: ignore[deprecated]
async def startup_event():
initialize_globals()
app = FastAPI(lifespan=lifespan)
app.include_router(router)


@app.middleware("http")
Expand Down
16 changes: 8 additions & 8 deletions apps/inference/neuronpedia_inference/sae_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,23 +190,23 @@ def print_sae_status(self):
"""
Print a nicely formatted status of loadable and loaded SAEs.
"""
print("\nSAE Status:")
print("===========")
logger.info("\nSAE Status:")
logger.info("===========")

print("\nLoadable SAEs:")
logger.info("\nLoadable SAEs:")
for sae_set, sae_ids in self.sae_set_to_saes.items():
if sae_set == self.NEURONS_SOURCESET:
continue
print(f" {sae_set}:")
logger.info(f" {sae_set}:")
for sae_id in sae_ids:
status = "Loaded" if sae_id in self.loaded_saes else "Not Loaded"
print(f" - {sae_id}: {status}")
logger.info(f" - {sae_id}: {status}")

print("\nCurrently Loaded SAEs:")
logger.info("\nCurrently Loaded SAEs:")
for i, sae_id in enumerate(self.loaded_saes, 1):
print(f" {i}. {sae_id}")
logger.info(f" {i}. {sae_id}")

print(f"\nTotal Loaded: {len(self.loaded_saes)} / {self.max_loaded_saes}")
logger.info(f"\nTotal Loaded: {len(self.loaded_saes)} / {self.max_loaded_saes}")

# Utility methods
def get_sae_type(self, sae_id: str) -> str:
Expand Down
100 changes: 54 additions & 46 deletions apps/inference/neuronpedia_inference/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import sys
import traceback
from collections.abc import Awaitable
from contextlib import asynccontextmanager
from typing import Callable

import sentry_sdk
Expand Down Expand Up @@ -55,55 +56,11 @@
global initialized
initialized = False

app = FastAPI()

# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=False,
allow_methods=["*"],
allow_headers=["*"],
)

args = parse_env_and_args()


# we have to initialize SAE's AFTER server startup, because some infrastructure providers require
# our server to respond to health checks within a few minutes of starting up
@app.on_event("startup") # pyright: ignore[reportDeprecated]
async def startup_event():
logger.info("Starting initialization...")
# Wait briefly to ensure server is ready
await asyncio.sleep(3)
# Start initialization in background
asyncio.create_task(initialize(args.custom_hf_model_id))
logger.info("Initialization started")


v1_router = APIRouter(prefix="/v1")

v1_router.include_router(activation_all_router)
v1_router.include_router(steer_completion_chat_router)
v1_router.include_router(steer_completion_router)
v1_router.include_router(activation_single_router)
v1_router.include_router(activation_topk_by_token_router)
v1_router.include_router(sae_topk_by_decoder_cossim_router)
v1_router.include_router(sae_vector_router)
v1_router.include_router(tokenize_router)

app.include_router(v1_router)


@app.get("/health")
async def health_check():
return {"status": "healthy"}


@app.post("/initialize")
async def initialize(
custom_hf_model_id: str | None = None,
):
# Define initialization function that will be called during startup
async def initialize(custom_hf_model_id: str | None = None):
logger.info("Initializing...")

# Move the heavy operations to a separate thread pool to prevent blocking
Expand Down Expand Up @@ -218,6 +175,57 @@ def load_model_and_sae():
await asyncio.get_event_loop().run_in_executor(None, load_model_and_sae)


@asynccontextmanager
async def lifespan(app: FastAPI): # noqa: ARG001
logger.info("Starting initialization...")
# Wait briefly to ensure server is ready
await asyncio.sleep(3)
# Start initialization in background
asyncio.create_task(initialize(args.custom_hf_model_id))
logger.info("Initialization started")

yield

logger.info("Shutting down...")


app = FastAPI(lifespan=lifespan)

# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=False,
allow_methods=["*"],
allow_headers=["*"],
)

v1_router = APIRouter(prefix="/v1")

v1_router.include_router(activation_all_router)
v1_router.include_router(steer_completion_chat_router)
v1_router.include_router(steer_completion_router)
v1_router.include_router(activation_single_router)
v1_router.include_router(activation_topk_by_token_router)
v1_router.include_router(sae_topk_by_decoder_cossim_router)
v1_router.include_router(sae_vector_router)
v1_router.include_router(tokenize_router)

app.include_router(v1_router)


@app.get("/health")
async def health_check():
return {"status": "healthy"}


@app.post("/initialize")
async def initialize_endpoint(
custom_hf_model_id: str | None = None,
):
return await initialize(custom_hf_model_id)


@app.middleware("http")
async def check_secret_key(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
Expand Down