diff --git a/environments/eval_environments/vision_evals/README.md b/environments/eval_environments/vision_evals/README.md new file mode 100644 index 000000000..9e6d79191 --- /dev/null +++ b/environments/eval_environments/vision_evals/README.md @@ -0,0 +1,216 @@ +# Vision Evaluation Benchmarks + +This folder contains 27 vision and multimodal benchmarks for evaluating vision-language models. The implementations follow VLMEvalKit patterns where applicable and use the Atropos Eval class for consistent async evaluation. + +## Benchmarks + +| Benchmark | What it Tests | Dataset | Scoring | +|-----------|---------------|---------|---------| +| MMMU | Multi-discipline academic QA | MMMU/MMMU | MCQ accuracy | +| MMMU-Pro | Harder MMMU with 10 choices | MMMU/MMMU_Pro | MCQ accuracy | +| MMBench | General multimodal understanding | lmms-lab/MMBench | MCQ accuracy | +| MMStar | Expert-level multimodal QA | Lin-Chen/MMStar | MCQ accuracy | +| MMVet | Open-ended VLM capabilities | lmms-lab/MMVet | GPT scoring | +| MMVP | CLIP blindspot detection | MMVP/MMVP | MCQ accuracy | +| AI2D | Scientific diagram understanding | lmms-lab/ai2d | MCQ accuracy | +| BLINK | Visual perception tasks | BLINK-Benchmark/BLINK | MCQ accuracy | +| ChartQA | Chart question answering | ahmed-masry/ChartQA | Relaxed accuracy | +| CharXiv | Scientific chart understanding | princeton-nlp/CharXiv | GPT judge | +| CountBench | Object counting | nielsr/countbench | Numeric match | +| DocVQA | Document understanding | lmms-lab/DocVQA | ANLS score | +| DynaMath | Dynamic math reasoning | DynaMath/DynaMath_Sample | JSON extraction | +| HallusionBench | Visual hallucination detection | lmms-lab/HallusionBench | Yes/No accuracy | +| InfoVQA | Infographic QA | lmms-lab/InfoVQA | ANLS score | +| LogicVista | Visual logical reasoning | Yijia-Xiao/LogicVista | GPT extraction | +| MathVerse | Visual math (multi-version) | AI4Math/MathVerse | GPT extraction + scoring | +| MathVision | Visual math problems | MathLLMs/MathVision | GPT extraction | +| MathVista | Visual math reasoning | AI4Math/MathVista | GPT extraction | +| MMT-Bench | Multi-task multimodal | OpenGVLab/MMT-Bench | MCQ accuracy | +| OCRBench | OCR capabilities | echo840/OCRBench | Substring match | +| POPE | Object hallucination | lmms-lab/POPE | Yes/No accuracy | +| RealWorldQA | Real-world visual QA | xai-org/RealworldQA | Fuzzy match | +| SEED-Bench2 | Visual understanding | lmms-lab/SEED-Bench-2 | MCQ accuracy | +| VisuLogic | Visual logic puzzles | visulogic dataset | MCQ accuracy | +| VLMBlind | Basic visual perception | XAI/vlmsareblind | Task-specific | +| WeMath | Visual math with 4D metrics | We-Math/We-Math | 4D scoring | + +## Running an Evaluation + +All benchmarks use the same CLI pattern: + +```bash +python mmmu_environment.py \ + --model-name "gpt-4o" \ + --server-url "https://api.openai.com/v1" +``` + +For local models with vLLM or Ollama: + +```bash +python mmbench_environment.py \ + --model-name "Qwen/Qwen2-VL-7B-Instruct" \ + --server-url "http://localhost:8000/v1" +``` + +The evaluations use the `ServerManager` from atroposlib for making API calls. + +## Comparison with VLMEvalKit + +These implementations are aligned with VLMEvalKit where it makes sense, but simplified for standalone use. Here are the key differences and similarities: + +### Scoring Methods + +**ChartQA** uses relaxed accuracy with 5% tolerance. Percentages are converted to decimals before comparison (5% becomes 0.05). This matches VLMEvalKit behavior in `vqa_eval.py`. + +**DocVQA and InfoVQA** use ANLS (Average Normalized Levenshtein Similarity) with a 0.5 threshold. This is the standard metric from the original papers. + +**MathVista** uses GPT-based answer extraction with 5 in-context learning examples. There is a prefetch mechanism that tries regex first before calling GPT. The extraction prompt and ICL examples are taken from VLMEvalKit. + +**MathVerse** uses two-stage GPT evaluation. First GPT extracts the answer from the response, then GPT judges whether the extracted answer matches the ground truth. This matches the VLMEvalKit approach. + +**WeMath** computes 4-dimensional metrics beyond simple accuracy: +- IK (Insufficient Knowledge): wrong on steps AND wrong on multi +- IG (Inadequate Generalization): right on steps BUT wrong on multi +- CM (Complete Mastery): right on steps AND right on multi +- RM (Rote Memorization): wrong on steps BUT right on multi + +**MMVet** uses GPT to score open-ended responses on a 0-1 scale. Without an API key it falls back to substring matching. + +**OCRBench** uses category-specific scoring. For handwritten math expressions it compares without spaces. For other categories it does case-insensitive substring matching. + +### What We Changed + +**Simpler data loading**: We use HuggingFace datasets directly instead of VLMEvalKit's TSV preprocessing. This makes the code easier to understand but may load data slightly differently. + +**Async evaluation**: Everything runs async with tqdm progress bars. VLMEvalKit uses synchronous evaluation by default. + +**No circular evaluation**: VLMEvalKit supports "circular" MCQ evaluation where options are rotated and the model must get all rotations correct. We do not implement this, which means our MCQ scores may be slightly higher than VLMEvalKit on some benchmarks. + +**Unified CLI**: All benchmarks use the same `eval_runner` CLI instead of VLMEvalKit's `run.py` with config files. + +### Expected Score Differences + +Due to the differences above, you should expect: + +- MCQ benchmarks (MMMU, MMBench, MMStar, AI2D): Within 1-2% of VLMEvalKit +- VQA benchmarks (DocVQA, ChartQA): Very close, same scoring methods +- Math benchmarks (MathVista, MathVerse): Within 2-3%, depends on GPT extraction +- Open-ended (MMVet): Can vary more, depends on GPT judge prompts + +## Benchmark Details + +### General Multimodal Understanding + +**MMMU** tests multi-discipline academic knowledge across 30 subjects from accounting to physics. Questions require understanding images and domain knowledge. The validation split has about 900 questions. + +**MMMU-Pro** is a harder version with 10 answer choices instead of 4. It has three variants: standard (10 options), standard_4 (4 options), and vision (question in image). + +**MMBench** is a comprehensive benchmark covering perception, reasoning, and knowledge. It has English and Chinese versions. + +**MMStar** focuses on expert-level questions that require both visual understanding and specialized knowledge. + +**SEED-Bench2** tests visual understanding across many categories including scene understanding, instance identity, and spatial relations. The dataset is large (24k samples) so we stream by default and limit to 1000 samples. + +**MMT-Bench** is a multi-task benchmark covering 32 different task types. Good for testing breadth of capabilities. + +### Document and Chart Understanding + +**DocVQA** tests understanding of document images like forms, receipts, and scientific papers. Uses ANLS scoring which allows for minor OCR errors. + +**InfoVQA** is similar to DocVQA but focuses on infographics with more complex layouts. + +**ChartQA** tests chart reading. Has human and augmented subsets. The human subset is harder. Uses relaxed accuracy (5% tolerance for numbers). + +**CharXiv** focuses on scientific charts from arXiv papers. Uses GPT as a judge with grading queries from the dataset. + +**OCRBench** tests pure OCR capabilities across 10 categories from regular text to handwritten math expressions. + +### Math and Reasoning + +**MathVista** is a visual math benchmark with multiple question types (free form, multiple choice) and answer types (integer, float, text, list). Uses the dataset's built-in query prompts. + +**MathVerse** has problems at different visual complexity levels from "text dominant" to "vision only". Uses two-stage GPT evaluation. + +**MathVision** is another visual math benchmark. Uses GPT extraction with fallback to regex. + +**DynaMath** tests dynamic math reasoning with JSON-formatted outputs. Has subject and difficulty level breakdowns. + +**WeMath** provides detailed 4D metrics to understand where models fail. Good for diagnosing reasoning vs memorization issues. + +**LogicVista** tests visual logical reasoning with 5 skill types. Supports multi-letter answers where multiple options can be correct. + +**VisuLogic** tests visual logic with diagram-based puzzles. + +### Perception and Hallucination + +**POPE** tests object hallucination with yes/no questions about whether objects exist in images. Has random, popular, and adversarial variants. + +**HallusionBench** tests visual hallucinations more broadly. Questions are designed to trick models into seeing things that are not there. + +**MMVP** tests visual perception on cases where CLIP-based models tend to fail. Useful for understanding encoder limitations. + +**BLINK** tests basic visual perception like counting, spatial relations, and similarity. Models often struggle on these "easy" tasks. + +**VLMBlind** (VLMs Are Blind) tests very basic visual tasks that humans find trivial but VLMs often fail. Includes counting grid cells, finding circled letters, and counting Olympic rings. + +**CountBench** is a simple object counting benchmark. + +### Real World + +**RealWorldQA** tests understanding of real-world images from XAI. Uses fuzzy matching for answers. + +**AI2D** tests understanding of scientific diagrams from AI2 (Allen Institute). Good for testing diagram reasoning. + +## GPT Judge Configuration + +Several benchmarks use GPT for answer extraction or scoring. To enable this: + +```bash +export OPENAI_API_KEY="your-key" +``` + +You can also configure the judge model when instantiating: + +```python +eval_env = MathVista( + use_gpt_extraction=True, + judge_model="gpt-4o-mini", + judge_base_url="https://api.openai.com/v1", +) +asyncio.run(eval_runner(eval_env)) +``` + +Without an API key, benchmarks fall back to regex-based extraction which is less accurate but free. + +## Output Format + +Results are saved to the eval directory: + +``` +eval_results/ + metrics.json # Overall scores + samples.jsonl # Per-item predictions +``` + +The metrics.json file contains accuracy and other metrics depending on the benchmark. The samples.jsonl file has one line per question with the prediction, answer, and whether it was correct. + +## Adding New Benchmarks + +To add a new vision benchmark: + +1. Create a new file like `new_benchmark_environment.py` +2. Inherit from `EvalBase` +3. Implement `setup_data()` to load the dataset +4. Implement `run_item(self, server: ServerManager, data_item: dict)` to process one item +5. Use `await self.chat_completion(server, messages)` for API calls +6. Add image encoding using the standard `encode_image()` pattern + +See any existing benchmark for a template. The MMMU implementation is a good starting point for MCQ benchmarks. DocVQA is a good template for VQA benchmarks. + +## References + +- VLMEvalKit: https://github.com/open-compass/VLMEvalKit +- OpenVLM Leaderboard: https://huggingface.co/spaces/opencompass/open_vlm_leaderboard +- MMMU: https://mmmu-benchmark.github.io/ +- MathVista: https://mathvista.github.io/ +- DocVQA: https://www.docvqa.org/ diff --git a/environments/eval_environments/vision_evals/ai2d_environment.py b/environments/eval_environments/vision_evals/ai2d_environment.py new file mode 100644 index 000000000..ea85e0ec4 --- /dev/null +++ b/environments/eval_environments/vision_evals/ai2d_environment.py @@ -0,0 +1,159 @@ +"""AI2D (AI2 Diagrams) evaluation environment.""" + +import asyncio +import base64 +import io +from string import ascii_uppercase +from typing import List, Optional, Tuple + +from datasets import load_dataset +from environments.eval_environments.eval import EvalBase, eval_runner +from PIL import Image + +from atroposlib.envs.server_handling.server_manager import ServerManager +from environments.eval_environments.eval_helpers import ( + extract_letter_from_answer_tag, + extract_mcqa_answer_with_fallback, +) + + +class AI2D(EvalBase): + """AI2D evaluation - diagram understanding benchmark.""" + + def setup_data(self) -> list: + split = getattr(self, "split", "test") + use_mask = getattr(self, "use_mask", True) + + try: + dataset = load_dataset("lmms-lab/ai2d", split=split) + print(f"Loaded {len(dataset)} examples from AI2D ({split})") + return list(dataset) + except Exception as e: + print(f"Warning: Could not load AI2D: {e}") + try: + dataset = load_dataset("allenai/ai2_diagrams", split=split) + print(f"Loaded {len(dataset)} examples from AI2D ({split})") + return list(dataset) + except Exception: + raise ValueError(f"Could not load AI2D dataset: {e}") + + def encode_image(self, pil_image: Image.Image) -> str: + buffer = io.BytesIO() + pil_image.save(buffer, format="PNG") + return base64.b64encode(buffer.getvalue()).decode("utf-8") + + def get_image_base64(self, item: dict) -> Optional[str]: + for key in ["image", "decoded_image"]: + if key in item and item[key] is not None: + if isinstance(item[key], Image.Image): + return self.encode_image(item[key]) + return None + + def build_messages(self, item: dict) -> List[dict]: + image_base64 = self.get_image_base64(item) + question = item.get("question", "") + + choices = item.get("choices", []) + if isinstance(choices, str): + try: + choices = eval(choices) + except Exception: + choices = [] + + options = {} + if choices: + for i, choice in enumerate(choices): + options[ascii_uppercase[i]] = choice + else: + for letter in ascii_uppercase[:6]: + if letter in item and item[letter] is not None: + val = item[letter] + if isinstance(val, str) and val.strip(): + options[letter] = val + + prompt = f"Question: {question}\n" + if options: + prompt += "Options:\n" + for letter in sorted(options.keys()): + prompt += f"{letter}. {options[letter]}\n" + prompt += "\nPlease select the correct answer from the options above." + + content = [] + if image_base64: + content.append( + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{image_base64}"}, + } + ) + content.append({"type": "text", "text": prompt}) + + return [{"role": "user", "content": content}] + + def extract_answer( + self, response: str, num_choices: int + ) -> Tuple[Optional[str], str]: + valid_letters = set(ascii_uppercase[:num_choices]) + + letter, method = extract_letter_from_answer_tag(response, valid_letters) + if letter: + return letter, method + + letter, method = extract_mcqa_answer_with_fallback(response, num_choices) + return letter, method + + async def run_item( + self, server: ServerManager, data_item: dict + ) -> Tuple[dict, dict]: + try: + messages = self.build_messages(data_item) + completion = await self.chat_completion(server, messages) + + if not completion.choices: + return {"accuracy": 0.0}, {"error": "Empty response"} + + message = completion.choices[0].message + response = message.content or "" + + if not response: + return {"accuracy": 0.0}, {"error": "Empty response"} + + answer = data_item.get("answer", "") + + choices = data_item.get("choices", []) + if isinstance(choices, str): + try: + choices = eval(choices) + except Exception: + choices = [] + + num_choices = len(choices) if choices else 4 + + extracted, method = self.extract_answer(response, num_choices) + + correct = False + if extracted and answer: + if str(answer).isdigit(): + answer_letter = ascii_uppercase[int(answer)] + else: + answer_letter = str(answer).upper() + correct = extracted.upper() == answer_letter + + sample = { + "id": data_item.get("index", data_item.get("id", "")), + "question": data_item.get("question", "")[:200], + "answer": answer, + "prediction": extracted, + "raw_response": response[:500], + "correct": correct, + "extraction_method": method, + } + + return {"accuracy": 1.0 if correct else 0.0}, sample + + except Exception as e: + return {"accuracy": 0.0}, {"error": str(e)} + + +if __name__ == "__main__": + asyncio.run(eval_runner(AI2D(split="test", temperature=0.0, max_tokens=256))) diff --git a/environments/eval_environments/vision_evals/blink_environment.py b/environments/eval_environments/vision_evals/blink_environment.py new file mode 100644 index 000000000..986b26abb --- /dev/null +++ b/environments/eval_environments/vision_evals/blink_environment.py @@ -0,0 +1,170 @@ +"""BLINK evaluation environment.""" + +import asyncio +import base64 +import io +from string import ascii_uppercase +from typing import List, Optional, Tuple + +from datasets import load_dataset +from environments.eval_environments.eval import EvalBase, eval_runner +from PIL import Image + +from atroposlib.envs.server_handling.server_manager import ServerManager +from environments.eval_environments.eval_helpers import ( + extract_letter_from_answer_tag, + extract_mcqa_answer_with_fallback, +) + + +class BLINK(EvalBase): + """BLINK evaluation - visual perception benchmark.""" + + def setup_data(self) -> list: + split = getattr(self, "split", "val") + task = getattr(self, "task", "Counting") # One of the BLINK task categories + + try: + dataset = load_dataset("BLINK-Benchmark/BLINK", task, split=split) + print(f"Loaded {len(dataset)} examples from BLINK ({split}, {task})") + return list(dataset) + except Exception as e: + print(f"Warning: Could not load BLINK: {e}") + try: + tasks = [ + "Counting", + "Spatial_Relation", + "Object_Localization", + "Visual_Similarity", + ] + all_data = [] + for t in tasks: + try: + ds = load_dataset("BLINK-Benchmark/BLINK", t, split=split) + for item in ds: + item["task"] = t + all_data.append(item) + except Exception: + pass + if all_data: + print(f"Loaded {len(all_data)} examples from BLINK ({split})") + return all_data + raise ValueError(f"Could not load BLINK dataset: {e}") + except Exception: + raise ValueError(f"Could not load BLINK dataset: {e}") + + def encode_image(self, pil_image: Image.Image) -> str: + buffer = io.BytesIO() + pil_image.save(buffer, format="PNG") + return base64.b64encode(buffer.getvalue()).decode("utf-8") + + def get_images(self, item: dict) -> List[str]: + """Get all images from item (BLINK can have multiple images).""" + images = [] + for i in range(1, 5): + key = f"image_{i}" + if key in item and item[key] is not None: + if isinstance(item[key], Image.Image): + images.append(self.encode_image(item[key])) + + if not images and "image" in item and item["image"] is not None: + if isinstance(item["image"], Image.Image): + images.append(self.encode_image(item["image"])) + + return images + + def build_messages(self, item: dict) -> List[dict]: + images = self.get_images(item) + question = item.get("question", "") + + options = {} + for letter in ascii_uppercase[:6]: + if letter in item and item[letter] is not None: + val = item[letter] + if isinstance(val, str) and val.strip(): + options[letter] = val + + prompt = f"Question: {question}\n" + if options: + prompt += "Options:\n" + for letter in sorted(options.keys()): + prompt += f"{letter}. {options[letter]}\n" + prompt += "\nPlease select the correct answer from the options above." + + content = [] + for img_b64 in images: + content.append( + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{img_b64}"}, + } + ) + content.append({"type": "text", "text": prompt}) + + return [{"role": "user", "content": content}] + + def extract_answer( + self, response: str, num_choices: int + ) -> Tuple[Optional[str], str]: + valid_letters = set(ascii_uppercase[:num_choices]) + + letter, method = extract_letter_from_answer_tag(response, valid_letters) + if letter: + return letter, method + + letter, method = extract_mcqa_answer_with_fallback(response, num_choices) + return letter, method + + async def run_item( + self, server: ServerManager, data_item: dict + ) -> Tuple[dict, dict]: + try: + messages = self.build_messages(data_item) + completion = await self.chat_completion(server, messages) + + if not completion.choices: + return {"accuracy": 0.0}, {"error": "Empty response"} + + message = completion.choices[0].message + response = message.content or "" + + if not response: + return {"accuracy": 0.0}, {"error": "Empty response"} + + answer = data_item.get("answer", "") + + num_choices = sum( + 1 + for letter in ascii_uppercase[:6] + if letter in data_item + and data_item[letter] is not None + and isinstance(data_item[letter], str) + and data_item[letter].strip() + ) + num_choices = max(num_choices, 4) + + extracted, method = self.extract_answer(response, num_choices) + + correct = False + if extracted and answer: + correct = extracted.upper() == str(answer).upper() + + sample = { + "id": data_item.get("index", data_item.get("id", "")), + "question": data_item.get("question", "")[:200], + "category": data_item.get("category", ""), + "answer": answer, + "prediction": extracted, + "raw_response": response[:500], + "correct": correct, + "extraction_method": method, + } + + return {"accuracy": 1.0 if correct else 0.0}, sample + + except Exception as e: + return {"accuracy": 0.0}, {"error": str(e)} + + +if __name__ == "__main__": + asyncio.run(eval_runner(BLINK(split="val", temperature=0.0, max_tokens=256))) diff --git a/environments/eval_environments/vision_evals/chartqa_environment.py b/environments/eval_environments/vision_evals/chartqa_environment.py new file mode 100644 index 000000000..8d44a970d --- /dev/null +++ b/environments/eval_environments/vision_evals/chartqa_environment.py @@ -0,0 +1,192 @@ +"""ChartQA evaluation environment.""" + +import asyncio +import base64 +import io +import re +from pathlib import Path +from typing import List, Optional, Tuple + +from datasets import load_dataset +from environments.eval_environments.eval import EvalBase, eval_runner +from PIL import Image + +from atroposlib.envs.server_handling.server_manager import ServerManager + + +class ChartQA(EvalBase): + """ + ChartQA evaluation environment. + + A benchmark for question answering about charts with relaxed accuracy scoring. + """ + + def setup_data(self) -> list: + subset = getattr(self, "subset", "human") + dataset = load_dataset("ahmed-masry/ChartQA", split="test") + + if subset == "human": + dataset = dataset.filter(lambda x: x.get("type", "") == "human") + elif subset == "augmented": + dataset = dataset.filter(lambda x: x.get("type", "") == "augmented") + + print(f"Loaded {len(dataset)} examples from ChartQA ({subset})") + return list(dataset) + + def encode_image(self, pil_image: Image.Image) -> str: + buffer = io.BytesIO() + pil_image.save(buffer, format="PNG") + return base64.b64encode(buffer.getvalue()).decode("utf-8") + + def get_image_base64(self, item: dict) -> str: + images_path: Optional[str] = getattr(self, "images_path", None) + if images_path: + imgname = item.get("imgname", "") + image_path = Path(images_path) / imgname + with open(image_path, "rb") as f: + return base64.b64encode(f.read()).decode("utf-8") + + if "image" in item and item["image"] is not None: + img = item["image"] + if isinstance(img, bytes): + return base64.b64encode(img).decode("utf-8") + elif isinstance(img, Image.Image): + return self.encode_image(img) + else: + raise ValueError(f"Unknown image type: {type(img)}") + + raise ValueError("Could not find image for item") + + def build_messages(self, item: dict) -> List[dict]: + image_base64 = self.get_image_base64(item) + query = item.get("query", item.get("question", "")) + + prompt = f"""Answer this question about the chart. Provide only the answer, nothing else. + +Question: {query}""" + + return [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{image_base64}"}, + }, + {"type": "text", "text": prompt}, + ], + } + ] + + def extract_answer(self, response: str) -> str: + response = response.strip() + + patterns = [ + r"(?:the answer is|answer:)\s*(.+?)(?:\.|$)", + r"^(\d+[\d,\.]*%?)$", + r"^(yes|no)$", + ] + + for pattern in patterns: + match = re.search(pattern, response, re.IGNORECASE) + if match: + return match.group(1).strip() + + if len(response.split()) <= 5: + return response + + first_line = response.split("\n")[0] + return first_line.strip() + + def _to_float(self, text: str) -> Optional[float]: + """ + Convert string to float, handling percentages. + + Following VLMEvalKit: percentages are converted to decimals (5% -> 0.05). + """ + text = str(text).strip() + try: + # Remove commas and dollar signs + text = text.replace(",", "").replace("$", "") + if text.endswith("%"): + # Convert percentage to decimal (VLMEvalKit behavior) + return float(text.rstrip("%")) / 100.0 + else: + return float(text) + except ValueError: + return None + + def score_relaxed(self, prediction: str, answer: str) -> bool: + """ + Calculate relaxed correctness following VLMEvalKit. + + For numeric answers: allows 5% relative tolerance. + For non-numeric answers: exact match (case-insensitive). + + Reference: https://arxiv.org/pdf/2203.10244.pdf, section 5.1 + """ + pred = str(prediction).strip() + ans = str(answer).strip() + + relaxed_tolerance = getattr(self, "relaxed_tolerance", 0.05) + + pred_float = self._to_float(pred) + ans_float = self._to_float(ans) + + if pred_float is not None and ans_float is not None: + if ans_float == 0: + return abs(pred_float) < 1e-6 + relative_change = abs(pred_float - ans_float) / abs(ans_float) + return relative_change <= relaxed_tolerance + + # Non-numeric: exact match (case-insensitive) + return pred.lower() == ans.lower() + + async def run_item( + self, server: ServerManager, data_item: dict + ) -> Tuple[dict, dict]: + try: + messages = self.build_messages(data_item) + + completion = await self.chat_completion(server, messages) + + if not completion.choices: + return {"accuracy": 0.0}, {"error": "Empty response"} + + message = completion.choices[0].message + response = message.content or "" + if hasattr(message, "reasoning") and message.reasoning and not response: + response = message.reasoning + if not response and hasattr(message, "model_extra"): + reasoning = message.model_extra.get("reasoning", "") + if reasoning: + response = reasoning + + if not response: + return {"accuracy": 0.0}, {"error": "Empty response"} + + extracted = self.extract_answer(response) + answer = data_item.get("label", data_item.get("answer", "")) + correct = self.score_relaxed(extracted, answer) + + sample = { + "question": data_item.get("query", data_item.get("question", "")), + "answer": answer, + "prediction": extracted, + "correct": correct, + } + + return {"accuracy": 1.0 if correct else 0.0}, sample + + except Exception as e: + return {"accuracy": 0.0}, {"error": str(e)} + + +if __name__ == "__main__": + asyncio.run( + eval_runner( + ChartQA( + subset="human", relaxed_tolerance=0.05, temperature=0.0, max_tokens=2048 + ) + ) + ) diff --git a/environments/eval_environments/vision_evals/charxiv_environment.py b/environments/eval_environments/vision_evals/charxiv_environment.py new file mode 100644 index 000000000..bacb16b12 --- /dev/null +++ b/environments/eval_environments/vision_evals/charxiv_environment.py @@ -0,0 +1,382 @@ +"""CharXiv evaluation environment.""" + +import asyncio +import base64 +import io +import json +import os +import re +from typing import Dict, List, Optional, Tuple + +import openai +from datasets import load_dataset +from environments.eval_environments.eval import EvalBase, eval_runner +from PIL import Image + +from atroposlib.envs.server_handling.server_manager import ServerManager + +DESCRIPTIVE_CATEGORIES = { + 1: "Information Extraction", + 2: "Information Extraction", + 3: "Information Extraction", + 4: "Information Extraction", + 5: "Information Extraction", + 6: "Information Extraction", + 7: "Information Extraction", + 8: "Enumeration", + 9: "Enumeration", + 10: "Counting", + 11: "Pattern Recognition", + 12: "Counting", + 13: "Enumeration", + 14: "Enumeration", + 15: "Enumeration", + 16: "Pattern Recognition", + 17: "Compositionality", + 18: "Pattern Recognition", + 19: "Counting", +} + +REASONING_CATEGORIES = { + 1: "Text-in-Chart", + 2: "Text-in-General", + 3: "Number-in-Chart", + 4: "Number-in-General", +} + +DESCRIPTIVE_QUESTIONS = { + 1: "What is the title of the chart?", + 2: "What is the label of the x-axis?", + 3: "What is the label of the y-axis?", + 4: "What is the leftmost labeled tick on the x-axis?", + 5: "What is the rightmost labeled tick on the x-axis?", + 6: "What is the spatially lowest labeled tick on the y-axis?", + 7: "What is the spatially highest labeled tick on the y-axis?", + 8: "What are all the labels in the legend?", + 9: "List all the categories in the x-axis.", + 10: "How many distinct bars are there?", + 11: "Does the chart contain a grid?", + 12: "How many lines are there in the chart?", + 13: "Is there a legend in the chart?", + 14: "What are the names of the curves in the chart?", + 15: "Does the chart contain horizontal bars?", + 16: "Do the bars have error bars?", + 17: "Describe the general trend of the chart.", + 18: "Is there any point emphasized/highlighted in the chart?", + 19: "How many sections does the pie chart have?", +} + +GRADING_QUERY_TEMPLATE = """You are evaluating a model's answer to a chart understanding question. + +Question: {question} +Ground Truth Answer: {answer} +Model's Answer: {prediction} + +Please evaluate whether the model's answer is correct or partially correct. +Consider semantic equivalence - different phrasings that mean the same thing should be considered correct. +For numerical answers, exact matches or very close values should be considered correct. +For yes/no questions, the meaning should match the ground truth. +For enumeration questions (listing items), all items should be present regardless of order. + +Respond with a JSON object containing: +- "extract_answer": The key answer extracted from the model's response +- "score": A float from 0.0 to 1.0 indicating correctness (0.0 = wrong, 0.5 = partial, 1.0 = correct) + +Example response: {{"extract_answer": "60", "score": 1.0}}""" + + +class CharXiv(EvalBase): + MODES = ["descriptive", "reasoning"] + + def setup_data(self) -> list: + mode = getattr(self, "mode", "descriptive") + split = getattr(self, "split", "validation") + + dataset = load_dataset("princeton-nlp/CharXiv", "default", split=split) + + data = [] + for item in dataset: + if mode == "descriptive": + for i in range(1, 5): + q_key = f"descriptive_q{i}" + a_key = f"descriptive_a{i}" + if a_key in item and item[a_key]: + template_id = item.get(q_key, i) + if ( + isinstance(template_id, int) + and template_id in DESCRIPTIVE_QUESTIONS + ): + question = DESCRIPTIVE_QUESTIONS[template_id] + else: + question = ( + str(template_id) + if template_id + else f"Descriptive question {i}" + ) + + data.append( + { + "image": item["image"], + "question": question, + "answer": item[a_key], + "qid": ( + template_id if isinstance(template_id, int) else i + ), + "category": item.get("category", ""), + "grading_query": item.get("grading_query", ""), + } + ) + elif mode == "reasoning": + if "reasoning_q" in item and item.get("reasoning_a"): + data.append( + { + "image": item["image"], + "question": item["reasoning_q"], + "answer": item["reasoning_a"], + "inst_category": item.get("category", 1), + "grading_query": item.get("grading_query", ""), + } + ) + else: + raise ValueError( + f"Invalid mode: {mode}. Must be 'descriptive' or 'reasoning'." + ) + + print(f"Loaded {len(data)} examples from CharXiv ({mode}, {split})") + return data + + def encode_image(self, pil_image: Image.Image) -> str: + buffer = io.BytesIO() + pil_image.save(buffer, format="PNG") + return base64.b64encode(buffer.getvalue()).decode("utf-8") + + def get_image_base64(self, item: dict) -> str: + if "image" in item and item["image"] is not None: + if isinstance(item["image"], Image.Image): + return self.encode_image(item["image"]) + raise ValueError("Could not find image for item") + + def build_messages(self, item: dict) -> List[dict]: + image_base64 = self.get_image_base64(item) + question = item.get("question", "") + + return [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{image_base64}"}, + }, + {"type": "text", "text": question}, + ], + } + ] + + async def _judge_with_gpt( + self, question: str, answer: str, prediction: str, grading_query: str = "" + ) -> Tuple[Optional[str], float]: + judge_model = getattr(self, "judge_model", "gpt-4o-mini") + judge_base_url = getattr(self, "judge_base_url", "https://api.openai.com/v1") + judge_api_key = os.environ.get( + getattr(self, "judge_api_key_env", "OPENAI_API_KEY"), "" + ) + + if not judge_api_key: + return None, 0.0 + + if grading_query: + prompt = grading_query.replace("{PREDICTION}", prediction) + else: + prompt = GRADING_QUERY_TEMPLATE.format( + question=question, answer=answer, prediction=prediction + ) + + try: + judge_client = openai.AsyncOpenAI( + api_key=judge_api_key, + base_url=judge_base_url, + ) + + completion = await judge_client.chat.completions.create( + model=judge_model, + messages=[{"role": "user", "content": prompt}], + temperature=0.0, + max_tokens=256, + ) + + response = completion.choices[0].message.content.strip() + + try: + result = json.loads(response) + if isinstance(result, dict): + extract_answer = result.get("extract_answer", "") + score = float(result.get("score", 0.0)) + return extract_answer, score + except json.JSONDecodeError: + json_match = re.search(r"\{[^}]+\}", response) + if json_match: + try: + result = json.loads(json_match.group()) + extract_answer = result.get("extract_answer", "") + score = float(result.get("score", 0.0)) + return extract_answer, score + except (json.JSONDecodeError, ValueError): + pass + + return None, 0.0 + + except Exception as e: + print(f"GPT judge error: {e}") + return None, 0.0 + + def _fallback_score( + self, prediction: str, answer: str, mode: str + ) -> Tuple[str, float]: + prediction = prediction.strip().lower() + answer = answer.strip().lower() + + if not prediction: + return "", 0.0 + + if mode == "reasoning": + if answer in prediction: + return prediction, 1.0 + try: + pred_nums = re.findall(r"-?\d+\.?\d*", prediction) + ans_nums = re.findall(r"-?\d+\.?\d*", answer) + if pred_nums and ans_nums: + for p in pred_nums: + for a in ans_nums: + if abs(float(p) - float(a)) < 0.01: + return prediction, 1.0 + except ValueError: + pass + return prediction, 0.0 + + else: + pred_words = set(prediction.split()) + ans_words = set(answer.split()) + if not ans_words: + return prediction, 0.0 + overlap = len(pred_words & ans_words) / len(ans_words) + return prediction, min(overlap, 1.0) + + def get_category(self, item: dict, mode: str) -> str: + if mode == "descriptive": + qid = item.get("qid", 1) + return DESCRIPTIVE_CATEGORIES.get(qid, "Information Extraction") + else: + inst_category = item.get("inst_category", 1) + return REASONING_CATEGORIES.get(inst_category, "Text-in-Chart") + + async def run_item( + self, server: ServerManager, data_item: dict + ) -> Tuple[dict, dict]: + try: + messages = self.build_messages(data_item) + mode = getattr(self, "mode", "descriptive") + + completion = await self.chat_completion(server, messages) + + if not completion.choices: + return {"accuracy": 0.0, "score": 0.0}, {"error": "Empty response"} + + message = completion.choices[0].message + response = message.content or "" + if hasattr(message, "reasoning") and message.reasoning and not response: + response = message.reasoning + if not response and hasattr(message, "model_extra"): + reasoning = message.model_extra.get("reasoning", "") + if reasoning: + response = reasoning + + if not response: + return {"accuracy": 0.0, "score": 0.0}, {"error": "Empty response"} + + use_gpt_judge = getattr(self, "use_gpt_judge", True) + grading_query = data_item.get("grading_query", "") + answer = data_item.get("answer", "") + question = data_item.get("question", "") + + if use_gpt_judge: + extracted, score = await self._judge_with_gpt( + question, answer, response, grading_query + ) + evaluation_method = "gpt_judge" + else: + extracted, score = self._fallback_score(response, answer, mode) + evaluation_method = "fallback" + + if extracted is None: + extracted, score = self._fallback_score(response, answer, mode) + evaluation_method = "fallback" + + category = self.get_category(data_item, mode) + + sample = { + "question": data_item.get("question", ""), + "answer": answer, + "prediction": response[:500], + "extract_answer": extracted, + "score": score, + "category": category, + "mode": mode, + "qid": data_item.get("qid", data_item.get("inst_category", "")), + "evaluation_method": evaluation_method, + } + + return {"accuracy": score, "score": score}, sample + + except Exception as e: + return {"accuracy": 0.0, "score": 0.0}, {"error": str(e)} + + +def compute_category_metrics(samples: List[dict]) -> Dict: + from collections import defaultdict + + scores_by_category = defaultdict(list) + + for sample in samples: + if "error" in sample: + continue + category = sample.get("category", "Unknown") + score = sample.get("score", 0.0) + scores_by_category[category].append(score) + + metrics = {} + total_score = 0.0 + total_count = 0 + + for category, scores in scores_by_category.items(): + if scores: + avg_score = sum(scores) / len(scores) + metrics[category] = { + "count": len(scores), + "average_score": avg_score, + } + total_score += sum(scores) + total_count += len(scores) + + if total_count > 0: + metrics["Overall"] = { + "count": total_count, + "average_score": total_score / total_count, + } + + return metrics + + +if __name__ == "__main__": + asyncio.run( + eval_runner( + CharXiv( + mode="descriptive", # or "reasoning" + split="validation", + use_gpt_judge=True, + judge_model="gpt-4o-mini", + temperature=0.0, + max_tokens=1024, + ) + ) + ) diff --git a/environments/eval_environments/vision_evals/countbench_environment.py b/environments/eval_environments/vision_evals/countbench_environment.py new file mode 100644 index 000000000..1fc673ab1 --- /dev/null +++ b/environments/eval_environments/vision_evals/countbench_environment.py @@ -0,0 +1,137 @@ +"""CountBench evaluation environment.""" + +import asyncio +import base64 +import io +import re +from typing import List, Optional, Tuple + +from datasets import load_dataset +from environments.eval_environments.eval import EvalBase, eval_runner +from PIL import Image + +from atroposlib.envs.server_handling.server_manager import ServerManager + + +class CountBench(EvalBase): + """CountBench evaluation - object counting benchmark.""" + + def setup_data(self) -> list: + split = getattr(self, "split", "train") # CountBench only has train split + + try: + dataset = load_dataset("nielsr/countbench", split=split) + print(f"Loaded {len(dataset)} examples from CountBench ({split})") + return list(dataset) + except Exception as e: + print(f"Warning: Could not load CountBench: {e}") + try: + # Try train split explicitly + dataset = load_dataset("nielsr/countbench", split="train") + print(f"Loaded {len(dataset)} examples from CountBench (train)") + return list(dataset) + except Exception: + try: + dataset = load_dataset( + "google-research/countbenchqa", split="train" + ) + print(f"Loaded {len(dataset)} examples from CountBench (train)") + return list(dataset) + except Exception: + raise ValueError(f"Could not load CountBench dataset: {e}") + + def encode_image(self, pil_image: Image.Image) -> str: + buffer = io.BytesIO() + pil_image.save(buffer, format="PNG") + return base64.b64encode(buffer.getvalue()).decode("utf-8") + + def get_image_base64(self, item: dict) -> Optional[str]: + for key in ["image", "decoded_image"]: + if key in item and item[key] is not None: + if isinstance(item[key], Image.Image): + return self.encode_image(item[key]) + return None + + def build_messages(self, item: dict) -> List[dict]: + image_base64 = self.get_image_base64(item) + question = item.get("question", "") + + prompt = f"{question}\n\nNote: Answer with a number directly, e.g., 3. Do not include any additional text." + + content = [] + if image_base64: + content.append( + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{image_base64}"}, + } + ) + content.append({"type": "text", "text": prompt}) + + return [{"role": "user", "content": content}] + + def extract_number(self, response: str) -> Optional[str]: + """Extract a number from the response.""" + numbers = re.findall(r"\b(\d+)\b", response) + if numbers: + return numbers[0] + return None + + def score(self, prediction: str, answer: str) -> bool: + """Score counting answer - check if answer appears in prediction.""" + answer_str = str(answer).strip() + + if answer_str in prediction: + return True + + extracted = self.extract_number(prediction) + if extracted and extracted == answer_str: + return True + + try: + pred_num = int(self.extract_number(prediction) or prediction.strip()) + ans_num = int(answer_str) + return pred_num == ans_num + except (ValueError, TypeError): + pass + + return False + + async def run_item( + self, server: ServerManager, data_item: dict + ) -> Tuple[dict, dict]: + try: + messages = self.build_messages(data_item) + completion = await self.chat_completion(server, messages) + + if not completion.choices: + return {"accuracy": 0.0}, {"error": "Empty response"} + + message = completion.choices[0].message + response = message.content or "" + + if not response: + return {"accuracy": 0.0}, {"error": "Empty response"} + + answer = data_item.get("answer", data_item.get("number", "")) + + correct = self.score(response, answer) + extracted = self.extract_number(response) + + sample = { + "id": data_item.get("index", data_item.get("id", "")), + "question": data_item.get("question", "")[:200], + "answer": answer, + "prediction": extracted or response[:50], + "raw_response": response[:200], + "correct": correct, + } + + return {"accuracy": 1.0 if correct else 0.0}, sample + + except Exception as e: + return {"accuracy": 0.0}, {"error": str(e)} + + +if __name__ == "__main__": + asyncio.run(eval_runner(CountBench(split="test", temperature=0.0, max_tokens=64))) diff --git a/environments/eval_environments/vision_evals/docvqa_environment.py b/environments/eval_environments/vision_evals/docvqa_environment.py new file mode 100644 index 000000000..ff341cda3 --- /dev/null +++ b/environments/eval_environments/vision_evals/docvqa_environment.py @@ -0,0 +1,192 @@ +import asyncio +import base64 +import io +import re +from typing import List, Tuple + +from datasets import load_dataset +from environments.eval_environments.eval import EvalBase, eval_runner +from PIL import Image + +from atroposlib.envs.server_handling.server_manager import ServerManager + + +class DocVQA(EvalBase): + QUESTION_TYPES = [ + "figure/diagram", + "layout", + "table/list", + "Image/Photo", + "handwritten", + "form", + "free_text", + "others", + ] + + def setup_data(self) -> list: + # Note: test split has hidden answers (for server evaluation) + # Use validation for local evaluation + split = getattr(self, "split", "validation") + dataset = load_dataset("lmms-lab/DocVQA", "DocVQA", split=split) + print(f"Loaded {len(dataset)} examples from DocVQA ({split})") + return list(dataset) + + def encode_image(self, pil_image: Image.Image) -> str: + buffer = io.BytesIO() + pil_image.save(buffer, format="PNG") + return base64.b64encode(buffer.getvalue()).decode("utf-8") + + def get_image_base64(self, item: dict) -> str: + if "image" in item and item["image"] is not None: + if isinstance(item["image"], Image.Image): + return self.encode_image(item["image"]) + raise ValueError( + f"Could not find image for item {item.get('questionId', 'unknown')}" + ) + + def build_messages(self, item: dict) -> List[dict]: + image_base64 = self.get_image_base64(item) + question = item.get("question", "") + + prompt = f"""Look at the document and answer the question. + +Question: {question} + +Provide only the answer, as concisely as possible.""" + + return [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{image_base64}"}, + }, + {"type": "text", "text": prompt}, + ], + } + ] + + def extract_answer(self, response: str) -> str: + response = response.strip() + + patterns = [ + r"answer[:\s]+(.+?)(?:\.|$)", + r"\"([^\"]+)\"", + ] + + for pattern in patterns: + match = re.search(pattern, response, re.IGNORECASE) + if match: + return match.group(1).strip() + + lines = response.split("\n") + if lines: + return lines[-1].strip() + + return response + + def normalize_text(self, text: str) -> str: + text = text.lower().strip() + text = re.sub(r"[^\w\s]", "", text) + text = " ".join(text.split()) + return text + + def anls_score( + self, prediction: str, answers: List[str], threshold: float = 0.5 + ) -> float: + """ + Calculate Average Normalized Levenshtein Similarity (ANLS). + This is the standard metric for DocVQA. + """ + pred_norm = self.normalize_text(prediction) + + if not pred_norm: + return 0.0 + + max_score = 0.0 + for answer in answers: + ans_norm = self.normalize_text(answer) + if not ans_norm: + continue + + if pred_norm == ans_norm: + max_score = 1.0 + break + + distance = self._levenshtein_distance(pred_norm, ans_norm) + max_len = max(len(pred_norm), len(ans_norm)) + nls = 1 - distance / max_len if max_len > 0 else 0 + + if nls >= threshold: + max_score = max(max_score, nls) + + return max_score + + def _levenshtein_distance(self, s1: str, s2: str) -> int: + if len(s1) < len(s2): + return self._levenshtein_distance(s2, s1) + + if len(s2) == 0: + return len(s1) + + previous_row = range(len(s2) + 1) + for i, c1 in enumerate(s1): + current_row = [i + 1] + for j, c2 in enumerate(s2): + insertions = previous_row[j + 1] + 1 + deletions = current_row[j] + 1 + substitutions = previous_row[j] + (c1 != c2) + current_row.append(min(insertions, deletions, substitutions)) + previous_row = current_row + + return previous_row[-1] + + async def run_item( + self, server: ServerManager, data_item: dict + ) -> Tuple[dict, dict]: + try: + messages = self.build_messages(data_item) + completion = await self.chat_completion(server, messages) + + if not completion.choices: + return {"accuracy": 0.0, "anls": 0.0}, {"error": "Empty response"} + + message = completion.choices[0].message + response = message.content or "" + if hasattr(message, "reasoning") and message.reasoning and not response: + response = message.reasoning + if not response and hasattr(message, "model_extra"): + reasoning = message.model_extra.get("reasoning", "") + if reasoning: + response = reasoning + + if not response: + return {"accuracy": 0.0, "anls": 0.0}, {"error": "Empty response"} + + extracted = self.extract_answer(response) + answers = data_item.get("answers", []) + if isinstance(answers, str): + answers = [answers] + + anls = self.anls_score(extracted, answers) + correct = anls >= 0.5 + + sample = { + "questionId": data_item.get("questionId", ""), + "question": data_item.get("question", ""), + "answers": answers, + "prediction": extracted, + "anls": anls, + "correct": correct, + "question_types": data_item.get("question_types", []), + } + + return {"accuracy": 1.0 if correct else 0.0, "anls": anls}, sample + + except Exception as e: + return {"accuracy": 0.0, "anls": 0.0}, {"error": str(e)} + + +if __name__ == "__main__": + asyncio.run(eval_runner(DocVQA(split="test", temperature=0.0, max_tokens=256))) diff --git a/environments/eval_environments/vision_evals/dynamath_environment.py b/environments/eval_environments/vision_evals/dynamath_environment.py new file mode 100644 index 000000000..f001bddcf --- /dev/null +++ b/environments/eval_environments/vision_evals/dynamath_environment.py @@ -0,0 +1,249 @@ +"""DynaMath evaluation environment.""" + +import asyncio +import base64 +import io +import json +import re +from string import ascii_uppercase +from typing import List, Optional, Tuple + +import numpy as np +from datasets import load_dataset +from environments.eval_environments.eval import EvalBase, eval_runner +from PIL import Image + +from atroposlib.envs.server_handling.server_manager import ServerManager + + +class DynaMath(EvalBase): + """DynaMath evaluation - dynamic mathematical reasoning benchmark.""" + + GUIDE = """ +## Answer Instruction +Please provide an answer to the question outlined above. Your response should adhere to the following JSON format, which includes two keys: 'solution' and 'short answer'. The 'solution' key can contain detailed steps needed to solve the question, and the 'short answer' key should provide a concise response. {INST} + +Example of expected JSON response format: + +{{ + "solution": "[Detailed step-by-step explanation]", + "short answer": "[Concise Answer]" +}} +""" + + def setup_data(self) -> list: + # DynaMath_Sample uses variant splits: sample_variant1, sample_variant2, etc. + split = getattr(self, "split", "sample_variant1") + + try: + # DynaMath_Sample is the publicly available dataset + dataset = load_dataset("DynaMath/DynaMath_Sample", split=split) + print(f"Loaded {len(dataset)} examples from DynaMath ({split})") + return list(dataset) + except Exception as e: + print(f"Warning: Could not load DynaMath: {e}") + try: + # Try sample_variant1 explicitly + dataset = load_dataset( + "DynaMath/DynaMath_Sample", split="sample_variant1" + ) + print(f"Loaded {len(dataset)} examples from DynaMath (sample_variant1)") + return list(dataset) + except Exception: + try: + dataset = load_dataset("lmms-lab/DynaMath", split="test") + print(f"Loaded {len(dataset)} examples from DynaMath (test)") + return list(dataset) + except Exception: + raise ValueError(f"Could not load DynaMath dataset: {e}") + + def encode_image(self, pil_image: Image.Image) -> str: + buffer = io.BytesIO() + pil_image.save(buffer, format="PNG") + return base64.b64encode(buffer.getvalue()).decode("utf-8") + + def get_image_base64(self, item: dict) -> Optional[str]: + for key in ["image", "decoded_image"]: + if key in item and item[key] is not None: + if isinstance(item[key], Image.Image): + return self.encode_image(item[key]) + return None + + def build_messages(self, item: dict) -> List[dict]: + image_base64 = self.get_image_base64(item) + question = item.get("question", "") + answer_type = item.get("answer_type", "free_form") + + use_json_format = getattr(self, "use_json_format", True) + + if use_json_format: + if answer_type == "multiple choice": + inst = "Provide the corresponding choice option in the 'short answer' key, such as 'A', 'B', 'C', or 'D'." + elif answer_type == "float": + inst = "Format the answer as a three-digit floating-point number and provide it in the 'short answer' key." + else: + inst = "Float numbers in the answer should be formatted as three-digit floating-point numbers." + + prompt = f"## Question\n{question}" + self.GUIDE.format(INST=inst) + else: + prompt = question + + content = [] + if image_base64: + content.append( + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{image_base64}"}, + } + ) + content.append({"type": "text", "text": prompt}) + + return [{"role": "user", "content": content}] + + def preprocess_response(self, response: str) -> str: + """Preprocess response to extract JSON.""" + response = str(response) + if 0 <= response.find("{") < response.rfind("}"): + response = response[response.find("{") : response.rfind("}") + 1] + response = response.replace("\\", "").replace("\\n", "\n") + return response + + def transfer_pi(self, value: str) -> float: + """Convert pi symbol to numeric value.""" + if "\u03c0" in value: + parts = value.split("\u03c0") + return float(parts[0]) * np.pi + return float(value) + + def parse_answer(self, answer: str, answer_type: str) -> Tuple[bool, Optional[str]]: + """Parse answer based on type.""" + if answer_type == "float": + if answer.isdigit(): + return True, str(float(answer)) + parts = answer.split(" ") + answer = parts[0] + try: + result = self.transfer_pi(answer) + return True, str(result) + except Exception: + return False, None + + elif answer_type == "multiple choice": + if len(answer) == 1 and answer.upper() in ascii_uppercase[:5]: + return True, answer.upper() + # Check if any letter appears + for ch in ascii_uppercase[:5]: + if ch in answer.upper(): + return True, ch + return False, None + + else: + return True, answer + + def extract_answer( + self, response: str, answer_type: str + ) -> Tuple[bool, Optional[str]]: + """Extract answer from response.""" + processed = self.preprocess_response(response) + + try: + dj = json.loads(processed, strict=False) + short_answer = dj.get("short answer") + if short_answer is not None: + return self.parse_answer(str(short_answer), answer_type) + except Exception: + pass + + if answer_type == "multiple choice": + for ch in ascii_uppercase[:5]: + if response.strip().upper().startswith(ch): + return True, ch + for ch in ascii_uppercase[:5]: + if ch in response.upper()[:20]: + return True, ch + elif answer_type == "float": + numbers = re.findall(r"-?\d+\.?\d*", response) + if numbers: + try: + return True, str(float(numbers[0])) + except ValueError: + pass + + return False, None + + def score_answer( + self, extracted: Optional[str], answer: str, answer_type: str, parsed: bool + ) -> bool: + """Score the extracted answer against ground truth.""" + if not parsed or extracted is None: + # Check if answer appears in raw response for MC + return False + + if answer_type == "float": + try: + pred_val = float(extracted) + ans_val = float(answer) + return abs(pred_val - ans_val) <= 0.001 + except (ValueError, TypeError): + return False + + elif answer_type == "multiple choice": + return extracted.upper() == str(answer).upper() + + else: + # Free form: substring match + return ( + extracted.lower() in answer.lower() + or answer.lower() in extracted.lower() + ) + + async def run_item( + self, server: ServerManager, data_item: dict + ) -> Tuple[dict, dict]: + try: + messages = self.build_messages(data_item) + + completion = await self.chat_completion(server, messages) + + if not completion.choices: + return {"accuracy": 0.0}, {"error": "Empty response"} + + message = completion.choices[0].message + response = message.content or "" + + if not response: + return {"accuracy": 0.0}, {"error": "Empty response"} + + answer = data_item.get("ground_truth", data_item.get("answer", "")) + answer_type = data_item.get("answer_type", "free_form") + + parsed, extracted = self.extract_answer(response, answer_type) + correct = self.score_answer(extracted, answer, answer_type, parsed) + + sample = { + "id": data_item.get("index", data_item.get("id", "")), + "question": data_item.get("question", "")[:200], + "subject": data_item.get("subject", ""), + "knowledge_level": data_item.get("knowledge_level", ""), + "answer_type": answer_type, + "answer": answer, + "prediction": extracted, + "parsed": parsed, + "raw_response": response[:500], + "correct": correct, + } + + return {"accuracy": 1.0 if correct else 0.0}, sample + + except Exception as e: + return {"accuracy": 0.0}, {"error": str(e)} + + +if __name__ == "__main__": + asyncio.run( + eval_runner( + DynaMath( + split="test", use_json_format=True, temperature=0.0, max_tokens=1024 + ) + ) + ) diff --git a/environments/eval_environments/vision_evals/hallusionbench_environment.py b/environments/eval_environments/vision_evals/hallusionbench_environment.py new file mode 100644 index 000000000..47a255ae2 --- /dev/null +++ b/environments/eval_environments/vision_evals/hallusionbench_environment.py @@ -0,0 +1,147 @@ +"""HallusionBench evaluation environment.""" + +import asyncio +import base64 +import io +import re +from typing import List, Optional, Tuple + +from datasets import load_dataset +from environments.eval_environments.eval import EvalBase, eval_runner +from PIL import Image + +from atroposlib.envs.server_handling.server_manager import ServerManager + + +class HallusionBench(EvalBase): + """HallusionBench evaluation - visual hallucination benchmark.""" + + def setup_data(self) -> list: + # HallusionBench has 'image' and 'non_image' splits + split = getattr(self, "split", "image") + + try: + dataset = load_dataset("lmms-lab/HallusionBench", split=split) + print(f"Loaded {len(dataset)} examples from HallusionBench ({split})") + return list(dataset) + except Exception as e: + print(f"Warning: Could not load HallusionBench: {e}") + try: + # Try combining both splits + all_data = [] + for s in ["image", "non_image"]: + try: + ds = load_dataset("lmms-lab/HallusionBench", split=s) + all_data.extend(list(ds)) + except Exception: + pass + if all_data: + print( + f"Loaded {len(all_data)} examples from HallusionBench (combined)" + ) + return all_data + raise ValueError(f"Could not load HallusionBench dataset: {e}") + except Exception: + raise ValueError(f"Could not load HallusionBench dataset: {e}") + + def encode_image(self, pil_image: Image.Image) -> str: + buffer = io.BytesIO() + pil_image.save(buffer, format="PNG") + return base64.b64encode(buffer.getvalue()).decode("utf-8") + + def get_image_base64(self, item: dict) -> Optional[str]: + for key in ["image", "decoded_image"]: + if key in item and item[key] is not None: + if isinstance(item[key], Image.Image): + return self.encode_image(item[key]) + return None + + def build_messages(self, item: dict) -> List[dict]: + image_base64 = self.get_image_base64(item) + question = item.get("question", "") + + prompt = f"{question}\n\nPlease answer yes or no." + + content = [] + if image_base64: + content.append( + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{image_base64}"}, + } + ) + content.append({"type": "text", "text": prompt}) + + return [{"role": "user", "content": content}] + + def extract_yorn(self, response: str) -> str: + """Extract Yes/No from response.""" + response_lower = response.lower().strip() + + if response_lower.startswith("yes"): + return "Yes" + if response_lower.startswith("no"): + return "No" + + yes_patterns = [r"\byes\b", r"\btrue\b", r"\bcorrect\b"] + no_patterns = [r"\bno\b", r"\bfalse\b", r"\bincorrect\b"] + + for pattern in yes_patterns: + if re.search(pattern, response_lower): + return "Yes" + + for pattern in no_patterns: + if re.search(pattern, response_lower): + return "No" + + return "Unknown" + + async def run_item( + self, server: ServerManager, data_item: dict + ) -> Tuple[dict, dict]: + try: + messages = self.build_messages(data_item) + completion = await self.chat_completion(server, messages) + + if not completion.choices: + return {"accuracy": 0.0}, {"error": "Empty response"} + + message = completion.choices[0].message + response = message.content or "" + + if not response: + return {"accuracy": 0.0}, {"error": "Empty response"} + + answer = data_item.get("answer", data_item.get("gt_answer", "")) + extracted = self.extract_yorn(response) + + answer_norm = str(answer).strip().lower() + if answer_norm in ["yes", "true", "1"]: + answer_norm = "Yes" + elif answer_norm in ["no", "false", "0"]: + answer_norm = "No" + else: + answer_norm = str(answer).strip() + + correct = extracted == answer_norm + + sample = { + "id": data_item.get("index", data_item.get("id", "")), + "question": data_item.get("question", "")[:200], + "category": data_item.get("category", data_item.get("subcategory", "")), + "answer": answer_norm, + "prediction": extracted, + "raw_response": response[:200], + "correct": correct, + } + + return {"accuracy": 1.0 if correct else 0.0}, sample + + except Exception as e: + return {"accuracy": 0.0}, {"error": str(e)} + + +if __name__ == "__main__": + asyncio.run( + eval_runner(HallusionBench(split="test", temperature=0.0, max_tokens=64)) + ) diff --git a/environments/eval_environments/vision_evals/infovqa_environment.py b/environments/eval_environments/vision_evals/infovqa_environment.py new file mode 100644 index 000000000..9e01100c9 --- /dev/null +++ b/environments/eval_environments/vision_evals/infovqa_environment.py @@ -0,0 +1,176 @@ +import asyncio +import base64 +import io +import re +from typing import List, Tuple + +from datasets import load_dataset +from environments.eval_environments.eval import EvalBase, eval_runner +from PIL import Image + +from atroposlib.envs.server_handling.server_manager import ServerManager + + +class InfoVQA(EvalBase): + def setup_data(self) -> list: + split = getattr(self, "split", "validation") + dataset = load_dataset("lmms-lab/DocVQA", "InfographicVQA", split=split) + print(f"Loaded {len(dataset)} examples from InfoVQA ({split})") + return list(dataset) + + def encode_image(self, pil_image: Image.Image) -> str: + buffer = io.BytesIO() + pil_image.save(buffer, format="PNG") + return base64.b64encode(buffer.getvalue()).decode("utf-8") + + def get_image_base64(self, item: dict) -> str: + if "image" in item and item["image"] is not None: + if isinstance(item["image"], Image.Image): + return self.encode_image(item["image"]) + raise ValueError(f"Could not find image for item {item.get('id', 'unknown')}") + + def build_messages(self, item: dict) -> List[dict]: + image_base64 = self.get_image_base64(item) + question = item.get("question", "") + + prompt = f"""Look at the infographic and answer the question. + +Question: {question} + +Provide only the answer, as concisely as possible.""" + + return [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{image_base64}"}, + }, + {"type": "text", "text": prompt}, + ], + } + ] + + def extract_answer(self, response: str) -> str: + response = response.strip() + + patterns = [ + r"answer[:\s]+(.+?)(?:\.|$)", + r"\"([^\"]+)\"", + ] + + for pattern in patterns: + match = re.search(pattern, response, re.IGNORECASE) + if match: + return match.group(1).strip() + + lines = response.split("\n") + if lines: + return lines[-1].strip() + + return response + + def normalize_text(self, text: str) -> str: + text = text.lower().strip() + text = re.sub(r"[^\w\s]", "", text) + text = " ".join(text.split()) + return text + + def anls_score( + self, prediction: str, answers: List[str], threshold: float = 0.5 + ) -> float: + """ + Calculate Average Normalized Levenshtein Similarity (ANLS). + This is the standard metric for InfoVQA. + """ + pred_norm = self.normalize_text(prediction) + + if not pred_norm: + return 0.0 + + max_score = 0.0 + for answer in answers: + ans_norm = self.normalize_text(answer) + if not ans_norm: + continue + + if pred_norm == ans_norm: + max_score = 1.0 + break + + distance = self._levenshtein_distance(pred_norm, ans_norm) + max_len = max(len(pred_norm), len(ans_norm)) + nls = 1 - distance / max_len if max_len > 0 else 0 + + if nls >= threshold: + max_score = max(max_score, nls) + + return max_score + + def _levenshtein_distance(self, s1: str, s2: str) -> int: + if len(s1) < len(s2): + return self._levenshtein_distance(s2, s1) + + if len(s2) == 0: + return len(s1) + + previous_row = range(len(s2) + 1) + for i, c1 in enumerate(s1): + current_row = [i + 1] + for j, c2 in enumerate(s2): + insertions = previous_row[j + 1] + 1 + deletions = current_row[j] + 1 + substitutions = previous_row[j] + (c1 != c2) + current_row.append(min(insertions, deletions, substitutions)) + previous_row = current_row + + return previous_row[-1] + + async def run_item( + self, server: ServerManager, data_item: dict + ) -> Tuple[dict, dict]: + try: + messages = self.build_messages(data_item) + completion = await self.chat_completion(server, messages) + + if not completion.choices: + return {"accuracy": 0.0, "anls": 0.0}, {"error": "Empty response"} + + message = completion.choices[0].message + response = message.content or "" + if hasattr(message, "reasoning") and message.reasoning and not response: + response = message.reasoning + if not response and hasattr(message, "model_extra"): + reasoning = message.model_extra.get("reasoning", "") + if reasoning: + response = reasoning + + if not response: + return {"accuracy": 0.0, "anls": 0.0}, {"error": "Empty response"} + + extracted = self.extract_answer(response) + answers = data_item.get("answer", []) + if isinstance(answers, str): + answers = [answers] + + anls = self.anls_score(extracted, answers) + correct = anls >= 0.5 + + sample = { + "id": data_item.get("id", ""), + "question": data_item.get("question", ""), + "answers": answers, + "prediction": extracted, + "anls": anls, + "correct": correct, + } + + return {"accuracy": 1.0 if correct else 0.0, "anls": anls}, sample + + except Exception as e: + return {"accuracy": 0.0, "anls": 0.0}, {"error": str(e)} + + +if __name__ == "__main__": + asyncio.run(eval_runner(InfoVQA(split="test", temperature=0.0, max_tokens=256))) diff --git a/environments/eval_environments/vision_evals/logicvista_environment.py b/environments/eval_environments/vision_evals/logicvista_environment.py new file mode 100644 index 000000000..8d240e4bb --- /dev/null +++ b/environments/eval_environments/vision_evals/logicvista_environment.py @@ -0,0 +1,303 @@ +"""LogicVista evaluation environment.""" + +import asyncio +import base64 +import io +import os +import re +from typing import Dict, List, Optional, Tuple + +import openai +from datasets import load_dataset +from environments.eval_environments.eval import EvalBase, eval_runner +from PIL import Image + +from atroposlib.envs.server_handling.server_manager import ServerManager + +EXTRACTION_PROMPT_TEMPLATE = """You are a information extractor that extracts multiple choice letter answer choices \ +from a paragraph that contains the answer choice and sometimes explaination of why that \ +choice is correct to the given question. +What letter did the following answer choose? If the answer did not select a letter answer choice, \ +first try to infer the answer based off the given choices. +If it does not correspond to an answer choice OR there is no selected answer, respond with Z. +Make sure you answer with ONLY the letters chosen. +Example 1: +Question: +What is the main object in image? +Options: A. teddy bear B. rabbit C. cat D. dog + +Answer: +a cute teddy bear + +Your output: A +Example 2: +Question: +What is the main object in image? +Options: A. teddy bear B. rabbit C. cat D. dog + +Answer: +Spider + +Your output: Z +Example 3: +Question: +Which figure is a rotation of the object? + +Answer: +The figure on the right, labeled "D," is a rotation of the object shown in the top left corner. + +Your output: D +Example 4: +Question: +Which of the boxes comes next in the sequence? Select from A-E + +Answer: +The sequence of the boxes is A, B, C, D, E. + +Your output: ABCDE +Example 5: +Question: +{question} + +Answer: +{prediction} + +Your output: """ + + +class LogicVista(EvalBase): + SKILL_CATEGORIES = [ + "inductive", + "deductive", + "numerical", + "spatial", + "mechanical", + ] + + CAPABILITY_CATEGORIES = [ + "diagram", + "ocr", + "patterns", + "graphs", + "tables", + "3d shapes", + "puzzles", + "sequences", + "physics", + ] + + def setup_data(self) -> list: + split = getattr(self, "split", "test") + dataset = load_dataset("lscpku/LogicVista", split=split) + print(f"Loaded {len(dataset)} examples from LogicVista ({split})") + return list(dataset) + + def encode_image(self, pil_image: Image.Image) -> str: + buffer = io.BytesIO() + pil_image.save(buffer, format="PNG") + return base64.b64encode(buffer.getvalue()).decode("utf-8") + + def get_image_base64(self, item: dict) -> str: + if "image" in item and item["image"] is not None: + if isinstance(item["image"], Image.Image): + return self.encode_image(item["image"]) + raise ValueError( + f"Could not find image for item {item.get('question_id', 'unknown')}" + ) + + def build_messages(self, item: dict) -> List[dict]: + image_base64 = self.get_image_base64(item) + question = item.get("question", "") + + prompt = f"""{question} + +Provide your answer as the letter(s) of the correct choice(s), e.g., A, B, C, D, or multiple letters if applicable.""" + + return [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{image_base64}"}, + }, + {"type": "text", "text": prompt}, + ], + } + ] + + async def _extract_with_gpt(self, question: str, response: str) -> Optional[str]: + judge_model = getattr(self, "judge_model", "gpt-4o-mini") + judge_base_url = getattr(self, "judge_base_url", "https://api.openai.com/v1") + judge_api_key = os.environ.get( + getattr(self, "judge_api_key_env", "OPENAI_API_KEY"), "" + ) + + if not judge_api_key: + return None + + try: + judge_client = openai.AsyncOpenAI( + api_key=judge_api_key, + base_url=judge_base_url, + ) + + prompt = EXTRACTION_PROMPT_TEMPLATE.format( + question=question, prediction=response + ) + + completion = await judge_client.chat.completions.create( + model=judge_model, + messages=[{"role": "user", "content": prompt}], + temperature=0.0, + max_tokens=128, + ) + + result = completion.choices[0].message.content.strip() + + if result and result.isupper() and result.isalpha(): + return result + return None + + except Exception as e: + print(f"GPT extraction error: {e}") + return None + + def extract_answer(self, response: str) -> str: + response = response.strip().upper() + + letters_with_sep = re.findall(r"[A-E](?:\s*[,\s]\s*[A-E])*", response) + if letters_with_sep: + letters = re.findall(r"[A-E]", letters_with_sep[-1]) + return "".join(sorted(set(letters))) + + letters = re.findall( + r"[A-E]", response[-20:] if len(response) > 20 else response + ) + if letters: + return "".join(sorted(set(letters))) + + all_letters = re.findall(r"[A-E]", response) + if all_letters: + return "".join(sorted(set(all_letters[-4:]))) + + return "" + + def score(self, prediction: str, answer: str) -> bool: + if not prediction: + return False + + answer_letters = re.findall(r"[A-Ea-e]", answer) + answer_normalized = "".join(sorted(set(c.lower() for c in answer_letters))) + + pred_letters = [c.lower() for c in prediction if c.isalpha()] + pred_normalized = "".join(sorted(set(pred_letters))) + + return pred_normalized == answer_normalized + + async def run_item( + self, server: ServerManager, data_item: dict + ) -> Tuple[dict, dict]: + try: + messages = self.build_messages(data_item) + + completion = await self.chat_completion(server, messages) + + if not completion.choices: + return {"accuracy": 0.0, "hit": 0}, {"error": "Empty response"} + + message = completion.choices[0].message + response = message.content or "" + if hasattr(message, "reasoning") and message.reasoning and not response: + response = message.reasoning + if not response and hasattr(message, "model_extra"): + reasoning = message.model_extra.get("reasoning", "") + if reasoning: + response = reasoning + + if not response: + return {"accuracy": 0.0, "hit": 0}, {"error": "Empty response"} + + use_gpt_extraction = getattr(self, "use_gpt_extraction", True) + extracted = None + extraction_method = "regex" + + if use_gpt_extraction: + question = data_item.get("question", "") + gpt_result = await self._extract_with_gpt(question, response) + if gpt_result and gpt_result != "Z": + extracted = gpt_result + extraction_method = "gpt" + + if not extracted: + extracted = self.extract_answer(response) + extraction_method = "regex" + + answer = data_item.get("answer", "") + correct = self.score(extracted, answer) + + sample = { + "question_id": data_item.get("question_id", data_item.get("index", "")), + "question": data_item.get("question", ""), + "answer": answer, + "prediction": extracted, + "raw_response": response[:500], + "hit": 1 if correct else 0, + "correct": correct, + "skill": data_item.get("skill", ""), + "extraction_method": extraction_method, + } + + return { + "accuracy": 1.0 if correct else 0.0, + "hit": 1 if correct else 0, + }, sample + + except Exception as e: + return {"accuracy": 0.0, "hit": 0}, {"error": str(e)} + + +def compute_skill_metrics(samples: List[dict]) -> Dict: + import pandas as pd + + df = pd.DataFrame(samples) + + if "hit" not in df.columns or "skill" not in df.columns: + return {"overall_accuracy": df.get("hit", pd.Series([0])).mean()} + + metrics = {} + + # Overall accuracy + metrics["Overall"] = { + "total": len(df), + "correct": int(df["hit"].sum()), + "accuracy": float(df["hit"].mean() * 100), + } + + # By skill category + skill_keywords = ["inductive", "deductive", "numerical", "spatial", "mechanical"] + + for skill in skill_keywords: + skill_df = df[df["skill"].str.contains(skill, case=False, na=False)] + if len(skill_df) > 0: + metrics[skill] = { + "total": len(skill_df), + "correct": int(skill_df["hit"].sum()), + "accuracy": float(skill_df["hit"].mean() * 100), + } + + return metrics + + +if __name__ == "__main__": + asyncio.run( + eval_runner( + LogicVista( + split="test", + use_gpt_extraction=True, + judge_model="gpt-4o-mini", + temperature=0.0, + max_tokens=512, + ) + ) + ) diff --git a/environments/eval_environments/vision_evals/mathverse_environment.py b/environments/eval_environments/vision_evals/mathverse_environment.py new file mode 100644 index 000000000..a09213c3e --- /dev/null +++ b/environments/eval_environments/vision_evals/mathverse_environment.py @@ -0,0 +1,336 @@ +"""MathVerse evaluation environment.""" + +import asyncio +import base64 +import io +import os +import re +from typing import List, Optional, Tuple + +import openai +from datasets import load_dataset +from environments.eval_environments.eval import EvalBase, eval_runner +from PIL import Image + +from atroposlib.envs.server_handling.server_manager import ServerManager + +EXTRACT_ICL_EXAMPLES = [ + "1.\nModel response: 'The perimeter of the sector is approximately (-2, 1)'\n" + "Extracted Answer: (-2, 1)\n", + "2.\nModel response: 'The correct option is D. They give the solutions to $f(t)=g(t)$.'\n" + "Extracted Answer: D\n", + "3.\nModel response: 'The range is (-4, 1]. Domain: (-3, 3], Range: (-4, 1]'\n" + "Extracted Answer: Domain: (-3, 3], Range: (-4, 1]\n", + "4.\nModel response: 'I cannot provide the answer because there is not enough information.'\n" + "Extracted Answer: null\n", + "5.\nModel response: 'The distance d between Ned and Bart is approximately 22.3 meters.'\n" + "Extracted answer: 22.3\n", + "6.\nModel response: 'The equation for f is f(x) = -x^2 - 2x + 1'\n" + "Extracted answer: f(x) = -x^2 - 2x + 1\n", +] + +SCORE_ICL_EXAMPLES = [ + """[Question]: Write the set of numbers represented on the number line in interval notation. +[Standard Answer]: (-2,1] +[Model_answer] : Extracted Answer: \\((-2, 1)\\) +Judgement: 0 +""", + """[Question]: As shown in the figure, circle O has a radius 1.0, if angle BAC = 60.0, then the length of BC is () +Choices: +A:2 +B:2√{3} +C:√{3} +D:2√{2} +[Standard Answer]: C +[Model_answer] : B:2√{3} +Judgement: 0 +""", + """[Question]: Find the domain and range of the function f using interval notation. +[Standard Answer]: domain: [-4, 0) and range: (-3, 1] +[Model_answer] : Range: \\((-4, 1]\\) +Judgement: 0 +""", + """[Question]: As shown in the figure, circle O has a radius 1.0, if angle BAC = 60.0, then the length of BC is () +Choices: +A:2 +B:2√{3} +C:√{3} +D:2√{2} +[Standard Answer]: C +[Model_answer] : null +Judgement: 0 +""", +] + + +class MathVerse(EvalBase): + PROBLEM_VERSIONS = [ + "Text Dominant", + "Text Lite", + "Vision Intensive", + "Vision Dominant", + "Vision Only", + ] + + def setup_data(self) -> list: + config = getattr(self, "config", "testmini") + dataset = load_dataset("AI4Math/MathVerse", config, split="testmini") + print(f"Loaded {len(dataset)} examples from MathVerse ({config})") + return list(dataset) + + def encode_image(self, pil_image: Image.Image) -> str: + buffer = io.BytesIO() + pil_image.save(buffer, format="PNG") + return base64.b64encode(buffer.getvalue()).decode("utf-8") + + def get_image_base64(self, item: dict) -> str: + if "image" in item and item["image"] is not None: + if isinstance(item["image"], Image.Image): + return self.encode_image(item["image"]) + raise ValueError( + f"Could not find image for item {item.get('sample_index', 'unknown')}" + ) + + def build_messages(self, item: dict) -> List[dict]: + image_base64 = self.get_image_base64(item) + + use_cot = getattr(self, "use_cot", False) + if use_cot and "query_cot" in item: + question = item["query_cot"] + elif "question_for_eval" in item: + question = item["question_for_eval"] + else: + question = item.get("question", "") + + prompt = f"""{question} + +Please solve the problem step by step and provide your final answer.""" + + return [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{image_base64}"}, + }, + {"type": "text", "text": prompt}, + ], + } + ] + + async def _extract_with_gpt(self, prediction: str) -> Optional[str]: + judge_model = getattr(self, "judge_model", "gpt-4o-mini") + judge_base_url = getattr(self, "judge_base_url", "https://api.openai.com/v1") + judge_api_key = os.environ.get( + getattr(self, "judge_api_key_env", "OPENAI_API_KEY"), "" + ) + + if not judge_api_key: + return None + + try: + judge_client = openai.AsyncOpenAI( + api_key=judge_api_key, + base_url=judge_base_url, + ) + + task_description = ( + "I am providing you a response from a model to a math problem, " + "termed 'Model Response'. You should extract the answer from the " + "response as 'Extracted Answer'. Directly output the extracted " + "answer with no explanation.\n\n" + ) + prompt = task_description + for example in EXTRACT_ICL_EXAMPLES: + prompt += example + "\n\n" + prompt += f"7.\nModel response: '{prediction}'\nExtracted Answer: " + + completion = await judge_client.chat.completions.create( + model=judge_model, + messages=[{"role": "user", "content": prompt}], + temperature=0.0, + max_tokens=256, + ) + + result = completion.choices[0].message.content.strip() + return result if result else None + + except Exception as e: + print(f"GPT extraction error: {e}") + return None + + async def _score_with_gpt( + self, question: str, answer: str, extracted: str + ) -> Optional[bool]: + judge_model = getattr(self, "judge_model", "gpt-4o-mini") + judge_base_url = getattr(self, "judge_base_url", "https://api.openai.com/v1") + judge_api_key = os.environ.get( + getattr(self, "judge_api_key_env", "OPENAI_API_KEY"), "" + ) + + if not judge_api_key: + return None + + if str(extracted).strip() == str(answer).strip(): + return True + + try: + judge_client = openai.AsyncOpenAI( + api_key=judge_api_key, + base_url=judge_base_url, + ) + + task_description = ( + "Below are two answers to a math question. Question is [Question], " + "[Standard Answer] is the standard answer to the question, and " + "[Model_answer] is the answer extracted from a model's output to " + "this question. Determine whether these two answers are consistent.\n" + "Please note that only when the [Model_answer] completely matches " + "the [Standard Answer] means they are consistent. For non-MCQ " + "questions, if the meaning is expressed in the same way, it is also " + "considered consistent, for example, 0.5m and 50cm.\n" + "If they are consistent, Judgement is 1; if different, Judgement is 0.\n\n" + ) + prompt = task_description + for example in SCORE_ICL_EXAMPLES: + prompt += example + "\n\n" + prompt += f"""[Question]: {question} +[Standard Answer]: {answer} +[Model_answer] : {extracted} +Judgement:""" + + completion = await judge_client.chat.completions.create( + model=judge_model, + messages=[{"role": "user", "content": prompt}], + temperature=0.0, + max_tokens=16, + ) + + result = completion.choices[0].message.content.strip() + if result in ["0", "1"]: + return int(result) == 1 + return None + + except Exception as e: + print(f"GPT scoring error: {e}") + return None + + def extract_answer_fallback(self, response: str) -> str: + response = response.strip().upper() + + for char in reversed(response): + if char in "ABCDE": + return char + + numbers = re.findall(r"-?\d+\.?\d*", response) + if numbers: + return numbers[-1] + + return response[:100] + + def score_fallback(self, prediction: str, answer: str) -> bool: + pred = prediction.strip().upper() + ans = answer.strip().upper() + + if pred == ans: + return True + + try: + pred_num = float(pred) + ans_num = float(ans) + return abs(pred_num - ans_num) < 0.01 + except ValueError: + pass + + return False + + async def run_item( + self, server: ServerManager, data_item: dict + ) -> Tuple[dict, dict]: + try: + messages = self.build_messages(data_item) + + completion = await self.chat_completion(server, messages) + + if not completion.choices: + return {"accuracy": 0.0}, {"error": "Empty response"} + + message = completion.choices[0].message + response = message.content or "" + if hasattr(message, "reasoning") and message.reasoning and not response: + response = message.reasoning + if not response and hasattr(message, "model_extra"): + reasoning = message.model_extra.get("reasoning", "") + if reasoning: + response = reasoning + + if not response: + return {"accuracy": 0.0}, {"error": "Empty response"} + + use_gpt_evaluation = getattr(self, "use_gpt_evaluation", True) + answer = data_item.get("answer", "") + question = data_item.get("question_for_eval", data_item.get("question", "")) + + if use_gpt_evaluation: + extracted = await self._extract_with_gpt(response) + if not extracted: + extracted = self.extract_answer_fallback(response) + extraction_method = "fallback" + else: + extraction_method = "gpt" + else: + extracted = self.extract_answer_fallback(response) + extraction_method = "fallback" + + if use_gpt_evaluation: + score_result = await self._score_with_gpt(question, answer, extracted) + if score_result is not None: + correct = score_result + scoring_method = "gpt" + else: + correct = self.score_fallback(extracted, answer) + scoring_method = "fallback" + else: + correct = self.score_fallback(extracted, answer) + scoring_method = "fallback" + + metadata = data_item.get("metadata", {}) + sample = { + "sample_index": data_item.get("sample_index", ""), + "problem_index": data_item.get("problem_index", ""), + "problem_version": data_item.get("problem_version", ""), + "question": question[:200], + "answer": answer, + "prediction": extracted, + "raw_response": response[:500], + "correct": correct, + "subject": ( + metadata.get("subject", "") if isinstance(metadata, dict) else "" + ), + "subfield": ( + metadata.get("subfield", "") if isinstance(metadata, dict) else "" + ), + "extraction_method": extraction_method, + "scoring_method": scoring_method, + } + + return {"accuracy": 1.0 if correct else 0.0}, sample + + except Exception as e: + return {"accuracy": 0.0}, {"error": str(e)} + + +if __name__ == "__main__": + asyncio.run( + eval_runner( + MathVerse( + split="testmini", + use_cot=False, + use_gpt_evaluation=True, + judge_model="gpt-4o-mini", + temperature=0.0, + max_tokens=2048, + ) + ) + ) diff --git a/environments/eval_environments/vision_evals/mathvision_environment.py b/environments/eval_environments/vision_evals/mathvision_environment.py new file mode 100644 index 000000000..bd98f9d4f --- /dev/null +++ b/environments/eval_environments/vision_evals/mathvision_environment.py @@ -0,0 +1,342 @@ +"""MathVision evaluation environment.""" + +import asyncio +import base64 +import io +import os +import re +from typing import Dict, List, Optional, Tuple + +import openai +from datasets import load_dataset +from environments.eval_environments.eval import EvalBase, eval_runner +from PIL import Image + +from atroposlib.envs.server_handling.server_manager import ServerManager + +ICL_EXAMPLES = [ + """Hint: Please answer the question and provide the final answer at the end. +Question: Which number is missing? +Model response: The number missing in the sequence is 14. +Extracted answer: 14 +""", + "Hint: Please answer the question and provide the final answer at the end.\n" + "Question: What is the fraction of females facing the camera?\n" + "Model response: The fraction of females facing the camera is 0.6.\n" + "Extracted answer: 0.6\n", + """Hint: Please answer the question and provide the final answer at the end. +Question: How much money does Luca need to buy a sour apple candy and a butter-scotch candy? (Unit: $) +Model response: Luca needs $1.45 to buy a sour apple candy and a butterscotch candy. +Extracted answer: 1.45 +""", + """Hint: Please answer the question and provide the final answer at the end. +Question: Between which two years does the line graph saw its maximum peak? +Model response: The line graph saw its maximum peak between 2007 and 2008. +Extracted answer: [2007, 2008] +""", + """Hint: Please answer the question and provide the correct option letter, e.g., A, B, C, D, at the end. +Question: What fraction of the shape is blue? +Choices: (A) 3/11 (B) 8/11 (C) 6/11 (D) 3/5 +Model response: The correct answer is (B) 8/11. +Extracted answer: B +""", +] + + +def can_infer_option(answer: str, choices: Dict[str, str]) -> Optional[str]: + if "Failed to obtain answer via API" in answer: + return None + + answer_mod = answer + for c in ".()[],:;!*#{}": + answer_mod = answer_mod.replace(c, " ") + + splits = [x.strip() for x in answer_mod.split()] + count = sum(1 for ch in choices if ch in splits) + + if count == 1: + for ch in choices: + if "A" in splits and len(splits) > 3: + continue + if ch in splits and splits.index(ch) > (len(splits) - 5): + return ch + + return None + + +def can_infer_text(answer: str, choices: Dict[str, str]) -> Optional[str]: + answer_lower = answer.lower() + + if len(answer_lower) > 2 * sum(len(str(v)) for v in choices.values()): + return None + + cands = [] + for k, v in choices.items(): + if str(v).lower() in answer_lower: + cands.append(k) + + if len(cands) == 1: + return cands[0] + + return None + + +def can_infer(answer: str, choices: Dict[str, str]) -> Optional[str]: + answer = str(answer) + result = can_infer_option(answer, choices) + if result: + return result + return can_infer_text(answer, choices) + + +def is_equal(asw: str, gt_asw: str) -> bool: + asw = str(asw).lower().strip() + gt_asw = str(gt_asw).lower().strip() + + if gt_asw == asw: + return True + + try: + a = eval(gt_asw) + b = eval(asw) + if abs(float(a) - float(b)) < 1e-6: + return True + except Exception: + pass + + try: + from latex2sympy2 import latex2sympy + + a = latex2sympy(gt_asw) + b = latex2sympy(asw) + if abs(eval(str(a)) - eval(str(b))) < 1e-6: + return True + if abs(float(a) - float(b)) < 1e-6: + return True + except Exception: + pass + + return False + + +class MathVision(EvalBase): + def setup_data(self) -> list: + split = getattr(self, "split", "testmini") + try: + dataset = load_dataset("MathLLMs/MathVision", split=split) + except Exception: + dataset = load_dataset("MathLLMs/MathVision", "default", split=split) + print(f"Loaded {len(dataset)} examples from MathVision ({split})") + return list(dataset) + + def encode_image(self, pil_image: Image.Image) -> str: + buffer = io.BytesIO() + pil_image.save(buffer, format="PNG") + return base64.b64encode(buffer.getvalue()).decode("utf-8") + + def get_image_base64(self, item: dict) -> str: + for key in ["decoded_image", "image"]: + if key in item and item[key] is not None: + if isinstance(item[key], Image.Image): + return self.encode_image(item[key]) + raise ValueError(f"Could not find image for item {item.get('id', 'unknown')}") + + def build_messages(self, item: dict) -> List[dict]: + image_base64 = self.get_image_base64(item) + question = item.get("question", "") + choices = item.get("choices", []) + + if choices: + try: + if isinstance(choices, str): + choices = eval(choices) + choices_text = "\n".join( + [f"({chr(65+i)}) {c}" for i, c in enumerate(choices)] + ) + hint = "Please answer the question and provide the correct option letter, e.g., A, B, C, D, at the end." + prompt = f"Hint: {hint}\nQuestion: {question}\nChoices:\n{choices_text}" + except Exception: + hint = "Please answer the question and provide the final answer at the end." + prompt = f"Hint: {hint}\nQuestion: {question}" + else: + hint = "Please answer the question and provide the final answer at the end." + prompt = f"Hint: {hint}\nQuestion: {question}" + + return [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{image_base64}"}, + }, + {"type": "text", "text": prompt}, + ], + } + ] + + def _prefetch_answer(self, response: str, item: dict) -> Tuple[Optional[str], bool]: + choices = item.get("choices", []) + + if choices: + try: + if isinstance(choices, str): + choices = eval(choices) + if len(choices) > 0: + choices_dict = {chr(65 + i): val for i, val in enumerate(choices)} + result = can_infer(response, choices_dict) + if result: + return result, True + except Exception: + pass + + return None, False + + async def _extract_with_gpt(self, question: str, response: str) -> Optional[str]: + judge_model = getattr(self, "judge_model", "gpt-4o-mini") + judge_base_url = getattr(self, "judge_base_url", "https://api.openai.com/v1") + judge_api_key = os.environ.get( + getattr(self, "judge_api_key_env", "OPENAI_API_KEY"), "" + ) + + if not judge_api_key: + return None + + try: + judge_client = openai.AsyncOpenAI( + api_key=judge_api_key, + base_url=judge_base_url, + ) + + task_description = """Please read the following example. +Then extract the answer from the model response and type it at the end of the prompt. + +""" + prompt = task_description + for example in ICL_EXAMPLES: + prompt += example + "\n" + prompt += question + "\n" + prompt += f"Model response: {response}\n" + prompt += "Extracted answer:" + + completion = await judge_client.chat.completions.create( + model=judge_model, + messages=[{"role": "user", "content": prompt}], + temperature=0.0, + max_tokens=128, + ) + + result = completion.choices[0].message.content.strip() + return result if result else None + + except Exception as e: + print(f"GPT extraction error: {e}") + return None + + def extract_answer_fallback(self, response: str) -> str: + response = response.strip() + + for char in reversed(response.upper()): + if char in "ABCDEFGH": + return char + + numbers = re.findall(r"-?\d+\.?\d*", response) + if numbers: + return numbers[-1] + + return response[:100] + + def score(self, prediction: str, answer: str, item: dict) -> bool: + choices = item.get("choices", []) + + if choices: + try: + if isinstance(choices, str): + choices = eval(choices) + if len(choices) > 0: + choices_dict = {chr(65 + i): val for i, val in enumerate(choices)} + result = can_infer(prediction, choices_dict) + if result: + return result.upper() == answer.upper() + except Exception: + pass + + return is_equal(prediction, answer) + + async def run_item( + self, server: ServerManager, data_item: dict + ) -> Tuple[dict, dict]: + try: + messages = self.build_messages(data_item) + + completion = await self.chat_completion(server, messages) + + if not completion.choices: + return {"accuracy": 0.0}, {"error": "Empty response"} + + message = completion.choices[0].message + response = message.content or "" + if hasattr(message, "reasoning") and message.reasoning and not response: + response = message.reasoning + if not response and hasattr(message, "model_extra"): + reasoning = message.model_extra.get("reasoning", "") + if reasoning: + response = reasoning + + if not response: + return {"accuracy": 0.0}, {"error": "Empty response"} + + use_gpt_extraction = getattr(self, "use_gpt_extraction", True) + answer = data_item.get("answer", "") + + prefetch_result, prefetch_success = self._prefetch_answer( + response, data_item + ) + + if prefetch_success and prefetch_result: + extracted = prefetch_result + extraction_method = "prefetch" + elif use_gpt_extraction: + question = data_item.get("question", "") + gpt_result = await self._extract_with_gpt(question, response) + if gpt_result: + extracted = gpt_result + extraction_method = "gpt" + else: + extracted = self.extract_answer_fallback(response) + extraction_method = "fallback" + else: + extracted = self.extract_answer_fallback(response) + extraction_method = "fallback" + + correct = self.score(extracted, answer, data_item) + + sample = { + "id": data_item.get("id", data_item.get("index", "")), + "question": data_item.get("question", "")[:200], + "answer": answer, + "prediction": extracted, + "raw_response": response[:500], + "correct": correct, + "category": data_item.get("category", ""), + "extraction_method": extraction_method, + } + + return {"accuracy": 1.0 if correct else 0.0}, sample + + except Exception as e: + return {"accuracy": 0.0}, {"error": str(e)} + + +if __name__ == "__main__": + asyncio.run( + eval_runner( + MathVision( + split="testmini", + use_gpt_extraction=True, + judge_model="gpt-4o-mini", + temperature=0.0, + max_tokens=2048, + ) + ) + ) diff --git a/environments/eval_environments/vision_evals/mathvista_environment.py b/environments/eval_environments/vision_evals/mathvista_environment.py new file mode 100644 index 000000000..651f42623 --- /dev/null +++ b/environments/eval_environments/vision_evals/mathvista_environment.py @@ -0,0 +1,408 @@ +"""MathVista evaluation environment.""" + +import asyncio +import base64 +import io +import os +import re +from typing import Dict, List, Optional, Tuple + +import openai +from datasets import load_dataset +from environments.eval_environments.eval import EvalBase, eval_runner +from PIL import Image + +from atroposlib.envs.server_handling.server_manager import ServerManager + +ICL_EXAMPLES = [ + """ +Hint: Please answer the question requiring an integer answer and provide the final value, +e.g., 1, 2, 3, at the end. +Question: Which number is missing? +Model response: The number missing in the sequence is 14. +Extracted answer: 14 +""", + """ +Hint: Please answer the question requiring a floating-point number with one decimal place and provide the final value, +e.g., 1.2, 1.3, 1.4, at the end. +Question: What is the fraction of females facing the camera? +Model response: The fraction of females facing the camera is 0.6, +which means that six out of ten females in the group are facing the camera. +Extracted answer: 0.6 +""", + """ +Hint: Please answer the question requiring a floating-point number with two decimal places and provide the final value, +e.g., 1.23, 1.34, 1.45, at the end. +Question: How much money does Luca need to buy a sour apple candy and a butter-scotch candy? (Unit: $) +Model response: Luca needs $1.45 to buy a sour apple candy and a butterscotch candy. +Extracted answer: 1.45 +""", + """ +Hint: Please answer the question requiring a Python list as an answer and provide the final list, +e.g., [1, 2, 3], [1.2, 1.3, 1.4], at the end. +Question: Between which two years does the line graph saw its maximum peak? +Model response: The line graph saw its maximum peak between 2007 and 2008. +Extracted answer: [2007, 2008] +""", + """ +Hint: Please answer the question and provide the correct option letter, e.g., A, B, C, D, at the end. +Question: What fraction of the shape is blue? +Choices: (A) 3/11 (B) 8/11 (C) 6/11 (D) 3/5 +Model response: The correct answer is (B) 8/11. +Extracted answer: B +""", +] + + +def build_extraction_prompt(question: str, prediction: str) -> str: + task_description = """Please read the following example. +Then extract the answer from the model response and type it at the end of the prompt. +""" + prompt = task_description + for example in ICL_EXAMPLES: + prompt += example + "\n" + prompt += question + "\n" + prompt += "Model response: " + prediction + "\n" + prompt += "Extracted answer:" + return prompt + + +def can_infer_option(answer: str, choices: Dict[str, str]) -> Optional[str]: + if "Failed to obtain answer via API" in answer: + return None + + answer_mod = answer + for c in ".()[],:;!*#{}": + answer_mod = answer_mod.replace(c, " ") + + splits = [x.strip() for x in answer_mod.split()] + count = sum(1 for ch in choices if ch in splits) + + if count == 1: + for ch in choices: + if "A" in splits and len(splits) > 3: + continue + if ch in splits and splits.index(ch) > (len(splits) - 5): + return ch + + return None + + +def can_infer_text(answer: str, choices: Dict[str, str]) -> Optional[str]: + answer_lower = answer.lower() + + if len(answer_lower) > 2 * sum(len(str(v)) for v in choices.values()): + return None + + cands = [] + for k, v in choices.items(): + if str(v).lower() in answer_lower: + cands.append(k) + + if len(cands) == 1: + return cands[0] + + return None + + +def can_infer(answer: str, choices: Dict[str, str]) -> Optional[str]: + answer = str(answer) + result = can_infer_option(answer, choices) + if result: + return result + return can_infer_text(answer, choices) + + +class MathVista(EvalBase): + TASK_TYPES = ["FQA", "GPS", "MWP", "TQA", "VQA"] + SKILL_TYPES = ["ALG", "ARI", "GEO", "LOG", "NUM", "SCI", "STA"] + + def setup_data(self) -> list: + split = getattr(self, "split", "testmini") + dataset = load_dataset("AI4Math/MathVista", split=split) + print(f"Loaded {len(dataset)} examples from MathVista ({split})") + return list(dataset) + + def encode_image(self, pil_image: Image.Image) -> str: + buffer = io.BytesIO() + pil_image.save(buffer, format="PNG") + return base64.b64encode(buffer.getvalue()).decode("utf-8") + + def get_image_base64(self, item: dict) -> str: + if "decoded_image" in item and item["decoded_image"] is not None: + return self.encode_image(item["decoded_image"]) + if "image" in item and item["image"] is not None: + if isinstance(item["image"], Image.Image): + return self.encode_image(item["image"]) + raise ValueError(f"Could not find image for item {item.get('pid', 'unknown')}") + + def build_messages(self, item: dict) -> List[dict]: + image_base64 = self.get_image_base64(item) + + use_query = getattr(self, "use_query", True) + if use_query and "query" in item: + prompt = item["query"] + else: + prompt = self._build_custom_prompt(item) + + return [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{image_base64}"}, + }, + {"type": "text", "text": prompt}, + ], + } + ] + + def _build_custom_prompt(self, item: dict) -> str: + question = item.get("question", "") + question_type = item.get("question_type", "free_form") + answer_type = item.get("answer_type", "text") + precision = item.get("precision", 2) + + if question_type == "multi_choice": + choices = item.get("choices", []) + choices_text = "\n".join(choices) if choices else "" + hint = ( + "Please answer the question and provide the correct option letter, " + "e.g., A, B, C, D, at the end." + ) + return f"Hint: {hint}\nQuestion: {question}\nChoices:\n{choices_text}" + + if answer_type == "integer": + hint = ( + "Please answer the question requiring an integer answer " + "and provide the final value, e.g., 1, 2, 3, at the end." + ) + elif answer_type == "float": + hint = ( + f"Please answer the question requiring a floating-point number " + f"with {precision} decimal place(s) and provide the final value at the end." + ) + elif answer_type == "list": + hint = ( + "Please answer the question requiring a Python list as an answer " + "and provide the final list, e.g., [1, 2, 3], at the end." + ) + else: + hint = "Please answer the question and provide the final answer at the end." + + return f"Hint: {hint}\nQuestion: {question}" + + def _prefetch_answer(self, response: str, item: dict) -> Tuple[Optional[str], bool]: + question_type = item.get("question_type", "free_form") + answer_type = item.get("answer_type", "text") + + if question_type == "multi_choice": + choices_list = item.get("choices", []) + if choices_list: + choices = {chr(65 + i): val for i, val in enumerate(choices_list)} + result = can_infer(response, choices) + if result: + return result, True + + # Fallback: find last letter + for char in reversed(response.upper()): + if char in "ABCDEFGH": + return char, True + return None, False + + response = response.strip() + + if answer_type == "integer": + numbers = re.findall(r"-?\d+", response) + if numbers: + return numbers[-1], True + + elif answer_type == "float": + numbers = re.findall(r"-?\d+\.?\d*", response) + if numbers: + return numbers[-1], True + + elif answer_type == "list": + match = re.search(r"\[[\d\.,\s-]+\]", response) + if match: + return match.group(0), True + + return None, False + + async def _extract_with_gpt(self, question: str, response: str) -> Optional[str]: + judge_model = getattr(self, "judge_model", "gpt-4o-mini") + judge_base_url = getattr(self, "judge_base_url", "https://api.openai.com/v1") + judge_api_key = os.environ.get( + getattr(self, "judge_api_key_env", "OPENAI_API_KEY"), "" + ) + + if not judge_api_key: + return None + + try: + judge_client = openai.AsyncOpenAI( + api_key=judge_api_key, + base_url=judge_base_url, + ) + + prompt = build_extraction_prompt(question, response) + + completion = await judge_client.chat.completions.create( + model=judge_model, + messages=[{"role": "user", "content": prompt}], + temperature=0.0, + max_tokens=128, + ) + + result = completion.choices[0].message.content.strip() + return result if result else None + + except Exception as e: + print(f"GPT extraction error: {e}") + return None + + def extract_answer( + self, response: str, answer_type: str, question_type: str + ) -> str: + response = response.strip() + + if question_type == "multi_choice": + for char in reversed(response.upper()): + if char in "ABCDEFGH": + return char + return "" + + if answer_type == "integer": + numbers = re.findall(r"-?\d+", response) + return numbers[-1] if numbers else "" + + if answer_type == "float": + numbers = re.findall(r"-?\d+\.?\d*", response) + return numbers[-1] if numbers else "" + + if answer_type == "list": + match = re.search(r"\[[\d\.,\s-]+\]", response) + return match.group(0) if match else "" + + return response + + def score( + self, prediction: str, answer: str, answer_type: str, precision: int = 0 + ) -> bool: + pred = prediction.strip() + ans = answer.strip() + + if not pred: + return False + + if answer_type == "text": + return pred.upper() == ans.upper() + + if answer_type == "integer": + try: + return int(float(pred)) == int(float(ans)) + except (ValueError, OverflowError): + return False + + if answer_type == "float": + try: + tolerance = 10 ** (-precision) if precision > 0 else 0.01 + return abs(float(pred) - float(ans)) < tolerance + except ValueError: + return False + + if answer_type == "list": + try: + pred_list = eval(pred) + ans_list = eval(ans) + return pred_list == ans_list + except Exception: + return False + + return pred.lower() == ans.lower() + + async def run_item( + self, server: ServerManager, data_item: dict + ) -> Tuple[dict, dict]: + try: + messages = self.build_messages(data_item) + + completion = await self.chat_completion(server, messages) + + if not completion.choices: + return {"accuracy": 0.0}, {"error": "Empty response"} + + message = completion.choices[0].message + response = message.content or "" + if hasattr(message, "reasoning") and message.reasoning and not response: + response = message.reasoning + if not response and hasattr(message, "model_extra"): + reasoning = message.model_extra.get("reasoning", "") + if reasoning: + response = reasoning + + if not response: + return {"accuracy": 0.0}, {"error": "Empty response"} + + answer_type = data_item.get("answer_type", "text") + question_type = data_item.get("question_type", "free_form") + precision = data_item.get("precision", 0) + + use_gpt_extraction = getattr(self, "use_gpt_extraction", True) + prefetch_result, prefetch_success = self._prefetch_answer( + response, data_item + ) + + if prefetch_success and prefetch_result: + extracted = prefetch_result + extraction_method = "prefetch" + elif use_gpt_extraction: + question = data_item.get("query", data_item.get("question", "")) + gpt_result = await self._extract_with_gpt(question, response) + if gpt_result: + extracted = gpt_result + extraction_method = "gpt" + else: + extracted = self.extract_answer( + response, answer_type, question_type + ) + extraction_method = "regex_fallback" + else: + extracted = self.extract_answer(response, answer_type, question_type) + extraction_method = "regex" + + answer = data_item.get("answer", "") + correct = self.score(extracted, answer, answer_type, precision) + + sample = { + "pid": data_item.get("pid", ""), + "question": data_item.get("question", ""), + "answer": answer, + "prediction": extracted, + "raw_response": response[:500], + "correct": correct, + "question_type": question_type, + "answer_type": answer_type, + "extraction_method": extraction_method, + } + + return {"accuracy": 1.0 if correct else 0.0}, sample + + except Exception as e: + return {"accuracy": 0.0}, {"error": str(e)} + + +if __name__ == "__main__": + asyncio.run( + eval_runner( + MathVista( + split="testmini", + use_query=True, + use_gpt_extraction=True, + judge_model="gpt-4o-mini", + temperature=0.0, + max_tokens=4096, + ) + ) + ) diff --git a/environments/eval_environments/vision_evals/mmbench_environment.py b/environments/eval_environments/vision_evals/mmbench_environment.py new file mode 100644 index 000000000..ca96b1052 --- /dev/null +++ b/environments/eval_environments/vision_evals/mmbench_environment.py @@ -0,0 +1,160 @@ +"""MMBench evaluation environment.""" + +import asyncio +import base64 +import io +from string import ascii_uppercase +from typing import List, Optional, Tuple + +from datasets import load_dataset +from environments.eval_environments.eval import EvalBase, eval_runner +from PIL import Image + +from atroposlib.envs.server_handling.server_manager import ServerManager +from environments.eval_environments.eval_helpers import ( + extract_letter_from_answer_tag, + extract_mcqa_answer_with_fallback, +) + + +class MMBench(EvalBase): + """MMBench evaluation - comprehensive multimodal benchmark.""" + + def setup_data(self) -> list: + split = getattr(self, "split", "dev") + lang = getattr(self, "lang", "en") # en, cn, cc + version = getattr(self, "version", "v1.1") # v1.0 or v1.1 + + try: + dataset = load_dataset("lmms-lab/MMBench", lang, split=split) + print(f"Loaded {len(dataset)} examples from MMBench ({split}, {lang})") + return list(dataset) + except Exception as e: + print(f"Warning: Could not load from lmms-lab: {e}") + try: + dataset = load_dataset("lmms-lab/MMBench_EN", split=split) + print(f"Loaded {len(dataset)} examples from MMBench ({split})") + return list(dataset) + except Exception: + raise ValueError(f"Could not load MMBench dataset: {e}") + + def encode_image(self, pil_image: Image.Image) -> str: + buffer = io.BytesIO() + pil_image.save(buffer, format="PNG") + return base64.b64encode(buffer.getvalue()).decode("utf-8") + + def get_image_base64(self, item: dict) -> Optional[str]: + for key in ["image", "decoded_image"]: + if key in item and item[key] is not None: + if isinstance(item[key], Image.Image): + return self.encode_image(item[key]) + return None + + def build_messages(self, item: dict) -> List[dict]: + image_base64 = self.get_image_base64(item) + question = item.get("question", "") + hint = item.get("hint", "") + + options = {} + for letter in ascii_uppercase: + if letter in item and item[letter] is not None: + val = item[letter] + if isinstance(val, str) and val.strip(): + options[letter] = val + elif not isinstance(val, float): + options[letter] = str(val) + + prompt = "" + if hint and str(hint).strip() and str(hint).lower() != "nan": + prompt += f"Hint: {hint}\n" + prompt += f"Question: {question}\n" + + if options: + prompt += "Options:\n" + for letter in sorted(options.keys()): + prompt += f"{letter}. {options[letter]}\n" + prompt += "\nPlease select the correct answer from the options above." + + content = [] + if image_base64: + content.append( + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{image_base64}"}, + } + ) + content.append({"type": "text", "text": prompt}) + + return [{"role": "user", "content": content}] + + def extract_answer( + self, response: str, num_choices: int + ) -> Tuple[Optional[str], str]: + valid_letters = set(ascii_uppercase[:num_choices]) + + letter, method = extract_letter_from_answer_tag(response, valid_letters) + if letter: + return letter, method + + letter, method = extract_mcqa_answer_with_fallback(response, num_choices) + return letter, method + + async def run_item( + self, server: ServerManager, data_item: dict + ) -> Tuple[dict, dict]: + try: + messages = self.build_messages(data_item) + completion = await self.chat_completion(server, messages) + + if not completion.choices: + return {"accuracy": 0.0}, {"error": "Empty response"} + + message = completion.choices[0].message + response = message.content or "" + + if not response: + return {"accuracy": 0.0}, {"error": "Empty response"} + + answer = data_item.get("answer", "") + + num_choices = 0 + for letter in ascii_uppercase: + if letter in data_item and data_item[letter] is not None: + val = data_item[letter] + if isinstance(val, str) and val.strip(): + num_choices += 1 + elif not isinstance(val, float): + num_choices += 1 + num_choices = max(num_choices, 4) + + extracted, method = self.extract_answer(response, num_choices) + + correct = False + if extracted and answer: + correct = extracted.upper() == str(answer).upper() + + sample = { + "id": data_item.get("index", data_item.get("id", "")), + "question": data_item.get("question", "")[:200], + "category": data_item.get("category", data_item.get("l2-category", "")), + "answer": answer, + "prediction": extracted, + "raw_response": response[:500], + "correct": correct, + "extraction_method": method, + } + + return {"accuracy": 1.0 if correct else 0.0}, sample + + except Exception as e: + return {"accuracy": 0.0}, {"error": str(e)} + + +if __name__ == "__main__": + asyncio.run( + eval_runner( + MMBench( + split="dev", lang="en", version="v1.1", temperature=0.0, max_tokens=256 + ) + ) + ) diff --git a/environments/eval_environments/vision_evals/mmmu_environment.py b/environments/eval_environments/vision_evals/mmmu_environment.py new file mode 100644 index 000000000..f64d9d473 --- /dev/null +++ b/environments/eval_environments/vision_evals/mmmu_environment.py @@ -0,0 +1,186 @@ +"""MMMU (Massive Multi-discipline Multimodal Understanding) evaluation environment.""" + +import asyncio +import base64 +import io +import re +from string import ascii_uppercase +from typing import List, Optional, Tuple + +from datasets import load_dataset +from environments.eval_environments.eval import EvalBase, eval_runner +from PIL import Image + +from atroposlib.envs.server_handling.server_manager import ServerManager +from environments.eval_environments.eval_helpers import ( + extract_letter_from_answer_tag, + extract_mcqa_answer_with_fallback, +) + + +class MMMU(EvalBase): + """MMMU evaluation - multi-discipline multimodal understanding benchmark.""" + + def setup_data(self) -> list: + split = getattr(self, "split", "validation") + subset = getattr(self, "subset", None) + + if subset: + dataset = load_dataset("MMMU/MMMU", subset, split=split) + else: + subjects = [ + "Accounting", + "Agriculture", + "Architecture_and_Engineering", + "Art", + "Art_Theory", + "Basic_Medical_Science", + "Biology", + "Chemistry", + "Clinical_Medicine", + "Computer_Science", + "Design", + "Diagnostics_and_Laboratory_Medicine", + "Economics", + "Electronics", + "Energy_and_Power", + "Finance", + "Geography", + "History", + "Literature", + "Manage", + "Marketing", + "Materials", + "Math", + "Mechanical_Engineering", + "Music", + "Pharmacy", + "Physics", + "Psychology", + "Public_Health", + "Sociology", + ] + all_data = [] + for subj in subjects: + try: + ds = load_dataset("MMMU/MMMU", subj, split=split) + for item in ds: + item["subject"] = subj + all_data.append(item) + except Exception as e: + print(f"Warning: Could not load subject {subj}: {e}") + print(f"Loaded {len(all_data)} examples from MMMU ({split})") + return all_data + + print(f"Loaded {len(dataset)} examples from MMMU ({split}, {subset})") + return list(dataset) + + def encode_image(self, pil_image: Image.Image) -> str: + buffer = io.BytesIO() + pil_image.save(buffer, format="PNG") + return base64.b64encode(buffer.getvalue()).decode("utf-8") + + def get_images(self, item: dict) -> List[str]: + """Extract all images from the item (MMMU can have multiple images).""" + images = [] + for i in range(1, 8): # MMMU supports up to 7 images + key = f"image_{i}" + if key in item and item[key] is not None: + if isinstance(item[key], Image.Image): + images.append(self.encode_image(item[key])) + return images + + def build_messages(self, item: dict) -> List[dict]: + images = self.get_images(item) + question = item.get("question", "") + options = item.get("options", []) + + if isinstance(options, str): + try: + options = eval(options) + except Exception: + options = [] + + if options: + options_text = "\n".join( + [f"({ascii_uppercase[i]}) {opt}" for i, opt in enumerate(options)] + ) + prompt = f"Question: {question}\n\nOptions:\n{options_text}\n\nPlease select the correct answer from the options above." + else: + prompt = f"Question: {question}\n\nProvide your answer." + + content = [] + for img_b64 in images: + content.append( + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{img_b64}"}, + } + ) + content.append({"type": "text", "text": prompt}) + + return [{"role": "user", "content": content}] + + def extract_answer( + self, response: str, num_choices: int + ) -> Tuple[Optional[str], str]: + """Extract answer letter from response.""" + valid_letters = set(ascii_uppercase[:num_choices]) + + letter, method = extract_letter_from_answer_tag(response, valid_letters) + if letter: + return letter, method + + letter, method = extract_mcqa_answer_with_fallback(response, num_choices) + return letter, method + + async def run_item( + self, server: ServerManager, data_item: dict + ) -> Tuple[dict, dict]: + try: + messages = self.build_messages(data_item) + completion = await self.chat_completion(server, messages) + + if not completion.choices: + return {"accuracy": 0.0}, {"error": "Empty response"} + + message = completion.choices[0].message + response = message.content or "" + + if not response: + return {"accuracy": 0.0}, {"error": "Empty response"} + + answer = data_item.get("answer", "") + options = data_item.get("options", []) + if isinstance(options, str): + try: + options = eval(options) + except Exception: + options = [] + + num_choices = len(options) if options else 4 + extracted, method = self.extract_answer(response, num_choices) + + correct = False + if extracted and answer: + correct = extracted.upper() == answer.upper() + + sample = { + "id": data_item.get("id", ""), + "question": data_item.get("question", "")[:200], + "subject": data_item.get("subject", ""), + "answer": answer, + "prediction": extracted, + "raw_response": response[:500], + "correct": correct, + "extraction_method": method, + } + + return {"accuracy": 1.0 if correct else 0.0}, sample + + except Exception as e: + return {"accuracy": 0.0}, {"error": str(e)} + + +if __name__ == "__main__": + asyncio.run(eval_runner(MMMU(split="validation", temperature=0.0, max_tokens=1024))) diff --git a/environments/eval_environments/vision_evals/mmmu_pro_environment.py b/environments/eval_environments/vision_evals/mmmu_pro_environment.py new file mode 100644 index 000000000..1b5ffccc2 --- /dev/null +++ b/environments/eval_environments/vision_evals/mmmu_pro_environment.py @@ -0,0 +1,210 @@ +"""MMMU-Pro evaluation environment.""" + +import asyncio +import base64 +import io +import re +from string import ascii_uppercase +from typing import List, Optional, Tuple + +from datasets import load_dataset +from environments.eval_environments.eval import EvalBase, eval_runner +from PIL import Image + +from atroposlib.envs.server_handling.server_manager import ServerManager +from environments.eval_environments.eval_helpers import ( + extract_letter_from_answer_tag, + extract_mcqa_answer_with_fallback, +) + + +class MMMUPro(EvalBase): + """MMMU-Pro evaluation - harder version of MMMU with 10 choices.""" + + def setup_data(self) -> list: + split = getattr(self, "split", "test") + variant = getattr(self, "variant", "standard") # standard, vision, standard_4 + + config_map = { + "standard": "standard (10 options)", + "standard_4": "standard (4 options)", + "vision": "vision", + } + config = config_map.get(variant, "standard (10 options)") + + try: + dataset = load_dataset("MMMU/MMMU_Pro", config, split=split) + print(f"Loaded {len(dataset)} examples from MMMU-Pro ({split}, {config})") + return list(dataset) + except Exception as e: + print(f"Error loading MMMU-Pro: {e}") + try: + dataset = load_dataset( + "MMMU/MMMU_Pro", "standard (10 options)", split="test" + ) + print(f"Loaded {len(dataset)} examples from MMMU-Pro (test)") + return list(dataset) + except Exception: + raise ValueError(f"Could not load MMMU-Pro dataset: {e}") + + def encode_image(self, pil_image: Image.Image) -> str: + buffer = io.BytesIO() + pil_image.save(buffer, format="PNG") + return base64.b64encode(buffer.getvalue()).decode("utf-8") + + def get_images(self, item: dict) -> List[str]: + """Extract all images from the item.""" + images = [] + for i in range(1, 8): + key = f"image_{i}" + if key in item and item[key] is not None: + if isinstance(item[key], Image.Image): + images.append(self.encode_image(item[key])) + if "image" in item and item["image"] is not None: + if isinstance(item["image"], Image.Image): + images.append(self.encode_image(item["image"])) + return images + + def build_messages(self, item: dict) -> List[dict]: + images = self.get_images(item) + question = item.get("question", "") + options = item.get("options", []) + + if isinstance(options, str): + try: + options = eval(options) + except Exception: + options = [] + + variant = getattr(self, "variant", "standard") + + if variant == "vision": + prompt = "Answer the following multiple-choice question in the image. Answer directly with the option letter from the given choices." + else: + if options: + options_text = "\n".join( + [f"{ascii_uppercase[i]}. {opt}" for i, opt in enumerate(options)] + ) + prompt = f"Question: {question}\n\nOptions:\n{options_text}\n\n" + + if variant == "cot": + prompt += ( + "Answer the following multiple-choice question. " + "The last line of your response should be of the following format: " + "'Answer: $LETTER' (without quotes) where LETTER is one of the options. " + "Think step by step before answering." + ) + else: + prompt += ( + "Answer directly with the option letter from the given choices." + ) + else: + prompt = f"Question: {question}\n\nProvide your answer." + + content = [] + for img_b64 in images: + content.append( + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{img_b64}"}, + } + ) + content.append({"type": "text", "text": prompt}) + + return [{"role": "user", "content": content}] + + def extract_answer_cot(self, response: str) -> Optional[str]: + """Extract answer from COT response format 'Answer: X'.""" + lines = response.strip().split("\n") + lines = [x.strip() for x in lines] + + for line in reversed(lines): + if line.startswith("Answer:"): + rest = line[7:].strip() + from collections import Counter + + letter_counts = Counter( + ch for ch in rest.upper() if ch in ascii_uppercase[:10] + ) + if len(letter_counts) == 1: + return list(letter_counts.keys())[0] + elif letter_counts: + for ch in rest.upper(): + if ch in ascii_uppercase[:10]: + return ch + return None + + def extract_answer( + self, response: str, num_choices: int + ) -> Tuple[Optional[str], str]: + """Extract answer letter from response.""" + variant = getattr(self, "variant", "standard") + + if variant == "cot": + cot_answer = self.extract_answer_cot(response) + if cot_answer: + return cot_answer, "cot_extraction" + + valid_letters = set(ascii_uppercase[:num_choices]) + + letter, method = extract_letter_from_answer_tag(response, valid_letters) + if letter: + return letter, method + + letter, method = extract_mcqa_answer_with_fallback(response, num_choices) + return letter, method + + async def run_item( + self, server: ServerManager, data_item: dict + ) -> Tuple[dict, dict]: + try: + messages = self.build_messages(data_item) + completion = await self.chat_completion(server, messages) + + if not completion.choices: + return {"accuracy": 0.0}, {"error": "Empty response"} + + message = completion.choices[0].message + response = message.content or "" + + if not response: + return {"accuracy": 0.0}, {"error": "Empty response"} + + answer = data_item.get("answer", "") + options = data_item.get("options", []) + if isinstance(options, str): + try: + options = eval(options) + except Exception: + options = [] + + num_choices = len(options) if options else 10 + extracted, method = self.extract_answer(response, num_choices) + + correct = False + if extracted and answer: + correct = extracted.upper() == answer.upper() + + sample = { + "id": data_item.get("id", ""), + "question": data_item.get("question", "")[:200], + "subject": data_item.get("subject", ""), + "answer": answer, + "prediction": extracted, + "raw_response": response[:500], + "correct": correct, + "extraction_method": method, + } + + return {"accuracy": 1.0 if correct else 0.0}, sample + + except Exception as e: + return {"accuracy": 0.0}, {"error": str(e)} + + +if __name__ == "__main__": + asyncio.run( + eval_runner( + MMMUPro(split="test", variant="standard", temperature=0.0, max_tokens=1024) + ) + ) diff --git a/environments/eval_environments/vision_evals/mmstar_environment.py b/environments/eval_environments/vision_evals/mmstar_environment.py new file mode 100644 index 000000000..e6ea99e75 --- /dev/null +++ b/environments/eval_environments/vision_evals/mmstar_environment.py @@ -0,0 +1,145 @@ +"""MMStar evaluation environment.""" + +import asyncio +import base64 +import io +from string import ascii_uppercase +from typing import List, Optional, Tuple + +from datasets import load_dataset +from environments.eval_environments.eval import EvalBase, eval_runner +from PIL import Image + +from atroposlib.envs.server_handling.server_manager import ServerManager +from environments.eval_environments.eval_helpers import ( + extract_letter_from_answer_tag, + extract_mcqa_answer_with_fallback, +) + + +class MMStar(EvalBase): + """MMStar evaluation - expert-level multimodal benchmark.""" + + def setup_data(self) -> list: + split = getattr(self, "split", "val") + + try: + dataset = load_dataset("Lin-Chen/MMStar", split=split) + print(f"Loaded {len(dataset)} examples from MMStar ({split})") + return list(dataset) + except Exception as e: + print(f"Warning: Could not load MMStar: {e}") + try: + dataset = load_dataset("lmms-lab/MMStar", split=split) + print(f"Loaded {len(dataset)} examples from MMStar ({split})") + return list(dataset) + except Exception: + raise ValueError(f"Could not load MMStar dataset: {e}") + + def encode_image(self, pil_image: Image.Image) -> str: + buffer = io.BytesIO() + pil_image.save(buffer, format="PNG") + return base64.b64encode(buffer.getvalue()).decode("utf-8") + + def get_image_base64(self, item: dict) -> Optional[str]: + for key in ["image", "decoded_image"]: + if key in item and item[key] is not None: + if isinstance(item[key], Image.Image): + return self.encode_image(item[key]) + return None + + def build_messages(self, item: dict) -> List[dict]: + image_base64 = self.get_image_base64(item) + question = item.get("question", "") + + options = {} + for letter in ascii_uppercase[:6]: # MMStar typically has up to 6 options + if letter in item and item[letter] is not None: + val = item[letter] + if isinstance(val, str) and val.strip(): + options[letter] = val + + prompt = f"Question: {question}\n" + if options: + prompt += "Options:\n" + for letter in sorted(options.keys()): + prompt += f"{letter}. {options[letter]}\n" + prompt += "\nPlease select the correct answer from the options above." + + content = [] + if image_base64: + content.append( + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{image_base64}"}, + } + ) + content.append({"type": "text", "text": prompt}) + + return [{"role": "user", "content": content}] + + def extract_answer( + self, response: str, num_choices: int + ) -> Tuple[Optional[str], str]: + valid_letters = set(ascii_uppercase[:num_choices]) + + letter, method = extract_letter_from_answer_tag(response, valid_letters) + if letter: + return letter, method + + letter, method = extract_mcqa_answer_with_fallback(response, num_choices) + return letter, method + + async def run_item( + self, server: ServerManager, data_item: dict + ) -> Tuple[dict, dict]: + try: + messages = self.build_messages(data_item) + completion = await self.chat_completion(server, messages) + + if not completion.choices: + return {"accuracy": 0.0}, {"error": "Empty response"} + + message = completion.choices[0].message + response = message.content or "" + + if not response: + return {"accuracy": 0.0}, {"error": "Empty response"} + + answer = data_item.get("answer", "") + + num_choices = sum( + 1 + for letter in ascii_uppercase[:6] + if letter in data_item + and data_item[letter] is not None + and isinstance(data_item[letter], str) + and data_item[letter].strip() + ) + num_choices = max(num_choices, 4) + + extracted, method = self.extract_answer(response, num_choices) + + correct = False + if extracted and answer: + correct = extracted.upper() == str(answer).upper() + + sample = { + "id": data_item.get("index", data_item.get("id", "")), + "question": data_item.get("question", "")[:200], + "category": data_item.get("category", ""), + "answer": answer, + "prediction": extracted, + "raw_response": response[:500], + "correct": correct, + "extraction_method": method, + } + + return {"accuracy": 1.0 if correct else 0.0}, sample + + except Exception as e: + return {"accuracy": 0.0}, {"error": str(e)} + + +if __name__ == "__main__": + asyncio.run(eval_runner(MMStar(split="val", temperature=0.0, max_tokens=256))) diff --git a/environments/eval_environments/vision_evals/mmt_bench_environment.py b/environments/eval_environments/vision_evals/mmt_bench_environment.py new file mode 100644 index 000000000..629de9cec --- /dev/null +++ b/environments/eval_environments/vision_evals/mmt_bench_environment.py @@ -0,0 +1,174 @@ +"""MMT-Bench evaluation environment.""" + +import asyncio +import base64 +import io +from string import ascii_uppercase +from typing import List, Optional, Tuple + +from datasets import load_dataset +from environments.eval_environments.eval import EvalBase, eval_runner +from PIL import Image + +from atroposlib.envs.server_handling.server_manager import ServerManager +from environments.eval_environments.eval_helpers import ( + extract_letter_from_answer_tag, + extract_mcqa_answer_with_fallback, +) + + +class MMTBench(EvalBase): + """MMT-Bench evaluation - multi-task multimodal benchmark.""" + + def setup_data(self) -> list: + split = getattr(self, "split", "train") + max_samples = getattr(self, "max_samples", None) # None = use all samples + + try: + # Try full dataset download first + dataset = load_dataset("OpenGVLab/MMT-Bench", split=split) + data = list(dataset) + if max_samples: + data = data[:max_samples] + print(f"Loaded {len(data)} examples from MMT-Bench ({split})") + return data + except Exception as e: + print(f"Warning: Full download failed, using streaming: {e}") + # Fallback to streaming if full download fails (known column mismatch issue) + try: + dataset = load_dataset( + "OpenGVLab/MMT-Bench", split=split, streaming=True + ) + if max_samples: + data = list(dataset.take(max_samples)) + else: + # Stream all available samples + data = [] + for i, item in enumerate(dataset): + data.append(item) + if i % 5000 == 0 and i > 0: + print(f" Streamed {i} samples...") + print( + f"Loaded {len(data)} examples from MMT-Bench ({split}, streaming)" + ) + return data + except Exception: + raise ValueError(f"Could not load MMT-Bench dataset: {e}") + + def encode_image(self, pil_image: Image.Image) -> str: + buffer = io.BytesIO() + pil_image.save(buffer, format="PNG") + return base64.b64encode(buffer.getvalue()).decode("utf-8") + + def get_image_base64(self, item: dict) -> Optional[str]: + for key in ["image", "decoded_image"]: + if key in item and item[key] is not None: + val = item[key] + if isinstance(val, Image.Image): + return self.encode_image(val) + elif isinstance(val, str) and len(val) > 100: + # Already base64-encoded string + return val + return None + + def build_messages(self, item: dict) -> List[dict]: + image_base64 = self.get_image_base64(item) + question = item.get("question", "") + hint = item.get("hint", "") + + options = {} + for letter in ascii_uppercase[:8]: # Support up to 8 options + if letter in item and item[letter] is not None: + val = item[letter] + if isinstance(val, str) and val.strip(): + options[letter] = val + + prompt = "" + if hint and str(hint).strip() and str(hint).lower() != "nan": + prompt += f"Hint: {hint}\n" + prompt += f"Question: {question}\n" + + if options: + prompt += "Options:\n" + for letter in sorted(options.keys()): + prompt += f"{letter}. {options[letter]}\n" + prompt += "\nPlease select the correct answer from the options above." + + content = [] + if image_base64: + content.append( + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{image_base64}"}, + } + ) + content.append({"type": "text", "text": prompt}) + + return [{"role": "user", "content": content}] + + def extract_answer( + self, response: str, num_choices: int + ) -> Tuple[Optional[str], str]: + valid_letters = set(ascii_uppercase[:num_choices]) + + letter, method = extract_letter_from_answer_tag(response, valid_letters) + if letter: + return letter, method + + letter, method = extract_mcqa_answer_with_fallback(response, num_choices) + return letter, method + + async def run_item( + self, server: ServerManager, data_item: dict + ) -> Tuple[dict, dict]: + try: + messages = self.build_messages(data_item) + + completion = await self.chat_completion(server, messages) + + if not completion.choices: + return {"accuracy": 0.0}, {"error": "Empty response"} + + message = completion.choices[0].message + response = message.content or "" + + if not response: + return {"accuracy": 0.0}, {"error": "Empty response"} + + answer = data_item.get("answer", "") + + num_choices = sum( + 1 + for letter in ascii_uppercase[:8] + if letter in data_item + and data_item[letter] is not None + and isinstance(data_item[letter], str) + and data_item[letter].strip() + ) + num_choices = max(num_choices, 4) + + extracted, method = self.extract_answer(response, num_choices) + + correct = False + if extracted and answer: + correct = extracted.upper() == str(answer).upper() + + sample = { + "id": data_item.get("index", data_item.get("id", "")), + "question": data_item.get("question", "")[:200], + "task": data_item.get("task", ""), + "answer": answer, + "prediction": extracted, + "raw_response": response[:500], + "correct": correct, + "extraction_method": method, + } + + return {"accuracy": 1.0 if correct else 0.0}, sample + + except Exception as e: + return {"accuracy": 0.0}, {"error": str(e)} + + +if __name__ == "__main__": + asyncio.run(eval_runner(MMTBench(split="val", temperature=0.0, max_tokens=256))) diff --git a/environments/eval_environments/vision_evals/mmvet_environment.py b/environments/eval_environments/vision_evals/mmvet_environment.py new file mode 100644 index 000000000..6cfe7a47d --- /dev/null +++ b/environments/eval_environments/vision_evals/mmvet_environment.py @@ -0,0 +1,186 @@ +"""MMVet evaluation environment.""" + +import asyncio +import base64 +import io +import os +from typing import List, Optional, Tuple + +import openai +from datasets import load_dataset +from environments.eval_environments.eval import EvalBase, eval_runner +from PIL import Image + +from atroposlib.envs.server_handling.server_manager import ServerManager + + +class MMVet(EvalBase): + """MMVet evaluation - comprehensive VLM capability benchmark with GPT-based scoring.""" + + def setup_data(self) -> list: + split = getattr(self, "split", "test") + + try: + dataset = load_dataset("lmms-lab/MMVet", split=split) + print(f"Loaded {len(dataset)} examples from MMVet ({split})") + return list(dataset) + except Exception as e: + print(f"Warning: Could not load MMVet: {e}") + try: + dataset = load_dataset("whyu/MM-Vet", split=split) + print(f"Loaded {len(dataset)} examples from MMVet ({split})") + return list(dataset) + except Exception: + raise ValueError(f"Could not load MMVet dataset: {e}") + + def encode_image(self, pil_image: Image.Image) -> str: + buffer = io.BytesIO() + pil_image.save(buffer, format="PNG") + return base64.b64encode(buffer.getvalue()).decode("utf-8") + + def get_image_base64(self, item: dict) -> Optional[str]: + for key in ["image", "decoded_image"]: + if key in item and item[key] is not None: + if isinstance(item[key], Image.Image): + return self.encode_image(item[key]) + return None + + def build_messages(self, item: dict) -> List[dict]: + image_base64 = self.get_image_base64(item) + question = item.get("question", "") + + content = [] + if image_base64: + content.append( + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{image_base64}"}, + } + ) + content.append({"type": "text", "text": question}) + + return [{"role": "user", "content": content}] + + async def gpt_score(self, question: str, answer: str, prediction: str) -> float: + """Use GPT to score the prediction against the ground truth answer.""" + judge_model = getattr(self, "judge_model", "gpt-4o-mini") + judge_base_url = getattr(self, "judge_base_url", "https://api.openai.com/v1") + judge_api_key = os.environ.get( + getattr(self, "judge_api_key_env", "OPENAI_API_KEY"), "" + ) + + if not judge_api_key: + pred_lower = prediction.lower().strip() + ans_lower = answer.lower().strip() + if pred_lower == ans_lower: + return 1.0 + elif ans_lower in pred_lower or pred_lower in ans_lower: + return 0.5 + return 0.0 + + try: + judge_client = openai.AsyncOpenAI( + api_key=judge_api_key, + base_url=judge_base_url, + ) + + prompt = f"""You are evaluating the quality of a model's answer compared to a reference answer. + +Question: {question} + +Reference Answer: {answer} + +Model's Answer: {prediction} + +Score the model's answer on a scale from 0 to 1: +- 1.0: Completely correct and matches the reference +- 0.5-0.9: Partially correct or captures the main idea +- 0.1-0.4: Somewhat related but mostly incorrect +- 0.0: Completely wrong or irrelevant + +Output ONLY a single number between 0 and 1.""" + + completion = await judge_client.chat.completions.create( + model=judge_model, + messages=[{"role": "user", "content": prompt}], + temperature=0.0, + max_tokens=10, + ) + + result = completion.choices[0].message.content.strip() + try: + score = float(result) + return max(0.0, min(1.0, score)) + except ValueError: + return 0.0 + + except Exception as e: + print(f"GPT scoring error: {e}") + pred_lower = prediction.lower().strip() + ans_lower = answer.lower().strip() + if pred_lower == ans_lower: + return 1.0 + elif ans_lower in pred_lower: + return 0.5 + return 0.0 + + async def run_item( + self, server: ServerManager, data_item: dict + ) -> Tuple[dict, dict]: + try: + messages = self.build_messages(data_item) + + completion = await self.chat_completion(server, messages) + + if not completion.choices: + return {"accuracy": 0.0}, {"error": "Empty response"} + + message = completion.choices[0].message + response = message.content or "" + + if not response: + return {"accuracy": 0.0}, {"error": "Empty response"} + + question = data_item.get("question", "") + answer = data_item.get("answer", "") + + use_gpt_scoring = getattr(self, "use_gpt_scoring", True) + if use_gpt_scoring: + score = await self.gpt_score(question, answer, response) + else: + pred_lower = response.lower().strip() + ans_lower = answer.lower().strip() + if pred_lower == ans_lower: + score = 1.0 + elif ans_lower in pred_lower: + score = 0.5 + else: + score = 0.0 + + sample = { + "id": data_item.get("index", data_item.get("id", "")), + "question": question[:200], + "category": data_item.get("capability", data_item.get("category", "")), + "answer": answer[:200], + "prediction": response[:500], + "score": score, + } + + return {"accuracy": score}, sample + + except Exception as e: + return {"accuracy": 0.0}, {"error": str(e)} + + +if __name__ == "__main__": + asyncio.run( + eval_runner( + MMVet( + split="test", + use_gpt_scoring=True, + judge_model="gpt-4o-mini", + temperature=0.0, + max_tokens=512, + ) + ) + ) diff --git a/environments/eval_environments/vision_evals/mmvp_environment.py b/environments/eval_environments/vision_evals/mmvp_environment.py new file mode 100644 index 000000000..7a8cbb271 --- /dev/null +++ b/environments/eval_environments/vision_evals/mmvp_environment.py @@ -0,0 +1,153 @@ +"""MMVP (Multimodal Visual Perception) evaluation environment.""" + +import asyncio +import base64 +import io +from string import ascii_uppercase +from typing import List, Optional, Tuple + +from datasets import load_dataset +from environments.eval_environments.eval import EvalBase, eval_runner +from PIL import Image + +from atroposlib.envs.server_handling.server_manager import ServerManager +from environments.eval_environments.eval_helpers import ( + extract_letter_from_answer_tag, + extract_mcqa_answer_with_fallback, +) + + +class MMVP(EvalBase): + """MMVP evaluation - visual perception benchmark testing CLIP blindspots.""" + + def setup_data(self) -> list: + split = getattr(self, "split", "train") # MMVP only has train split + + try: + dataset = load_dataset("MMVP/MMVP", split=split) + print(f"Loaded {len(dataset)} examples from MMVP ({split})") + return list(dataset) + except Exception as e: + print(f"Warning: Could not load MMVP: {e}") + try: + dataset = load_dataset("lmms-lab/MMVP", split=split) + print(f"Loaded {len(dataset)} examples from MMVP ({split})") + return list(dataset) + except Exception: + raise ValueError(f"Could not load MMVP dataset: {e}") + + def encode_image(self, pil_image: Image.Image) -> str: + buffer = io.BytesIO() + pil_image.save(buffer, format="PNG") + return base64.b64encode(buffer.getvalue()).decode("utf-8") + + def get_images(self, item: dict) -> List[str]: + """Get all images from item (MMVP typically has paired images).""" + images = [] + for i in range(1, 3): + key = f"image_{i}" + if key in item and item[key] is not None: + if isinstance(item[key], Image.Image): + images.append(self.encode_image(item[key])) + + if not images and "image" in item and item["image"] is not None: + if isinstance(item["image"], Image.Image): + images.append(self.encode_image(item["image"])) + + return images + + def build_messages(self, item: dict) -> List[dict]: + images = self.get_images(item) + question = item.get("question", "") + + options = {} + for letter in ascii_uppercase[:4]: # MMVP typically has 2-4 options + if letter in item and item[letter] is not None: + val = item[letter] + if isinstance(val, str) and val.strip(): + options[letter] = val + + prompt = f"Question: {question}\n" + if options: + prompt += "Options:\n" + for letter in sorted(options.keys()): + prompt += f"{letter}. {options[letter]}\n" + prompt += "\nPlease select the correct answer from the options above." + + content = [] + for img_b64 in images: + content.append( + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{img_b64}"}, + } + ) + content.append({"type": "text", "text": prompt}) + + return [{"role": "user", "content": content}] + + def extract_answer( + self, response: str, num_choices: int + ) -> Tuple[Optional[str], str]: + valid_letters = set(ascii_uppercase[:num_choices]) + + letter, method = extract_letter_from_answer_tag(response, valid_letters) + if letter: + return letter, method + + letter, method = extract_mcqa_answer_with_fallback(response, num_choices) + return letter, method + + async def run_item( + self, server: ServerManager, data_item: dict + ) -> Tuple[dict, dict]: + try: + messages = self.build_messages(data_item) + completion = await self.chat_completion(server, messages) + + if not completion.choices: + return {"accuracy": 0.0}, {"error": "Empty response"} + + message = completion.choices[0].message + response = message.content or "" + + if not response: + return {"accuracy": 0.0}, {"error": "Empty response"} + + answer = data_item.get("answer", "") + + num_choices = sum( + 1 + for letter in ascii_uppercase[:4] + if letter in data_item + and data_item[letter] is not None + and isinstance(data_item[letter], str) + and data_item[letter].strip() + ) + num_choices = max(num_choices, 2) + + extracted, method = self.extract_answer(response, num_choices) + + correct = False + if extracted and answer: + correct = extracted.upper() == str(answer).upper() + + sample = { + "id": data_item.get("index", data_item.get("id", "")), + "question": data_item.get("question", "")[:200], + "category": data_item.get("category", ""), + "answer": answer, + "prediction": extracted, + "raw_response": response[:500], + "correct": correct, + "extraction_method": method, + } + + return {"accuracy": 1.0 if correct else 0.0}, sample + + except Exception as e: + return {"accuracy": 0.0}, {"error": str(e)} + + +if __name__ == "__main__": + asyncio.run(eval_runner(MMVP(split="test", temperature=0.0, max_tokens=256))) diff --git a/environments/eval_environments/vision_evals/ocrbench_environment.py b/environments/eval_environments/vision_evals/ocrbench_environment.py new file mode 100644 index 000000000..3125c10fb --- /dev/null +++ b/environments/eval_environments/vision_evals/ocrbench_environment.py @@ -0,0 +1,142 @@ +"""OCRBench evaluation environment.""" + +import asyncio +import base64 +import io +from typing import Dict, List, Optional, Tuple + +from datasets import load_dataset +from environments.eval_environments.eval import EvalBase, eval_runner +from PIL import Image + +from atroposlib.envs.server_handling.server_manager import ServerManager + + +class OCRBench(EvalBase): + """OCRBench evaluation - OCR capabilities benchmark.""" + + # Categories and their scoring + CATEGORIES = [ + "Regular Text Recognition", + "Irregular Text Recognition", + "Artistic Text Recognition", + "Handwriting Recognition", + "Digit String Recognition", + "Non-Semantic Text Recognition", + "Scene Text-centric VQA", + "Doc-oriented VQA", + "Key Information Extraction", + "Handwritten Mathematical Expression Recognition", + ] + + def setup_data(self) -> list: + split = getattr(self, "split", "test") + + try: + dataset = load_dataset("echo840/OCRBench", split=split) + print(f"Loaded {len(dataset)} examples from OCRBench ({split})") + return list(dataset) + except Exception as e: + print(f"Warning: Could not load OCRBench: {e}") + try: + dataset = load_dataset("lmms-lab/OCRBench", split=split) + print(f"Loaded {len(dataset)} examples from OCRBench ({split})") + return list(dataset) + except Exception: + raise ValueError(f"Could not load OCRBench dataset: {e}") + + def encode_image(self, pil_image: Image.Image) -> str: + buffer = io.BytesIO() + pil_image.save(buffer, format="PNG") + return base64.b64encode(buffer.getvalue()).decode("utf-8") + + def get_image_base64(self, item: dict) -> Optional[str]: + for key in ["image", "decoded_image"]: + if key in item and item[key] is not None: + if isinstance(item[key], Image.Image): + return self.encode_image(item[key]) + return None + + def build_messages(self, item: dict) -> List[dict]: + image_base64 = self.get_image_base64(item) + question = item.get("question", "") + + prompt = f"{question}\n\nAnswer the question using a single word or phrase." + + content = [] + if image_base64: + content.append( + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{image_base64}"}, + } + ) + content.append({"type": "text", "text": prompt}) + + return [{"role": "user", "content": content}] + + def score_ocr(self, prediction: str, answers: List[str], category: str) -> bool: + """Category-specific scoring for OCR tasks.""" + predict = prediction.strip() + + if category == "Handwritten Mathematical Expression Recognition": + predict_clean = predict.replace("\n", " ").replace(" ", "") + for answer in answers: + answer_clean = answer.strip().replace("\n", " ").replace(" ", "") + if answer_clean in predict_clean: + return True + else: + predict_lower = predict.lower().replace("\n", " ") + for answer in answers: + answer_lower = answer.lower().strip().replace("\n", " ") + if answer_lower in predict_lower: + return True + + return False + + async def run_item( + self, server: ServerManager, data_item: dict + ) -> Tuple[dict, dict]: + try: + messages = self.build_messages(data_item) + completion = await self.chat_completion(server, messages) + + if not completion.choices: + return {"accuracy": 0.0}, {"error": "Empty response"} + + message = completion.choices[0].message + response = message.content or "" + + if not response: + return {"accuracy": 0.0}, {"error": "Empty response"} + + answers = data_item.get("answer", []) + if isinstance(answers, str): + try: + answers = eval(answers) + except Exception: + answers = [answers] + if not isinstance(answers, list): + answers = [answers] + + category = data_item.get("category", "") + + correct = self.score_ocr(response, answers, category) + + sample = { + "id": data_item.get("index", data_item.get("id", "")), + "question": data_item.get("question", "")[:200], + "category": category, + "answer": answers[0] if answers else "", + "prediction": response[:200], + "correct": correct, + } + + return {"accuracy": 1.0 if correct else 0.0}, sample + + except Exception as e: + return {"accuracy": 0.0}, {"error": str(e)} + + +if __name__ == "__main__": + asyncio.run(eval_runner(OCRBench(split="test", temperature=0.0, max_tokens=256))) diff --git a/environments/eval_environments/vision_evals/pope_environment.py b/environments/eval_environments/vision_evals/pope_environment.py new file mode 100644 index 000000000..dfe7e15a5 --- /dev/null +++ b/environments/eval_environments/vision_evals/pope_environment.py @@ -0,0 +1,134 @@ +"""POPE (Polling-based Object Probing Evaluation) evaluation environment.""" + +import asyncio +import base64 +import io +import re +from typing import List, Optional, Tuple + +from datasets import load_dataset +from environments.eval_environments.eval import EvalBase, eval_runner +from PIL import Image + +from atroposlib.envs.server_handling.server_manager import ServerManager + + +class POPE(EvalBase): + """POPE evaluation - object hallucination benchmark with yes/no questions.""" + + def setup_data(self) -> list: + split = getattr(self, "split", "test") + variant = getattr(self, "variant", "random") # random, popular, adversarial + + try: + dataset = load_dataset("lmms-lab/POPE", split=split) + print(f"Loaded {len(dataset)} examples from POPE ({split})") + return list(dataset) + except Exception as e: + print(f"Warning: Could not load POPE: {e}") + try: + dataset = load_dataset("OpenGVLab/POPE", split=split) + print(f"Loaded {len(dataset)} examples from POPE ({split})") + return list(dataset) + except Exception: + raise ValueError(f"Could not load POPE dataset: {e}") + + def encode_image(self, pil_image: Image.Image) -> str: + buffer = io.BytesIO() + pil_image.save(buffer, format="PNG") + return base64.b64encode(buffer.getvalue()).decode("utf-8") + + def get_image_base64(self, item: dict) -> Optional[str]: + for key in ["image", "decoded_image"]: + if key in item and item[key] is not None: + if isinstance(item[key], Image.Image): + return self.encode_image(item[key]) + return None + + def build_messages(self, item: dict) -> List[dict]: + image_base64 = self.get_image_base64(item) + question = item.get("question", "") + + prompt = f"{question}\n\nPlease answer yes or no." + + content = [] + if image_base64: + content.append( + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{image_base64}"}, + } + ) + content.append({"type": "text", "text": prompt}) + + return [{"role": "user", "content": content}] + + def extract_yorn(self, response: str) -> str: + """Extract Yes/No from response.""" + response_lower = response.lower().strip() + + if response_lower.startswith("yes"): + return "Yes" + if response_lower.startswith("no"): + return "No" + + yes_patterns = [r"\byes\b", r"\btrue\b", r"\bcorrect\b", r"\baffirmative\b"] + no_patterns = [r"\bno\b", r"\bfalse\b", r"\bincorrect\b", r"\bnegative\b"] + + for pattern in yes_patterns: + if re.search(pattern, response_lower): + return "Yes" + + for pattern in no_patterns: + if re.search(pattern, response_lower): + return "No" + + return "Unknown" + + async def run_item( + self, server: ServerManager, data_item: dict + ) -> Tuple[dict, dict]: + try: + messages = self.build_messages(data_item) + completion = await self.chat_completion(server, messages) + + if not completion.choices: + return {"accuracy": 0.0}, {"error": "Empty response"} + + message = completion.choices[0].message + response = message.content or "" + + if not response: + return {"accuracy": 0.0}, {"error": "Empty response"} + + answer = data_item.get("answer", "") + extracted = self.extract_yorn(response) + + answer_norm = answer.strip().lower() + if answer_norm in ["yes", "true", "1"]: + answer_norm = "Yes" + elif answer_norm in ["no", "false", "0"]: + answer_norm = "No" + else: + answer_norm = answer.strip() + + correct = extracted == answer_norm + + sample = { + "id": data_item.get("index", data_item.get("question_id", "")), + "question": data_item.get("question", "")[:200], + "category": data_item.get("category", ""), + "answer": answer_norm, + "prediction": extracted, + "raw_response": response[:200], + "correct": correct, + } + + return {"accuracy": 1.0 if correct else 0.0}, sample + + except Exception as e: + return {"accuracy": 0.0}, {"error": str(e)} + + +if __name__ == "__main__": + asyncio.run(eval_runner(POPE(split="test", temperature=0.0, max_tokens=64))) diff --git a/environments/eval_environments/vision_evals/realworldqa_environment.py b/environments/eval_environments/vision_evals/realworldqa_environment.py new file mode 100644 index 000000000..79b2e620e --- /dev/null +++ b/environments/eval_environments/vision_evals/realworldqa_environment.py @@ -0,0 +1,121 @@ +import asyncio +import base64 +import io +from typing import List, Tuple + +from datasets import load_dataset +from environments.eval_environments.eval import EvalBase, eval_runner +from PIL import Image + +from atroposlib.envs.server_handling.server_manager import ServerManager + + +class RealWorldQA(EvalBase): + def setup_data(self) -> list: + split = getattr(self, "split", "test") + dataset = load_dataset("xai-org/RealworldQA", split=split) + print(f"Loaded {len(dataset)} examples from RealWorldQA ({split})") + return list(dataset) + + def encode_image(self, pil_image: Image.Image) -> str: + buffer = io.BytesIO() + pil_image.save(buffer, format="PNG") + return base64.b64encode(buffer.getvalue()).decode("utf-8") + + def get_image_base64(self, item: dict) -> str: + if "image" in item and item["image"] is not None: + if isinstance(item["image"], Image.Image): + return self.encode_image(item["image"]) + raise ValueError("Could not find image for item") + + def build_messages(self, item: dict) -> List[dict]: + image_base64 = self.get_image_base64(item) + question = item.get("question", "") + + prompt = f"""{question} + +Provide a brief, direct answer.""" + + return [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{image_base64}"}, + }, + {"type": "text", "text": prompt}, + ], + } + ] + + def extract_answer(self, response: str) -> str: + response = response.strip() + lines = response.split("\n") + if lines: + return lines[0].strip() + return response + + def score(self, prediction: str, answer: str) -> bool: + pred = prediction.strip().lower() + ans = answer.strip().lower() + + if not pred: + return False + + if pred == ans: + return True + + if ans in pred or pred in ans: + return True + + pred_words = set(pred.split()) + ans_words = set(ans.split()) + overlap = pred_words & ans_words + if len(overlap) >= len(ans_words) * 0.5: + return True + + return False + + async def run_item( + self, server: ServerManager, data_item: dict + ) -> Tuple[dict, dict]: + try: + messages = self.build_messages(data_item) + + completion = await self.chat_completion(server, messages) + + if not completion.choices: + return {"accuracy": 0.0}, {"error": "Empty response"} + + message = completion.choices[0].message + response = message.content or "" + if hasattr(message, "reasoning") and message.reasoning and not response: + response = message.reasoning + if not response and hasattr(message, "model_extra"): + reasoning = message.model_extra.get("reasoning", "") + if reasoning: + response = reasoning + + if not response: + return {"accuracy": 0.0}, {"error": "Empty response"} + + extracted = self.extract_answer(response) + answer = data_item.get("answer", "") + correct = self.score(extracted, answer) + + sample = { + "question": data_item.get("question", ""), + "answer": answer, + "prediction": extracted, + "correct": correct, + } + + return {"accuracy": 1.0 if correct else 0.0}, sample + + except Exception as e: + return {"accuracy": 0.0}, {"error": str(e)} + + +if __name__ == "__main__": + asyncio.run(eval_runner(RealWorldQA(split="test", temperature=0.0, max_tokens=256))) diff --git a/environments/eval_environments/vision_evals/seedbench2_plus_environment.py b/environments/eval_environments/vision_evals/seedbench2_plus_environment.py new file mode 100644 index 000000000..4d3243399 --- /dev/null +++ b/environments/eval_environments/vision_evals/seedbench2_plus_environment.py @@ -0,0 +1,195 @@ +"""SEED-Bench2-Plus evaluation environment.""" + +import asyncio +import base64 +import io +from string import ascii_uppercase +from typing import List, Optional, Tuple + +from datasets import load_dataset +from environments.eval_environments.eval import EvalBase, eval_runner +from PIL import Image + +from atroposlib.envs.server_handling.server_manager import ServerManager +from environments.eval_environments.eval_helpers import ( + extract_letter_from_answer_tag, + extract_mcqa_answer_with_fallback, +) + + +class SEEDBench2Plus(EvalBase): + """SEED-Bench2-Plus evaluation - comprehensive visual understanding benchmark.""" + + def setup_data(self) -> list: + split = getattr(self, "split", "test") + max_samples = getattr(self, "max_samples", None) + + try: + # Use streaming to avoid memory issues with this large dataset + dataset = load_dataset("lmms-lab/SEED-Bench-2", split=split, streaming=True) + + # Take samples from streaming dataset + if max_samples: + data = list(dataset.take(max_samples)) + else: + # Default to 1000 samples to avoid loading entire 24k dataset + data = list(dataset.take(1000)) + + print(f"Loaded {len(data)} examples from SEED-Bench2 ({split}, streaming)") + return data + except Exception as e: + print(f"Warning: Could not load SEED-Bench2: {e}") + try: + dataset = load_dataset( + "lmms-lab/SEED-Bench", split=split, streaming=True + ) + if max_samples: + data = list(dataset.take(max_samples)) + else: + data = list(dataset.take(1000)) + print( + f"Loaded {len(data)} examples from SEED-Bench ({split}, streaming)" + ) + return data + except Exception: + raise ValueError(f"Could not load SEED-Bench2-Plus dataset: {e}") + + def encode_image(self, pil_image: Image.Image) -> str: + buffer = io.BytesIO() + pil_image.save(buffer, format="PNG") + return base64.b64encode(buffer.getvalue()).decode("utf-8") + + def get_image_base64(self, item: dict) -> Optional[str]: + for key in ["image", "decoded_image"]: + if key in item and item[key] is not None: + val = item[key] + if isinstance(val, Image.Image): + return self.encode_image(val) + elif isinstance(val, list) and len(val) > 0: + # SEED-Bench-2 stores images as a list of PIL images + if isinstance(val[0], Image.Image): + return self.encode_image(val[0]) + return None + + def build_messages(self, item: dict) -> List[dict]: + image_base64 = self.get_image_base64(item) + question = item.get("question", "") + + options = {} + for letter in ascii_uppercase[:6]: + # Check for choice_a, choice_b format + choice_key = f"choice_{letter.lower()}" + if choice_key in item and item[choice_key] is not None: + val = item[choice_key] + if isinstance(val, str) and val.strip(): + options[letter] = val + elif letter in item and item[letter] is not None: + val = item[letter] + if isinstance(val, str) and val.strip(): + options[letter] = val + + if not options: + choices = item.get("choices", []) + if isinstance(choices, str): + try: + choices = eval(choices) + except Exception: + choices = [] + for i, choice in enumerate(choices): + options[ascii_uppercase[i]] = choice + + prompt = f"Question: {question}\n" + if options: + prompt += "Options:\n" + for letter in sorted(options.keys()): + prompt += f"{letter}. {options[letter]}\n" + prompt += "\nPlease select the correct answer from the options above." + + content = [] + if image_base64: + content.append( + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{image_base64}"}, + } + ) + content.append({"type": "text", "text": prompt}) + + return [{"role": "user", "content": content}] + + def extract_answer( + self, response: str, num_choices: int + ) -> Tuple[Optional[str], str]: + valid_letters = set(ascii_uppercase[:num_choices]) + + letter, method = extract_letter_from_answer_tag(response, valid_letters) + if letter: + return letter, method + + letter, method = extract_mcqa_answer_with_fallback(response, num_choices) + return letter, method + + async def run_item( + self, server: ServerManager, data_item: dict + ) -> Tuple[dict, dict]: + try: + messages = self.build_messages(data_item) + + completion = await self.chat_completion(server, messages) + + if not completion.choices: + return {"accuracy": 0.0}, {"error": "Empty response"} + + message = completion.choices[0].message + response = message.content or "" + + if not response: + return {"accuracy": 0.0}, {"error": "Empty response"} + + answer = data_item.get("answer", "") + + choices = data_item.get("choices", []) + if isinstance(choices, str): + try: + choices = eval(choices) + except Exception: + choices = [] + + num_choices = len(choices) if choices else 4 + if num_choices == 0: + num_choices = sum( + 1 + for letter in ascii_uppercase[:6] + if letter in data_item and data_item[letter] is not None + ) + num_choices = max(num_choices, 4) + + extracted, method = self.extract_answer(response, num_choices) + + correct = False + if extracted and answer: + correct = extracted.upper() == str(answer).upper() + + sample = { + "id": data_item.get("index", data_item.get("question_id", "")), + "question": data_item.get("question", "")[:200], + "category": data_item.get( + "question_type_id", data_item.get("category", "") + ), + "answer": answer, + "prediction": extracted, + "raw_response": response[:500], + "correct": correct, + "extraction_method": method, + } + + return {"accuracy": 1.0 if correct else 0.0}, sample + + except Exception as e: + return {"accuracy": 0.0}, {"error": str(e)} + + +if __name__ == "__main__": + asyncio.run( + eval_runner(SEEDBench2Plus(split="test", temperature=0.0, max_tokens=256)) + ) diff --git a/environments/eval_environments/vision_evals/visulogic_environment.py b/environments/eval_environments/vision_evals/visulogic_environment.py new file mode 100644 index 000000000..61aa41864 --- /dev/null +++ b/environments/eval_environments/vision_evals/visulogic_environment.py @@ -0,0 +1,184 @@ +import asyncio +import base64 +import io +import json +import zipfile +from pathlib import Path +from typing import List, Tuple + +from environments.eval_environments.eval import EvalBase, eval_runner +from huggingface_hub import hf_hub_download +from PIL import Image + +from atroposlib.envs.server_handling.server_manager import ServerManager + +DEFAULT_DATA_DIR = Path.home() / ".cache" / "visulogic_hf" + + +class VisuLogic(EvalBase): + TAGS = [ + "Quantitative Reasoning", + "Spatial Reasoning", + "Positional Reasoning", + "Attribute Reasoning", + "Stylistic Reasoning", + "Other", + ] + + def _download_data(self, data_dir: Path) -> None: + jsonl_path = data_dir / "data.jsonl" + images_dir = data_dir / "images" + + if jsonl_path.exists() and images_dir.exists(): + return + + print(f"Downloading VisuLogic dataset to {data_dir}...") + data_dir.mkdir(parents=True, exist_ok=True) + + # Download data.jsonl + hf_hub_download( + repo_id="VisuLogic/VisuLogic", + filename="data.jsonl", + repo_type="dataset", + local_dir=data_dir, + ) + + # Download and extract images.zip + images_zip_path = hf_hub_download( + repo_id="VisuLogic/VisuLogic", + filename="images.zip", + repo_type="dataset", + local_dir=data_dir, + ) + + print("Extracting images...") + with zipfile.ZipFile(images_zip_path, "r") as zip_ref: + zip_ref.extractall(data_dir) + + print("Download complete!") + + def setup_data(self) -> list: + """ + Load and return dataset as a list. + + Auto-downloads the VisuLogic dataset if data_path is not specified + or doesn't exist. + """ + data_path = getattr(self, "data_path", None) + + if data_path is None: + data_dir = DEFAULT_DATA_DIR + self._download_data(data_dir) + jsonl_path = data_dir / "data.jsonl" + self.images_base = str(data_dir) + else: + data_dir = Path(data_path) + jsonl_path = data_dir / "data.jsonl" + self.images_base = str(data_dir) + + if not jsonl_path.exists(): + raise FileNotFoundError( + f"Dataset not found at {jsonl_path}. " + "Remove data_path argument to auto-download." + ) + + dataset = [] + with open(jsonl_path, "r", encoding="utf-8") as f: + for line in f: + item = json.loads(line.strip()) + dataset.append(item) + + print(f"Loaded {len(dataset)} examples from VisuLogic") + return dataset + + def encode_image(self, pil_image: Image.Image) -> str: + buffer = io.BytesIO() + pil_image.save(buffer, format="PNG") + return base64.b64encode(buffer.getvalue()).decode("utf-8") + + def get_image_base64(self, item: dict) -> str: + image_path = item.get("image_path", "") + full_path = Path(self.images_base) / image_path + if full_path.exists(): + with Image.open(full_path) as img: + return self.encode_image(img) + raise ValueError(f"Could not find image at {full_path}") + + def build_messages(self, item: dict) -> List[dict]: + image_base64 = self.get_image_base64(item) + question = item.get("question", "") + + prompt = f"""{question} + +Answer with only the letter (A, B, C, or D).""" + + return [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{image_base64}"}, + }, + {"type": "text", "text": prompt}, + ], + } + ] + + def extract_answer(self, response: str) -> str: + response = response.strip().upper() + + for char in reversed(response): + if char in "ABCD": + return char + + return "" + + def score(self, prediction: str, answer: str) -> bool: + if not prediction: + return False + return prediction.upper() == answer.upper() + + async def run_item( + self, server: ServerManager, data_item: dict + ) -> Tuple[dict, dict]: + try: + messages = self.build_messages(data_item) + + completion = await self.chat_completion(server, messages) + + if not completion.choices: + return {"accuracy": 0.0}, {"error": "Empty response"} + + message = completion.choices[0].message + response = message.content or "" + if hasattr(message, "reasoning") and message.reasoning and not response: + response = message.reasoning + if not response and hasattr(message, "model_extra"): + reasoning = message.model_extra.get("reasoning", "") + if reasoning: + response = reasoning + + if not response: + return {"accuracy": 0.0}, {"error": "Empty response"} + + extracted = self.extract_answer(response) + answer = data_item.get("label", "") + correct = self.score(extracted, answer) + + sample = { + "question": data_item.get("question", ""), + "answer": answer, + "prediction": extracted, + "correct": correct, + "tag": data_item.get("tag", ""), + } + + return {"accuracy": 1.0 if correct else 0.0}, sample + + except Exception as e: + return {"accuracy": 0.0}, {"error": str(e)} + + +if __name__ == "__main__": + asyncio.run(eval_runner(VisuLogic(temperature=0.0, max_tokens=256))) diff --git a/environments/eval_environments/vision_evals/vlmblind_environment.py b/environments/eval_environments/vision_evals/vlmblind_environment.py new file mode 100644 index 000000000..c5fb65809 --- /dev/null +++ b/environments/eval_environments/vision_evals/vlmblind_environment.py @@ -0,0 +1,165 @@ +"""VLMBlind (VLMs are Blind) evaluation environment.""" + +import asyncio +import base64 +import io +import re +from typing import List, Optional, Tuple + +from datasets import load_dataset +from environments.eval_environments.eval import EvalBase, eval_runner +from PIL import Image + +from atroposlib.envs.server_handling.server_manager import ServerManager + + +class VLMBlind(EvalBase): + """VLMBlind evaluation - tests basic visual perception abilities of VLMs.""" + + TASK_PATTERNS = { + "Subway Connections": r"\{([^}]+)\}", + "Nested Squares": r"\{([^}]+)\}", + "Line Plot Intersections": r"\{([^}]+)\}", + "Touching Circles": None, # Substring match + "Counting Grid": r"(\d+)\s*(?:rows?|r).*?(\d+)\s*(?:columns?|cols?|c)|(\d+)\s*[xX×]\s*(\d+)", + "Olympic Counting": None, # Substring match + "Circled Letter": r"\{([^}]+)\}", + } + + def setup_data(self) -> list: + # XAI/vlmsareblind only has 'valid' split + split = getattr(self, "split", "valid") + + try: + dataset = load_dataset("XAI/vlmsareblind", split=split) + print(f"Loaded {len(dataset)} examples from VLMBlind ({split})") + return list(dataset) + except Exception as e: + print(f"Warning: Could not load VLMBlind: {e}") + try: + # Try valid split explicitly + dataset = load_dataset("XAI/vlmsareblind", split="valid") + print(f"Loaded {len(dataset)} examples from VLMBlind (valid)") + return list(dataset) + except Exception: + raise ValueError(f"Could not load VLMBlind dataset: {e}") + + def encode_image(self, pil_image: Image.Image) -> str: + buffer = io.BytesIO() + pil_image.save(buffer, format="PNG") + return base64.b64encode(buffer.getvalue()).decode("utf-8") + + def get_image_base64(self, item: dict) -> Optional[str]: + for key in ["image", "decoded_image"]: + if key in item and item[key] is not None: + if isinstance(item[key], Image.Image): + return self.encode_image(item[key]) + return None + + def build_messages(self, item: dict) -> List[dict]: + image_base64 = self.get_image_base64(item) + # XAI/vlmsareblind uses 'prompt' instead of 'question' + question = item.get("prompt", item.get("question", "")) + + prompt = f"{question}\n\nProvide your answer." + + content = [] + if image_base64: + content.append( + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{image_base64}"}, + } + ) + content.append({"type": "text", "text": prompt}) + + return [{"role": "user", "content": content}] + + def extract_and_score( + self, response: str, answer: str, task: str + ) -> Tuple[bool, str]: + """Task-specific answer extraction and scoring.""" + response_lower = response.lower().strip() + answer_lower = str(answer).lower().strip() + + if task in [ + "Subway Connections", + "Nested Squares", + "Line Plot Intersections", + "Circled Letter", + ]: + match = re.search(r"\{([^}]+)\}", response) + if match: + extracted = match.group(1).strip().lower() + return extracted == answer_lower, extracted + return answer_lower in response_lower, response_lower[:50] + + elif task == "Touching Circles": + return answer_lower in response_lower, response_lower[:50] + + elif "Counting Grid" in task or "Grid" in task: + patterns = [ + r"(\d+)\s*[xX×]\s*(\d+)", + r"(\d+)\s*(?:rows?|r).*?(\d+)\s*(?:columns?|cols?|c)", + r"(\d+)\s*(?:columns?|cols?|c).*?(\d+)\s*(?:rows?|r)", + ] + for pattern in patterns: + match = re.search(pattern, response) + if match: + groups = match.groups() + extracted = f"{groups[0]}x{groups[1]}" + ans_match = re.search(r"(\d+)\s*[xX×,]\s*(\d+)", answer) + if ans_match: + answer_parsed = f"{ans_match.group(1)}x{ans_match.group(2)}" + return extracted == answer_parsed, extracted + return answer_lower in response_lower, response_lower[:50] + + elif "Olympic" in task or "Counting" in task: + return answer_lower in response_lower, response_lower[:50] + + else: + return answer_lower in response_lower, response_lower[:50] + + async def run_item( + self, server: ServerManager, data_item: dict + ) -> Tuple[dict, dict]: + try: + messages = self.build_messages(data_item) + + completion = await self.chat_completion(server, messages) + + if not completion.choices: + return {"accuracy": 0.0}, {"error": "Empty response"} + + message = completion.choices[0].message + response = message.content or "" + + if not response: + return {"accuracy": 0.0}, {"error": "Empty response"} + + # XAI/vlmsareblind uses 'groundtruth' instead of 'answer' + answer = data_item.get("groundtruth", data_item.get("answer", "")) + task = data_item.get("task", data_item.get("category", "")) + + correct, extracted = self.extract_and_score(response, answer, task) + + sample = { + "id": data_item.get("index", data_item.get("id", "")), + "question": data_item.get("prompt", data_item.get("question", ""))[ + :200 + ], + "task": task, + "answer": answer, + "prediction": extracted, + "raw_response": response[:500], + "correct": correct, + } + + return {"accuracy": 1.0 if correct else 0.0}, sample + + except Exception as e: + return {"accuracy": 0.0}, {"error": str(e)} + + +if __name__ == "__main__": + asyncio.run(eval_runner(VLMBlind(split="test", temperature=0.0, max_tokens=512))) diff --git a/environments/eval_environments/vision_evals/wemath_environment.py b/environments/eval_environments/vision_evals/wemath_environment.py new file mode 100644 index 000000000..432d6d0fb --- /dev/null +++ b/environments/eval_environments/vision_evals/wemath_environment.py @@ -0,0 +1,373 @@ +"""We-Math evaluation environment.""" + +import asyncio +import base64 +import io +import re +import string +from collections import defaultdict +from typing import Dict, List, Tuple + +import pandas as pd +from datasets import load_dataset +from environments.eval_environments.eval import EvalBase, eval_runner +from PIL import Image + +from atroposlib.envs.server_handling.server_manager import ServerManager + + +class WeMath(EvalBase): + """ + We-Math evaluation environment. + + A benchmark for visual mathematical reasoning with multi-step problems + and 4-dimensional evaluation metrics (IK, IG, CM, RM). + """ + + def setup_data(self) -> list: + split = getattr(self, "split", "testmini") + dataset = load_dataset("We-Math/We-Math", split=split) + print(f"Loaded {len(dataset)} examples from We-Math ({split})") + return list(dataset) + + def encode_image(self, pil_image: Image.Image) -> str: + buffer = io.BytesIO() + pil_image.save(buffer, format="PNG") + return base64.b64encode(buffer.getvalue()).decode("utf-8") + + def get_image_base64(self, item: dict) -> str: + img = item.get("image_path") or item.get("image") + if img is not None: + if isinstance(img, Image.Image): + return self.encode_image(img) + elif isinstance(img, bytes): + return base64.b64encode(img).decode("utf-8") + raise ValueError( + f"Could not find image for item {item.get('ID', item.get('problem_id', 'unknown'))}" + ) + + def build_messages(self, item: dict) -> List[dict]: + """Build prompt with question, options, and optional hint (MCQ format).""" + image_base64 = self.get_image_base64(item) + question = item.get("question", "") + + # Build options from A-H if present + options = {} + for letter in string.ascii_uppercase[:8]: # A-H + if ( + letter in item + and item[letter] is not None + and not pd.isna(item.get(letter, float("nan"))) + ): + options[letter] = item[letter] + + # Build prompt + prompt_parts = [] + + # Add hint if present + hint = item.get("hint", "") + if hint and not pd.isna(hint): + prompt_parts.append(f"Hint: {hint}") + + prompt_parts.append(f"Question: {question}") + + # Add options if present + if options: + options_text = "Options:\n" + for letter, value in options.items(): + options_text += f"{letter}. {value}\n" + prompt_parts.append(options_text) + + # Add COT requirement if dataset is WeMath_COT + use_cot = getattr(self, "use_cot", False) + requirement = item.get("requirement", "") + if use_cot and requirement and not pd.isna(requirement): + prompt_parts.append(requirement) + else: + prompt_parts.append( + "Answer with the option's letter from the given choices directly." + ) + + prompt = "\n".join(prompt_parts) + + return [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{image_base64}"}, + }, + {"type": "text", "text": prompt}, + ], + } + ] + + def extract_answer(self, response: str) -> str: + """ + Extract MCQ answer letter from response. + + Following VLMEvalKit logic: look for letter after "Answer" keyword, + or extract first valid letter (A-H). + """ + response = str(response).strip() + + # Try to find answer after "Answer" keyword + answer_match = re.search(r"Answer[:\s]*([A-Ha-h])", response, re.IGNORECASE) + if answer_match: + return answer_match.group(1).upper() + + # Clean response and look for first valid letter + cleaned = re.sub(r"[>><<:.\s]", "", response).strip() + if cleaned and cleaned[0].upper() in "ABCDEFGH": + return cleaned[0].upper() + + # Fallback: find any letter A-H in the response + for char in response.upper(): + if char in "ABCDEFGH": + return char + + return "" + + def score(self, prediction: str, answer: str) -> bool: + """Check if prediction matches answer (case-insensitive).""" + if not prediction: + return False + return prediction.upper() == answer.upper() + + async def run_item( + self, server: ServerManager, data_item: dict + ) -> Tuple[dict, dict]: + try: + messages = self.build_messages(data_item) + + completion = await self.chat_completion(server, messages) + + if not completion.choices: + return {"accuracy": 0.0, "hit": 0}, {"error": "Empty response"} + + message = completion.choices[0].message + response = message.content or "" + if hasattr(message, "reasoning") and message.reasoning and not response: + response = message.reasoning + if not response and hasattr(message, "model_extra"): + reasoning = message.model_extra.get("reasoning", "") + if reasoning: + response = reasoning + + if not response: + return {"accuracy": 0.0, "hit": 0}, {"error": "Empty response"} + + extracted = self.extract_answer(response) + answer = data_item.get("answer", "") + correct = self.score(extracted, answer) + + # Get problem metadata for 4-dimensional analysis + problem_id = data_item.get("ID", data_item.get("problem_id", "")) + key = data_item.get("key", "") # e.g., "2steps_1", "2steps_multi", etc. + + sample = { + "ID": problem_id, + "key": key, + "question": data_item.get("question", ""), + "answer": answer, + "prediction": extracted, + "raw_response": response[:500], # Truncate for logging + "hit": 1 if correct else 0, + "joker": correct, # VLMEvalKit naming convention + } + + return { + "accuracy": 1.0 if correct else 0.0, + "hit": 1 if correct else 0, + }, sample + + except Exception as e: + return {"accuracy": 0.0, "hit": 0}, {"error": str(e)} + + +def compute_4d_metrics(samples: List[dict]) -> Dict: + """ + Compute We-Math 4-dimensional metrics from evaluation samples. + + This implements the evaluation logic from VLMEvalKit's wemath.py. + + Returns metrics for: + - IK (Insufficient Knowledge): Steps wrong AND multi wrong + - IG (Inadequate Generalization): Steps right BUT multi wrong + - CM (Complete Mastery): Steps right AND multi right + - RM (Rote Memorization): Steps wrong BUT multi right + """ + # Convert samples to DataFrame + df = pd.DataFrame(samples) + + if "key" not in df.columns or df["key"].isna().all(): + # Dataset doesn't have step structure, return basic accuracy + return { + "overall_accuracy": df["hit"].mean() if "hit" in df.columns else 0.0, + "note": "Dataset lacks step structure for 4D metrics", + } + + # Separate by step type + data_2steps = df[df["key"].str.contains("2steps", na=False)] + data_3steps = df[df["key"].str.contains("3steps", na=False)] + + metrics = {} + + # Process 2-step problems + if len(data_2steps) > 0: + merged_2steps = _process_steps_data(data_2steps, 2) + if merged_2steps is not None: + metrics["2step"] = _calculate_step_metrics(merged_2steps, 2) + + # Process 3-step problems + if len(data_3steps) > 0: + merged_3steps = _process_steps_data(data_3steps, 3) + if merged_3steps is not None: + metrics["3step"] = _calculate_step_metrics(merged_3steps, 3) + + # Compute overall 4D metrics + if "2step" in metrics or "3step" in metrics: + total_counts = _compute_total_counts(metrics) + total_count = 525 # Standard We-Math total + + # Compute rates and final scores + final_metrics = _compute_final_scores(total_counts, total_count) + metrics["overall"] = final_metrics + + # Basic accuracy + metrics["step_accuracy"] = df["hit"].mean() if "hit" in df.columns else 0.0 + + return metrics + + +def _process_steps_data(df: pd.DataFrame, steps: int) -> pd.DataFrame: + """Process step data and merge by problem ID.""" + try: + steps_data = {} + for i in range(1, steps + 1): + key = f"{steps}steps_{i}" + step_df = df[df["key"] == key].copy() + if len(step_df) == 0: + return None + step_df.columns = [f"{col}_{i}" for col in step_df.columns] + steps_data[i] = step_df + + # Get multi-step data + multi_key = f"{steps}steps_multi" + multi_df = df[df["key"] == multi_key].copy() + if len(multi_df) == 0: + return None + multi_df.columns = [f"{col}_multi" for col in multi_df.columns] + + # Merge all steps + merged = steps_data[1] + for i in range(2, steps + 1): + merged = pd.merge( + merged, + steps_data[i], + left_on="ID_1", + right_on=f"ID_{i}", + how="left", + ) + merged = pd.merge( + merged, multi_df, left_on="ID_1", right_on="ID_multi", how="left" + ) + + return merged + except Exception: + return None + + +def _calculate_step_metrics(merged: pd.DataFrame, steps: int) -> Dict: + """Calculate metrics for a step type (2-step or 3-step).""" + try: + # Get joker columns + joker_cols = [f"joker_{i}" for i in range(1, steps + 1)] + joker_multi = "joker_multi" + + # Check if columns exist + for col in joker_cols + [joker_multi]: + if col not in merged.columns: + return {} + + # Calculate conditions + all_steps_correct = merged[joker_cols].all(axis=1) + any_step_correct = merged[joker_cols].any(axis=1) + all_steps_wrong = ~merged[joker_cols].any(axis=1) + any_step_wrong = ~merged[joker_cols].all(axis=1) + multi_correct = merged[joker_multi] == True # noqa: E712 + + return { + # Strict: ALL steps must be correct + "CM_strict": int((all_steps_correct & multi_correct).sum()), + "IG": int((all_steps_correct & ~multi_correct).sum()), + "RM_strict": int((any_step_wrong & multi_correct).sum()), + "IK": int((any_step_wrong & ~multi_correct).sum()), + # Loose: ANY step correct + "CM_loose": int((any_step_correct & multi_correct).sum()), + "RM_loose": int((all_steps_wrong & multi_correct).sum()), + "total": len(merged), + } + except Exception: + return {} + + +def _compute_total_counts(metrics: Dict) -> Dict: + """Aggregate counts across step types.""" + totals = defaultdict(int) + + for step_type in ["2step", "3step"]: + if step_type in metrics: + for key in ["CM_strict", "CM_loose", "IG", "RM_strict", "RM_loose", "IK"]: + if key in metrics[step_type]: + totals[key] += metrics[step_type][key] + + return dict(totals) + + +def _compute_final_scores(total_counts: Dict, total_count: int = 525) -> Dict: + """Compute final 4D scores and rates.""" + results = {} + + # Calculate rates + for key in ["IK", "IG", "CM_strict", "CM_loose", "RM_strict", "RM_loose"]: + count = total_counts.get(key, 0) + results[f"{key}_count"] = count + results[f"{key}_rate"] = count / total_count if total_count > 0 else 0.0 + + # Calculate RM rates (relative to CM + RM) + cm_rm_strict = total_counts.get("CM_strict", 0) + total_counts.get("RM_strict", 0) + cm_rm_loose = total_counts.get("CM_loose", 0) + total_counts.get("RM_loose", 0) + + results["RM_strict_relative"] = ( + total_counts.get("RM_strict", 0) / cm_rm_strict if cm_rm_strict > 0 else 0.0 + ) + results["RM_loose_relative"] = ( + total_counts.get("RM_loose", 0) / cm_rm_loose if cm_rm_loose > 0 else 0.0 + ) + + # Final scores (VLMEvalKit formula) + results["score_strict"] = ( + total_count + - 0.5 * total_counts.get("IG", 0) + - total_counts.get("RM_strict", 0) + - total_counts.get("IK", 0) + ) / total_count + + results["score_loose"] = ( + total_count + - 0.5 * total_counts.get("IG", 0) + - total_counts.get("RM_loose", 0) + - total_counts.get("IK", 0) + ) / total_count + + return results + + +if __name__ == "__main__": + asyncio.run( + eval_runner( + WeMath(split="testmini", use_cot=False, temperature=0.0, max_tokens=512) + ) + )