diff --git a/environments/anagram_environment.py b/environments/anagram_environment.py
new file mode 100644
index 000000000..48208b86b
--- /dev/null
+++ b/environments/anagram_environment.py
@@ -0,0 +1,955 @@
+"""
+Anagram Word Puzzle Environment
+
+This environment trains models to unscramble anagrams - words with their letters
+randomly rearranged. The model must identify the original English word from the
+scrambled version.
+
+Example:
+- Scrambled: "elppa" -> Answer: "apple"
+- Scrambled: "nragle" -> Answer: "learng" or "glaner" (must be valid word)
+
+This tests pattern recognition, vocabulary knowledge, and reasoning skills.
+"""
+
+import asyncio
+import random
+import re
+import time
+from typing import Dict, List, Optional, Tuple, Union
+
+import wandb
+from pydantic import Field
+from tqdm.asyncio import tqdm_asyncio
+
+from atroposlib.envs.base import (
+ APIServerConfig,
+ BaseEnv,
+ BaseEnvConfig,
+ Item,
+ ScoredDataGroup,
+)
+from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
+
+# Built-in word list for training (common English words of various lengths)
+DEFAULT_WORD_LIST = [
+ # 3-4 letter words
+ "cat",
+ "dog",
+ "run",
+ "sun",
+ "hat",
+ "bat",
+ "map",
+ "cup",
+ "pen",
+ "bed",
+ "red",
+ "big",
+ "hot",
+ "top",
+ "box",
+ "fox",
+ "mix",
+ "six",
+ "day",
+ "way",
+ "say",
+ "may",
+ "pay",
+ "lay",
+ "bay",
+ "ray",
+ "hay",
+ "key",
+ "boy",
+ "toy",
+ "joy",
+ "cow",
+ "how",
+ "now",
+ "row",
+ "low",
+ "new",
+ "few",
+ "dew",
+ "sew",
+ "air",
+ "ear",
+ "far",
+ "car",
+ "bar",
+ "jar",
+ "tar",
+ "war",
+ "arm",
+ "art",
+ "ant",
+ "act",
+ "add",
+ "age",
+ "ago",
+ "aid",
+ "aim",
+ "all",
+ "and",
+ "any",
+ # 5 letter words
+ "apple",
+ "beach",
+ "chair",
+ "dance",
+ "eagle",
+ "flame",
+ "grape",
+ "house",
+ "inner",
+ "joker",
+ "knife",
+ "lemon",
+ "mango",
+ "nurse",
+ "ocean",
+ "piano",
+ "queen",
+ "river",
+ "snake",
+ "table",
+ "uncle",
+ "video",
+ "water",
+ "xenon",
+ "yacht",
+ "zebra",
+ "brain",
+ "climb",
+ "dream",
+ "earth",
+ "fresh",
+ "giant",
+ "happy",
+ "image",
+ "judge",
+ "knock",
+ "laugh",
+ "magic",
+ "night",
+ "olive",
+ "peace",
+ "quick",
+ "robot",
+ "stone",
+ "tiger",
+ "unity",
+ "voice",
+ "world",
+ # 6 letter words
+ "basket",
+ "candle",
+ "desert",
+ "engine",
+ "flower",
+ "garden",
+ "helmet",
+ "insect",
+ "jacket",
+ "kettle",
+ "laptop",
+ "marble",
+ "needle",
+ "orange",
+ "pepper",
+ "quartz",
+ "rabbit",
+ "silver",
+ "temple",
+ "unique",
+ "velvet",
+ "wallet",
+ "yellow",
+ "zipper",
+ "anchor",
+ "bridge",
+ "castle",
+ "donkey",
+ "escape",
+ "frozen",
+ "ginger",
+ "honest",
+ "island",
+ "jingle",
+ "kitten",
+ # 7 letter words
+ "amazing",
+ "balance",
+ "captain",
+ "diamond",
+ "elegant",
+ "fantasy",
+ "general",
+ "healthy",
+ "imagine",
+ "jealous",
+ "kingdom",
+ "library",
+ "machine",
+ "natural",
+ "obvious",
+ "perfect",
+ "quality",
+ "rainbow",
+ "science",
+ "teacher",
+ "unusual",
+ "village",
+ "weather",
+ "example",
+ "younger",
+ "zealous",
+ "ancient",
+ "brother",
+ "chicken",
+ "dolphin",
+ "emperor",
+ "fiction",
+ "glacier",
+ "harvest",
+ "iceberg",
+ # 8+ letter words
+ "absolute",
+ "baseball",
+ "children",
+ "daughter",
+ "elephant",
+ "familiar",
+ "generous",
+ "handsome",
+ "innocent",
+ "kangaroo",
+ "language",
+ "minister",
+ "notebook",
+ "opposite",
+ "pleasure",
+ "question",
+ "romantic",
+ "shoulder",
+ "thousand",
+ "universe",
+ "valuable",
+ "wonderful",
+ "airplane",
+ "birthday",
+ "computer",
+ "dinosaur",
+ "exercise",
+ "function",
+ "grateful",
+ "hospital",
+]
+
+
+def scramble_word(word: str) -> str:
+ """
+ Scramble a word's letters randomly, ensuring it's different from original.
+
+ Args:
+ word: The original word to scramble
+
+ Returns:
+ A scrambled version of the word (different from original if len > 1)
+ """
+ if len(word) <= 1:
+ return word
+
+ letters = list(word.lower())
+ original = word.lower()
+
+ # Try to get a different arrangement
+ max_attempts = 100
+ for _ in range(max_attempts):
+ random.shuffle(letters)
+ scrambled = "".join(letters)
+ if scrambled != original:
+ return scrambled
+
+ # If word has all same letters (like "aaa"), just return it
+ return "".join(letters)
+
+
+class AnagramConfig(BaseEnvConfig):
+ """Configuration for AnagramEnv with customizable options."""
+
+ thinking_mode: bool = Field(
+ default=True,
+ description="Whether to enable thinking mode with tags.",
+ )
+
+ custom_thinking_prompt: Optional[str] = Field(
+ default=None,
+ description="Custom thinking prompt. If None, uses the default thinking prompt.",
+ )
+
+ eval_temperature: float = Field(
+ default=0.6,
+ description="Temperature for evaluation completions.",
+ )
+
+ rollout_temperature: float = Field(
+ default=1.0,
+ description="Temperature for training rollout completions.",
+ )
+
+ eval_max_tokens: int = Field(
+ default=2048,
+ description="Maximum tokens for evaluation completions.",
+ )
+
+ train_max_tokens: int = Field(
+ default=2048,
+ description="Maximum tokens for training completions.",
+ )
+
+ max_retries: int = Field(
+ default=3,
+ ge=1,
+ description="Maximum number of retries for failed API calls.",
+ )
+
+ retry_delay: float = Field(
+ default=1.0,
+ ge=0.0,
+ description="Delay in seconds between retry attempts.",
+ )
+
+ min_response_length: int = Field(
+ default=3,
+ ge=1,
+ description="Minimum response length to consider valid.",
+ )
+
+ min_word_length: int = Field(
+ default=4,
+ ge=3,
+ description="Minimum word length to use for anagram puzzles.",
+ )
+
+ max_word_length: int = Field(
+ default=10,
+ ge=4,
+ description="Maximum word length to use for anagram puzzles.",
+ )
+
+
+class AnagramEnv(BaseEnv):
+ """
+ Anagram Word Puzzle Environment.
+
+ This environment presents scrambled words to the model and rewards it for
+ correctly identifying the original word. It tests vocabulary knowledge,
+ pattern recognition, and reasoning abilities.
+ """
+
+ name = "anagram"
+ env_config_cls = AnagramConfig
+
+ def __init__(
+ self,
+ config: AnagramConfig,
+ server_configs: List[APIServerConfig],
+ slurm=True,
+ testing=False,
+ ):
+ super().__init__(config, server_configs, slurm, testing)
+ self.config: AnagramConfig = config
+ self.percent_correct_buffer = []
+ self.eval_metrics = []
+
+ # Tracking metrics
+ self.successful_solves = 0
+ self.failed_solves = 0
+ self.format_errors = 0
+ self.total_attempts = 0
+ self.rollouts_for_wandb = []
+
+ # Pre-compile regex patterns
+ self._think_pattern = re.compile(r"")
+ self._think_close_pattern = re.compile(r"")
+ self._think_content_pattern = re.compile(r"\s*(.*)", re.DOTALL)
+ self._answer_pattern = re.compile(r"\s*(.*?)\s*", re.DOTALL)
+
+ # System prompts
+ self.thinking_system_prompt = self._get_thinking_prompt()
+ self.base_system_prompt = (
+ "You are an expert at solving anagram puzzles. "
+ "Given a scrambled word, identify the original English word. "
+ "Wrap your final answer in tags."
+ )
+
+ def _get_thinking_prompt(self) -> str:
+ """Get thinking system prompt."""
+ return (
+ self.config.custom_thinking_prompt
+ if self.config.custom_thinking_prompt
+ else "You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the "
+ "problem and deliberate with yourself via systematic reasoning processes to help come to a correct "
+ "solution prior to answering. You should enclose your thoughts and internal monologue inside "
+ " tags, and then provide your solution or response to the problem."
+ )
+
+ def _reset_metrics(self) -> None:
+ """Reset training metrics."""
+ self.percent_correct_buffer = []
+ self.successful_solves = 0
+ self.failed_solves = 0
+ self.format_errors = 0
+ self.total_attempts = 0
+
+ def _create_system_content(self) -> str:
+ """Create system message content based on thinking mode."""
+ if self.config.thinking_mode:
+ return f"{self.thinking_system_prompt}\n\n{self.base_system_prompt}"
+ return self.base_system_prompt
+
+ @classmethod
+ def config_init(cls) -> Tuple[AnagramConfig, List[APIServerConfig]]:
+ """Initialize default configuration."""
+ env_config = AnagramConfig(
+ tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
+ group_size=8,
+ use_wandb=True,
+ max_num_workers_per_node=8,
+ rollout_server_url="http://localhost:8000",
+ total_steps=2000,
+ batch_size=512,
+ steps_per_eval=25,
+ train_max_tokens=2048,
+ eval_max_tokens=2048,
+ thinking_mode=True,
+ wandb_name="anagram",
+ )
+ server_configs = [
+ APIServerConfig(
+ model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
+ base_url="http://localhost:9004/v1",
+ api_key="x",
+ ),
+ ]
+ return env_config, server_configs
+
+ async def setup(self) -> None:
+ """Set up the environment by preparing word lists."""
+ # Filter words by length configuration
+ self.word_list = [
+ word
+ for word in DEFAULT_WORD_LIST
+ if self.config.min_word_length <= len(word) <= self.config.max_word_length
+ ]
+
+ # Create separate eval set (last 20% of words)
+ split_idx = int(len(self.word_list) * 0.8)
+ self.train_words = self.word_list[:split_idx]
+ self.eval_words = self.word_list[split_idx:]
+
+ # Shuffle training words
+ random.seed(42)
+ random.shuffle(self.train_words)
+
+ print(f"\nAnagram Environment Configuration:")
+ print(
+ f" - Word length range: {self.config.min_word_length}-{self.config.max_word_length}"
+ )
+ print(f" - Training words: {len(self.train_words)}")
+ print(f" - Evaluation words: {len(self.eval_words)}")
+ print(f" - Thinking mode: {self.config.thinking_mode}")
+ print(f" - Sample words: {self.train_words[:5]}")
+
+ self.iter = 0
+
+ def _extract_answer(self, response: str) -> Optional[str]:
+ """
+ Extract the answer from within tags.
+
+ Args:
+ response: Model response text
+
+ Returns:
+ Extracted answer or None if not found/invalid format
+ """
+ if self.config.thinking_mode:
+ # Check for exactly one pair of think tags
+ think_open_count = len(self._think_pattern.findall(response))
+ think_close_count = len(self._think_close_pattern.findall(response))
+
+ if think_open_count != 1 or think_close_count != 1:
+ return None
+
+ # Parse only content after tags
+ match = self._think_content_pattern.search(response)
+ if match:
+ response = match.group(1)
+ else:
+ return None
+
+ # Find answer between tags
+ matches = self._answer_pattern.findall(response)
+
+ # Must have exactly one answer block
+ if len(matches) != 1:
+ return None
+
+ return matches[0].strip().lower()
+
+ def _create_anagram_prompt(self, scrambled: str, original: str) -> str:
+ """Create the user prompt for anagram solving task."""
+ return (
+ f"Unscramble the following letters to form a valid English word.\n\n"
+ f"Scrambled letters: {scrambled}\n\n"
+ f"Hint: The word has {len(original)} letters.\n\n"
+ f"Provide your answer wrapped in tags."
+ )
+
+ async def get_next_item(self) -> Item:
+ """Generate next training item with anagram puzzle."""
+ self.iter += 1
+
+ # Get next word
+ original_word = self.train_words[self.iter % len(self.train_words)]
+ scrambled_word = scramble_word(original_word)
+
+ # Create system message
+ system_content = self._create_system_content()
+
+ # Create user prompt
+ user_content = self._create_anagram_prompt(scrambled_word, original_word)
+
+ prompt = tuple(
+ [
+ frozenset({"role": "system", "content": system_content}.items()),
+ frozenset({"role": "user", "content": user_content}.items()),
+ ]
+ )
+
+ return (prompt, original_word.lower())
+
+ def _convert_messages_to_list(self, prompt_tuple: Tuple) -> List[Dict]:
+ """Convert frozenset message format to list format."""
+ messages = []
+ for role_dict in prompt_tuple:
+ messages.append(dict(role_dict))
+ return messages
+
+ def _get_train_completion_params(self) -> Dict:
+ """Get completion parameters for training rollouts."""
+ return {
+ "n": self.config.group_size,
+ "max_tokens": self.config.train_max_tokens,
+ "temperature": self.config.rollout_temperature,
+ }
+
+ def _get_eval_completion_params(self) -> Dict:
+ """Get completion parameters for evaluation."""
+ return {
+ "n": 1,
+ "max_tokens": self.config.eval_max_tokens,
+ "temperature": self.config.eval_temperature,
+ "split": "eval",
+ }
+
+ async def collect_trajectories(self, item: Item) -> Tuple[ScoredDataGroup, List]:
+ """Collect and score model trajectories."""
+ messages = self._convert_messages_to_list(item[0])
+ completion_params = self._get_train_completion_params()
+
+ max_retries = self.config.max_retries
+ retry_delay = self.config.retry_delay
+
+ for attempt in range(max_retries):
+ try:
+ completions = await self.server.chat_completion(
+ messages=messages, **completion_params
+ )
+
+ if not completions.choices:
+ if attempt < max_retries - 1:
+ await asyncio.sleep(retry_delay)
+ continue
+ return None, []
+
+ # Filter valid completions
+ valid_completions = []
+ for completion_choice in completions.choices:
+ if (
+ completion_choice.message.content is not None
+ and isinstance(completion_choice.message.content, str)
+ and len(completion_choice.message.content.strip())
+ >= self.config.min_response_length
+ ):
+ valid_completions.append(completion_choice)
+
+ if len(valid_completions) < len(completions.choices) // 2:
+ if attempt < max_retries - 1:
+ await asyncio.sleep(retry_delay)
+ continue
+
+ # Build trajectories
+ to_score = []
+ for completion_choice in valid_completions:
+ trajectory_messages = messages + [
+ {
+ "role": "assistant",
+ "content": completion_choice.message.content,
+ }
+ ]
+ to_score.append((tuple(trajectory_messages), item[1]))
+
+ break
+
+ except Exception as e:
+ if attempt < max_retries - 1:
+ await asyncio.sleep(retry_delay)
+ continue
+ return None, []
+
+ scored_data = await self.score(to_score)
+
+ if scored_data is not None:
+ await self.add_rollouts_for_wandb(scored_data, item)
+
+ return scored_data, []
+
+ async def score(self, rollout_group_data: List[Tuple]) -> Optional[ScoredDataGroup]:
+ """Score a group of rollout data."""
+ if not rollout_group_data:
+ return None
+
+ try:
+ scores = ScoredDataGroup()
+ scores["tokens"] = []
+ scores["masks"] = []
+ scores["scores"] = []
+
+ random.shuffle(rollout_group_data)
+
+ for item in rollout_group_data:
+ if not item or len(item) < 2 or not item[0]:
+ continue
+
+ model_response = item[0][-1]["content"]
+ expected_answer = item[1]
+
+ # Extract answer from model response
+ extracted_answer = self._extract_answer(model_response)
+
+ # Score 1.0 if exact match, 0.0 otherwise
+ reward = 1.0 if extracted_answer == expected_answer else 0.0
+
+ # Track metrics
+ self.total_attempts += 1
+ if reward == 1.0:
+ self.successful_solves += 1
+ else:
+ self.failed_solves += 1
+ if extracted_answer is None:
+ self.format_errors += 1
+
+ # Tokenize the conversation
+ out_dict = tokenize_for_trainer(self.tokenizer, item[0])
+ tokens = out_dict["tokens"]
+ masks = out_dict["masks"]
+
+ # Skip obviously bad examples
+ if len([1 for mask in masks if mask != -100]) < 10:
+ continue
+
+ scores["tokens"].append(tokens)
+ scores["masks"].append(masks)
+ scores["scores"].append(reward)
+
+ if len(scores["tokens"]) >= self.config.group_size:
+ break
+
+ if not scores["tokens"]:
+ return None
+
+ # Log group results
+ group_successes = sum(1 for score in scores["scores"] if score == 1.0)
+ group_size = len(scores["scores"])
+ success_indicator = "✅" if group_successes > 0 else "❌"
+
+ total_success_rate = (
+ (self.successful_solves / self.total_attempts * 100)
+ if self.total_attempts > 0
+ else 0.0
+ )
+
+ print(
+ f"{success_indicator} Group scored: {group_successes}/{group_size} solved | "
+ f"Total success rate: {self.successful_solves}/{self.total_attempts} ({total_success_rate:.1f}%)"
+ )
+
+ # Update buffer
+ for score in scores["scores"]:
+ self.percent_correct_buffer.append(max(score, 0))
+
+ # Return None if all scores are the same (no learning signal)
+ if len(set(scores["scores"])) == 1:
+ return None
+
+ return scores
+
+ except Exception as e:
+ print(f"Error in score method: {e}")
+ return None
+
+ async def rollout_and_score_eval(self, word: str) -> dict:
+ """Rollout and score evaluation for a single word."""
+ try:
+ scrambled = scramble_word(word)
+ expected_answer = word.lower()
+
+ system_content = self._create_system_content()
+ user_content = self._create_anagram_prompt(scrambled, word)
+
+ messages = [
+ {"role": "system", "content": system_content},
+ {"role": "user", "content": user_content},
+ ]
+
+ completion_params = self._get_eval_completion_params()
+
+ for attempt in range(self.config.max_retries):
+ try:
+ completion = await self.server.chat_completion(
+ messages=messages, **completion_params
+ )
+
+ if not completion.choices:
+ if attempt < self.config.max_retries - 1:
+ await asyncio.sleep(self.config.retry_delay)
+ continue
+ return {"score": 0.0, "sample": None}
+
+ model_response = completion.choices[0].message.content
+
+ if (
+ model_response is None
+ or len(model_response.strip()) < self.config.min_response_length
+ ):
+ if attempt < self.config.max_retries - 1:
+ await asyncio.sleep(self.config.retry_delay)
+ continue
+ return {"score": 0.0, "sample": None}
+
+ break
+
+ except Exception:
+ if attempt < self.config.max_retries - 1:
+ await asyncio.sleep(self.config.retry_delay)
+ continue
+ return {"score": 0.0, "sample": None}
+
+ extracted_answer = self._extract_answer(model_response)
+ score = 1.0 if extracted_answer == expected_answer else 0.0
+
+ full_messages = messages + [
+ {"role": "assistant", "content": model_response}
+ ]
+
+ sample = {
+ "messages": full_messages,
+ "scrambled_word": scrambled,
+ "original_word": word,
+ "expected_answer": expected_answer,
+ "extracted_answer": extracted_answer,
+ "score": int(score),
+ "correct": bool(score),
+ "format_compliant": extracted_answer is not None,
+ }
+
+ return {"score": score, "sample": sample}
+
+ except Exception as e:
+ print(f"Error in evaluation: {e}")
+ return {"score": 0.0, "sample": None}
+
+ async def evaluate(self, *args, **kwargs) -> None:
+ """Evaluate the model on the evaluation word set."""
+ start_time = time.time()
+
+ try:
+ eval_tasks = [self.rollout_and_score_eval(word) for word in self.eval_words]
+ results = await tqdm_asyncio.gather(*eval_tasks)
+
+ valid_results = [
+ result
+ for result in results
+ if not isinstance(result, Exception)
+ and result
+ and result.get("sample") is not None
+ ]
+
+ if not valid_results:
+ print("Warning: No valid evaluation results obtained")
+ return
+
+ except Exception as e:
+ print(f"Error during evaluation: {e}")
+ return
+
+ scores = [result["score"] for result in valid_results]
+ samples = [result["sample"] for result in valid_results]
+ valid_scores = [s for s in scores if s is not None]
+
+ if not valid_scores:
+ print("Warning: No valid scores found during evaluation")
+ return
+
+ percent_correct = sum(valid_scores) / len(valid_scores)
+ self.eval_metrics.append(("eval/percent_correct", percent_correct))
+
+ format_compliant = sum(
+ 1 for sample in samples if sample.get("format_compliant", False)
+ )
+
+ end_time = time.time()
+
+ eval_metrics = {
+ "eval/percent_correct": percent_correct,
+ "eval/total_samples": len(samples),
+ "eval/correct_samples": sum(valid_scores),
+ "eval/format_compliance_rate": (
+ format_compliant / len(samples) if samples else 0.0
+ ),
+ }
+
+ try:
+ await self.evaluate_log(
+ metrics=eval_metrics,
+ samples=samples,
+ start_time=start_time,
+ end_time=end_time,
+ generation_parameters={
+ "temperature": self.config.eval_temperature,
+ "max_tokens": self.config.eval_max_tokens,
+ "thinking_mode": self.config.thinking_mode,
+ },
+ )
+ except Exception as e:
+ print(f"Error logging evaluation results: {e}")
+
+ async def add_rollouts_for_wandb(
+ self,
+ scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]],
+ item: Item = None,
+ ) -> None:
+ """Add rollouts to wandb for visualization."""
+ if item is None or scored_data is None or not scored_data.get("tokens"):
+ return
+
+ expected_answer = item[1]
+
+ # Extract scrambled word from the item prompt
+ scrambled_word = "unknown"
+ try:
+ for role_dict in item[0]:
+ role_dict_converted = dict(role_dict)
+ if role_dict_converted.get("role") == "user":
+ user_content = role_dict_converted.get("content", "")
+ if "Scrambled letters:" in user_content:
+ start = user_content.find("Scrambled letters:") + len(
+ "Scrambled letters:"
+ )
+ end = user_content.find("\n", start)
+ scrambled_word = user_content[start:end].strip()
+ break
+ except Exception:
+ scrambled_word = "extraction_failed"
+
+ num_keep = self.config.num_rollouts_per_group_for_logging
+ if num_keep == -1:
+ num_keep = self.config.group_size
+
+ num_keep = min(num_keep, len(scored_data["tokens"]))
+
+ current_rollouts = []
+
+ for i in range(num_keep):
+ full_text = self.tokenizer.decode(
+ scored_data["tokens"][i], skip_special_tokens=True
+ )
+ score_val = scored_data["scores"][i]
+
+ current_rollouts.append(
+ (full_text, score_val, expected_answer, scrambled_word)
+ )
+
+ self.rollouts_for_wandb.append(current_rollouts)
+
+ if len(self.rollouts_for_wandb) > self.config.num_rollouts_to_keep:
+ self.rollouts_for_wandb.pop(0)
+
+ async def create_rollout_table(self, wandb_metrics: Dict) -> Dict:
+ """Create wandb table for rollout visualization."""
+ if not self.rollouts_for_wandb:
+ return wandb_metrics
+
+ table = wandb.Table(
+ columns=["full_text", "score", "expected_answer", "scrambled_word"]
+ )
+
+ for group_rollouts in self.rollouts_for_wandb:
+ for rollout_tuple in group_rollouts:
+ if len(rollout_tuple) == 4:
+ table.add_data(*rollout_tuple)
+
+ wandb_metrics["train/rollouts"] = table
+ self.rollouts_for_wandb = []
+ return wandb_metrics
+
+ async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
+ """Log metrics to wandb."""
+ if wandb_metrics is None:
+ wandb_metrics = {}
+
+ if self.percent_correct_buffer:
+ wandb_metrics["train/percent_correct"] = sum(
+ self.percent_correct_buffer
+ ) / len(self.percent_correct_buffer)
+
+ if self.total_attempts > 0:
+ wandb_metrics["train/success_rate"] = (
+ self.successful_solves / self.total_attempts
+ )
+ wandb_metrics["train/failure_rate"] = (
+ self.failed_solves / self.total_attempts
+ )
+ wandb_metrics["train/format_error_rate"] = (
+ self.format_errors / self.total_attempts
+ )
+
+ wandb_metrics.update(
+ {
+ "train/thinking_mode_enabled": (
+ 1.0 if self.config.thinking_mode else 0.0
+ ),
+ "train/total_attempts": self.total_attempts,
+ "train/successful_solves": self.successful_solves,
+ "train/failed_solves": self.failed_solves,
+ "train/format_errors": self.format_errors,
+ }
+ )
+
+ self._reset_metrics()
+
+ for metric_name, metric_value in self.eval_metrics:
+ wandb_metrics[metric_name] = metric_value
+ self.eval_metrics = []
+
+ wandb_metrics = await self.create_rollout_table(wandb_metrics)
+
+ await super().wandb_log(wandb_metrics)
+
+
+if __name__ == "__main__":
+ AnagramEnv.cli()