diff --git a/environments/eval_environments/super_gpqa_eval.py b/environments/eval_environments/super_gpqa_eval.py
new file mode 100644
index 000000000..1ccbb4af0
--- /dev/null
+++ b/environments/eval_environments/super_gpqa_eval.py
@@ -0,0 +1,712 @@
+"""
+SuperGPQA Evaluation Environment for Atropos
+
+This environment evaluates models on the SuperGPQA benchmark.
+
+Dataset: m-a-p/SuperGPQA
+Paper: https://www.arxiv.org/pdf/2502.14739
+
+SuperGPQA is a comprehensive benchmark designed to evaluate the
+knowledge and reasoning abilities of Large Language Models (LLMs)
+across 285 graduate-level disciplines.
+It features at least 50 questions per discipline,
+covering a broad spectrum of graduate-level topics.
+"""
+
+import asyncio
+import os
+import random
+import re
+import time
+from string import ascii_uppercase
+from typing import Dict, List, Optional, Tuple
+
+from datasets import load_dataset
+from eval_helpers import (
+ create_system_content,
+ extract_letter_from_answer_tag,
+ get_default_thinking_prompt,
+)
+from pydantic import Field
+from tqdm.asyncio import tqdm_asyncio
+
+from atroposlib.envs.base import (
+ APIServerConfig,
+ BaseEnv,
+ BaseEnvConfig,
+ EvalHandlingEnum,
+)
+
+SUPER_GPQA_ZERO_SHOT_PROMPT = """Answer the following multiple-choice question. There is only one correct answer.
+Provide your final answer within tags, containing only the letter (A, B, C, D, E, F, G, H, I, or J.).
+
+Example format:
+A
+
+Question: {Question}"""
+
+
+class SuperGPQAEvalConfig(BaseEnvConfig):
+ """Configuration for SuperGPQA eval environment"""
+
+ custom_system_prompt: Optional[str] = Field(
+ default=None,
+ description="Custom system prompt to append after thinking prompt (if thinking_mode) or use directly.",
+ )
+
+ custom_thinking_prompt: Optional[str] = Field(
+ default=None,
+ description="Custom thinking prompt. If None, uses the default thinking prompt.",
+ )
+
+ thinking_mode: bool = Field(
+ default=False,
+ description="Whether to enable thinking mode with tags.",
+ )
+
+ dataset_name: str = Field(
+ default="m-a-p/SuperGPQA",
+ description="HuggingFace dataset for SuperGPQA.",
+ )
+
+ eval_split: str = Field(
+ default="train",
+ description="Dataset split to use for evaluation (SuperGPQA has a single split 'train' used for evaluation).",
+ )
+
+ eval_max_tokens: int = Field(
+ default=32000,
+ description="Maximum tokens for reasoning models.",
+ )
+
+ min_response_length: int = Field(
+ default=1,
+ ge=1,
+ description="Min length for a valid response.",
+ )
+
+ # all main results in supergpqa use temperature=0
+ eval_temperature: float = Field(
+ default=0,
+ description="Temperature for evaluation (0.0 for deterministic).",
+ )
+
+ max_retries: int = Field(
+ default=3,
+ ge=1,
+ description="Maximum retries for failed API calls.",
+ )
+
+ retry_delay: float = Field(
+ default=1.0,
+ ge=0.0,
+ description="Delay between retry attempts in seconds.",
+ )
+
+ full_debug: bool = Field(
+ default=False,
+ description="Enable verbose debug logging.",
+ )
+
+ shuffle_seed: Optional[int] = Field(
+ default=42,
+ description="Seed for shuffling answer positions. Set to None for random shuffling each run.",
+ )
+
+
+class SuperGPQAEvalEnv(BaseEnv):
+ """
+ SuperGPQA eval environment for Atropos, supporting reasoning/instruct models via a 0-shot prompt.
+
+ Pipeline:
+ - Load SuperGPQA from HuggingFace
+ - Use 0-shot prompts from the SuperGPQA paper
+ - Extract answer choice via regex
+ """
+
+ name = "supergpqa_eval"
+ env_config_cls = SuperGPQAEvalConfig
+
+ def __init__(
+ self,
+ config: SuperGPQAEvalConfig,
+ server_configs: List[APIServerConfig],
+ slurm=True,
+ testing=False,
+ ):
+ super().__init__(config, server_configs, slurm, testing)
+ self.config: SuperGPQAEvalConfig = config
+
+ self.eval_metrics = []
+ if self.config.shuffle_seed is not None:
+ self.shuffle_rng = random.Random(self.config.shuffle_seed)
+ else:
+ self.shuffle_rng = random.Random()
+
+ self._think_pattern = re.compile(r"")
+ self._think_close_pattern = re.compile(r"")
+ self._think_content_pattern = re.compile(r"\s*(.*)", re.DOTALL)
+ self._thinking_extract_pattern = re.compile(r"(.*?)", re.DOTALL)
+ self._answer_pattern = re.compile(
+ r"(.*?)", re.DOTALL | re.IGNORECASE
+ )
+
+ # Build fallback answer extraction patterns
+ self._build_extraction_patterns()
+
+ def _get_thinking_prompt(self) -> str:
+ """Get thinking system prompt."""
+ return get_default_thinking_prompt(self.config.custom_thinking_prompt)
+
+ def _create_system_content(self) -> Optional[str]:
+ """Create system message content based on thinking mode."""
+ return create_system_content(
+ self.config.thinking_mode,
+ self.config.custom_thinking_prompt,
+ self.config.custom_system_prompt,
+ )
+
+ def _build_extraction_patterns(self):
+ """Build regex patterns for extracting answer letters from model responses."""
+ letters = "ABCDEFGHIJ"
+ letter_pattern = rf"([{letters}]|\([{letters}]\))"
+
+ # Patterns ordered by priority (most specific first)
+ self._pattern_final_answer_hope = re.compile(
+ rf"(?i:final\s+answer\s+is)\s*:?\s*{letter_pattern}\.?\s*I\s*hope",
+ re.IGNORECASE,
+ )
+ self._pattern_final_answer_is = re.compile(
+ rf"(?i:final\s+answer).{{0,100}}?\s+is\s*:?\s*{letter_pattern}",
+ re.IGNORECASE | re.DOTALL,
+ )
+ self._pattern_the_answer_is = re.compile(
+ rf"(?i:the\s+answer\s+is)\s*:?\s*{letter_pattern}", re.IGNORECASE
+ )
+ self._pattern_answer_colon = re.compile(
+ rf"(?i:answer)\s*:\s*.{{0,50}}?{letter_pattern}", re.IGNORECASE | re.DOTALL
+ )
+ self._pattern_answer_space = re.compile(
+ rf"(?i:answer)\s+{letter_pattern}", re.IGNORECASE
+ )
+ self._pattern_start = re.compile(
+ rf"^\s*\**{letter_pattern}\**[\s\.\)\:]", re.IGNORECASE
+ )
+ self._pattern_line_start = re.compile(
+ rf"\n\s*\**{letter_pattern}\**[\s\.\)\:]", re.IGNORECASE
+ )
+ self._pattern_standalone = re.compile(rf"\b{letter_pattern}\b", re.IGNORECASE)
+
+ self._extraction_patterns = [
+ (0, self._pattern_final_answer_hope, "final_answer_hope"),
+ (50, self._pattern_final_answer_is, "final_answer_is"),
+ (75, self._pattern_the_answer_is, "the_answer_is"),
+ (100, self._pattern_answer_colon, "answer_colon"),
+ (150, self._pattern_answer_space, "answer_space"),
+ (200, self._pattern_start, "start"),
+ (210, self._pattern_line_start, "line_start"),
+ (250, self._pattern_standalone, "standalone"),
+ ]
+
+ @classmethod
+ def config_init(cls) -> Tuple[SuperGPQAEvalConfig, List[APIServerConfig]]:
+ """Initialize default configuration for the environment."""
+ env_config = SuperGPQAEvalConfig(
+ tokenizer_name="NousResearch/Hermes-3-Llama-3.1-8B",
+ group_size=1,
+ use_wandb=True,
+ max_num_workers_per_node=128,
+ rollout_server_url="http://localhost:8000",
+ total_steps=1,
+ batch_size=1,
+ steps_per_eval=1,
+ inference_weight=1.0,
+ wandb_name="super_gpqa_eval",
+ eval_handling=EvalHandlingEnum.STOP_TRAIN,
+ max_eval_workers=256,
+ max_num_workers=1024,
+ dataset_name="m-a-p/SuperGPQA",
+ eval_temperature=0,
+ eval_max_tokens=0, # this uses the default model max
+ thinking_mode=False,
+ )
+
+ server_configs = [
+ APIServerConfig(
+ model_name="Hermes-3-Llama-3.1-8B",
+ base_url="http://localhost:9000/v1",
+ api_key=os.getenv("OPENAI_API_KEY", "none"),
+ num_max_requests_at_once=32,
+ num_requests_for_eval=1024,
+ ),
+ ]
+
+ return env_config, server_configs
+
+ async def setup(self) -> None:
+ """Load SuperGPQA dataset and process all points to create prompts for rollout generation."""
+ print("\nSuperGPQA Evaluation Setup:")
+ print("=" * 20 + "DATASET DETAILS" + "=" * 20)
+ print(f" Dataset: {self.config.dataset_name}")
+ print(f" Evaluation split: {self.config.eval_split}")
+ print("=" * 20 + "GENERATION DETAILS" + "=" * 20)
+ print(f" Thinking mode: {self.config.thinking_mode}")
+ print(f" Max tokens: {self.config.eval_max_tokens}")
+
+ try:
+ dataset = load_dataset(
+ self.config.dataset_name,
+ split=self.config.eval_split,
+ )
+ self.eval_data = list(dataset)
+ print(f" Loaded {len(self.eval_data)} evaluation items")
+ except Exception as e:
+ print(f"Error loading dataset from HuggingFace: {e}")
+ raise
+
+ self.all_eval_items = []
+ for item in self.eval_data:
+ processed = self._process_super_gpqa_item(item)
+ self.all_eval_items.append(processed)
+
+ self.iter = 0
+
+ def _process_super_gpqa_item(self, item: Dict) -> Dict:
+ # check if shuffling is done per rollout or once per eval
+ """
+ Process a SuperGPQA item.
+ Shuffle answer positions to avoid reward hacking where the model
+ might learn positions rather than the correct option.
+ """
+ correct_answer = item["answer"]
+ all_answers = item["options"]
+ choices = all_answers.copy()
+ self.shuffle_rng.shuffle(choices)
+ gold_index = choices.index(correct_answer)
+
+ return {
+ "question": item["question"],
+ "choices": choices,
+ "gold_index": gold_index,
+ "gold_letter": ascii_uppercase[gold_index],
+ "subfield": item.get("subfield", "unknown"),
+ "original_item": item,
+ }
+
+ def _format_super_gpqa_prompt(self, question: str, choices: List[str]) -> str:
+ """
+ Format a SuperGPQA question, add answer choices to the 0-shot prompt.
+ """
+ return (
+ SUPER_GPQA_ZERO_SHOT_PROMPT.format(
+ Question=question.strip(),
+ )
+ + "\n\n"
+ + "\n".join(
+ [
+ f"{ascii_uppercase[i]}) {choice.strip()}"
+ for i, choice in enumerate(choices)
+ ]
+ )
+ )
+
+ def _validate_thinking_format(self, response: str) -> Tuple[bool, str]:
+ """Validate thinking format and extract content after tags."""
+ if not self.config.thinking_mode:
+ return True, response
+
+ think_open_count = len(self._think_pattern.findall(response))
+ think_close_count = len(self._think_close_pattern.findall(response))
+
+ if think_open_count != 1 or think_close_count != 1:
+ return False, response
+
+ match = self._think_content_pattern.search(response)
+ if match:
+ return True, match.group(1).strip()
+ else:
+ return False, response
+
+ def _extract_thinking_content(self, response: str) -> Optional[str]:
+ """Extract the content inside tags."""
+ match = self._thinking_extract_pattern.search(response)
+ if match:
+ return match.group(1).strip()
+ return None
+
+ def _extract_answer(
+ self, response: str, num_choices: int, choices: Optional[List[str]] = None
+ ) -> Tuple[Optional[str], str]:
+ """
+ Extract the answer letter from the model's response.
+
+ Primary method: Look for tags, or match against choice texts.
+ Fallback: Use priority-ordered regex patterns.
+ """
+ if not response:
+ return None, "empty_response"
+
+ valid_letters = set(ascii_uppercase[:num_choices])
+
+ # PRIMARY: Try tags first
+ # Also matches against choice texts if provided
+ letter, method = extract_letter_from_answer_tag(
+ response, valid_letters, debug=self.config.full_debug, choices=choices
+ )
+ if letter:
+ return letter, method
+
+ # FALLBACK: Try each pattern in priority order
+ for _, pattern, method_name in self._extraction_patterns:
+ matches = pattern.findall(response)
+ if matches:
+ match = (
+ matches[-1]
+ if method_name
+ in [
+ "final_answer_is",
+ "the_answer_is",
+ "answer_colon",
+ "answer_space",
+ ]
+ else matches[0]
+ )
+ if isinstance(match, tuple):
+ match = match[0]
+ letter = match.strip("()").upper()
+
+ if letter in valid_letters:
+ if self.config.full_debug:
+ print(
+ f" Extracted '{letter}' using fallback method '{method_name}'"
+ )
+ return letter, f"fallback_{method_name}"
+
+ for letter in reversed(list(valid_letters)):
+ if letter in response.upper():
+ if self.config.full_debug:
+ print(
+ f" Extracted '{letter}' using fallback 'last_valid_letter'"
+ )
+ return letter, "fallback_last_valid_letter"
+
+ return None, "no_match"
+
+ async def get_next_item(self):
+ """Get next item for training (not used in eval-only environment)."""
+ self.iter += 1
+ if self.all_eval_items:
+ item = self.all_eval_items[self.iter % len(self.all_eval_items)]
+ return item
+ return None
+
+ async def collect_trajectories(self, item):
+ """Collect trajectories (not used in eval-only environment)."""
+ return None, []
+
+ async def score(self, rollout_group_data):
+ """Score rollouts (not used in eval-only environment)."""
+ return None
+
+ async def rollout_and_score_eval(self, eval_item: Dict) -> Dict:
+ """Evaluate a single SuperGPQA question."""
+ try:
+ question = eval_item.get("question", "")
+ choices = eval_item.get("choices", [])
+ gold_letter = eval_item.get("gold_letter", "A")
+ subfield = eval_item.get("subfield", "unknown")
+
+ if not question or len(choices) < 2:
+ return {"is_correct": None, "sample": None}
+
+ formatted_prompt = self._format_super_gpqa_prompt(question, choices)
+
+ messages = []
+ system_content = self._create_system_content()
+ if system_content:
+ messages.append({"role": "system", "content": system_content})
+ messages.append({"role": "user", "content": formatted_prompt})
+
+ model_response = None
+ finish_reason = None
+
+ completion_kwargs = {
+ "messages": messages,
+ "n": 1,
+ "temperature": self.config.eval_temperature,
+ "split": "eval",
+ }
+ if (
+ self.config.eval_max_tokens > 0
+ ): # 0 means "use model default", so we don't pass the parameter
+ completion_kwargs["max_tokens"] = self.config.eval_max_tokens
+
+ for attempt in range(self.config.max_retries):
+ try:
+ completion = await self.server.chat_completion(**completion_kwargs)
+
+ if completion.choices and completion.choices[0].message.content:
+ model_response = completion.choices[0].message.content
+ finish_reason = getattr(
+ completion.choices[0], "finish_reason", None
+ )
+
+ if (
+ len(model_response.strip())
+ >= self.config.min_response_length
+ ):
+ break
+ elif attempt < self.config.max_retries - 1:
+ if self.config.full_debug:
+ print(" Response too short, retrying...")
+ await asyncio.sleep(self.config.retry_delay)
+
+ except Exception as e:
+ # Always log API errors to help diagnose issues
+ print(
+ f" API Error (attempt {attempt + 1}/{self.config.max_retries}): {type(e).__name__}: {e}"
+ )
+ if hasattr(e, "response"):
+ try:
+ print(
+ f" Response: {e.response.text[:500] if hasattr(e.response, 'text') else e.response}"
+ )
+ except Exception:
+ pass
+ if attempt < self.config.max_retries - 1:
+ await asyncio.sleep(self.config.retry_delay)
+ else:
+ print(f" Failed after {self.config.max_retries} attempts")
+ return {"is_correct": None, "sample": None}
+
+ if not model_response:
+ return {"is_correct": None, "sample": None}
+
+ # Validate thinking format if enabled
+ format_valid, content_for_extraction = self._validate_thinking_format(
+ model_response
+ )
+
+ # Extract thinking content for logging
+ thinking_content = None
+ if self.config.thinking_mode:
+ thinking_content = self._extract_thinking_content(model_response)
+
+ extracted_answer, extraction_method = self._extract_answer(
+ content_for_extraction, num_choices=len(choices), choices=choices
+ )
+ is_correct = extracted_answer == gold_letter if extracted_answer else False
+
+ # Build sample record
+ sample = {
+ "question": question,
+ "choices": choices,
+ "gold_answer": gold_letter,
+ "model_response": model_response,
+ "extracted_answer": extracted_answer,
+ "extraction_method": extraction_method,
+ "is_correct": is_correct,
+ "subfield": subfield,
+ "finish_reason": finish_reason,
+ "response_length": len(model_response),
+ "thinking_mode": self.config.thinking_mode,
+ "format_valid": format_valid,
+ }
+
+ if self.config.thinking_mode:
+ sample["thinking_content"] = thinking_content
+ sample["response_after_think"] = (
+ content_for_extraction if format_valid else None
+ )
+
+ if self.config.full_debug:
+ status = "✓" if is_correct else "✗"
+ print(
+ f" [{status}] {subfield}: gold={gold_letter}, extracted={extracted_answer}"
+ )
+
+ return {"is_correct": is_correct, "sample": sample}
+
+ except Exception as e:
+ if self.config.full_debug:
+ print(f"Error in rollout_and_score_eval: {e}")
+ import traceback
+
+ traceback.print_exc()
+ return {"is_correct": None, "sample": None}
+
+ async def evaluate(self, *args, **kwargs) -> None:
+ """Run SuperGPQA evaluation."""
+ start_time = time.time()
+
+ print(f"\n{'='*60}")
+ print("Starting SuperGPQA Evaluation:")
+ print(f"{'='*60}")
+ print(f" Total questions: {len(self.all_eval_items)}")
+ print(f" Max tokens (for reasoning): {self.config.eval_max_tokens}")
+ print(f" Thinking mode: {self.config.thinking_mode}")
+ print(f"{'='*60}\n")
+
+ try:
+ eval_tasks = [
+ self.rollout_and_score_eval(item) for item in self.all_eval_items
+ ]
+ results = await tqdm_asyncio.gather(
+ *eval_tasks, desc="Evaluating SuperGPQA"
+ )
+
+ valid_results = [
+ r
+ for r in results
+ if r and r.get("sample") is not None and r.get("is_correct") is not None
+ ]
+
+ if not valid_results:
+ print("Warning: No valid evaluation results obtained")
+ return
+
+ except Exception as e:
+ print(f"Error during evaluation: {e}")
+ import traceback
+
+ traceback.print_exc()
+ return
+
+ end_time = time.time()
+
+ # Compute metrics
+ samples = [r["sample"] for r in valid_results]
+
+ # Overall accuracy
+ total_correct = sum(1 for r in valid_results if r["is_correct"])
+ total_count = len(valid_results)
+ overall_accuracy = total_correct / total_count if total_count > 0 else 0.0
+
+ # Per-subfield accuracy
+ subfield_results = {}
+ for sample in samples:
+ subfield = sample.get("subfield", "unknown")
+ if subfield not in subfield_results:
+ subfield_results[subfield] = {"correct": 0, "total": 0}
+ subfield_results[subfield]["total"] += 1
+ if sample["is_correct"]:
+ subfield_results[subfield]["correct"] += 1
+
+ # Extraction method statistics
+ extraction_methods = {}
+ for sample in samples:
+ method = sample.get("extraction_method", "unknown")
+ if method not in extraction_methods:
+ extraction_methods[method] = {"count": 0, "correct": 0}
+ extraction_methods[method]["count"] += 1
+ if sample["is_correct"]:
+ extraction_methods[method]["correct"] += 1
+
+ # Average response length
+ response_lengths = [s.get("response_length", 0) for s in samples]
+ avg_response_length = (
+ sum(response_lengths) / len(response_lengths) if response_lengths else 0
+ )
+
+ # Format compliance
+ format_compliant = sum(1 for s in samples if s.get("format_valid", True))
+ format_compliance_rate = format_compliant / len(samples) if samples else 0.0
+
+ # Thinking utilization
+ thinking_utilization = 0
+ if self.config.thinking_mode:
+ thinking_utilization = sum(1 for s in samples if s.get("thinking_content"))
+
+ # Build metrics dictionary
+ eval_metrics = {
+ "eval/overall_accuracy": overall_accuracy,
+ "eval/total_questions": total_count,
+ "eval/total_correct": total_correct,
+ "eval/evaluation_time_seconds": end_time - start_time,
+ "eval/avg_response_length": avg_response_length,
+ "eval/format_compliance_rate": format_compliance_rate,
+ "eval/thinking_mode_enabled": 1.0 if self.config.thinking_mode else 0.0,
+ }
+
+ if self.config.thinking_mode:
+ thinking_utilization_rate = (
+ thinking_utilization / len(samples) if samples else 0.0
+ )
+ eval_metrics["eval/thinking_utilization_rate"] = thinking_utilization_rate
+
+ # Add subfield metrics
+ for subfield, stats in subfield_results.items():
+ if stats["total"] > 0:
+ subfield_accuracy = stats["correct"] / stats["total"]
+ subfield_key = subfield.replace(" ", "_").replace("-", "_").lower()
+ eval_metrics[f"eval/subfield_{subfield_key}_accuracy"] = (
+ subfield_accuracy
+ )
+
+ # Store metrics for wandb logging
+ self.eval_metrics = [(k, v) for k, v in eval_metrics.items()]
+
+ # Print summary
+ print(f"\n{'='*60}")
+ print("SuperGPQA Evaluation Results")
+ print(f"{'='*60}")
+ print(
+ f"Overall Accuracy: {overall_accuracy:.4f} ({total_correct}/{total_count})"
+ )
+ print(f"Evaluation Time: {end_time - start_time:.1f} seconds")
+ print(f"Avg Response Length: {avg_response_length:.0f} chars")
+ if self.config.thinking_mode:
+ print(f"Format Compliance: {format_compliance_rate:.4f}")
+ print(f"Thinking Utilization: {thinking_utilization}/{total_count}")
+
+ print("\nSubfield Breakdown:")
+ for subfield, stats in sorted(subfield_results.items()):
+ if stats["total"] > 0:
+ subfield_acc = stats["correct"] / stats["total"]
+ print(
+ f" {subfield}: {subfield_acc:.4f} ({stats['correct']}/{stats['total']})"
+ )
+
+ print("\nExtraction Method Statistics:")
+ for method, stats in sorted(
+ extraction_methods.items(), key=lambda x: -x[1]["count"]
+ ):
+ if stats["count"] > 0:
+ method_acc = stats["correct"] / stats["count"]
+ print(f" {method}: {stats['count']} uses, {method_acc:.4f} accuracy")
+
+ print(f"{'='*60}\n")
+
+ try:
+ await self.evaluate_log(
+ metrics=eval_metrics,
+ start_time=start_time,
+ end_time=end_time,
+ samples=samples,
+ generation_parameters={
+ "temperature": self.config.eval_temperature,
+ "max_tokens": self.config.eval_max_tokens,
+ "thinking_mode": self.config.thinking_mode,
+ },
+ )
+ except Exception as e:
+ print(f"Error logging evaluation results: {e}")
+
+ async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
+ """Log metrics to wandb."""
+ if wandb_metrics is None:
+ wandb_metrics = {}
+
+ for metric_name, metric_value in self.eval_metrics:
+ wandb_metrics[metric_name] = metric_value
+ self.eval_metrics = []
+
+ wandb_metrics["config/thinking_mode"] = (
+ 1.0 if self.config.thinking_mode else 0.0
+ )
+ wandb_metrics["config/eval_max_tokens"] = self.config.eval_max_tokens
+ await super().wandb_log(wandb_metrics)
+
+
+if __name__ == "__main__":
+ SuperGPQAEvalEnv.cli()