diff --git a/.gitignore b/.gitignore index 65d1fa7e..12e31eef 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,17 @@ outputs/ +# Prime-RL (separate repo with trained models) +prime-rl/ + +# CSV data exports +*.csv + +# Generated plots +plots/ + +# Eval results +eval_results/ + # Byte-compiled / optimized / DLL files __pycache__/ *.py[codz] diff --git a/configs/medcalc_bench_8b_nothink_rl.toml b/configs/medcalc_bench_8b_nothink_rl.toml new file mode 100644 index 00000000..e0ed7a8d --- /dev/null +++ b/configs/medcalc_bench_8b_nothink_rl.toml @@ -0,0 +1,72 @@ +inference_gpu_ids = [0, 1, 2, 3, 4, 5] +trainer_gpu_ids = [6, 7] + +max_steps = 300 +seq_len = 4096 + +[wandb] +project = "medcalc-bench-verified-8b" +name = "medcalc-bench-verified-8b-nothink" + +[ckpt] +# Checkpoint at the end of training + +[model] +name = "Qwen/Qwen3-8B-Instruct-2507" + +[orchestrator] +batch_size = 128 # Reduced for 8B model +rollouts_per_example = 16 + +[[orchestrator.env]] +id = "medcalc_bench" + +# Environment arguments for MedCalc-Bench +[orchestrator.env.args] +one_shot = false +add_python_tool = false +use_think = false # No thinking tags +answer_format = "xml" +use_verified_dataset = true # Use MedCalc-Bench-Verified + +[orchestrator.sampling] +max_tokens = 1024 +temperature = 1.0 + +# Evaluation configuration +[orchestrator.eval] +interval = 10 +eval_base_model = true + +[[orchestrator.eval.env]] +id = "medcalc_bench" +num_examples = 300 +rollouts_per_example = 1 + +[orchestrator.eval.env.args] +one_shot = false +add_python_tool = false +use_think = false # No thinking tags +answer_format = "xml" +use_verified_dataset = true # Use MedCalc-Bench-Verified + +[orchestrator.eval.sampling] +temperature = 0.0 +max_tokens = 1024 + +[trainer] +# Default trainer config (GRPO) + +[trainer.model.lora] +rank = 32 +alpha = 64.0 + +[trainer.optim] +lr = 1e-5 + +[inference] +# Inference server config + +[inference.parallel] +tp = 2 +dp = 3 diff --git a/configs/medcalc_bench_rl.toml b/configs/medcalc_bench_rl.toml new file mode 100644 index 00000000..59a00ff5 --- /dev/null +++ b/configs/medcalc_bench_rl.toml @@ -0,0 +1,73 @@ +# MedCalc-Bench RL Training Configuration +# Clinical calculator reasoning with numeric/date outputs + +inference_gpu_ids = [0, 1, 2, 3, 4, 5] +trainer_gpu_ids = [6, 7] + +max_steps = 300 +seq_len = 4096 + +[wandb] +project = "medcalc-bench-rl-1" +name = "medcalc-bench-rl-1" + +[ckpt] +# Checkpoint at the end of training + +[model] +name = "Qwen/Qwen3-4B-Instruct-2507" + +[orchestrator] +batch_size = 512 +rollouts_per_example = 16 + +[[orchestrator.env]] +id = "medcalc_bench" + +# Environment arguments for MedCalc-Bench +[orchestrator.env.args] +one_shot = false # Zero-shot prompting +add_python_tool = false # No code execution +use_think = true # Enable reasoning tags +answer_format = "xml" # Use XML format for answers + +[orchestrator.sampling] +max_tokens = 2048 +temperature = 1.0 + +# Evaluation configuration +[orchestrator.eval] +interval = 10 # Evaluate every 10 steps +eval_base_model = true # Also evaluate the base model at step 0 + +[[orchestrator.eval.env]] +id = "medcalc_bench" +num_examples = 300 +rollouts_per_example = 1 + +[orchestrator.eval.env.args] +one_shot = false +add_python_tool = false +use_think = true +answer_format = "xml" + +[orchestrator.eval.sampling] +temperature = 0.0 # Greedy sampling for evaluation +max_tokens = 2048 + +[trainer] +# Default trainer config (GRPO) + +[trainer.model.lora] +rank = 32 +alpha = 64.0 + +[trainer.optim] +lr = 1e-5 # LoRA typically benefits from higher LR + +[inference] +# Inference server config + +[inference.parallel] +tp = 2 # Split model across 2 GPUs for faster per-query latency +dp = 3 # 3 parallel instances diff --git a/configs/medcalc_bench_short_rl.toml b/configs/medcalc_bench_short_rl.toml new file mode 100644 index 00000000..93767f58 --- /dev/null +++ b/configs/medcalc_bench_short_rl.toml @@ -0,0 +1,72 @@ +inference_gpu_ids = [0, 1, 2, 3, 4, 5] +trainer_gpu_ids = [6, 7] + +max_steps = 300 +seq_len = 4096 + +[wandb] +project = "medcalc-bench-verified" +name = "medcalc-bench-verified-2048tok" + +[ckpt] +# Checkpoint at the end of training + +[model] +name = "Qwen/Qwen3-4B-Instruct-2507" + +[orchestrator] +batch_size = 256 +rollouts_per_example = 16 + +[[orchestrator.env]] +id = "medcalc_bench" + +# Environment arguments for MedCalc-Bench +[orchestrator.env.args] +one_shot = false +add_python_tool = false +use_think = true +answer_format = "xml" +use_verified_dataset = true # Use MedCalc-Bench-Verified + +[orchestrator.sampling] +max_tokens = 1024 +temperature = 1.0 + +# Evaluation configuration +[orchestrator.eval] +interval = 10 +eval_base_model = true + +[[orchestrator.eval.env]] +id = "medcalc_bench" +num_examples = 300 +rollouts_per_example = 1 + +[orchestrator.eval.env.args] +one_shot = false +add_python_tool = false +use_think = true +answer_format = "xml" +use_verified_dataset = true # Use MedCalc-Bench-Verified + +[orchestrator.eval.sampling] +temperature = 0.0 +max_tokens = 1024 + +[trainer] +# Default trainer config (GRPO) + +[trainer.model.lora] +rank = 32 +alpha = 64.0 + +[trainer.optim] +lr = 1e-5 + +[inference] +# Inference server config + +[inference.parallel] +tp = 2 +dp = 3 diff --git a/configs/medcalc_bench_tools_rl.toml b/configs/medcalc_bench_tools_rl.toml new file mode 100644 index 00000000..63fffe13 --- /dev/null +++ b/configs/medcalc_bench_tools_rl.toml @@ -0,0 +1,77 @@ +# MedCalc-Bench RL Training Configuration (with Python Tool) +# Clinical calculator reasoning with code execution + +inference_gpu_ids = [0, 1, 2, 3, 4, 5] +trainer_gpu_ids = [6, 7] + +max_steps = 300 +seq_len = 8192 # Longer for multi-turn tool use + +[wandb] +project = "medcalc-bench-tools" +name = "medcalc-bench-tools-rl" + +[ckpt] +# Checkpoint at the end of training + +[model] +name = "Qwen/Qwen3-4B-Instruct-2507" + +[orchestrator] +batch_size = 256 # Smaller batch for multi-turn (more tokens per example) +rollouts_per_example = 8 + +[[orchestrator.env]] +id = "medcalc_bench" + +# Environment arguments for MedCalc-Bench with tools +[orchestrator.env.args] +one_shot = false +add_python_tool = true # Enable Python tool for calculations +use_think = true +answer_format = "xml" +max_turns = 10 # Allow multiple tool calls +use_verified_dataset = true # Use MedCalc-Bench-Verified + +[orchestrator.sampling] +max_tokens = 2048 +temperature = 1.0 + +# Evaluation configuration +[orchestrator.eval] +interval = 10 +eval_base_model = true + +[[orchestrator.eval.env]] +id = "medcalc_bench" +num_examples = 300 +rollouts_per_example = 1 + +[orchestrator.eval.env.args] +one_shot = false +add_python_tool = true # Tools enabled for eval too +use_think = true +answer_format = "xml" +max_turns = 10 +use_verified_dataset = true # Use MedCalc-Bench-Verified + +[orchestrator.eval.sampling] +temperature = 0.0 +max_tokens = 2048 + +[trainer] +# Default trainer config (GRPO) + +[trainer.model.lora] +rank = 32 +alpha = 64.0 + +[trainer.optim] +lr = 1e-5 + +[inference] +# Inference server config + +[inference.parallel] +tp = 2 +dp = 3 diff --git a/configs/medcasereasoning_rl.toml b/configs/medcasereasoning_rl.toml new file mode 100644 index 00000000..6be8f20b --- /dev/null +++ b/configs/medcasereasoning_rl.toml @@ -0,0 +1,72 @@ +# MedCaseReasoning RL Training Configuration +# Medical diagnosis from case presentations using LLM-as-a-Judge + +inference_gpu_ids = [0, 1, 2, 3, 4, 5] +trainer_gpu_ids = [6, 7] + +max_steps = 300 +seq_len = 2048 # Prompts are shorter (max ~1034 tokens) + +[wandb] +project = "medcasereasoning-rl" +name = "medcasereasoning-rl" + +[ckpt] +# Checkpoint at the end of training + +[model] +name = "Qwen/Qwen3-4B-Instruct-2507" + +[orchestrator] +batch_size = 512 +rollouts_per_example = 16 + +[[orchestrator.env]] +id = "medcasereasoning" + +# Environment arguments - uses LLM-as-a-Judge +[orchestrator.env.args] +judge_model = "gpt-5-nano" +reasoning_effort = "low" +# judge_base_url = "http://localhost:8001/v1" # Uncomment for local judge +# judge_api_key = "your-api-key" # Optional, uses OPENAI_API_KEY by default + +[orchestrator.sampling] +max_tokens = 1024 +temperature = 1.0 + +# Evaluation configuration +[orchestrator.eval] +interval = 10 +eval_base_model = true + +[[orchestrator.eval.env]] +id = "medcasereasoning" +num_examples = 100 # Smaller eval set (LLM judge is slow/expensive) +rollouts_per_example = 1 + +[orchestrator.eval.env.args] +judge_model = "gpt-5-nano" +reasoning_effort = "low" +eval_split = "test" # Use test set for evaluation + +[orchestrator.eval.sampling] +temperature = 0.0 +max_tokens = 1024 + +[trainer] +# Default trainer config (GRPO) + +[trainer.model.lora] +rank = 32 +alpha = 64.0 + +[trainer.optim] +lr = 1e-5 + +[inference] +# Inference server config + +[inference.parallel] +tp = 2 +dp = 3 diff --git a/configs/medical_combined_rl.toml b/configs/medical_combined_rl.toml new file mode 100644 index 00000000..1b03a988 --- /dev/null +++ b/configs/medical_combined_rl.toml @@ -0,0 +1,135 @@ +# Combined Medical RL Training Configuration +# MedMCQA + MedCalc-Bench + MedCaseReasoning + +inference_gpu_ids = [0, 1, 2, 3, 4, 5] +trainer_gpu_ids = [6, 7] + +max_steps = 500 +seq_len = 4096 # Accommodate longest prompts (MedCalc can be ~2800 tokens) + +[wandb] +project = "medical-combined-rl" +name = "medical-combined-rl" + +[ckpt] +# Checkpoint at the end of training + +[model] +name = "Qwen/Qwen3-4B-Instruct-2507" + +[orchestrator] +batch_size = 512 +rollouts_per_example = 16 + +# Buffer config - sample evenly from each environment with difficulty filtering +[orchestrator.buffer] +env_ratios = [1.0, 1.0, 1.0] # Equal sampling: MedMCQA, MedCalc, MedCaseReasoning + +# Difficulty-based filtering +easy_threshold = 1.0 # Problems with avg reward >= 1.0 are "easy" (always correct) +hard_threshold = 0.0 # Problems with avg reward <= 0.0 are "hard" (always wrong) +online_difficulty_filtering = true # Filter out rollouts with 0% or 100% success rate + +# When resuming, convert some easy/hard back to normal pool +easy_fraction = 0.1 # 10% of easy problems back to normal +hard_fraction = 0.1 # 10% of hard problems back to normal + +# Environment 1: MedMCQA (182k examples, MCQ) +[[orchestrator.env]] +id = "med_mcqa" +name = "medmcqa" + +[orchestrator.env.args] +use_think = true +answer_format = "xml" +shuffle_answers = false + +# Environment 2: MedCalc-Bench (10k examples, calculations) +[[orchestrator.env]] +id = "medcalc_bench" +name = "medcalc" + +[orchestrator.env.args] +one_shot = false +add_python_tool = false +use_think = true +answer_format = "xml" +use_verified_dataset = true + +# Environment 3: MedCaseReasoning (13k examples, diagnosis with LLM judge) +[[orchestrator.env]] +id = "medcasereasoning" +name = "medcase" + +[orchestrator.env.args] +judge_model = "gpt-5-nano" +reasoning_effort = "low" + +[orchestrator.sampling] +max_tokens = 2048 # Accommodate reasoning for all tasks +temperature = 1.0 + +# Evaluation configuration +[orchestrator.eval] +interval = 25 # Less frequent eval since we have 3 envs +eval_base_model = true + +# Eval: MedMCQA +[[orchestrator.eval.env]] +id = "med_mcqa" +name = "medmcqa" +num_examples = 300 +rollouts_per_example = 1 + +[orchestrator.eval.env.args] +use_think = true +answer_format = "xml" +shuffle_answers = false +eval_split = "validation" + +# Eval: MedCalc-Bench +[[orchestrator.eval.env]] +id = "medcalc_bench" +name = "medcalc" +num_examples = 300 +rollouts_per_example = 1 + +[orchestrator.eval.env.args] +one_shot = false +add_python_tool = false +use_think = true +answer_format = "xml" +use_verified_dataset = true + +# Eval: MedCaseReasoning +[[orchestrator.eval.env]] +id = "medcasereasoning" +name = "medcase" +num_examples = 100 # Smaller due to LLM judge cost +rollouts_per_example = 1 + +[orchestrator.eval.env.args] +judge_model = "gpt-5-nano" +reasoning_effort = "low" +eval_split = "test" + +[orchestrator.eval.sampling] +temperature = 0.0 +max_tokens = 2048 + +[trainer] +# Default trainer config (GRPO) + +[trainer.model.lora] +rank = 32 +alpha = 64.0 + +[trainer.optim] +lr = 1e-5 + +[inference] +# Inference server config + +[inference.parallel] +tp = 2 +dp = 3 diff --git a/configs/medmcqa_rl.toml b/configs/medmcqa_rl.toml new file mode 100644 index 00000000..b7caee85 --- /dev/null +++ b/configs/medmcqa_rl.toml @@ -0,0 +1,72 @@ +# MedMCQA RL Training Configuration +# Medical multiple-choice QA (AIIMS & NEET PG entrance exams) + +inference_gpu_ids = [0, 1, 2, 3, 4, 5] +trainer_gpu_ids = [6, 7] + +max_steps = 300 +seq_len = 1024 + +[wandb] +project = "medmcqa-rl" +name = "medmcqa-rl" + +[ckpt] +# Checkpoint at the end of training + +[model] +name = "Qwen/Qwen3-4B-Instruct-2507" + +[orchestrator] +batch_size = 512 +rollouts_per_example = 16 + +[[orchestrator.env]] +id = "med_mcqa" + +# Environment arguments +[orchestrator.env.args] +use_think = true +answer_format = "xml" +shuffle_answers = false # Keep original answer order + +[orchestrator.sampling] +max_tokens = 512 # Short answers (A/B/C/D with reasoning) +temperature = 1.0 + +# Evaluation configuration +[orchestrator.eval] +interval = 10 +eval_base_model = true + +[[orchestrator.eval.env]] +id = "med_mcqa" +num_examples = 500 # Sample from 4k validation set +rollouts_per_example = 1 + +[orchestrator.eval.env.args] +use_think = true +answer_format = "xml" +shuffle_answers = false +eval_split = "validation" # Test set has no labels, must use validation + +[orchestrator.eval.sampling] +temperature = 0.0 +max_tokens = 512 + +[trainer] +# Default trainer config (GRPO) + +[trainer.model.lora] +rank = 32 +alpha = 64.0 + +[trainer.optim] +lr = 1e-5 + +[inference] +# Inference server config + +[inference.parallel] +tp = 2 +dp = 3 diff --git a/environments/med_mcqa/med_mcqa.py b/environments/med_mcqa/med_mcqa.py index 28392a3c..6d40b2e3 100644 --- a/environments/med_mcqa/med_mcqa.py +++ b/environments/med_mcqa/med_mcqa.py @@ -60,14 +60,18 @@ def load_environment( shuffle_answers: bool = False, shuffle_seed: int | None = 1618, answer_format: AnswerFormat | str = AnswerFormat.XML, + eval_split: str = "validation", # "validation" or "test" ) -> vf.Environment: """ - Load the MedMCQA environment with train and validation splits. + Load the MedMCQA environment with train and eval splits. Supports reasoning (use_think=True) or standard evaluation. Returns a SingleTurnEnv ready for model evaluation. + + Args: + eval_split: Which split to use for evaluation ("validation" or "test") """ train_ds = load_dataset("lighteval/med_mcqa", split="train") - val_ds = load_dataset("lighteval/med_mcqa", split="validation") + eval_ds = load_dataset("lighteval/med_mcqa", split=eval_split) def _map_example(example: dict[str, Any]) -> dict[str, Any] | None: cop = example.get("cop", -1) @@ -113,7 +117,7 @@ def _map_example(example: dict[str, Any]) -> dict[str, Any] | None: remove_columns=columns_to_remove, load_from_cache_file=load_from_cache_file, ).filter(lambda x: x is not None, load_from_cache_file=load_from_cache_file) - val_mapped = val_ds.map( + eval_mapped = eval_ds.map( _map_example, remove_columns=columns_to_remove, load_from_cache_file=load_from_cache_file, @@ -142,7 +146,7 @@ def accuracy(completion: Any, answer: str, parser: vf.Parser, info: dict[str, An env = vf.SingleTurnEnv( dataset=train_mapped, - eval_dataset=val_mapped, + eval_dataset=eval_mapped, system_prompt=system_prompt, parser=parser, rubric=rubric, diff --git a/environments/medcalc_bench/medcalc_bench/medcalc_bench.py b/environments/medcalc_bench/medcalc_bench/medcalc_bench.py index 8959da60..2b2ef37e 100644 --- a/environments/medcalc_bench/medcalc_bench/medcalc_bench.py +++ b/environments/medcalc_bench/medcalc_bench/medcalc_bench.py @@ -297,6 +297,7 @@ def load_environment( answer_format: AnswerFormat | str = AnswerFormat.XML, use_think: bool = False, system_prompt: str | None = None, + use_verified_dataset: bool = False, # Use nsk7153/MedCalc-Bench-Verified instead **kwargs, ) -> vf.Environment: # -------- normalize answer_format -------- @@ -322,7 +323,10 @@ def load_environment( system_prompt = system_prompt # -------- load dataset and convert to vf format -------- - ds = load_dataset("ncbi/MedCalc-Bench-v1.2") + if use_verified_dataset: + ds = load_dataset("nsk7153/MedCalc-Bench-Verified") + else: + ds = load_dataset("ncbi/MedCalc-Bench-v1.2") one_shot_examples = None if one_shot: # create mapping from calc id to one-shot example @@ -377,9 +381,10 @@ def _map(row: dict): use_calculator=add_calculator_tool, **kwargs, ) - # Add ToolRubric to track tool usage metrics - tool_rubric = vf.ToolRubric(tools=env.tools) - env.rubric = vf.RubricGroup(rubrics=[tool_rubric, env.rubric]) + # Add ToolRubric to track tool usage metrics (if available) + if hasattr(vf, "ToolRubric"): + tool_rubric = vf.ToolRubric(tools=env.tools) + env.rubric = vf.RubricGroup(rubrics=[tool_rubric, env.rubric]) return env else: return vf.SingleTurnEnv( diff --git a/environments/medcasereasoning/medcasereasoning.py b/environments/medcasereasoning/medcasereasoning.py index 471c1615..2713f745 100644 --- a/environments/medcasereasoning/medcasereasoning.py +++ b/environments/medcasereasoning/medcasereasoning.py @@ -43,6 +43,8 @@ def load_environment( judge_model: str = "gpt-4o-mini", judge_base_url: str | None = None, judge_api_key: str | None = None, + reasoning_effort: str | None = None, # "low", "medium", "high" for reasoning models + eval_split: str = "val", # "val" or "test" ) -> vf.Environment: """ MedCaseReasoning environment using LLM-as-a-Judge evaluation. @@ -50,11 +52,18 @@ def load_environment( This environment loads the MedCaseReasoning dataset and uses an LLM judge to evaluate whether model responses are equivalent to the ground truth medical diagnoses. + + Args: + judge_model: Model to use for judging (e.g., "gpt-4o-mini", "gpt-5-nano") + judge_base_url: Optional base URL for judge API + judge_api_key: Optional API key for judge + reasoning_effort: Reasoning effort for reasoning models ("low", "medium", "high") + eval_split: Which split to use for evaluation ("val" or "test") """ # Load the MedCaseReasoning dataset full_dataset = load_dataset("zou-lab/MedCaseReasoning") - # Use train split for training, val split for evaluation + # Use train split for training train_dataset = full_dataset["train"].map( lambda x: { "question": QUESTION_TEMPLATE.format(question=x["case_prompt"]), @@ -64,7 +73,7 @@ def load_environment( } ) - eval_dataset = full_dataset["val"].map( + eval_dataset = full_dataset[eval_split].map( lambda x: { "question": QUESTION_TEMPLATE.format(question=x["case_prompt"]), "answer": x["final_diagnosis"], @@ -73,11 +82,11 @@ def load_environment( } ) - # System prompt for the task - # Initialize OpenAI client for judge api_key = default_judge_api_key(judge_base_url) if judge_api_key is None else judge_api_key - sampling_args, default_headers = judge_sampling_args_and_headers(judge_model, judge_base_url) + sampling_args, default_headers = judge_sampling_args_and_headers( + judge_model, judge_base_url, reasoning_effort=reasoning_effort + ) judge_client = AsyncOpenAI(base_url=judge_base_url, api_key=api_key, default_headers=default_headers) # Create JudgeRubric with custom prompt diff --git a/scripts/evaluate_models.py b/scripts/evaluate_models.py new file mode 100644 index 00000000..ebe63e22 --- /dev/null +++ b/scripts/evaluate_models.py @@ -0,0 +1,225 @@ +#!/usr/bin/env python3 +""" +Evaluate trained RL models on test sets using the verifiers library. + +Usage: + # Start vLLM server first (in another terminal): + cd ~/med-lm-envs/prime-rl + CUDA_VISIBLE_DEVICES=0,1 uv run vllm serve --tensor-parallel-size 2 --port 8000 + + # Then run evaluation: + cd ~/med-lm-envs + uv run python scripts/evaluate_models.py --api-base http://localhost:8000/v1 --dataset med_mcqa + + # For MedCaseReasoning (needs judge): + export OPENAI_API_KEY= + uv run python scripts/evaluate_models.py --api-base http://localhost:8000/v1 --dataset medcasereasoning +""" + +import argparse +import asyncio +import json +import sys +from pathlib import Path + +# Add environments to path +sys.path.insert(0, "/admin/home/nikhil/med-lm-envs/environments/medcalc_bench") +sys.path.insert(0, "/admin/home/nikhil/med-lm-envs/environments/med_mcqa") +sys.path.insert(0, "/admin/home/nikhil/med-lm-envs/environments/medcasereasoning") + +import verifiers as vf +from openai import AsyncOpenAI + +# Model paths +MODELS = { + "base": "Qwen/Qwen3-4B-Instruct-2507", + "medcalc_trained": "/admin/home/nikhil/med-lm-envs/prime-rl/outputs_verified_short/weights/step_300", + "medmcqa_trained": "/admin/home/nikhil/med-lm-envs/prime-rl/outputs_medmcqa/weights/step_300", + "medcase_trained": "/admin/home/nikhil/med-lm-envs/prime-rl/outputs_medcasereasoning/weights/step_300", +} + +OUTPUT_DIR = Path("/admin/home/nikhil/med-lm-envs/eval_results") + + +def load_medcalc_env(): + """Load MedCalc-Bench-Verified environment.""" + from medcalc_bench.medcalc_bench import load_environment + + env = load_environment( + use_think=True, + use_verified_dataset=True, + answer_format="xml", + ) + return env + + +def load_medmcqa_env(): + """Load MedMCQA environment.""" + from med_mcqa import load_environment + + env = load_environment( + use_think=True, + answer_format="xml", + ) + return env + + +def load_medcase_env(judge_api_key: str = None): + """Load MedCaseReasoning environment.""" + import sys + sys.path.insert(0, "/admin/home/nikhil/med-lm-envs/environments/medcasereasoning") + from medcasereasoning import load_environment + + env = load_environment( + judge_model="gpt-5-nano", + reasoning_effort="low", + judge_api_key=judge_api_key, + ) + return env + + +async def evaluate_on_env( + env: vf.SingleTurnEnv, + client: AsyncOpenAI, + model_name: str, + n_samples: int = 100, + temperature: float = 0.0, + max_tokens: int = 1024, + max_concurrent: int = 10, +) -> dict: + """ + Evaluate a model on a verifiers environment. + + Uses the environment's built-in evaluate() method which handles: + - Prompt formatting with system prompt + - Response parsing + - Reward calculation via rubric + """ + # Configure the environment + env.set_kwargs( + client=client, + model=model_name, + temperature=temperature, + max_tokens=max_tokens, + ) + + # Get eval dataset + eval_dataset = env.get_eval_dataset() + if eval_dataset is None: + raise ValueError("Environment has no eval dataset") + + # Limit samples if needed + if n_samples > 0 and n_samples < len(eval_dataset): + eval_dataset = eval_dataset.select(range(n_samples)) + + print(f" Evaluating on {len(eval_dataset)} samples...") + + # Run evaluation using the environment's evaluate method + results = await env.evaluate( + eval_dataset, + max_concurrent=max_concurrent, + rollouts_per_example=1, + ) + + # Calculate accuracy from results + rewards = [r.reward for r in results if r.reward is not None] + accuracy = sum(1 for r in rewards if r > 0.5) / len(rewards) if rewards else 0 + mean_reward = sum(rewards) / len(rewards) if rewards else 0 + + return { + "n_samples": len(eval_dataset), + "n_completed": len(rewards), + "accuracy": accuracy, + "mean_reward": mean_reward, + } + + +async def main(): + parser = argparse.ArgumentParser(description="Evaluate trained models using verifiers") + parser.add_argument("--api-base", type=str, default="http://localhost:8000/v1", + help="API base URL for vLLM server") + parser.add_argument("--api-key", type=str, default="EMPTY", + help="API key (use EMPTY for local vLLM)") + parser.add_argument("--model-name", type=str, default="model", + help="Model name to use in API calls") + parser.add_argument("--dataset", type=str, required=True, + choices=["medcalc_bench", "med_mcqa", "medcasereasoning", "all"], + help="Dataset to evaluate on") + parser.add_argument("--n-samples", type=int, default=100, + help="Number of samples to evaluate") + parser.add_argument("--max-concurrent", type=int, default=10, + help="Maximum concurrent API requests") + parser.add_argument("--temperature", type=float, default=0.0, + help="Sampling temperature") + parser.add_argument("--max-tokens", type=int, default=1024, + help="Maximum tokens to generate") + parser.add_argument("--judge-api-key", type=str, default=None, + help="OpenAI API key for MedCaseReasoning judge") + parser.add_argument("--output-name", type=str, default=None, + help="Name for output file (default: dataset name)") + args = parser.parse_args() + + OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + + # Create client + client = AsyncOpenAI(base_url=args.api_base, api_key=args.api_key) + + # Define dataset loaders + loaders = { + "medcalc_bench": load_medcalc_env, + "med_mcqa": load_medmcqa_env, + "medcasereasoning": lambda: load_medcase_env(args.judge_api_key), + } + + if args.dataset == "all": + datasets_to_eval = list(loaders.keys()) + else: + datasets_to_eval = [args.dataset] + + results = {} + + for dataset_name in datasets_to_eval: + print(f"\n{'='*60}") + print(f"Evaluating on {dataset_name}") + print(f"{'='*60}") + + try: + env = loaders[dataset_name]() + result = await evaluate_on_env( + env=env, + client=client, + model_name=args.model_name, + n_samples=args.n_samples, + temperature=args.temperature, + max_tokens=args.max_tokens, + max_concurrent=args.max_concurrent, + ) + results[dataset_name] = result + print(f" Accuracy: {result['accuracy']:.1%}") + print(f" Mean Reward: {result['mean_reward']:.3f}") + except Exception as e: + print(f" Error: {e}") + import traceback + traceback.print_exc() + results[dataset_name] = {"error": str(e)} + + # Save results + output_name = args.output_name or args.dataset + output_file = OUTPUT_DIR / f"{output_name}_results.json" + with open(output_file, "w") as f: + json.dump(results, f, indent=2) + print(f"\nResults saved to {output_file}") + + # Print summary + print("\n" + "="*60) + print("SUMMARY") + print("="*60) + for dataset, result in results.items(): + if "error" in result: + print(f" {dataset}: ERROR - {result['error']}") + else: + print(f" {dataset}: {result['accuracy']:.1%} accuracy ({result['n_completed']}/{result['n_samples']} samples)") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/scripts/evaluate_models.sh b/scripts/evaluate_models.sh new file mode 100755 index 00000000..dcf9cea2 --- /dev/null +++ b/scripts/evaluate_models.sh @@ -0,0 +1,126 @@ +#!/bin/bash +# Evaluate trained RL models on test sets using medarc-eval +# +# Usage: +# ./evaluate_models.sh # Evaluate all models on all datasets +# ./evaluate_models.sh --model medcalc # Evaluate specific model +# ./evaluate_models.sh --dataset medmcqa # Evaluate on specific dataset +# +# Prerequisites: +# 1. Start vLLM server for the model you want to evaluate: +# cd ~/med-lm-envs/prime-rl +# CUDA_VISIBLE_DEVICES=0,1 uv run vllm serve --tensor-parallel-size 2 --port 8000 +# +# 2. Run this script with the API pointing to the vLLM server + +set -e + +# Configuration +N_SAMPLES=100 +MAX_CONCURRENT=10 +TEMPERATURE=0.0 +MAX_TOKENS=1024 +API_BASE="http://localhost:8000/v1" + +# Model paths +BASE_MODEL="Qwen/Qwen3-4B-Instruct-2507" +MEDCALC_MODEL="/admin/home/nikhil/med-lm-envs/prime-rl/outputs_verified_short/weights/step_300" +MEDMCQA_MODEL="/admin/home/nikhil/med-lm-envs/prime-rl/outputs_medmcqa/weights/step_300" +MEDCASE_MODEL="/admin/home/nikhil/med-lm-envs/prime-rl/outputs_medcasereasoning/weights/step_300" + +# Output directory +OUTPUT_DIR="/admin/home/nikhil/med-lm-envs/eval_results" +mkdir -p "$OUTPUT_DIR" + +echo "==========================================" +echo "Model Evaluation Script" +echo "==========================================" +echo "Samples per dataset: $N_SAMPLES" +echo "Output directory: $OUTPUT_DIR" +echo "" + +# Function to run evaluation +run_eval() { + local model_name=$1 + local dataset=$2 + local extra_args=$3 + + echo ">>> Evaluating $model_name on $dataset" + + output_file="$OUTPUT_DIR/${model_name}_${dataset}.json" + + cd /admin/home/nikhil/med-lm-envs + uv run medarc-eval "$dataset" \ + -m "openai/model" \ + -b "$API_BASE" \ + -n "$N_SAMPLES" \ + --max-concurrent "$MAX_CONCURRENT" \ + --temperature "$TEMPERATURE" \ + --max-tokens "$MAX_TOKENS" \ + $extra_args \ + 2>&1 | tee "$OUTPUT_DIR/${model_name}_${dataset}.log" + + echo ">>> Finished $model_name on $dataset" + echo "" +} + +# Print instructions +cat << 'EOF' +======================================== +INSTRUCTIONS +======================================== + +To evaluate a model, you need to: + +1. Start a vLLM server for the model: + + # For base model: + cd ~/med-lm-envs/prime-rl + CUDA_VISIBLE_DEVICES=0,1 uv run vllm serve Qwen/Qwen3-4B-Instruct-2507 \ + --tensor-parallel-size 2 --port 8000 + + # For MedCalc-trained model: + CUDA_VISIBLE_DEVICES=0,1 uv run vllm serve \ + /admin/home/nikhil/med-lm-envs/prime-rl/outputs_verified_short/weights/step_300 \ + --tensor-parallel-size 2 --port 8000 + + # For MedMCQA-trained model: + CUDA_VISIBLE_DEVICES=0,1 uv run vllm serve \ + /admin/home/nikhil/med-lm-envs/prime-rl/outputs_medmcqa/weights/step_300 \ + --tensor-parallel-size 2 --port 8000 + + # For MedCaseReasoning-trained model: + CUDA_VISIBLE_DEVICES=0,1 uv run vllm serve \ + /admin/home/nikhil/med-lm-envs/prime-rl/outputs_medcasereasoning/weights/step_300 \ + --tensor-parallel-size 2 --port 8000 + +2. In another terminal, run the evaluation: + + # MedCalc-Bench-Verified + cd ~/med-lm-envs + uv run medarc-eval medcalc_bench \ + -m openai/model \ + -b http://localhost:8000/v1 \ + -n 100 \ + --temperature 0.0 \ + --use-think \ + --use-verified-dataset + + # MedMCQA + uv run medarc-eval med_mcqa \ + -m openai/model \ + -b http://localhost:8000/v1 \ + -n 100 \ + --temperature 0.0 \ + --use-think + + # MedCaseReasoning (requires judge API key) + export OPENAI_API_KEY= + uv run medarc-eval medcasereasoning \ + -m openai/model \ + -b http://localhost:8000/v1 \ + -n 100 \ + --temperature 0.0 + +======================================== +EOF diff --git a/scripts/plot_combined_7panel.py b/scripts/plot_combined_7panel.py new file mode 100644 index 00000000..cc997248 --- /dev/null +++ b/scripts/plot_combined_7panel.py @@ -0,0 +1,199 @@ +#!/usr/bin/env python3 +""" +7-panel plot for combined medical RL training. +Row 1: 3 test accuracy plots (MedMCQA, MedCalc, MedCase) +Row 2: 4 training reward plots (MedMCQA, MedCalc, MedCase, Mean) +""" + +import matplotlib.pyplot as plt +import pandas as pd +from pathlib import Path + +# Colors for each dataset +COLORS = { + 'medmcqa': '#0d0887', # Dark blue/purple + 'medcalc': '#7e03a8', # Purple + 'medcase': '#cc4778', # Pink/magenta + 'mean': '#f89540', # Orange +} + +LABELS = { + 'medmcqa': 'MedMCQA', + 'medcalc': 'MedCalc-Bench', + 'medcase': 'MedCaseReasoning', + 'mean': 'Mean Reward', +} + +# Style settings for publication +plt.rcParams.update({ + 'font.family': 'serif', + 'font.size': 11, + 'axes.labelsize': 12, + 'axes.titlesize': 13, + 'xtick.labelsize': 10, + 'ytick.labelsize': 10, + 'legend.fontsize': 9, + 'figure.dpi': 150, + 'savefig.dpi': 300, + 'savefig.bbox': 'tight', + 'axes.grid': False, + 'axes.spines.top': False, + 'axes.spines.right': False, +}) + +# CSV file paths for combined training +TEST_CSVS = { + 'medmcqa': "/admin/home/nikhil/med-lm-envs/wandb_export_2026-01-29T08_59_05.095-05_00.csv", + 'medcalc': "/admin/home/nikhil/med-lm-envs/wandb_export_2026-01-29T08_59_30.758-05_00.csv", + 'medcase': "/admin/home/nikhil/med-lm-envs/wandb_export_2026-01-29T08_59_51.236-05_00.csv", +} + +REWARD_CSVS = { + 'medmcqa': "/admin/home/nikhil/med-lm-envs/wandb_export_2026-01-29T09_00_00.134-05_00.csv", + 'medcase': "/admin/home/nikhil/med-lm-envs/wandb_export_2026-01-29T09_00_08.253-05_00.csv", + 'medcalc': "/admin/home/nikhil/med-lm-envs/wandb_export_2026-01-29T09_00_16.389-05_00.csv", + 'mean': "/admin/home/nikhil/med-lm-envs/wandb_export_2026-01-29T09_00_21.894-05_00.csv", +} + + +def load_test_data(csv_path: str) -> pd.DataFrame: + """Load test accuracy data.""" + df = pd.read_csv(csv_path) + df.columns = ['Step', 'pass@1', 'pass@1_min', 'pass@1_max'] + return df + + +def load_reward_data(csv_path: str) -> pd.DataFrame: + """Load reward data.""" + df = pd.read_csv(csv_path) + df.columns = ['Step', 'reward_mean', 'reward_min', 'reward_max'] + return df + + +def main(): + # Create figure with custom layout: 3 on top, 4 on bottom + fig = plt.figure(figsize=(16, 8)) + + # Create grid spec for custom layout + gs = fig.add_gridspec(2, 12, hspace=0.35, wspace=0.5) + + # Row 1: 3 test accuracy plots (each takes 4 columns, with gaps) + ax_test = [ + fig.add_subplot(gs[0, 0:4]), # MedMCQA + fig.add_subplot(gs[0, 4:8]), # MedCalc + fig.add_subplot(gs[0, 8:12]), # MedCase + ] + + # Row 2: 4 training reward plots (each takes 3 columns) + ax_reward = [ + fig.add_subplot(gs[1, 0:3]), # MedMCQA reward + fig.add_subplot(gs[1, 3:6]), # MedCalc reward + fig.add_subplot(gs[1, 6:9]), # MedCase reward + fig.add_subplot(gs[1, 9:12]), # Mean reward + ] + + test_datasets = ['medmcqa', 'medcalc', 'medcase'] + reward_datasets = ['medmcqa', 'medcalc', 'medcase', 'mean'] + + # Y-axis limits for test accuracy + test_ylims = { + 'medmcqa': (0.55, 0.65), + 'medcalc': (0.35, 0.75), + 'medcase': (0.15, 0.40), + } + + # Y-axis limits for reward + reward_ylims = { + 'medmcqa': (0.4, 0.9), + 'medcalc': (0.2, 0.9), + 'medcase': (0.0, 0.5), + 'mean': (0.3, 0.7), + } + + # ========================================= + # Row 1: Test Accuracy plots + # ========================================= + for idx, dataset_key in enumerate(test_datasets): + ax = ax_test[idx] + df = load_test_data(TEST_CSVS[dataset_key]) + color = COLORS[dataset_key] + label = LABELS[dataset_key] + + # Plot data points + ax.plot(df['Step'], df['pass@1'], 'o-', color=color, + linewidth=2, markersize=4, alpha=0.5, label='Data') + + # Add smoothed trend + window = 3 + df['smoothed'] = df['pass@1'].rolling(window=window, center=True).mean() + ax.plot(df['Step'], df['smoothed'], '-', color=color, + linewidth=3, alpha=1.0, label=f'Smoothed (w={window})') + + ax.set_xlabel('Training Step') + if idx == 0: + ax.set_ylabel('Pass@1 Accuracy') + ax.set_title(f'{label}') + ax.set_ylim(test_ylims[dataset_key]) + ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f'{y:.0%}')) + ax.legend(loc='lower right', fontsize=8, framealpha=0.9) + + # ========================================= + # Row 2: Training Reward plots + # ========================================= + for idx, dataset_key in enumerate(reward_datasets): + ax = ax_reward[idx] + df = load_reward_data(REWARD_CSVS[dataset_key]) + color = COLORS[dataset_key] + label = LABELS[dataset_key] + + # Plot raw data with low alpha + ax.plot(df['Step'], df['reward_mean'], '-', color=color, + linewidth=0.5, alpha=0.3, label='Raw') + + # Add smoothed trend line (EMA) + span = 20 + df['smoothed'] = df['reward_mean'].ewm(span=span).mean() + ax.plot(df['Step'], df['smoothed'], '-', color=color, + linewidth=2.5, alpha=1.0, label=f'Smoothed (EMA={span})') + + ax.set_xlabel('Training Step') + if idx == 0: + ax.set_ylabel('Mean Reward') + ax.set_title(f'{label}') + ax.set_ylim(reward_ylims[dataset_key]) + ax.legend(loc='lower right', fontsize=8, framealpha=0.9) + + # Adjust layout + plt.subplots_adjust(left=0.05, right=0.98, top=0.95, bottom=0.08) + + # Save + output_dir = Path("/admin/home/nikhil/med-lm-envs/plots") + output_dir.mkdir(exist_ok=True) + plt.savefig(output_dir / 'combined_7panel.png') + plt.savefig(output_dir / 'combined_7panel.pdf') + print(f"Saved to {output_dir / 'combined_7panel.png'}") + print(f"Saved to {output_dir / 'combined_7panel.pdf'}") + + # Print summary statistics + print("\n" + "="*60) + print("Combined Training Summary Statistics") + print("="*60) + + for dataset_key in test_datasets: + test_df = load_test_data(TEST_CSVS[dataset_key]) + start_acc = test_df['pass@1'].iloc[0] + end_acc = test_df['pass@1'].iloc[-1] + print(f"\n{LABELS[dataset_key]}:") + print(f" Test Accuracy: {start_acc:.1%} → {end_acc:.1%} (+{end_acc - start_acc:.1%})") + + print("\nReward Summary:") + for dataset_key in reward_datasets: + reward_df = load_reward_data(REWARD_CSVS[dataset_key]) + reward_df['smoothed'] = reward_df['reward_mean'].ewm(span=20).mean() + start_reward = reward_df['smoothed'].iloc[:10].mean() + end_reward = reward_df['smoothed'].iloc[-10:].mean() + print(f" {LABELS[dataset_key]}: {start_reward:.3f} → {end_reward:.3f} (+{end_reward - start_reward:.3f})") + + +if __name__ == "__main__": + main() diff --git a/scripts/plot_combined_training.py b/scripts/plot_combined_training.py new file mode 100644 index 00000000..d9dcc76f --- /dev/null +++ b/scripts/plot_combined_training.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python3 +""" +Combined plot showing test accuracy and training reward for all three datasets. +""" + +import matplotlib.pyplot as plt +import pandas as pd +from pathlib import Path + +# Plasma colormap colors - distinct colors for each dataset +COLORS = { + 'medcalc': '#0d0887', # Dark blue/purple + 'medmcqa': '#7e03a8', # Purple + 'medcase': '#cc4778', # Pink/magenta +} + +LABELS = { + 'medcalc': 'MedCalc-Bench-Verified', + 'medmcqa': 'MedMCQA', + 'medcase': 'MedCaseReasoning', +} + +# Style settings for publication +plt.rcParams.update({ + 'font.family': 'serif', + 'font.size': 11, + 'axes.labelsize': 12, + 'axes.titlesize': 13, + 'xtick.labelsize': 10, + 'ytick.labelsize': 10, + 'legend.fontsize': 9, + 'figure.dpi': 150, + 'savefig.dpi': 300, + 'savefig.bbox': 'tight', + 'axes.grid': False, + 'axes.spines.top': False, + 'axes.spines.right': False, +}) + +# CSV file paths +TEST_CSVS = { + 'medcalc': "/admin/home/nikhil/med-lm-envs/wandb_export_2026-01-29T01_13_09.334-05_00.csv", + 'medmcqa': "/admin/home/nikhil/med-lm-envs/wandb_export_2026-01-29T02_22_42.275-05_00.csv", + 'medcase': "/admin/home/nikhil/med-lm-envs/wandb_export_2026-01-29T03_37_39.838-05_00.csv", +} + +REWARD_CSVS = { + 'medcalc': "/admin/home/nikhil/med-lm-envs/wandb_export_2026-01-29T01_22_36.871-05_00.csv", + 'medmcqa': "/admin/home/nikhil/med-lm-envs/wandb_export_2026-01-29T02_22_54.385-05_00.csv", + 'medcase': "/admin/home/nikhil/med-lm-envs/wandb_export_2026-01-29T03_37_48.290-05_00.csv", +} + + +def load_test_data(csv_path: str) -> pd.DataFrame: + """Load test accuracy data.""" + df = pd.read_csv(csv_path) + df.columns = ['Step', 'pass@1', 'pass@1_min', 'pass@1_max'] + return df + + +def load_reward_data(csv_path: str) -> pd.DataFrame: + """Load reward data.""" + df = pd.read_csv(csv_path) + df.columns = ['Step', 'reward_mean', 'reward_min', 'reward_max'] + return df + + +def main(): + # Create figure with 2x3 grid of subplots + fig, axes = plt.subplots(2, 3, figsize=(14, 8)) + + datasets = ['medcalc', 'medmcqa', 'medcase'] + + # Y-axis limits for each dataset (test accuracy) + test_ylims = { + 'medcalc': (0.3, 0.7), + 'medmcqa': (0.5, 0.65), + 'medcase': (0.1, 0.4), + } + + # Y-axis limits for each dataset (reward) + reward_ylims = { + 'medcalc': (0.2, 0.9), + 'medmcqa': (0.4, 0.9), + 'medcase': (0.0, 0.4), + } + + # ========================================= + # Row 1: Test Accuracy plots + # ========================================= + for col, dataset_key in enumerate(datasets): + ax = axes[0, col] + df = load_test_data(TEST_CSVS[dataset_key]) + color = COLORS[dataset_key] + label = LABELS[dataset_key] + + # Plot data points + ax.plot(df['Step'], df['pass@1'], 'o-', color=color, + linewidth=2, markersize=4, alpha=0.5, label='Data') + + # Add smoothed trend + window = 3 + df['smoothed'] = df['pass@1'].rolling(window=window, center=True).mean() + ax.plot(df['Step'], df['smoothed'], '-', color=color, + linewidth=3, alpha=1.0, label=f'Smoothed (w={window})') + + ax.set_xlabel('Training Step') + if col == 0: + ax.set_ylabel('Pass@1 Accuracy') + ax.set_title(f'{label}') + ax.set_ylim(test_ylims[dataset_key]) + ax.set_xlim(-10, 350) + ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f'{y:.0%}')) + ax.legend(loc='lower right', fontsize=8, framealpha=0.9) + + # ========================================= + # Row 2: Training Reward plots + # ========================================= + for col, dataset_key in enumerate(datasets): + ax = axes[1, col] + df = load_reward_data(REWARD_CSVS[dataset_key]) + color = COLORS[dataset_key] + label = LABELS[dataset_key] + + # Plot raw data with low alpha + ax.plot(df['Step'], df['reward_mean'], '-', color=color, + linewidth=0.5, alpha=0.3, label='Raw') + + # Add smoothed trend line (EMA) + span = 20 + df['smoothed'] = df['reward_mean'].ewm(span=span).mean() + ax.plot(df['Step'], df['smoothed'], '-', color=color, + linewidth=2.5, alpha=1.0, label=f'Smoothed (EMA={span})') + + ax.set_xlabel('Training Step') + if col == 0: + ax.set_ylabel('Mean Reward') + ax.set_ylim(reward_ylims[dataset_key]) + ax.set_xlim(0, 335) + ax.legend(loc='lower right', fontsize=8, framealpha=0.9) + + # Add row labels + fig.text(0.02, 0.75, 'Test Accuracy', va='center', rotation='vertical', fontsize=12, fontweight='bold') + fig.text(0.02, 0.3, 'Training Reward', va='center', rotation='vertical', fontsize=12, fontweight='bold') + + # Adjust layout + plt.tight_layout() + plt.subplots_adjust(left=0.08) + + # Save + output_dir = Path("/admin/home/nikhil/med-lm-envs/plots") + output_dir.mkdir(exist_ok=True) + plt.savefig(output_dir / 'combined_training.png') + plt.savefig(output_dir / 'combined_training.pdf') + print(f"Saved to {output_dir / 'combined_training.png'}") + print(f"Saved to {output_dir / 'combined_training.pdf'}") + + # Print summary statistics + print("\n" + "="*60) + print("Summary Statistics") + print("="*60) + + for dataset_key in ['medcalc', 'medmcqa', 'medcase']: + test_df = load_test_data(TEST_CSVS[dataset_key]) + reward_df = load_reward_data(REWARD_CSVS[dataset_key]) + + start_acc = test_df['pass@1'].iloc[0] + end_acc = test_df['pass@1'].iloc[-1] + + reward_df['smoothed'] = reward_df['reward_mean'].ewm(span=20).mean() + start_reward = reward_df['smoothed'].iloc[:10].mean() + end_reward = reward_df['smoothed'].iloc[-10:].mean() + + print(f"\n{LABELS[dataset_key]}:") + print(f" Test Accuracy: {start_acc:.1%} → {end_acc:.1%} (+{end_acc - start_acc:.1%})") + print(f" Reward: {start_reward:.3f} → {end_reward:.3f} (+{end_reward - start_reward:.3f})") + + +if __name__ == "__main__": + main() diff --git a/scripts/plot_medcasereasoning_reward.py b/scripts/plot_medcasereasoning_reward.py new file mode 100644 index 00000000..3ba18823 --- /dev/null +++ b/scripts/plot_medcasereasoning_reward.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 +"""Plot MedCaseReasoning reward curve from W&B exported CSV.""" + +import matplotlib.pyplot as plt +import pandas as pd +from pathlib import Path + +# Plasma colormap colors +PLASMA_COLORS = ['#0d0887', '#7e03a8', '#cc4778', '#f89540', '#f0f921'] + +# Style settings for publication +plt.rcParams.update({ + 'font.family': 'serif', + 'font.size': 11, + 'axes.labelsize': 12, + 'axes.titlesize': 13, + 'xtick.labelsize': 10, + 'ytick.labelsize': 10, + 'legend.fontsize': 10, + 'figure.figsize': (8, 5), + 'figure.dpi': 150, + 'savefig.dpi': 300, + 'savefig.bbox': 'tight', + 'axes.grid': False, + 'axes.spines.top': False, + 'axes.spines.right': False, +}) + +# Load data +csv_path = Path("/admin/home/nikhil/med-lm-envs/wandb_export_2026-01-29T03_37_48.290-05_00.csv") +df = pd.read_csv(csv_path) + +# Clean column names +df.columns = ['Step', 'reward_mean', 'reward_min', 'reward_max'] + +# Create plot +fig, ax = plt.subplots(figsize=(8, 5)) + +# Plot raw data with low alpha (plasma colors - same as other datasets) +ax.plot(df['Step'], df['reward_mean'], '-', color=PLASMA_COLORS[1], # 7e03a8 purple + linewidth=0.8, alpha=0.3, label='_nolegend_') + +# Add smoothed trend line (exponential moving average) +span = 20 # smoothing window +df['smoothed'] = df['reward_mean'].ewm(span=span).mean() +ax.plot(df['Step'], df['smoothed'], '-', color=PLASMA_COLORS[3], # f89540 orange + linewidth=2.5, label='MedCaseReasoning (smoothed)') + +# Calculate improvement for summary +start_reward = df['smoothed'].iloc[:10].mean() +end_reward = df['smoothed'].iloc[-10:].mean() +improvement = end_reward - start_reward + +ax.set_title('MedCaseReasoning Training Reward') + +ax.set_xlabel('Training Step') +ax.set_ylabel('Mean Reward') +ax.legend(loc='lower right', framealpha=0.9) +ax.set_ylim(0.0, 0.4) +ax.set_xlim(0, 335) + +plt.tight_layout() + +# Save +output_dir = Path("/admin/home/nikhil/med-lm-envs/plots") +output_dir.mkdir(exist_ok=True) +plt.savefig(output_dir / 'medcasereasoning_reward.png') +plt.savefig(output_dir / 'medcasereasoning_reward.pdf') +print(f"Saved to {output_dir / 'medcasereasoning_reward.png'}") +print(f"Saved to {output_dir / 'medcasereasoning_reward.pdf'}") + +# Print summary +print(f"\nReward Summary:") +print(f" Start (smoothed): {start_reward:.3f}") +print(f" End (smoothed): {end_reward:.3f}") +print(f" Improvement: +{improvement:.3f}") +print(f" Raw min: {df['reward_mean'].min():.3f}") +print(f" Raw max: {df['reward_mean'].max():.3f}") diff --git a/scripts/plot_medcasereasoning_training.py b/scripts/plot_medcasereasoning_training.py new file mode 100644 index 00000000..4ea868cf --- /dev/null +++ b/scripts/plot_medcasereasoning_training.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 +"""Plot MedCaseReasoning training accuracy from W&B exported CSV.""" + +import matplotlib.pyplot as plt +import pandas as pd +from pathlib import Path + +# Plasma colormap colors +PLASMA_COLORS = ['#0d0887', '#7e03a8', '#cc4778', '#f89540', '#f0f921'] + +# Style settings for publication +plt.rcParams.update({ + 'font.family': 'serif', + 'font.size': 11, + 'axes.labelsize': 12, + 'axes.titlesize': 13, + 'xtick.labelsize': 10, + 'ytick.labelsize': 10, + 'legend.fontsize': 10, + 'figure.figsize': (8, 5), + 'figure.dpi': 150, + 'savefig.dpi': 300, + 'savefig.bbox': 'tight', + 'axes.grid': False, + 'axes.spines.top': False, + 'axes.spines.right': False, +}) + +# Load data +csv_path = Path("/admin/home/nikhil/med-lm-envs/wandb_export_2026-01-29T03_37_39.838-05_00.csv") +df = pd.read_csv(csv_path) + +# Clean column names +df.columns = ['Step', 'pass@1', 'pass@1_min', 'pass@1_max'] + +# Create plot +fig, ax = plt.subplots(figsize=(8, 5)) + +# Plot main line (plasma colors - same as other datasets) +ax.plot(df['Step'], df['pass@1'], 'o-', color=PLASMA_COLORS[1], # 7e03a8 purple + linewidth=2, markersize=5, alpha=0.8, label='MedCaseReasoning') + +# Add trend line (moving average) +window = 3 +df['smoothed'] = df['pass@1'].rolling(window=window, center=True).mean() +ax.plot(df['Step'], df['smoothed'], '--', color=PLASMA_COLORS[3], # f89540 orange + linewidth=2, alpha=0.8, label=f'Smoothed (window={window})') + +# For summary stats +start_acc = df['pass@1'].iloc[0] +end_acc = df['pass@1'].iloc[-1] +improvement = end_acc - start_acc + +ax.set_xlabel('Training Step') +ax.set_ylabel('Pass@1 Accuracy') +ax.set_title('MedCaseReasoning Test Accuracy') +ax.legend(loc='lower right', framealpha=0.9) +ax.set_ylim(0.1, 0.4) +ax.set_xlim(-10, 350) + +ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f'{y:.0%}')) + +plt.tight_layout() + +# Save +output_dir = Path("/admin/home/nikhil/med-lm-envs/plots") +output_dir.mkdir(exist_ok=True) +plt.savefig(output_dir / 'medcasereasoning_training.png') +plt.savefig(output_dir / 'medcasereasoning_training.pdf') +print(f"Saved to {output_dir / 'medcasereasoning_training.png'}") +print(f"Saved to {output_dir / 'medcasereasoning_training.pdf'}") + +# Print summary +print(f"\nTraining Summary:") +print(f" Start accuracy: {start_acc:.1%}") +print(f" End accuracy: {end_acc:.1%}") +print(f" Improvement: +{improvement:.1%}") +print(f" Steps: {df['Step'].iloc[-1]}") diff --git a/scripts/plot_medmcqa_reward.py b/scripts/plot_medmcqa_reward.py new file mode 100644 index 00000000..c88e3ce6 --- /dev/null +++ b/scripts/plot_medmcqa_reward.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 +"""Plot MedMCQA reward curve from W&B exported CSV.""" + +import matplotlib.pyplot as plt +import pandas as pd +from pathlib import Path + +# Plasma colormap colors +PLASMA_COLORS = ['#0d0887', '#7e03a8', '#cc4778', '#f89540', '#f0f921'] + +# Style settings for publication +plt.rcParams.update({ + 'font.family': 'serif', + 'font.size': 11, + 'axes.labelsize': 12, + 'axes.titlesize': 13, + 'xtick.labelsize': 10, + 'ytick.labelsize': 10, + 'legend.fontsize': 10, + 'figure.figsize': (8, 5), + 'figure.dpi': 150, + 'savefig.dpi': 300, + 'savefig.bbox': 'tight', + 'axes.grid': False, + 'axes.spines.top': False, + 'axes.spines.right': False, +}) + +# Load data +csv_path = Path("/admin/home/nikhil/med-lm-envs/wandb_export_2026-01-29T02_22_54.385-05_00.csv") +df = pd.read_csv(csv_path) + +# Clean column names +df.columns = ['Step', 'reward_mean', 'reward_min', 'reward_max'] + +# Create plot +fig, ax = plt.subplots(figsize=(8, 5)) + +# Plot raw data with low alpha (plasma colors - same as MedCalc) +ax.plot(df['Step'], df['reward_mean'], '-', color=PLASMA_COLORS[1], # 7e03a8 purple + linewidth=0.8, alpha=0.3, label='_nolegend_') + +# Add smoothed trend line (exponential moving average) +span = 20 # smoothing window +df['smoothed'] = df['reward_mean'].ewm(span=span).mean() +ax.plot(df['Step'], df['smoothed'], '-', color=PLASMA_COLORS[3], # f89540 orange + linewidth=2.5, label='MedMCQA (smoothed)') + +# Calculate improvement for summary +start_reward = df['smoothed'].iloc[:10].mean() +end_reward = df['smoothed'].iloc[-10:].mean() +improvement = end_reward - start_reward + +ax.set_title('MedMCQA Training Reward') + +ax.set_xlabel('Training Step') +ax.set_ylabel('Mean Reward') +ax.legend(loc='lower right', framealpha=0.9) +ax.set_ylim(0.4, 0.9) +ax.set_xlim(0, 335) + +plt.tight_layout() + +# Save +output_dir = Path("/admin/home/nikhil/med-lm-envs/plots") +output_dir.mkdir(exist_ok=True) +plt.savefig(output_dir / 'medmcqa_reward.png') +plt.savefig(output_dir / 'medmcqa_reward.pdf') +print(f"Saved to {output_dir / 'medmcqa_reward.png'}") +print(f"Saved to {output_dir / 'medmcqa_reward.pdf'}") + +# Print summary +print(f"\nReward Summary:") +print(f" Start (smoothed): {start_reward:.3f}") +print(f" End (smoothed): {end_reward:.3f}") +print(f" Improvement: +{improvement:.3f}") +print(f" Raw min: {df['reward_mean'].min():.3f}") +print(f" Raw max: {df['reward_mean'].max():.3f}") diff --git a/scripts/plot_medmcqa_training.py b/scripts/plot_medmcqa_training.py new file mode 100644 index 00000000..e4d39073 --- /dev/null +++ b/scripts/plot_medmcqa_training.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 +"""Plot MedMCQA training accuracy from W&B exported CSV.""" + +import matplotlib.pyplot as plt +import pandas as pd +from pathlib import Path + +# Plasma colormap colors +PLASMA_COLORS = ['#0d0887', '#7e03a8', '#cc4778', '#f89540', '#f0f921'] + +# Style settings for publication +plt.rcParams.update({ + 'font.family': 'serif', + 'font.size': 11, + 'axes.labelsize': 12, + 'axes.titlesize': 13, + 'xtick.labelsize': 10, + 'ytick.labelsize': 10, + 'legend.fontsize': 10, + 'figure.figsize': (8, 5), + 'figure.dpi': 150, + 'savefig.dpi': 300, + 'savefig.bbox': 'tight', + 'axes.grid': False, + 'axes.spines.top': False, + 'axes.spines.right': False, +}) + +# Load data +csv_path = Path("/admin/home/nikhil/med-lm-envs/wandb_export_2026-01-29T02_22_42.275-05_00.csv") +df = pd.read_csv(csv_path) + +# Clean column names +df.columns = ['Step', 'pass@1', 'pass@1_min', 'pass@1_max'] + +# Create plot +fig, ax = plt.subplots(figsize=(8, 5)) + +# Plot main line (plasma colors - same as MedCalc) +ax.plot(df['Step'], df['pass@1'], 'o-', color=PLASMA_COLORS[1], # 7e03a8 purple + linewidth=2, markersize=5, alpha=0.8, label='MedMCQA') + +# Add trend line (moving average) +window = 3 +df['smoothed'] = df['pass@1'].rolling(window=window, center=True).mean() +ax.plot(df['Step'], df['smoothed'], '--', color=PLASMA_COLORS[3], # f89540 orange + linewidth=2, alpha=0.8, label=f'Smoothed (window={window})') + +# For summary stats +start_acc = df['pass@1'].iloc[0] +end_acc = df['pass@1'].iloc[-1] +improvement = end_acc - start_acc + +ax.set_xlabel('Training Step') +ax.set_ylabel('Pass@1 Accuracy') +ax.set_title('MedMCQA Test Accuracy') +ax.legend(loc='lower right', framealpha=0.9) +ax.set_ylim(0.5, 0.7) +ax.set_xlim(-10, 350) + +ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f'{y:.0%}')) + +plt.tight_layout() + +# Save +output_dir = Path("/admin/home/nikhil/med-lm-envs/plots") +output_dir.mkdir(exist_ok=True) +plt.savefig(output_dir / 'medmcqa_training.png') +plt.savefig(output_dir / 'medmcqa_training.pdf') +print(f"Saved to {output_dir / 'medmcqa_training.png'}") +print(f"Saved to {output_dir / 'medmcqa_training.pdf'}") + +# Print summary +print(f"\nTraining Summary:") +print(f" Start accuracy: {start_acc:.1%}") +print(f" End accuracy: {end_acc:.1%}") +print(f" Improvement: +{improvement:.1%}") +print(f" Steps: {df['Step'].iloc[-1]}") diff --git a/scripts/plot_reward_csv.py b/scripts/plot_reward_csv.py new file mode 100644 index 00000000..49d39c91 --- /dev/null +++ b/scripts/plot_reward_csv.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 +"""Plot reward curve from W&B exported CSV.""" + +import matplotlib.pyplot as plt +import pandas as pd +from pathlib import Path + +# Plasma colormap colors +PLASMA_COLORS = ['#0d0887', '#7e03a8', '#cc4778', '#f89540', '#f0f921'] + +# Style settings for publication +plt.rcParams.update({ + 'font.family': 'serif', + 'font.size': 11, + 'axes.labelsize': 12, + 'axes.titlesize': 13, + 'xtick.labelsize': 10, + 'ytick.labelsize': 10, + 'legend.fontsize': 10, + 'figure.figsize': (8, 5), + 'figure.dpi': 150, + 'savefig.dpi': 300, + 'savefig.bbox': 'tight', + 'axes.grid': False, + 'axes.spines.top': False, + 'axes.spines.right': False, +}) + +# Load data +csv_path = Path("/admin/home/nikhil/med-lm-envs/wandb_export_2026-01-29T01_22_36.871-05_00.csv") +df = pd.read_csv(csv_path) + +# Clean column names +df.columns = ['Step', 'reward_mean', 'reward_min', 'reward_max'] + +# Create plot +fig, ax = plt.subplots(figsize=(8, 5)) + +# Plot raw data with low alpha (plasma colors) +ax.plot(df['Step'], df['reward_mean'], '-', color=PLASMA_COLORS[1], + linewidth=0.8, alpha=0.3, label='_nolegend_') + +# Add smoothed trend line (exponential moving average) +span = 20 # smoothing window +df['smoothed'] = df['reward_mean'].ewm(span=span).mean() +ax.plot(df['Step'], df['smoothed'], '-', color=PLASMA_COLORS[3], + linewidth=2.5, label='MedCalc-Bench-Verified (smoothed)') + +# Calculate improvement for title +start_reward = df['smoothed'].iloc[:10].mean() +end_reward = df['smoothed'].iloc[-10:].mean() +improvement = end_reward - start_reward + +ax.set_title('MedCalc-Bench-Verified Training Reward') + +ax.set_xlabel('Training Step') +ax.set_ylabel('Mean Reward') +ax.legend(loc='lower right', framealpha=0.9) +ax.set_ylim(0.2, 0.9) +ax.set_xlim(0, 335) + +plt.tight_layout() + +# Save +output_dir = Path("/admin/home/nikhil/med-lm-envs/plots") +output_dir.mkdir(exist_ok=True) +plt.savefig(output_dir / 'medcalc_reward.png') +plt.savefig(output_dir / 'medcalc_reward.pdf') +print(f"Saved to {output_dir / 'medcalc_reward.png'}") +print(f"Saved to {output_dir / 'medcalc_reward.pdf'}") + +# Print summary +print(f"\nReward Summary:") +print(f" Start (smoothed): {start_reward:.3f}") +print(f" End (smoothed): {end_reward:.3f}") +print(f" Improvement: +{improvement:.3f}") +print(f" Raw min: {df['reward_mean'].min():.3f}") +print(f" Raw max: {df['reward_mean'].max():.3f}") diff --git a/scripts/plot_training_results.py b/scripts/plot_training_results.py new file mode 100644 index 00000000..5917c889 --- /dev/null +++ b/scripts/plot_training_results.py @@ -0,0 +1,609 @@ +#!/usr/bin/env python3 +""" +Generate publication-quality plots from RL training results. +Fetches data from W&B and creates matplotlib figures. +""" + +import argparse +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +# Try to import wandb, fall back to manual data if not available + +import wandb +WANDB_AVAILABLE = True + + +# Style settings for publication +plt.rcParams.update({ + 'font.family': 'serif', + 'font.size': 11, + 'axes.labelsize': 12, + 'axes.titlesize': 12, + 'xtick.labelsize': 10, + 'ytick.labelsize': 10, + 'legend.fontsize': 10, + 'figure.figsize': (8, 5), + 'figure.dpi': 150, + 'savefig.dpi': 300, + 'savefig.bbox': 'tight', + 'axes.grid': True, + 'grid.alpha': 0.3, +}) + +# Color palette (colorblind-friendly) +COLORS = { + 'medmcqa': '#2196F3', # Blue + 'medcalc': '#4CAF50', # Green + 'medcase': '#FF9800', # Orange + 'combined': '#9C27B0', # Purple +} + +LABELS = { + 'medmcqa': 'MedMCQA', + 'medcalc': 'MedCalc-Bench', + 'medcase': 'MedCaseReasoning', +} + + +def fetch_wandb_data(project: str, run_name: str | None = None) -> pd.DataFrame: + """Fetch training data from W&B.""" + if not WANDB_AVAILABLE: + raise ImportError("wandb not installed. Install with: pip install wandb") + + api = wandb.Api() + runs = api.runs(project) + + if run_name: + runs = [r for r in runs if run_name in r.name] + + all_data = [] + for run in runs: + history = run.history() + history['run_name'] = run.name + all_data.append(history) + + return pd.concat(all_data, ignore_index=True) + + +def fetch_wandb_run( + entity: str, + project: str, + run_id: str, + metrics: list[str] | None = None, +) -> pd.DataFrame: + """ + Fetch specific metrics from a W&B run. + + Args: + entity: W&B entity (username or team) + project: W&B project name + run_id: The run ID (from the URL, e.g., 'abc123xy') + metrics: List of metric names to fetch. If None, fetches all. + + Returns: + DataFrame with step and requested metrics + """ + if not WANDB_AVAILABLE: + raise ImportError("wandb not installed. Install with: pip install wandb") + + api = wandb.Api() + run = api.run(f"{entity}/{project}/{run_id}") + + # Get history with specific keys if provided + if metrics: + # Always include _step + keys = ['_step'] + [m for m in metrics if m != '_step'] + history = run.history(keys=keys) + else: + history = run.history() + + return history + + +def fetch_training_data_from_wandb( + entity: str, + project: str, + run_id: str, + reward_key: str = "train/mean_reward", + accuracy_keys: dict[str, str] | None = None, +) -> tuple[dict, dict]: + """ + Fetch reward and accuracy data from a W&B run. + + Args: + entity: W&B entity + project: W&B project + run_id: Run ID + reward_key: Key for reward metric + accuracy_keys: Dict mapping env name to accuracy metric key + e.g., {'medmcqa': 'eval/medmcqa/pass@1'} + + Returns: + Tuple of (accuracy_data, reward_data) ready for plotting + """ + if not WANDB_AVAILABLE: + raise ImportError("wandb not installed. Install with: pip install wandb") + + api = wandb.Api() + run = api.run(f"{entity}/{project}/{run_id}") + + # Fetch all history + history = run.history() + + # Process reward data + reward_df = history[['_step', reward_key]].dropna() + reward_data = { + 'combined': { + 'steps': reward_df['_step'].tolist(), + 'reward': reward_df[reward_key].tolist(), + 'label': 'Combined Training', + } + } + + # Process accuracy data if keys provided + accuracy_data = {} + if accuracy_keys: + for env_name, key in accuracy_keys.items(): + if key in history.columns: + acc_df = history[['_step', key]].dropna() + accuracy_data[env_name] = { + 'steps': acc_df['_step'].tolist(), + 'values': acc_df[key].tolist(), + } + + return accuracy_data, reward_data + + +def list_wandb_runs(entity: str, project: str, limit: int = 20): + """List recent runs in a W&B project.""" + if not WANDB_AVAILABLE: + raise ImportError("wandb not installed. Install with: pip install wandb") + + api = wandb.Api() + runs = api.runs(f"{entity}/{project}", per_page=limit) + + print(f"\nRecent runs in {entity}/{project}:") + print("-" * 80) + for run in runs: + print(f" ID: {run.id}") + print(f" Name: {run.name}") + print(f" State: {run.state}") + print(f" Created: {run.created_at}") + print(f" URL: {run.url}") + print("-" * 80) + + +def list_wandb_metrics(entity: str, project: str, run_id: str): + """List available metrics in a W&B run.""" + if not WANDB_AVAILABLE: + raise ImportError("wandb not installed. Install with: pip install wandb") + + api = wandb.Api() + run = api.run(f"{entity}/{project}/{run_id}") + history = run.history() + + print(f"\nAvailable metrics in run {run_id}:") + print("-" * 40) + for col in sorted(history.columns): + non_null = history[col].notna().sum() + print(f" {col}: {non_null} values") + + +def plot_from_wandb( + entity: str, + project: str, + run_id: str, + output_dir: Path, + reward_key: str = "train/mean_reward", + accuracy_keys: dict[str, str] | None = None, +): + """ + Fetch data from W&B and generate plots. + + Example: + plot_from_wandb( + entity="your-username", + project="medical-rl", + run_id="abc123xy", + output_dir=Path("plots"), + reward_key="train/mean_reward", + accuracy_keys={ + 'medmcqa': 'eval/medmcqa/pass@1', + 'medcalc': 'eval/medcalc_bench/pass@1', + 'medcase': 'eval/medcasereasoning/pass@1', + } + ) + """ + output_dir.mkdir(parents=True, exist_ok=True) + + accuracy_data, reward_data = fetch_training_data_from_wandb( + entity, project, run_id, reward_key, accuracy_keys + ) + + if accuracy_data: + plot_training_curves( + accuracy_data, + output_dir / 'training_curves_wandb.png', + title='Medical RL Training Progress', + ylabel='Pass@1 Accuracy' + ) + + if reward_data: + plot_reward_curves( + reward_data, + output_dir / 'reward_curves_wandb.png', + title='Training Reward' + ) + + if accuracy_data and reward_data: + plot_multi_panel( + accuracy_data, + reward_data, + output_dir / 'training_combined_wandb.png', + title='Medical RL Training' + ) + + print(f"\nPlots saved to {output_dir}") + + +def plot_training_curves( + data: dict[str, dict], + output_path: Path, + title: str = "RL Training Progress", + ylabel: str = "Pass@1 Accuracy", +): + """ + Plot training curves for multiple environments. + + Args: + data: Dict of {env_name: {'steps': [...], 'values': [...], 'label': str}} + output_path: Where to save the figure + title: Plot title + ylabel: Y-axis label + """ + fig, ax = plt.subplots(figsize=(8, 5)) + + for env_name, env_data in data.items(): + steps = env_data['steps'] + values = env_data['values'] + label = env_data.get('label', LABELS.get(env_name, env_name)) + color = COLORS.get(env_name, None) + + ax.plot(steps, values, 'o-', label=label, color=color, + linewidth=2, markersize=6, alpha=0.8) + + # Add error bars if provided + if 'std' in env_data: + ax.fill_between(steps, + np.array(values) - np.array(env_data['std']), + np.array(values) + np.array(env_data['std']), + alpha=0.2, color=color) + + ax.set_xlabel('Training Step') + ax.set_ylabel(ylabel) + ax.set_title(title) + ax.legend(loc='best', framealpha=0.9) + ax.set_ylim(0, 1) + + # Add minor gridlines + ax.grid(True, which='major', linestyle='-', alpha=0.3) + ax.grid(True, which='minor', linestyle=':', alpha=0.2) + ax.minorticks_on() + + plt.tight_layout() + plt.savefig(output_path) + plt.savefig(output_path.with_suffix('.pdf')) # Also save PDF for paper + print(f"Saved plot to {output_path} and {output_path.with_suffix('.pdf')}") + plt.close() + + +def plot_comparison_bar( + data: dict[str, dict], + output_path: Path, + title: str = "Model Comparison", +): + """ + Create a grouped bar chart comparing base model vs trained model. + + Args: + data: Dict of {env_name: {'base': float, 'trained': float}} + """ + fig, ax = plt.subplots(figsize=(8, 5)) + + envs = list(data.keys()) + x = np.arange(len(envs)) + width = 0.35 + + base_values = [data[env]['base'] for env in envs] + trained_values = [data[env]['trained'] for env in envs] + + bars1 = ax.bar(x - width/2, base_values, width, label='Base Model', + color='#9E9E9E', alpha=0.8) + bars2 = ax.bar(x + width/2, trained_values, width, label='RL-Trained', + color='#2196F3', alpha=0.8) + + # Add value labels on bars + def add_labels(bars): + for bar in bars: + height = bar.get_height() + ax.annotate(f'{height:.1%}', + xy=(bar.get_x() + bar.get_width() / 2, height), + xytext=(0, 3), + textcoords="offset points", + ha='center', va='bottom', fontsize=9) + + add_labels(bars1) + add_labels(bars2) + + ax.set_ylabel('Pass@1 Accuracy') + ax.set_title(title) + ax.set_xticks(x) + ax.set_xticklabels([LABELS.get(env, env) for env in envs]) + ax.legend(loc='upper right') + ax.set_ylim(0, 1) + + plt.tight_layout() + plt.savefig(output_path) + plt.savefig(output_path.with_suffix('.pdf')) + print(f"Saved plot to {output_path} and {output_path.with_suffix('.pdf')}") + plt.close() + + +def plot_reward_curves( + data: dict[str, dict], + output_path: Path, + title: str = "Training Reward", +): + """ + Plot reward curves during training. + + Args: + data: Dict of {env_name: {'steps': [...], 'reward': [...], 'reward_std': [...]}} + output_path: Where to save the figure + title: Plot title + """ + fig, ax = plt.subplots(figsize=(8, 5)) + + for env_name, env_data in data.items(): + steps = env_data['steps'] + rewards = env_data['reward'] + label = env_data.get('label', LABELS.get(env_name, env_name)) + color = COLORS.get(env_name, None) + + ax.plot(steps, rewards, '-', label=label, color=color, + linewidth=2, alpha=0.8) + + # Add shading for std if provided + if 'reward_std' in env_data: + ax.fill_between(steps, + np.array(rewards) - np.array(env_data['reward_std']), + np.array(rewards) + np.array(env_data['reward_std']), + alpha=0.2, color=color) + + ax.set_xlabel('Training Step') + ax.set_ylabel('Mean Reward') + ax.set_title(title) + ax.legend(loc='best', framealpha=0.9) + + # Add minor gridlines + ax.grid(True, which='major', linestyle='-', alpha=0.3) + ax.grid(True, which='minor', linestyle=':', alpha=0.2) + ax.minorticks_on() + + plt.tight_layout() + plt.savefig(output_path) + plt.savefig(output_path.with_suffix('.pdf')) + print(f"Saved plot to {output_path} and {output_path.with_suffix('.pdf')}") + plt.close() + + +def plot_multi_panel( + accuracy_data: dict[str, dict], + reward_data: dict[str, dict], + output_path: Path, + title: str = "Medical RL Training", +): + """ + Create a two-panel figure with accuracy and reward curves. + """ + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4.5)) + + # Left panel: Accuracy + for env_name, env_data in accuracy_data.items(): + steps = env_data['steps'] + values = env_data['values'] + label = env_data.get('label', LABELS.get(env_name, env_name)) + color = COLORS.get(env_name, None) + + ax1.plot(steps, values, 'o-', label=label, color=color, + linewidth=2, markersize=5, alpha=0.8) + + if 'std' in env_data: + ax1.fill_between(steps, + np.array(values) - np.array(env_data['std']), + np.array(values) + np.array(env_data['std']), + alpha=0.15, color=color) + + ax1.set_xlabel('Training Step') + ax1.set_ylabel('Pass@1 Accuracy') + ax1.set_title('(a) Evaluation Accuracy') + ax1.legend(loc='lower right', framealpha=0.9) + ax1.set_ylim(0, 1) + ax1.grid(True, which='major', linestyle='-', alpha=0.3) + + # Right panel: Reward + for env_name, env_data in reward_data.items(): + steps = env_data['steps'] + rewards = env_data['reward'] + label = env_data.get('label', LABELS.get(env_name, env_name)) + color = COLORS.get(env_name, None) + + ax2.plot(steps, rewards, '-', label=label, color=color, + linewidth=2, alpha=0.8) + + if 'reward_std' in env_data: + ax2.fill_between(steps, + np.array(rewards) - np.array(env_data['reward_std']), + np.array(rewards) + np.array(env_data['reward_std']), + alpha=0.15, color=color) + + ax2.set_xlabel('Training Step') + ax2.set_ylabel('Mean Reward') + ax2.set_title('(b) Training Reward') + ax2.legend(loc='lower right', framealpha=0.9) + ax2.grid(True, which='major', linestyle='-', alpha=0.3) + + fig.suptitle(title, fontsize=14, y=1.02) + plt.tight_layout() + plt.savefig(output_path) + plt.savefig(output_path.with_suffix('.pdf')) + print(f"Saved plot to {output_path} and {output_path.with_suffix('.pdf')}") + plt.close() + + # Individual plots + plot_training_curves( + training_data, + output_dir / 'training_curves.png', + title='Medical RL Training Progress', + ylabel='Pass@1 Accuracy' + ) + + plot_reward_curves( + reward_data, + output_dir / 'reward_curves.png', + title='Training Reward' + ) + + # Combined multi-panel figure (good for papers) + plot_multi_panel( + training_data, + reward_data, + output_dir / 'training_combined.png', + title='Medical RL Training' + ) + + # Comparison bar chart + comparison_data = { + 'medmcqa': {'base': 0.59, 'trained': 0.70}, + 'medcalc': {'base': 0.42, 'trained': 0.59}, + 'medcase': {'base': 0.19, 'trained': 0.31}, + } + + plot_comparison_bar( + comparison_data, + output_dir / 'model_comparison.png', + title='Base Model vs RL-Trained Model' + ) + + print(f"\nExample plots created in {output_dir}") + print("Replace placeholder data with actual W&B data for final plots.") + + +def export_wandb_to_csv(project: str, output_path: Path): + """Export W&B data to CSV for manual plotting.""" + if not WANDB_AVAILABLE: + print("wandb not installed. Install with: pip install wandb") + return + + api = wandb.Api() + runs = api.runs(project) + + for run in runs: + history = run.history() + csv_path = output_path / f"{run.name.replace('/', '_')}.csv" + history.to_csv(csv_path, index=False) + print(f"Exported {run.name} to {csv_path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Generate publication-quality plots") + parser.add_argument("--output-dir", type=Path, default=Path("plots"), + help="Output directory for plots") + parser.add_argument("--example", action="store_true", + help="Generate example plots with placeholder data") + + # W&B options + parser.add_argument("--wandb-entity", type=str, default=None, + help="W&B entity (username or team)") + parser.add_argument("--wandb-project", type=str, default=None, + help="W&B project name") + parser.add_argument("--wandb-run", type=str, default=None, + help="W&B run ID") + parser.add_argument("--list-runs", action="store_true", + help="List runs in W&B project") + parser.add_argument("--list-metrics", action="store_true", + help="List metrics in a W&B run") + parser.add_argument("--export-csv", action="store_true", + help="Export W&B data to CSV") + + # Metric keys + parser.add_argument("--reward-key", type=str, default="train/mean_reward", + help="W&B key for reward metric") + + args = parser.parse_args() + + if args.example: + create_example_plots(args.output_dir) + + elif args.list_runs: + if not args.wandb_entity or not args.wandb_project: + print("Error: --wandb-entity and --wandb-project required") + else: + list_wandb_runs(args.wandb_entity, args.wandb_project) + + elif args.list_metrics: + if not args.wandb_entity or not args.wandb_project or not args.wandb_run: + print("Error: --wandb-entity, --wandb-project, and --wandb-run required") + else: + list_wandb_metrics(args.wandb_entity, args.wandb_project, args.wandb_run) + + elif args.export_csv: + if not args.wandb_entity or not args.wandb_project or not args.wandb_run: + print("Error: --wandb-entity, --wandb-project, and --wandb-run required") + else: + args.output_dir.mkdir(parents=True, exist_ok=True) + df = fetch_wandb_run(args.wandb_entity, args.wandb_project, args.wandb_run) + csv_path = args.output_dir / f"{args.wandb_run}.csv" + df.to_csv(csv_path, index=False) + print(f"Exported to {csv_path}") + + elif args.wandb_run: + if not args.wandb_entity or not args.wandb_project: + print("Error: --wandb-entity and --wandb-project required") + else: + # Default accuracy keys - customize these for your runs + accuracy_keys = { + 'medmcqa': 'eval/medmcqa/pass@1', + 'medcalc': 'eval/medcalc_bench/pass@1', + 'medcase': 'eval/medcasereasoning/pass@1', + } + plot_from_wandb( + args.wandb_entity, + args.wandb_project, + args.wandb_run, + args.output_dir, + args.reward_key, + accuracy_keys, + ) + + else: + print("""Usage: + # Generate example plots with placeholder data + python plot_training_results.py --example + + # List runs in a W&B project + python plot_training_results.py --list-runs --wandb-entity USER --wandb-project PROJECT + + # List metrics available in a run + python plot_training_results.py --list-metrics --wandb-entity USER --wandb-project PROJECT --wandb-run RUN_ID + + # Export run data to CSV + python plot_training_results.py --export-csv --wandb-entity USER --wandb-project PROJECT --wandb-run RUN_ID + + # Generate plots from W&B run + python plot_training_results.py --wandb-entity USER --wandb-project PROJECT --wandb-run RUN_ID +""") diff --git a/scripts/plot_wandb_csv.py b/scripts/plot_wandb_csv.py new file mode 100644 index 00000000..2686cb8f --- /dev/null +++ b/scripts/plot_wandb_csv.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python3 +"""Plot W&B exported CSV data.""" + +import matplotlib.pyplot as plt +import pandas as pd +from pathlib import Path + +# Plasma colormap colors +PLASMA_COLORS = ['#0d0887', '#7e03a8', '#cc4778', '#f89540', '#f0f921'] + +# Style settings for publication +plt.rcParams.update({ + 'font.family': 'serif', + 'font.size': 11, + 'axes.labelsize': 12, + 'axes.titlesize': 13, + 'xtick.labelsize': 10, + 'ytick.labelsize': 10, + 'legend.fontsize': 10, + 'figure.figsize': (8, 5), + 'figure.dpi': 150, + 'savefig.dpi': 300, + 'savefig.bbox': 'tight', + 'axes.grid': False, + 'axes.spines.top': False, + 'axes.spines.right': False, +}) + +# Load data +csv_path = Path("/admin/home/nikhil/med-lm-envs/wandb_export_2026-01-29T01_13_09.334-05_00.csv") +df = pd.read_csv(csv_path) + +# Clean column names +df.columns = ['Step', 'pass@1', 'pass@1_min', 'pass@1_max'] + +# Optional: filter to max_steps from config +# df = df[df['Step'] <= 300] + +# Create plot +fig, ax = plt.subplots(figsize=(8, 5)) + +# Plot main line (plasma colors) +ax.plot(df['Step'], df['pass@1'], 'o-', color=PLASMA_COLORS[1], + linewidth=2, markersize=5, alpha=0.8, label='MedCalc-Bench-Verified') + +# Add trend line (moving average) +window = 3 +df['smoothed'] = df['pass@1'].rolling(window=window, center=True).mean() +ax.plot(df['Step'], df['smoothed'], '--', color=PLASMA_COLORS[3], + linewidth=2, alpha=0.8, label=f'Smoothed (window={window})') + +# For summary stats +start_acc = df['pass@1'].iloc[0] +end_acc = df['pass@1'].iloc[-1] +improvement = end_acc - start_acc + +ax.set_xlabel('Training Step') +ax.set_ylabel('Pass@1 Accuracy') +ax.set_title('MedCalc-Bench-Verified Test Accuracy') +ax.legend(loc='lower right', framealpha=0.9) +ax.set_ylim(0.3, 0.7) +ax.set_xlim(-10, 350) + +ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f'{y:.0%}')) + +plt.tight_layout() + +# Save +output_dir = Path("/admin/home/nikhil/med-lm-envs/plots") +output_dir.mkdir(exist_ok=True) +plt.savefig(output_dir / 'medcalc_training.png') +plt.savefig(output_dir / 'medcalc_training.pdf') +print(f"Saved to {output_dir / 'medcalc_training.png'}") +print(f"Saved to {output_dir / 'medcalc_training.pdf'}") + +# Print summary +print(f"\nTraining Summary:") +print(f" Start accuracy: {start_acc:.1%}") +print(f" End accuracy: {end_acc:.1%}") +print(f" Improvement: +{improvement:.1%}") +print(f" Steps: {df['Step'].iloc[-1]}") diff --git a/scripts/upload_models_to_hf.py b/scripts/upload_models_to_hf.py new file mode 100644 index 00000000..96dd1ece --- /dev/null +++ b/scripts/upload_models_to_hf.py @@ -0,0 +1,145 @@ +#!/usr/bin/env python3 +"""Upload trained models to HuggingFace Hub.""" + +import os +from huggingface_hub import HfApi, upload_folder + +# Models to upload +MODELS = { + "nsk7153/Qwen3-4B-MedCombined-RL": { + "path": "/admin/home/nikhil/med-lm-envs/prime-rl/outputs_combined/weights/step_500", + "description": "Qwen3-4B fine-tuned with RL on combined medical datasets (MedCalc-Bench, MedMCQA, MedCaseReasoning)", + }, + "nsk7153/Qwen3-4B-MedCalc-RL": { + "path": "/admin/home/nikhil/med-lm-envs/prime-rl/outputs_verified_short/weights/step_300", + "description": "Qwen3-4B fine-tuned with RL on MedCalc-Bench-Verified for medical calculations", + }, + "nsk7153/Qwen3-4B-MedMCQA-RL": { + "path": "/admin/home/nikhil/med-lm-envs/prime-rl/outputs_medmcqa/weights/step_300", + "description": "Qwen3-4B fine-tuned with RL on MedMCQA for medical multiple choice QA", + }, + "nsk7153/Qwen3-4B-MedCaseReasoning-RL": { + "path": "/admin/home/nikhil/med-lm-envs/prime-rl/outputs_medcasereasoning/weights/step_300", + "description": "Qwen3-4B fine-tuned with RL on MedCaseReasoning for clinical case analysis", + }, +} + +# Model card template +MODEL_CARD_TEMPLATE = """--- +license: apache-2.0 +base_model: Qwen/Qwen3-4B-Instruct-2507 +tags: +- medical +- reinforcement-learning +- qwen3 +- healthcare +--- + +# {repo_name} + +{description} + +## Model Details + +- **Base Model**: [Qwen/Qwen3-4B-Instruct-2507](https://huggingface.co/Qwen/Qwen3-4B-Instruct-2507) +- **Training Method**: Reinforcement Learning (GRPO) +- **Framework**: [verifiers](https://github.com/willieneis/verifiers) + [prime-rl](https://github.com/PRIME-RL/PRIME-RL) + +## Training Data + +{training_data} + +## Usage + +```python +from transformers import AutoModelForCausalLM, AutoTokenizer + +model = AutoModelForCausalLM.from_pretrained("{repo_id}") +tokenizer = AutoTokenizer.from_pretrained("{repo_id}") + +# Example usage +messages = [ + {{"role": "user", "content": "Your medical question here"}} +] +text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) +inputs = tokenizer(text, return_tensors="pt") +outputs = model.generate(**inputs, max_new_tokens=512) +print(tokenizer.decode(outputs[0], skip_special_tokens=True)) +``` + +## License + +Apache 2.0 +""" + +TRAINING_DATA = { + "nsk7153/Qwen3-4B-MedCombined-RL": """ +This model was trained on a combination of three medical datasets: +- **MedCalc-Bench-Verified**: Medical calculation problems with verified numerical answers +- **MedMCQA**: Medical multiple choice questions from AIIMS/NEET PG exams +- **MedCaseReasoning**: Clinical case reasoning with LLM-as-judge evaluation +""", + "nsk7153/Qwen3-4B-MedCalc-RL": """ +This model was trained on **MedCalc-Bench-Verified**, a dataset of medical calculation problems +requiring numerical reasoning about drug dosages, lab values, clinical scores, etc. +""", + "nsk7153/Qwen3-4B-MedMCQA-RL": """ +This model was trained on **MedMCQA**, a large-scale multiple choice question dataset +covering various medical topics from AIIMS/NEET PG entrance exams. +""", + "nsk7153/Qwen3-4B-MedCaseReasoning-RL": """ +This model was trained on **MedCaseReasoning**, a dataset of clinical case analysis problems +evaluated using an LLM-as-judge approach for reasoning quality. +""", +} + + +def main(): + api = HfApi() + + for repo_id, info in MODELS.items(): + print(f"\n{'='*60}") + print(f"Uploading: {repo_id}") + print(f"From: {info['path']}") + print(f"{'='*60}") + + # Create model card + model_card = MODEL_CARD_TEMPLATE.format( + repo_name=repo_id.split("/")[1], + description=info["description"], + training_data=TRAINING_DATA[repo_id], + repo_id=repo_id, + ) + + # Write model card to the model directory + readme_path = os.path.join(info["path"], "README.md") + with open(readme_path, "w") as f: + f.write(model_card) + print(f"Created README.md") + + # Create the repo if it doesn't exist + try: + api.create_repo(repo_id, repo_type="model", exist_ok=True) + print(f"Created/verified repo: {repo_id}") + except Exception as e: + print(f"Repo creation note: {e}") + + # Upload the folder + print(f"Uploading files...") + upload_folder( + folder_path=info["path"], + repo_id=repo_id, + repo_type="model", + ignore_patterns=["STABLE"], # Skip the STABLE marker file + ) + print(f"✓ Successfully uploaded to https://huggingface.co/{repo_id}") + + print("\n" + "="*60) + print("ALL UPLOADS COMPLETE!") + print("="*60) + for repo_id in MODELS: + print(f" https://huggingface.co/{repo_id}") + + +if __name__ == "__main__": + main()