diff --git a/apps/autointerp/README.md b/apps/autointerp/README.md index 6003742a7..4a5eed2e7 100644 --- a/apps/autointerp/README.md +++ b/apps/autointerp/README.md @@ -3,7 +3,9 @@ - [repo status](#repo-status) - [what this is](#what-this-is) - [simple non-docker setup](#simple-non-docker-setup) +- [Simple Usage](#simple-usage) - [some docker commands for reference](#some-docker-commands-for-reference) +- [Testing, Linting, and Formatting](#testing-linting-and-formatting) ## repo status @@ -33,6 +35,31 @@ as much as possible we try to use classes/types from the `packages/python/neuron poetry run uvicorn server:app --host 0.0.0.0 --port 5003 --workers 1 --reload ``` +## Simple Usage + +Generate an explanation given 2 activations + +``` +curl -X POST "http://localhost:5003/v1/explain/default" \ + -H "Content-Type: application/json" \ + -d '{ + "activations": [ + { + "tokens": ["The", " cat", " sat", " on", " the", " mat"], + "values": [0.0, 0.8, 0.0, 0.0, 0.0, 0.0] + }, + { + "tokens": ["I", " like", " felines"], + "values": [0, 0, 0.9] + } + ], + "openrouter_key": "YOUR_OPENROUTER_KEY", + "model": "openai/gpt-4o-mini" + }' +``` + +See other endpoints under `/schemas/openapi/autointerp/paths`. + ## some docker commands for reference build the image from root directory diff --git a/apps/autointerp/neuronpedia_autointerp/logging.py b/apps/autointerp/neuronpedia_autointerp/logging.py new file mode 100644 index 000000000..8ecc6d881 --- /dev/null +++ b/apps/autointerp/neuronpedia_autointerp/logging.py @@ -0,0 +1,63 @@ +import os +from logging.config import dictConfig + + +def initialize_logging(): + home_dir = os.environ.get("HOME_DIR", ".") + log_directory = os.path.join(home_dir, "logs") + os.makedirs(log_directory, exist_ok=True) + + dictConfig( + { + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "default": { + "format": "[%(asctime)s] %(levelname)s [%(name)s:%(lineno)s] %(message)s", + "datefmt": "%Y-%m-%d %H:%M:%S", + } + }, + "handlers": { + "console": { + "class": "logging.StreamHandler", + "stream": "ext://sys.stdout", + "formatter": "default", + "level": "DEBUG", + }, + "file": { + "class": "logging.handlers.RotatingFileHandler", + "filename": f"{log_directory}/server.log", + "maxBytes": 10485760, # 10MB + "backupCount": 5, + "formatter": "default", + "level": "DEBUG", + }, + }, + "loggers": { + "neuronpedia_autointerp": { # Main package logger + "level": "DEBUG", + "handlers": ["console", "file"], + "propagate": False, + }, + "neuronpedia_autointerp.routes": { # All routes logger + "level": "DEBUG", + "handlers": ["console", "file"], + "propagate": False, + }, + "uvicorn": { # Add uvicorn logger + "level": "INFO", + "handlers": ["console", "file"], + "propagate": False, + }, + "fastapi": { # Add fastapi logger + "level": "INFO", + "handlers": ["console", "file"], + "propagate": False, + }, + }, + "root": { # Root logger + "level": "INFO", + "handlers": ["console", "file"], + }, + } + ) diff --git a/apps/autointerp/server.py b/apps/autointerp/server.py index ff98a306b..b5d229e8b 100644 --- a/apps/autointerp/server.py +++ b/apps/autointerp/server.py @@ -1,7 +1,7 @@ -# ruff: noqa: T201 - +import logging import os from collections.abc import Awaitable, Callable +from contextlib import asynccontextmanager import sentry_sdk import torch @@ -20,6 +20,7 @@ ) from sentence_transformers import SentenceTransformer +from neuronpedia_autointerp.logging import initialize_logging from neuronpedia_autointerp.routes.explain.default import explain_default from neuronpedia_autointerp.routes.score.embedding import generate_score_embedding from neuronpedia_autointerp.routes.score.fuzz_detection import ( @@ -34,9 +35,12 @@ load_dotenv() SECRET = os.getenv("SECRET") +initialize_logging() +logger = logging.getLogger(__name__) + # only initialize sentry if we have a dsn if os.getenv("SENTRY_DSN"): - print("initializing sentry") + logger.info("initializing sentry") sentry_sdk.init( dsn=os.getenv("SENTRY_DSN"), # Set traces_sample_rate to 1.0 to capture 100% @@ -54,21 +58,20 @@ def initialize_globals(): - print("initializing globals") + logger.info("initializing globals") global model if torch.cuda.is_available(): model = SentenceTransformer( "dunzhang/stella_en_400M_v5", trust_remote_code=True, # type: ignore[call-arg] ).cuda() - print("initialized embedding model") + logger.info("initialized embedding model") else: - print("no cuda available, not initializing embedding model") + logger.info("no cuda available, not initializing embedding model") @router.post("/explain/default") async def explanation_endpoint(request: ExplainDefaultPostRequest): - print("Explain Default Called") return await explain_default(request) @@ -76,23 +79,23 @@ async def explanation_endpoint(request: ExplainDefaultPostRequest): async def score_embedding_endpoint(request: ScoreEmbeddingPostRequest): if model is None: raise HTTPException(status_code=500, detail="Model not initialized") - print("Score Embedding Called") return await generate_score_embedding(request, model) @router.post("/score/fuzz-detection") async def score_fuzz_detection_endpoint(request: ScoreFuzzDetectionPostRequest): - print("Score Fuzz Detection Called") return await generate_score_fuzz_detection(request) -app = FastAPI() -app.include_router(router) +@asynccontextmanager +async def lifespan(app: FastAPI): + initialize_globals() + yield + logger.info("shutting down") -@app.on_event("startup") # type: ignore[deprecated] -async def startup_event(): - initialize_globals() +app = FastAPI(lifespan=lifespan) +app.include_router(router) @app.middleware("http") diff --git a/apps/inference/neuronpedia_inference/sae_manager.py b/apps/inference/neuronpedia_inference/sae_manager.py index c036c8415..4df8614de 100644 --- a/apps/inference/neuronpedia_inference/sae_manager.py +++ b/apps/inference/neuronpedia_inference/sae_manager.py @@ -190,23 +190,23 @@ def print_sae_status(self): """ Print a nicely formatted status of loadable and loaded SAEs. """ - print("\nSAE Status:") - print("===========") + logger.info("\nSAE Status:") + logger.info("===========") - print("\nLoadable SAEs:") + logger.info("\nLoadable SAEs:") for sae_set, sae_ids in self.sae_set_to_saes.items(): if sae_set == self.NEURONS_SOURCESET: continue - print(f" {sae_set}:") + logger.info(f" {sae_set}:") for sae_id in sae_ids: status = "Loaded" if sae_id in self.loaded_saes else "Not Loaded" - print(f" - {sae_id}: {status}") + logger.info(f" - {sae_id}: {status}") - print("\nCurrently Loaded SAEs:") + logger.info("\nCurrently Loaded SAEs:") for i, sae_id in enumerate(self.loaded_saes, 1): - print(f" {i}. {sae_id}") + logger.info(f" {i}. {sae_id}") - print(f"\nTotal Loaded: {len(self.loaded_saes)} / {self.max_loaded_saes}") + logger.info(f"\nTotal Loaded: {len(self.loaded_saes)} / {self.max_loaded_saes}") # Utility methods def get_sae_type(self, sae_id: str) -> str: diff --git a/apps/inference/neuronpedia_inference/server.py b/apps/inference/neuronpedia_inference/server.py index 3526cb80c..4cfbcc905 100644 --- a/apps/inference/neuronpedia_inference/server.py +++ b/apps/inference/neuronpedia_inference/server.py @@ -6,6 +6,7 @@ import sys import traceback from collections.abc import Awaitable +from contextlib import asynccontextmanager from typing import Callable import sentry_sdk @@ -55,55 +56,11 @@ global initialized initialized = False -app = FastAPI() - -# Add CORS middleware -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=False, - allow_methods=["*"], - allow_headers=["*"], -) - args = parse_env_and_args() -# we have to initialize SAE's AFTER server startup, because some infrastructure providers require -# our server to respond to health checks within a few minutes of starting up -@app.on_event("startup") # pyright: ignore[reportDeprecated] -async def startup_event(): - logger.info("Starting initialization...") - # Wait briefly to ensure server is ready - await asyncio.sleep(3) - # Start initialization in background - asyncio.create_task(initialize(args.custom_hf_model_id)) - logger.info("Initialization started") - - -v1_router = APIRouter(prefix="/v1") - -v1_router.include_router(activation_all_router) -v1_router.include_router(steer_completion_chat_router) -v1_router.include_router(steer_completion_router) -v1_router.include_router(activation_single_router) -v1_router.include_router(activation_topk_by_token_router) -v1_router.include_router(sae_topk_by_decoder_cossim_router) -v1_router.include_router(sae_vector_router) -v1_router.include_router(tokenize_router) - -app.include_router(v1_router) - - -@app.get("/health") -async def health_check(): - return {"status": "healthy"} - - -@app.post("/initialize") -async def initialize( - custom_hf_model_id: str | None = None, -): +# Define initialization function that will be called during startup +async def initialize(custom_hf_model_id: str | None = None): logger.info("Initializing...") # Move the heavy operations to a separate thread pool to prevent blocking @@ -218,6 +175,57 @@ def load_model_and_sae(): await asyncio.get_event_loop().run_in_executor(None, load_model_and_sae) +@asynccontextmanager +async def lifespan(app: FastAPI): # noqa: ARG001 + logger.info("Starting initialization...") + # Wait briefly to ensure server is ready + await asyncio.sleep(3) + # Start initialization in background + asyncio.create_task(initialize(args.custom_hf_model_id)) + logger.info("Initialization started") + + yield + + logger.info("Shutting down...") + + +app = FastAPI(lifespan=lifespan) + +# Add CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=False, + allow_methods=["*"], + allow_headers=["*"], +) + +v1_router = APIRouter(prefix="/v1") + +v1_router.include_router(activation_all_router) +v1_router.include_router(steer_completion_chat_router) +v1_router.include_router(steer_completion_router) +v1_router.include_router(activation_single_router) +v1_router.include_router(activation_topk_by_token_router) +v1_router.include_router(sae_topk_by_decoder_cossim_router) +v1_router.include_router(sae_vector_router) +v1_router.include_router(tokenize_router) + +app.include_router(v1_router) + + +@app.get("/health") +async def health_check(): + return {"status": "healthy"} + + +@app.post("/initialize") +async def initialize_endpoint( + custom_hf_model_id: str | None = None, +): + return await initialize(custom_hf_model_id) + + @app.middleware("http") async def check_secret_key( request: Request, call_next: Callable[[Request], Awaitable[Response]]