diff --git a/environments/medreason/medreason.py b/environments/medreason/medreason.py new file mode 100644 index 0000000..13605f2 --- /dev/null +++ b/environments/medreason/medreason.py @@ -0,0 +1,237 @@ +import json +from typing import Optional + +import verifiers as vf +from datasets import load_dataset +from datasets.utils.logging import disable_progress_bar +from medarc_verifiers.parsers.xml_parser import XMLParser +from medarc_verifiers.prompts import THINK_XML_SYSTEM_PROMPT, XML_SYSTEM_PROMPT, AnswerFormat +from medarc_verifiers.rewards.multiple_choice_accuracy import multiple_choice_accuracy +from medarc_verifiers.utils import default_judge_api_key, judge_sampling_args_and_headers +from medarc_verifiers.utils.randomize_multiple_choice import randomize_multiple_choice +from openai import AsyncOpenAI +from verifiers.types import Info, State +from verifiers.utils.data_utils import BOXED_SYSTEM_PROMPT, THINK_BOXED_SYSTEM_PROMPT, extract_boxed_answer + +disable_progress_bar() + +MCQ_QUESTION_TEMPLATE = """\ +Question: {question} +Choices: +{choices} +Answer:""" + +OPEN_QUESTION_TEMPLATE = """\ +{question}""" + +JUDGE_TEMPLATE = """\ +You are evaluating an AI assistant's answer to a medical question. + +{question} +{answer} +{response} + +Is the assistant's answer medically equivalent to the reference answer? +Consider synonyms, paraphrasing, and reasonable generalizations as correct. +Answer [yes/no].""" + + +def _parse_options(options_str: str | None) -> dict[str, str] | None: + """Parse the options field from the dataset. + + The options field can be a JSON string representing a dict or list, + or it can be empty/None for open-ended questions. + """ + if not options_str or options_str.strip() in ("", "None", "null", "{}"): + return None + try: + parsed = json.loads(options_str) + except (json.JSONDecodeError, TypeError): + return None + + if isinstance(parsed, dict): + if not parsed: + return None + return {str(k): str(v) for k, v in parsed.items()} + if isinstance(parsed, list): + if not parsed: + return None + labels = [chr(ord("A") + i) for i in range(len(parsed))] + return dict(zip(labels, [str(v) for v in parsed])) + return None + + +def _format_mcq_prompt(question: str, options: dict[str, str]) -> str: + """Format a multiple-choice question prompt.""" + choices = "\n".join(f"{k}. {v}" for k, v in options.items()) + return MCQ_QUESTION_TEMPLATE.format(question=question, choices=choices) + + +def load_environment( + use_think: bool = False, + system_prompt: Optional[str] = None, + shuffle_answers: bool = False, + shuffle_seed: int | None = 1618, + answer_format: AnswerFormat | str = AnswerFormat.XML, + judge_model: str = "gpt-4o-mini", + judge_base_url: str | None = None, + judge_api_key: str | None = None, +) -> vf.Environment: + """ + MedReason medical reasoning evaluation environment. + + Supports both multiple-choice and open-ended questions from the MedReason + dataset (UCSC-VLAA/MedReason). MCQ items are graded by accuracy; open-ended + items use LLM-as-a-Judge evaluation. + + Args: + use_think: Enable chain-of-thought reasoning with tags. + system_prompt: Custom system prompt override. + shuffle_answers: Shuffle MCQ answer options. + shuffle_seed: Seed for deterministic answer shuffling. + answer_format: Answer format (xml or boxed). + judge_model: Model to use for LLM-as-judge evaluation. + judge_base_url: Base URL for judge API. + judge_api_key: API key for judge model. + """ + ds = load_dataset("UCSC-VLAA/MedReason", split="train") + + # Set up judge for open-ended questions + api_key = default_judge_api_key(judge_base_url) if judge_api_key is None else judge_api_key + sampling_args, default_headers = judge_sampling_args_and_headers(judge_model, judge_base_url) + judge_client = AsyncOpenAI(base_url=judge_base_url, api_key=api_key, default_headers=default_headers) + judge_rubric = vf.JudgeRubric( + judge_client=judge_client, + judge_model=judge_model, + judge_prompt="{question}", + judge_sampling_args=sampling_args, + ) + + def _map(ex, idx=None): + question_text = ex["question"] + answer_text = ex["answer"] + options = _parse_options(ex.get("options")) + + if options: + # MCQ: find gold letter by matching answer text to options + gold_letter = None + for letter, opt_text in options.items(): + if opt_text.strip().lower() == answer_text.strip().lower(): + gold_letter = letter + break + + if gold_letter is None: + # Answer is the letter itself + candidate = answer_text.strip().upper() + if candidate in options: + gold_letter = candidate + else: + gold_letter = "A" + + if shuffle_answers and gold_letter in options: + options, gold_letter, _ = randomize_multiple_choice( + options=options, + answer_choice=gold_letter, + seed=shuffle_seed, + row_id=ex.get("id_in_dataset", idx), + ) + + return { + "question": _format_mcq_prompt(question_text, options), + "answer": gold_letter, + "info": { + "is_mcq": True, + "answer_text": options.get(gold_letter, answer_text), + "dataset_name": ex.get("dataset_name", ""), + **({} if not shuffle_answers else {"options": options}), + }, + } + else: + # Open-ended question + return { + "question": OPEN_QUESTION_TEMPLATE.format(question=question_text), + "answer": answer_text, + "info": { + "is_mcq": False, + "dataset_name": ex.get("dataset_name", ""), + "question_raw": question_text, + }, + } + + load_from_cache_file = not shuffle_answers + eval_dataset = ds.map( + _map, + with_indices=True, + remove_columns=ds.column_names, + load_from_cache_file=load_from_cache_file, + ) + + # Set up parser based on answer format + answer_format = AnswerFormat(answer_format) if isinstance(answer_format, str) else answer_format + if answer_format == AnswerFormat.XML: + final_system_prompt = system_prompt or (THINK_XML_SYSTEM_PROMPT if use_think else XML_SYSTEM_PROMPT) + parser_fields = ["think", "answer"] if use_think else ["answer"] + parser = XMLParser(fields=parser_fields, answer_field="answer") + elif answer_format == AnswerFormat.BOXED: + parser = vf.ThinkParser(extract_boxed_answer) if use_think else vf.Parser(extract_boxed_answer) + final_system_prompt = system_prompt or (THINK_BOXED_SYSTEM_PROMPT if use_think else BOXED_SYSTEM_PROMPT) + else: + raise ValueError(f"Unsupported answer format: {answer_format=}") + + async def medreason_reward_func( + completion, + answer, + info: Info, + state: State, + **kwargs, + ) -> float: + """Unified reward: accuracy for MCQ, LLM judge for open-ended.""" + is_mcq = info.get("is_mcq", False) + + if is_mcq: + parsed_answer = parser.parse_answer(completion) or "" + answer_text_val = info.get("answer_text", None) + is_correct = multiple_choice_accuracy( + llm_answer=parsed_answer, + answer_letter=answer, + answer_text=answer_text_val, + ) + return 1.0 if is_correct else 0.0 + else: + # Open-ended: use LLM judge + parsed = parser.parse(completion, last=True) + model_answer = getattr(parsed, "answer", None) + + if model_answer is not None: + question_raw = info.get("question_raw", "") + judge_prompt = JUDGE_TEMPLATE.format( + question=question_raw, + answer=answer, + response=model_answer, + ) + judge_response = await judge_rubric.judge(judge_prompt, model_answer, answer, state) + judge_response_clean = judge_response.strip().lower() + else: + judge_response_clean = "no" + judge_response = "no answer" + + info.setdefault("judge_feedback", []).append( + { + "parsed": judge_response_clean, + "raw_judge": str(judge_response), + } + ) + + if "yes" in judge_response_clean and "no" not in judge_response_clean: + return 1.0 + else: + return 0.0 + + judge_rubric.add_reward_func(medreason_reward_func, weight=1.0) + + return vf.SingleTurnEnv( + eval_dataset=eval_dataset, + system_prompt=final_system_prompt, + parser=parser, + rubric=judge_rubric, + ) diff --git a/environments/medreason/pyproject.toml b/environments/medreason/pyproject.toml new file mode 100644 index 0000000..140c8dd --- /dev/null +++ b/environments/medreason/pyproject.toml @@ -0,0 +1,27 @@ +[project] +name = "medreason" +version = "0.1.0" +description = "MedReason medical reasoning evaluation with mixed MCQ and open-ended QA" +readme = "README.md" +requires-python = ">=3.11" +dependencies = [ + "datasets>=4.0.0", + "verifiers>=0.1.2.post0", + "medarc_verifiers>=0.1.0", + "openai", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build] +include = ["medreason.py"] + +[tool.uv.sources] +medarc_verifiers = { git = "https://github.com/MedARC-AI/med-lm-envs" } + +[tool.prime.environment] +loader = "medreason:load_environment" +display_name = "MedReason" +visibility = "PUBLIC" diff --git a/tests/test_environment_packages.py b/tests/test_environment_packages.py new file mode 100644 index 0000000..2f5c441 --- /dev/null +++ b/tests/test_environment_packages.py @@ -0,0 +1,170 @@ +"""Validate structural consistency of all environment packages. + +These tests auto-discover every environment under ``environments/`` and verify +that each package has the required files, valid pyproject.toml configuration, +and a discoverable ``load_environment`` loader function. All checks are +offline — no network calls, API keys, or dataset downloads are needed. +""" + +from __future__ import annotations + +import ast +from pathlib import Path +from typing import Any + +import pytest + +REPO_ROOT = Path(__file__).resolve().parent.parent +ENVIRONMENTS_DIR = REPO_ROOT / "environments" + + +def _discover_envs() -> list[str]: + """Return sorted list of environment directory names.""" + if not ENVIRONMENTS_DIR.is_dir(): + return [] + return sorted( + d.name + for d in ENVIRONMENTS_DIR.iterdir() + if d.is_dir() and not d.name.startswith((".", "_")) + ) + + +def _load_toml(path: Path) -> dict[str, Any]: + """Load a TOML file, using tomllib (3.11+) or tomli as fallback.""" + try: + import tomllib + except ModuleNotFoundError: + import tomli as tomllib # type: ignore[no-redef] + + with open(path, "rb") as f: + return tomllib.load(f) + + +def _find_load_environment(env_dir: Path) -> bool: + """Search for a ``load_environment`` function in any .py file. + + Handles both single-module envs (``env.py``) and sub-package envs + (``env/env/__init__.py`` or ``env/env/module.py``). + """ + py_files: list[Path] = list(env_dir.rglob("*.py")) + for py_file in py_files: + try: + source = py_file.read_text(encoding="utf-8") + tree = ast.parse(source) + except (SyntaxError, UnicodeDecodeError): + continue + for node in ast.walk(tree): + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + if node.name == "load_environment": + return True + return False + + +ENV_NAMES = _discover_envs() + +# Environments that are known to be missing [tool.prime.environment] metadata. +# These are pre-existing upstream issues and are marked as expected failures. +_KNOWN_MISSING_LOADER_METADATA: set[str] = { + "healthbench", + "med_dialog", + "medagentbench", + "medcasereasoning", + "mtsamples_procedures", + "mtsamples_replicate", + "pubmedqa", +} + + +# --------------------------------------------------------------------------- +# 1. pyproject.toml exists and is parseable +# --------------------------------------------------------------------------- +@pytest.mark.parametrize("env_name", ENV_NAMES) +def test_pyproject_exists_and_parses(env_name: str) -> None: + """Every environment must have a parseable pyproject.toml.""" + toml_path = ENVIRONMENTS_DIR / env_name / "pyproject.toml" + assert toml_path.exists(), f"{env_name}: missing pyproject.toml" + data = _load_toml(toml_path) + assert "project" in data, f"{env_name}: pyproject.toml missing [project] table" + + +# --------------------------------------------------------------------------- +# 2. [project] has required fields +# --------------------------------------------------------------------------- +@pytest.mark.parametrize("env_name", ENV_NAMES) +def test_project_has_name_and_version(env_name: str) -> None: + """[project] must declare name and version.""" + data = _load_toml(ENVIRONMENTS_DIR / env_name / "pyproject.toml") + project = data.get("project", {}) + assert "name" in project, f"{env_name}: [project] missing 'name'" + assert "version" in project, f"{env_name}: [project] missing 'version'" + + +# --------------------------------------------------------------------------- +# 3. Build system is configured +# --------------------------------------------------------------------------- +@pytest.mark.parametrize("env_name", ENV_NAMES) +def test_build_system_configured(env_name: str) -> None: + """pyproject.toml must have [build-system].""" + data = _load_toml(ENVIRONMENTS_DIR / env_name / "pyproject.toml") + assert "build-system" in data, f"{env_name}: missing [build-system]" + + +# --------------------------------------------------------------------------- +# 4. Loader is discoverable via [tool.prime.environment] or entry-points +# --------------------------------------------------------------------------- +@pytest.mark.parametrize("env_name", ENV_NAMES) +def test_loader_discoverable(env_name: str) -> None: + """Environment loader must be declared in pyproject.toml. + + Accepted mechanisms: + - [tool.prime.environment] with a ``loader`` key + - [project.entry-points."verifiers.environments"] + """ + if env_name in _KNOWN_MISSING_LOADER_METADATA: + pytest.xfail(f"{env_name}: known to be missing loader metadata (upstream issue)") + data = _load_toml(ENVIRONMENTS_DIR / env_name / "pyproject.toml") + + has_prime = ( + "tool" in data + and "prime" in data.get("tool", {}) + and "environment" in data["tool"]["prime"] + and "loader" in data["tool"]["prime"]["environment"] + ) + + has_entry_points = ( + "project" in data + and "entry-points" in data.get("project", {}) + and "verifiers.environments" in data["project"]["entry-points"] + ) + + assert has_prime or has_entry_points, ( + f"{env_name}: no loader discoverable. Add [tool.prime.environment] " + f"with 'loader' key, or [project.entry-points.\"verifiers.environments\"]" + ) + + +# --------------------------------------------------------------------------- +# 5. A load_environment function exists somewhere in the package +# --------------------------------------------------------------------------- +@pytest.mark.parametrize("env_name", ENV_NAMES) +def test_load_environment_exists(env_name: str) -> None: + """The environment package must contain a load_environment function.""" + env_dir = ENVIRONMENTS_DIR / env_name + assert _find_load_environment(env_dir), ( + f"{env_name}: no load_environment function found in any .py file" + ) + + +# --------------------------------------------------------------------------- +# 6. Dependencies include verifiers +# --------------------------------------------------------------------------- +@pytest.mark.parametrize("env_name", ENV_NAMES) +def test_dependencies_include_verifiers(env_name: str) -> None: + """All environments must depend on verifiers.""" + data = _load_toml(ENVIRONMENTS_DIR / env_name / "pyproject.toml") + deps = data.get("project", {}).get("dependencies", []) + dep_names = [ + d.split(">")[0].split("<")[0].split("=")[0].split("[")[0].strip().lower() + for d in deps + ] + assert "verifiers" in dep_names, f"{env_name}: 'verifiers' not in dependencies"