Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions src/modelbench/benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,32 @@ def __init__(self, runner: "BenchmarkRunner"):
for t in h.tests(self.secrets):
self.add_test(t)

def compile_annotations(self) -> list:
annotations = []
for sut_uid, hazards in self.finished_items.items():
for full_hazard_string, items in hazards.items():
try:
test, hazard, locale, prompt_set, version = full_hazard_string.split("-")
except:
hazard = full_hazard_string
prompt_set = "unknown"
# most prompts and annotations can't be shared. Until we have per-prompt shareability info,
# we only export annotations for items from the demo prompt.
# TODO fix this when we have per-prompt shareability info
if prompt_set != "demo":
continue
for item in items: # this is a TestRunItem
for _, annotation in item.annotations.items():
ann = {
"hazard": hazard,
"prompt": item.test_item.prompt.text,
"response": item.sut_response.text,
"is_safe": annotation.is_safe,
"is_valid": annotation.is_valid,
}
annotations.append(ann)
return annotations


class IntermediateCachingPipe(Pipe):
"""
Expand Down
15 changes: 11 additions & 4 deletions src/modelbench/cli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
import faulthandler
import io
import json
import logging
import pathlib
import pkgutil
Expand All @@ -16,7 +17,7 @@
from rich.table import Table

import modelgauge.annotators.cheval.registration # noqa: F401
from modelbench.benchmark_runner import BenchmarkRunner, JsonRunTracker, TqdmRunTracker
from modelbench.benchmark_runner import BenchmarkRun, BenchmarkRunner, JsonRunTracker, TqdmRunTracker
from modelbench.benchmarks import GeneralPurposeAiChatBenchmarkV1, SecurityBenchmark
from modelbench.consistency_checker import (
ConsistencyChecker,
Expand Down Expand Up @@ -185,21 +186,27 @@ def security_benchmark(
sut = make_sut(sut_uid)
benchmark = SecurityBenchmark(locale, prompt_set, evaluator=evaluator)
check_benchmark(benchmark)

run_and_report_benchmark(benchmark, sut, max_instances, debug, json_logs, output_dir, run_uid)


def run_and_report_benchmark(benchmark, sut, max_instances, debug, json_logs, output_dir, run_uid):
start_time = datetime.now(timezone.utc)
run = run_benchmarks_for_sut([benchmark], sut, max_instances, debug=debug, json_logs=json_logs)

benchmark_scores = score_benchmarks(run)
output_dir.mkdir(exist_ok=True, parents=True)
print_summary(benchmark, benchmark_scores)
json_path = output_dir / f"benchmark_record-{benchmark.uid}.json"
scores = [score for score in benchmark_scores if score.benchmark_definition == benchmark]
dump_json(json_path, start_time, benchmark, scores, run_uid)
print(f"Wrote record for {benchmark.uid} to {json_path}.")

# export the annotations separately
annotations = {"job_id": run.run_id, "annotations": run.compile_annotations()}
annotation_path = output_dir / f"annotations-{benchmark.uid}.json"
with open(annotation_path, "w") as annotation_records:
annotation_records.write(json.dumps(annotations))
print(f"Wrote annotations for {benchmark.uid} to {annotation_path}.")

run_consistency_check(run.journal_path, verbose=True)


Expand Down Expand Up @@ -286,7 +293,7 @@ def run_benchmarks_for_sut(
thread_count=32,
calibrating=False,
run_path: str = "./run",
):
) -> BenchmarkRun:
runner = BenchmarkRunner(pathlib.Path(run_path), calibrating=calibrating)
runner.secrets = load_secrets_from_config()
runner.benchmarks = benchmarks
Expand Down