diff --git a/apps/autointerp/README.md b/apps/autointerp/README.md index 6003742a7..4a5eed2e7 100644 --- a/apps/autointerp/README.md +++ b/apps/autointerp/README.md @@ -3,7 +3,9 @@ - [repo status](#repo-status) - [what this is](#what-this-is) - [simple non-docker setup](#simple-non-docker-setup) +- [Simple Usage](#simple-usage) - [some docker commands for reference](#some-docker-commands-for-reference) +- [Testing, Linting, and Formatting](#testing-linting-and-formatting) ## repo status @@ -33,6 +35,31 @@ as much as possible we try to use classes/types from the `packages/python/neuron poetry run uvicorn server:app --host 0.0.0.0 --port 5003 --workers 1 --reload ``` +## Simple Usage + +Generate an explanation given 2 activations + +``` +curl -X POST "http://localhost:5003/v1/explain/default" \ + -H "Content-Type: application/json" \ + -d '{ + "activations": [ + { + "tokens": ["The", " cat", " sat", " on", " the", " mat"], + "values": [0.0, 0.8, 0.0, 0.0, 0.0, 0.0] + }, + { + "tokens": ["I", " like", " felines"], + "values": [0, 0, 0.9] + } + ], + "openrouter_key": "YOUR_OPENROUTER_KEY", + "model": "openai/gpt-4o-mini" + }' +``` + +See other endpoints under `/schemas/openapi/autointerp/paths`. + ## some docker commands for reference build the image from root directory diff --git a/apps/autointerp/neuronpedia_autointerp/logging.py b/apps/autointerp/neuronpedia_autointerp/logging.py new file mode 100644 index 000000000..8ecc6d881 --- /dev/null +++ b/apps/autointerp/neuronpedia_autointerp/logging.py @@ -0,0 +1,63 @@ +import os +from logging.config import dictConfig + + +def initialize_logging(): + home_dir = os.environ.get("HOME_DIR", ".") + log_directory = os.path.join(home_dir, "logs") + os.makedirs(log_directory, exist_ok=True) + + dictConfig( + { + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "default": { + "format": "[%(asctime)s] %(levelname)s [%(name)s:%(lineno)s] %(message)s", + "datefmt": "%Y-%m-%d %H:%M:%S", + } + }, + "handlers": { + "console": { + "class": "logging.StreamHandler", + "stream": "ext://sys.stdout", + "formatter": "default", + "level": "DEBUG", + }, + "file": { + "class": "logging.handlers.RotatingFileHandler", + "filename": f"{log_directory}/server.log", + "maxBytes": 10485760, # 10MB + "backupCount": 5, + "formatter": "default", + "level": "DEBUG", + }, + }, + "loggers": { + "neuronpedia_autointerp": { # Main package logger + "level": "DEBUG", + "handlers": ["console", "file"], + "propagate": False, + }, + "neuronpedia_autointerp.routes": { # All routes logger + "level": "DEBUG", + "handlers": ["console", "file"], + "propagate": False, + }, + "uvicorn": { # Add uvicorn logger + "level": "INFO", + "handlers": ["console", "file"], + "propagate": False, + }, + "fastapi": { # Add fastapi logger + "level": "INFO", + "handlers": ["console", "file"], + "propagate": False, + }, + }, + "root": { # Root logger + "level": "INFO", + "handlers": ["console", "file"], + }, + } + ) diff --git a/apps/autointerp/server.py b/apps/autointerp/server.py index 49ea4eae2..b5d229e8b 100644 --- a/apps/autointerp/server.py +++ b/apps/autointerp/server.py @@ -1,5 +1,4 @@ -# ruff: noqa: T201 - +import logging import os from collections.abc import Awaitable, Callable from contextlib import asynccontextmanager @@ -21,6 +20,7 @@ ) from sentence_transformers import SentenceTransformer +from neuronpedia_autointerp.logging import initialize_logging from neuronpedia_autointerp.routes.explain.default import explain_default from neuronpedia_autointerp.routes.score.embedding import generate_score_embedding from neuronpedia_autointerp.routes.score.fuzz_detection import ( @@ -35,9 +35,12 @@ load_dotenv() SECRET = os.getenv("SECRET") +initialize_logging() +logger = logging.getLogger(__name__) + # only initialize sentry if we have a dsn if os.getenv("SENTRY_DSN"): - print("initializing sentry") + logger.info("initializing sentry") sentry_sdk.init( dsn=os.getenv("SENTRY_DSN"), # Set traces_sample_rate to 1.0 to capture 100% @@ -55,21 +58,20 @@ def initialize_globals(): - print("initializing globals") + logger.info("initializing globals") global model if torch.cuda.is_available(): model = SentenceTransformer( "dunzhang/stella_en_400M_v5", trust_remote_code=True, # type: ignore[call-arg] ).cuda() - print("initialized embedding model") + logger.info("initialized embedding model") else: - print("no cuda available, not initializing embedding model") + logger.info("no cuda available, not initializing embedding model") @router.post("/explain/default") async def explanation_endpoint(request: ExplainDefaultPostRequest): - print("Explain Default Called") return await explain_default(request) @@ -77,13 +79,11 @@ async def explanation_endpoint(request: ExplainDefaultPostRequest): async def score_embedding_endpoint(request: ScoreEmbeddingPostRequest): if model is None: raise HTTPException(status_code=500, detail="Model not initialized") - print("Score Embedding Called") return await generate_score_embedding(request, model) @router.post("/score/fuzz-detection") async def score_fuzz_detection_endpoint(request: ScoreFuzzDetectionPostRequest): - print("Score Fuzz Detection Called") return await generate_score_fuzz_detection(request) @@ -91,7 +91,7 @@ async def score_fuzz_detection_endpoint(request: ScoreFuzzDetectionPostRequest): async def lifespan(app: FastAPI): initialize_globals() yield - print("shutting down") + logger.info("shutting down") app = FastAPI(lifespan=lifespan) diff --git a/apps/inference/neuronpedia_inference/sae_manager.py b/apps/inference/neuronpedia_inference/sae_manager.py index c036c8415..4df8614de 100644 --- a/apps/inference/neuronpedia_inference/sae_manager.py +++ b/apps/inference/neuronpedia_inference/sae_manager.py @@ -190,23 +190,23 @@ def print_sae_status(self): """ Print a nicely formatted status of loadable and loaded SAEs. """ - print("\nSAE Status:") - print("===========") + logger.info("\nSAE Status:") + logger.info("===========") - print("\nLoadable SAEs:") + logger.info("\nLoadable SAEs:") for sae_set, sae_ids in self.sae_set_to_saes.items(): if sae_set == self.NEURONS_SOURCESET: continue - print(f" {sae_set}:") + logger.info(f" {sae_set}:") for sae_id in sae_ids: status = "Loaded" if sae_id in self.loaded_saes else "Not Loaded" - print(f" - {sae_id}: {status}") + logger.info(f" - {sae_id}: {status}") - print("\nCurrently Loaded SAEs:") + logger.info("\nCurrently Loaded SAEs:") for i, sae_id in enumerate(self.loaded_saes, 1): - print(f" {i}. {sae_id}") + logger.info(f" {i}. {sae_id}") - print(f"\nTotal Loaded: {len(self.loaded_saes)} / {self.max_loaded_saes}") + logger.info(f"\nTotal Loaded: {len(self.loaded_saes)} / {self.max_loaded_saes}") # Utility methods def get_sae_type(self, sae_id: str) -> str: