diff --git a/environments/aci_bench/aci_bench/aci_bench.py b/environments/aci_bench/aci_bench/aci_bench.py index 512d0439..1392f186 100644 --- a/environments/aci_bench/aci_bench/aci_bench.py +++ b/environments/aci_bench/aci_bench/aci_bench.py @@ -1,5 +1,6 @@ from typing import Any +import evaluate import verifiers as vf from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset from datasets.utils.logging import disable_progress_bar @@ -99,6 +100,7 @@ def load_environment( judge_base_url: str | None = None, judge_api_key: str | None = None, system_prompt: str | None = None, + compute_auto_metrics: bool = True, **kwargs: Any, ) -> vf.Environment: # -------- load dataset and convert to vf format -------- @@ -116,6 +118,18 @@ def load_environment( # valid_ds = _to_vf_format(dataset["valid"]) test_ds = _to_vf_format(concatenate_datasets([dataset["test1"], dataset["test2"], dataset["test3"]])) + # -------- initialize automatic metrics -------- + rouge_metric = None + bertscore_metric = None + + if compute_auto_metrics: + try: + rouge_metric = evaluate.load("rouge") + bertscore_metric = evaluate.load("bertscore") + except Exception as e: + print(f"Warning: Could not load automatic metrics: {e}") + compute_auto_metrics = False + # -------- normalize answer_format -------- answer_format = AnswerFormat(answer_format) if isinstance(answer_format, str) else answer_format @@ -132,6 +146,8 @@ def load_environment( # -------- setup judge -------- api_key = default_judge_api_key(judge_base_url) if judge_api_key is None else judge_api_key sampling_args, default_headers = judge_sampling_args_and_headers(judge_model, judge_base_url) + # Remove extra_body as OpenAI doesn't support the usage tracking parameter + sampling_args.pop("extra_body", None) judge_parser = JSONParser(fields=["accuracy", "completeness", "clarity"]) judge_rubric = vf.JudgeRubric( @@ -174,6 +190,50 @@ async def judge_rubric_reward(completion: Messages, info: Info, state: State, ** } ) + # --- Automatic Metrics (BLEU, ROUGE, BERTScore) --- + if compute_auto_metrics and completion_text and gold_response: + auto_metrics: dict[str, Any] = {} + predictions = [completion_text] + references = [gold_response] + + # BLEU (with smoothing for sentence-level evaluation) + try: + from sacrebleu.metrics import BLEU + + bleu_scorer = BLEU(smooth_method="exp", effective_order=True) + bleu_result = bleu_scorer.sentence_score(completion_text, [gold_response]) + auto_metrics["bleu"] = bleu_result.score / 100.0 # normalize to 0-1 + except Exception: + auto_metrics["bleu"] = 0.0 + + # ROUGE + try: + rouge_scores = rouge_metric.compute(predictions=predictions, references=references) + auto_metrics["rouge1"] = rouge_scores.get("rouge1", 0.0) + auto_metrics["rouge2"] = rouge_scores.get("rouge2", 0.0) + auto_metrics["rougeL"] = rouge_scores.get("rougeL", 0.0) + auto_metrics["rougeLsum"] = rouge_scores.get("rougeLsum", 0.0) + except Exception: + auto_metrics["rouge1"] = 0.0 + auto_metrics["rouge2"] = 0.0 + auto_metrics["rougeL"] = 0.0 + auto_metrics["rougeLsum"] = 0.0 + + # BERTScore + try: + bert_scores = bertscore_metric.compute(predictions=predictions, references=references, lang="en") + auto_metrics["bertscore_precision"] = ( + bert_scores["precision"][0] if bert_scores.get("precision") else 0.0 + ) + auto_metrics["bertscore_recall"] = bert_scores["recall"][0] if bert_scores.get("recall") else 0.0 + auto_metrics["bertscore_f1"] = bert_scores["f1"][0] if bert_scores.get("f1") else 0.0 + except Exception: + auto_metrics["bertscore_precision"] = 0.0 + auto_metrics["bertscore_recall"] = 0.0 + auto_metrics["bertscore_f1"] = 0.0 + + info["auto_metrics"] = auto_metrics + return normalized judge_rubric.add_reward_func(judge_rubric_reward, weight=1.0) diff --git a/environments/aci_bench/pyproject.toml b/environments/aci_bench/pyproject.toml index c97e0579..5286bea8 100644 --- a/environments/aci_bench/pyproject.toml +++ b/environments/aci_bench/pyproject.toml @@ -8,6 +8,10 @@ requires-python = ">=3.11" dependencies = [ "verifiers>=0.1.5.post0", "medarc_verifiers>=0.1.0", + "evaluate", + "bert_score", + "rouge_score", + "sacrebleu", ] [tool.prime.environment]