Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ on:
workflow_dispatch:
inputs:
branch:
description: 'Branch'
description: "Branch"
required: true
default: main

Expand All @@ -32,7 +32,13 @@ jobs:
./start_services.sh --no-jupyter -d

- name: Install poetry
run: pipx install "poetry == 1.8.5"
run: pipx install poetry
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not specify the version anymore?

Copy link
Contributor Author

@superdosh superdosh Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We were pinned to a very old version because of some weirdness (in modelbench) that I never fully understood, but I figure that's not necessary here anymore -- I think it was actually driven by modelbench-private? Potentially pinning to some newer version is good, but at least we shouldn't be stuck on 1 if not necessary. Going to leave it open for now!


- name: Verify MLflow versions match
run: ./scripts/check_mlflow_versions.sh

- name: Check poetry lock file
run: poetry check --lock

- name: Remove existing virtual environment
run: |
Expand Down Expand Up @@ -75,6 +81,5 @@ jobs:
run: |
docker exec modelplane-jupyter-1 poetry run python /app/test_notebooks.py


- name: Stop MLflow server
run: docker compose down
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,7 @@ secrets.toml
.vscode/
.coverage*
.cache
*.csv
*.txt
*.json
*.jsonl
7 changes: 3 additions & 4 deletions Dockerfile.mlflow
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
FROM ghcr.io/mlflow/mlflow:v3.1.1
FROM ghcr.io/mlflow/mlflow:v3.7.0

# The base image does not include various dependencies that are needed for
# the MLflow server. We assume a postgres backend, so we need psycopg2.
# We also need boto3 for S3 support, and google-cloud-storage for GCS support.
# TODO: better way to install these (maybe using poetry.lock to grab consistent versions?)
RUN pip install mlflow[auth]==3.1.1 psycopg2-binary==2.9.10 boto3==1.38.31 \
google-cloud-storage==3.1.0
RUN pip install mlflow[auth]==3.7.0 psycopg2-binary==2.9.11 boto3==1.42.5 \
google-cloud-storage==3.4.1
20 changes: 8 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,8 @@ for access.
MLFlow server (`MLFLOW_TRACKING_USERNAME` /
`MLFLOW_TRACKING_PASSWORD`).
* Alternatively, put the credentials in `~/.mlflow/credentials` as described [here](https://mlflow.org/docs/latest/ml/auth/#credentials-file).
1. To access `modelbench-private` code (assuming you have
access), you must also set `USE_MODELBENCH_PRIVATE=true` in `.env.jupyteronly`. This will forward your ssh agent to the container
allowing it to load the private repository to build the image.
1. To access the private annotators, you need to set up credentials to access cheval (see `modelgauge.annotators.cheval.registration`)
and reach out to [[email protected]](mailto:[email protected]) for the credentials.
1. Start jupyter with `./start_jupyter.sh`. (You can add the
`-d` flag to start in the background.)

Expand Down Expand Up @@ -97,17 +96,14 @@ or you can get the `run_id` via the MLFlow UI.
MLFLOW_TRACKING_URI=http://localhost:8080 poetry run modelplane annotate --annotator_id {annotator_id} --experiment expname --response_run_id {run_id}
```

### Custom Ensembles
#### Private Ensemble
If you have access to the private annotator, you can run directly with:
```
MLFLOW_TRACKING_URI=http://localhost:8080 poetry run modelplane annotate --annotator_id {annotator_id1} --annotator_id {annotator_id2} --ensemble_strategy {ensemble_strategy} --experiment expname --response_file path/to/response.csv
MLFLOW_TRACKING_URI=http://localhost:8080 poetry run modelplane annotate --annotator_id safety-v1.1 --experiment expname --response_run_id {run_id}
```

### Private Ensemble
If you have access to the private ensemble, you can install with the needed extras
```
poetry install --extras modelbench-private
```
And then run annotations with:

### Custom Ensembles
```
MLFLOW_TRACKING_URI=http://localhost:8080 poetry run modelplane annotate --ensemble_id official --experiment expname --response_run_id {run_id}
MLFLOW_TRACKING_URI=http://localhost:8080 poetry run modelplane annotate --annotator_id {annotator_id1} --annotator_id {annotator_id2} --ensemble_strategy {ensemble_strategy} --experiment expname --response_file path/to/response.csv
```
6 changes: 5 additions & 1 deletion docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ services:
GOOGLE_APPLICATION_CREDENTIALS: /creds/gcp-key.json
# if not provided via volume below, AWS S3 will not work as artifact store
AWS_SHARED_CREDENTIALS_FILE: /creds/aws-credentials
# https://mlflow.org/docs/latest/community/usage-tracking/
MLFLOW_DISABLE_TELEMETRY: "true"
DO_NOT_TRACK: "true"
depends_on:
- postgres
# grab backend from .env, pass artifact root, if provided, otherwise local storage of artifacts
Expand All @@ -37,6 +40,7 @@ services:
--serve-artifacts
--host 0.0.0.0
--port 8080
--allowed-hosts 'mlflow:*,localhost:*,127.0.0.1:*'
ports:
- "8080:8080"
healthcheck:
Expand All @@ -51,7 +55,7 @@ services:
- ${GOOGLE_CREDENTIALS_PATH:-/dev/null}:/creds/gcp-key.json:ro
# Volume only needed for AWS S3 storage artifacts
- ${AWS_CREDENTIALS_PATH:-/dev/null}:/creds/aws-credentials:ro

jupyter:
build:
context: .
Expand Down
31 changes: 10 additions & 21 deletions flightpaths/Annotator Development Template.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@
"\n",
"An implementation of the annotator should define:\n",
"\n",
"* `translate_request`: takes a `TestItem` and `SUTResponse` (which have the needed information on the prompt responses for the given LLM) and translates the request into an arbitrary form useful for the `annotate` method.\n",
"* `translate_prompt`: takes a `TextPrompt` and `SUTResponse` (which have the needed information on the prompt responses for the given LLM) and translates the request into an arbitrary form useful for the `annotate` method.\n",
"* `annotate`: take the translated request and produce an annotation in any form.\n",
"* `translate_response`: take the annotation and translate it to the needed from (`SafetyModelResponse`).\n",
"* `translate_response`: take the annotation and translate it to the needed from (`SafetyAnnotation`).\n",
"\n",
"Below, we give a sample `RandomAnnotator`, which randomly decides whether a given response is safe by hazard category."
]
Expand All @@ -94,10 +94,10 @@
"\n",
"from pydantic import BaseModel\n",
"\n",
"from modelgauge.annotator import CompletionAnnotator\n",
"from modelgauge.annotation import SafetyAnnotation\n",
"from modelgauge.annotator import Annotator\n",
"from modelgauge.annotator_registry import ANNOTATORS\n",
"from modelgauge.prompt import ChatPrompt, TextPrompt\n",
"from modelgauge.safety_model_response import SafetyModelResponse\n",
"from modelgauge.single_turn_prompt_response import TestItem\n",
"from modelgauge.sut import SUTResponse\n",
"from modelgauge.tests.safe_v1 import BaseSafeTestVersion1\n",
Expand All @@ -107,9 +107,8 @@
" response: SUTResponse\n",
"\n",
"\n",
"class RandomAnnotator(CompletionAnnotator[SafetyModelResponse]):\n",
" \"\"\"Simple random implementation for demonstration. Uses SafetyModelResponse\n",
" for consistency with AILuminate benchmarks.\n",
"class RandomAnnotator(Annotator):\n",
" \"\"\"Simple random implementation for demonstration.\n",
"\n",
" The methods to implement are based on the idea that annotate will \n",
" potentially make a call to another system and the translators are helpful \n",
Expand All @@ -126,18 +125,16 @@
" response=response,\n",
" )\n",
"\n",
" def annotate(self, annotation_request: AnnotationRequest) -> SafetyModelResponse:\n",
" def annotate(self, annotation_request: AnnotationRequest) -> SafetyAnnotation:\n",
" unsafe_categories = []\n",
" for category in BaseSafeTestVersion1.hazards:\n",
" if random.random() < self.prob_unsafe:\n",
" unsafe_categories.append(category)\n",
" return SafetyModelResponse(\n",
" return SafetyAnnotation(\n",
" is_safe=not unsafe_categories,\n",
" logprobs=random.random(),\n",
" safety_categories=unsafe_categories,\n",
" )\n",
"\n",
" def translate_response(self, request: AnnotationRequest, response: SafetyModelResponse) -> SafetyModelResponse:\n",
" def translate_response(self, request: AnnotationRequest, response: SafetyAnnotation) -> SafetyAnnotation:\n",
" return response\n",
"\n",
"\n",
Expand Down Expand Up @@ -222,14 +219,6 @@
" ground_truth=ground_truth,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "af9debec-28be-4a50-82da-5d7025de7d76",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand All @@ -248,7 +237,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.11"
"version": "3.12.12"
}
},
"nbformat": 4,
Expand Down
46 changes: 14 additions & 32 deletions flightpaths/Ensemble Development Template.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,11 @@
"\n",
"* An implementation of an annotator should define:\n",
"\n",
" * `translate_request`: takes a `TestItem` and `SUTResponse` (which have the needed information on the prompt responses for the given LLM) and translates the request into an arbitrary form useful for the `annotate` method.\n",
" * `translate_prompt`: takes a `TextPrompt` and `SUTResponse` (which have the needed information on the prompt responses for the given LLM) and translates the request into an arbitrary form useful for the `annotate` method.\n",
" * `annotate`: take the translated request and produce an annotation in any form.\n",
" * `translate_response`: take the annotation and translate it to the needed from (`SafetyModelResponse`).\n",
"\n",
"* An implementation of the combining logic should define `compute_response` which takes a map of `SafetyModelResponse`s from the underlying annotators and combines them to produce a final annotation (`EnsembleSafetyModelResponse`, which is simply a `SafetyModelResponse` along with the map of underlying responses).\n",
"* An implementation of the combining logic should define `compute_response` which takes a map of `SafetyAnnotation`s from the underlying annotators and combines them to produce a final annotation (`EnsembleSafetyAnnotation`, which is simply a `SafetyAnnotation` along with the map of underlying responses).\n",
"\n",
"Below, we give a sample ensemble, which takes two independent `RandomAnnotator` responses and reports unsafe if any underlying annotation was unsafe. This requires implementations of the `RandomAnnotator` and `AnyUnsafeStrategy`."
]
Expand All @@ -100,21 +100,21 @@
"source": [
"import random\n",
"\n",
"from pydantic import BaseModel\n",
"\n",
"from modelgauge.annotator import CompletionAnnotator\n",
"from modelgauge.annotation import SafetyAnnotation\n",
"from modelgauge.annotator import Annotator\n",
"from modelgauge.annotator_registry import ANNOTATORS\n",
"from modelgauge.prompt import ChatPrompt, TextPrompt\n",
"from modelgauge.safety_model_response import SafetyModelResponse\n",
"from modelgauge.sut import SUTResponse\n",
"from modelgauge.tests.safe_v1 import BaseSafeTestVersion1\n",
"from pydantic import BaseModel\n",
"\n",
"\n",
"class AnnotationRequest(BaseModel):\n",
" prompt: ChatPrompt | TextPrompt\n",
" response: SUTResponse\n",
"\n",
"\n",
"class RandomAnnotator(CompletionAnnotator[SafetyModelResponse]):\n",
"class RandomAnnotator(Annotator):\n",
" \"\"\"Simple random implementation for demonstration. Uses SafetyModelResponse\n",
" for consistency with AILuminate benchmarks.\n",
"\n",
Expand All @@ -132,19 +132,17 @@
" response=response,\n",
" )\n",
"\n",
" def annotate(self, annotation_request: AnnotationRequest) -> SafetyModelResponse:\n",
" def annotate(self, annotation_request: AnnotationRequest) -> SafetyAnnotation:\n",
" unsafe_categories = []\n",
" for category in BaseSafeTestVersion1.hazards:\n",
" if random.random() < self.prob_unsafe:\n",
" unsafe_categories.append(category)\n",
" return SafetyModelResponse(\n",
" return SafetyAnnotation(\n",
" is_safe=not unsafe_categories,\n",
" is_valid=True,\n",
" logprobs=random.random(),\n",
" safety_categories=unsafe_categories,\n",
" )\n",
"\n",
" def translate_response(self, request: AnnotationRequest, response: SafetyModelResponse) -> SafetyModelResponse:\n",
" def translate_response(self, request: AnnotationRequest, response: SafetyAnnotation) -> SafetyAnnotation:\n",
" return response\n",
"\n",
"\n",
Expand Down Expand Up @@ -176,26 +174,22 @@
"source": [
"from typing import Dict\n",
"\n",
"from modelgauge.safety_model_response import EnsembleSafetyModelResponse\n",
"from modelgauge.ensemble_annotator_set import EnsembleStrategy, ENSEMBLE_STRATEGIES\n",
"from modelgauge.annotation import EnsembleSafetyAnnotation\n",
"from modelgauge.ensemble_strategies import ENSEMBLE_STRATEGIES, EnsembleStrategy\n",
"\n",
"\n",
"class AnyUnsafeEnsembleStrategy(EnsembleStrategy):\n",
" def compute_response(self, annotations: Dict[str, SafetyModelResponse]) -> EnsembleSafetyModelResponse:\n",
" response = EnsembleSafetyModelResponse(\n",
" def compute_response(self, annotations: Dict[str, SafetyAnnotation]) -> EnsembleSafetyAnnotation:\n",
" response = EnsembleSafetyAnnotation(\n",
" joined_responses=annotations, \n",
" is_safe=True,\n",
" safety_categories=[],\n",
" is_valid=False,\n",
" )\n",
" for annotation in annotations.values():\n",
" if annotation.is_valid:\n",
" response.is_valid = True\n",
" if not annotation.is_safe:\n",
" response.is_safe = False\n",
" current_categories = set(response.safety_categories)\n",
" current_categories.update(annotation.safety_categories)\n",
" response.safety_categories = sorted(current_categories)\n",
" return response\n",
"\n",
"ENSEMBLE_STRATEGIES[\"anyunsafe\"] = AnyUnsafeEnsembleStrategy()"
Expand Down Expand Up @@ -286,18 +280,6 @@
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.11"
}
},
"nbformat": 4,
Expand Down
Loading