Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
396 changes: 396 additions & 0 deletions torchtitan/grpo/test/gsm8k_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,396 @@
import random
import time
from typing import Dict, List, Optional, Tuple, TypedDict, Union

from datasets import load_dataset
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
from tqdm.asyncio import tqdm_asyncio

from atroposlib.envs.base import (
APIServerConfig,
BaseEnv,
BaseEnvConfig,
ScoredDataGroup,
)
from atroposlib.type_definitions import Item

system_prompt = (
"You are a deep thinking AI, you may use extremely long chains of thought "
"to deeply consider the problem and deliberate with yourself via systematic "
"reasoning processes to help come to a correct solution prior to answering. "
"You should enclose your thoughts and internal monologue inside <think> </think> "
"tags, and then provide your solution or response to the problem.\n\n"
)

system_prompt += """You are allocated a maximum of 2048 tokens, please strive to use less.

You will then provide your answer like this: \\boxed{your answer here}
It is important that you provide your answer in the correct format.
If you do not, you will not receive credit for your answer.
So please end your answer with \\boxed{your answer here}"""


class GSM8kRow(TypedDict):
question: str
answer: str


class GSM8kEnv(BaseEnv):

name = "gsm8k"

def __init__(
self,
config: BaseEnvConfig,
server_configs: List[APIServerConfig],
slurm=True,
testing=False,
):
super().__init__(config, server_configs, slurm, testing)
print(f"DEBUG: GSM8kEnv initialized with {len(self.server.servers)} servers")
for i, server in enumerate(self.server.servers):
if hasattr(server, 'config'):
print(f"DEBUG: Server {i}: {server.config.base_url}")
self.percent_correct_buffer = list()
self.eval_metrics = list()
# Add tracking for wandb visualizations
self.rollouts_for_wandb = []
self.completion_lengths = []

@classmethod
def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]:
env_config = BaseEnvConfig(
tokenizer_name="Qwen/Qwen2.5-7B",
group_size=8,
use_wandb=True,
rollout_server_url="http://localhost:8000",
total_steps=1000,
batch_size=12,
steps_per_eval=100,
max_token_length=2048,
wandb_name="gsm8k_qwen3_test",
)
server_configs = [
APIServerConfig(
model_name="Qwen/Qwen2.5-7B",
base_url="http://localhost:9001/v1",
api_key="x",
num_requests_for_eval=256,
),
]

return env_config, server_configs

async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
if wandb_metrics is None:
wandb_metrics = {}

# Try to calculate percent_correct, pass if there's a division by zero
try:
wandb_metrics["train/percent_correct"] = sum(
self.percent_correct_buffer
) / len(self.percent_correct_buffer)
except ZeroDivisionError:
# Skip if buffer is empty
pass

self.percent_correct_buffer = list()
for item in self.eval_metrics:
wandb_metrics[item[0]] = item[1]
self.eval_metrics = list()
# Call the parent method to handle the server metrics
await super().wandb_log(wandb_metrics)

async def setup(self):
self.train = load_dataset("gsm8k", "main", split="train").shuffle(seed=42)
test_data = load_dataset("gsm8k", "main", split="test").shuffle(seed=42)
self.test = list()
for item in test_data:
self.test.append(
{
"question": item["question"],
"gold_answer": item["answer"]
.split("#")[-1]
.strip()
.replace(",", ""),
}
)
self.iter = 0

def save_checkpoint(self, step, data=None):
if data is None:
data = {}
data["iter"] = self.iter
super().save_checkpoint(step, data)

async def rollout_and_score_eval(self, question: str, answer: str) -> dict:
"""Rollout and score evaluation with detailed sample data collection."""

async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
completion = await managed.chat_completion(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": question},
],
n=1,
max_tokens=self.config.max_token_length,
temperature=0.6,
)

response_content = completion.choices[0].message.content

# Parse gold answer
gold_parsed = parse(
"\\boxed{" + answer + "}",
extraction_mode="first_match",
extraction_config=[LatexExtractionConfig()],
)

# Parse model answer
answer_parsed = parse(
response_content.split("</think>")[-1],
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
equations=True,
boxed="all",
units=True,
),
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)

score = 1 if verify(answer_parsed, gold_parsed) else 0

sample = {
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": question},
{"role": "assistant", "content": response_content},
],
"question": question,
"gold_answer": answer,
"gold_parsed": str(gold_parsed) if gold_parsed else None,
"model_parsed": str(answer_parsed) if answer_parsed else None,
"score": int(score),
"correct": bool(score),
"finish_reason": completion.choices[0].finish_reason,
"response_after_think": (
response_content.split("</think>")[-1]
if "</think>" in response_content
else response_content
),
}

return {"score": score, "sample": sample}

async def evaluate(self, *args, **kwargs):
start_time = time.time()

eval_tasks = []
for item in self.test:
eval_tasks.append(
self.rollout_and_score_eval(item["question"], item["gold_answer"])
)
results = await tqdm_asyncio.gather(*eval_tasks)

# Extract scores and samples
scores = [result["score"] for result in results]
samples = [result["sample"] for result in results]

percent_correct = sum(scores) / len(scores)

end_time = time.time()

# Add to existing metrics for wandb
self.eval_metrics.append(("eval/percent_correct", percent_correct))

# Log evaluation results
eval_metrics = {
"eval/percent_correct": percent_correct,
}

await self.evaluate_log(
metrics=eval_metrics,
samples=samples,
start_time=start_time,
end_time=end_time,
generation_parameters={
"temperature": 0.0,
"max_tokens": self.config.max_token_length,
},
)

async def collect_trajectories(
self, item: GSM8kRow
) -> Tuple[ScoredDataGroup, list[Item]]:
print(f"DEBUG: collect_trajectories() called for question: {item['question'][:80]}...")
try:
user_message = {"role": "user", "content": item["question"]}
gold_answer = (
"\\boxed{" + item["answer"].split("#")[-1].strip().replace(",", "") + "}"
)

print(f"DEBUG: About to call managed_server...")
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
print(f"DEBUG: Inside managed_server context, about to call chat_completion...")

chat_completions = await managed.chat_completion(
messages=[{"role": "system", "content": system_prompt}, user_message],
n=self.config.group_size,
max_tokens=self.config.max_token_length,
temperature=1.0,
)
print(f"DEBUG: chat_completion returned, got {len(chat_completions.choices)} completions")

state = managed.get_state()
nodes = state["nodes"]
print(f"DEBUG: Got state with {len(nodes)} nodes")
except Exception as e:
print(f"ERROR in collect_trajectories: {type(e).__name__}: {e}")
import traceback
traceback.print_exc()
return None, []

to_score = list()
to_backlog = list()
for i, chat_completion in enumerate(chat_completions.choices):
messages = (
{"role": "system", "content": system_prompt},
user_message,
{"role": "assistant", "content": chat_completion.message.content},
)
to_score.append(
{
"messages": messages,
"gold_answer": gold_answer,
"finish_reason": chat_completion.finish_reason,
"tokens": nodes[i].tokens,
"masks": nodes[i].masked_tokens,
"logprobs": nodes[i].logprobs,
}
)
to_postprocess = await self.score(to_score)
return to_postprocess, to_backlog

async def score(
self, rollout_group_data
) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]:
print(f"DEBUG: score() called with {len(rollout_group_data)} rollouts")
scores = ScoredDataGroup()
scores["tokens"] = list()
scores["masks"] = list()
scores["scores"] = list()
scores["inference_logprobs"] = list()
gold_parsed = parse(
rollout_group_data[0]["gold_answer"],
extraction_mode="first_match",
extraction_config=[LatexExtractionConfig()],
)
print(f"DEBUG: Gold answer parsed: {len(gold_parsed)} elements")
if len(gold_parsed) != 0:
# We require the answer to be provided in correct latex (no malformed operators)
random.shuffle(rollout_group_data)
for item in rollout_group_data:
# print(item[0][-1]["content"])
answer_parsed = parse(
item["messages"][-1]["content"].split("</think>")[-1],
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
equations=True,
boxed="all",
units=True,
),
# Ensures that boxed is tried first
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)
# Reward 1 if the content is the same as the ground truth, 0 otherwise
reward = verify(answer_parsed, gold_parsed)

tokens = item["tokens"]
masks = item["masks"]
logprobs = item["logprobs"]

# remove obviously bad examples
num_valid_tokens = len([1 for i in masks if i != -100])
if num_valid_tokens < 5: # Lowered from 10 to 5 to be less strict
print(f"Filtering out sample with only {num_valid_tokens} valid tokens")
continue
scores["tokens"].append(tokens)
scores["masks"].append(masks)
scores["inference_logprobs"].append(logprobs)
scores["scores"].append(1.0 if reward else -1.0)

if len(scores["tokens"]) >= self.config.group_size:
break

# Check if we have enough valid samples after filtering
if len(scores["tokens"]) < self.config.group_size:
print(f"Warning: Only got {len(scores['tokens'])} samples after filtering, need {self.config.group_size}")
return None

for score in scores["scores"]:
self.percent_correct_buffer.append(max(score, 0))

# check if all the same
# print(scores['scores'])
if all([score == 1 for score in scores["scores"]]):
# Do length penalty :)
token_lengths = [len(token) for token in scores["tokens"]]
if max(token_lengths) == 0:
# What? But don't want to crash a run so just in case...
return None

# Get max allowed token length from config
max_allowed_length = self.config.max_token_length
# Set threshold at 50% of max_token_length - no penalty below this
length_threshold = max_allowed_length * 0.5

# Apply modified length penalty with threshold
scores["scores"] = []
for length in token_lengths:
if length <= length_threshold:
# No penalty for responses under threshold
scores["scores"].append(1.0)
else:
# Calculate how far we are between threshold and max as a percentage
percentage_of_range = (length - length_threshold) / (
max_allowed_length - length_threshold
)
# Cap at 1.0 in case length exceeds max_allowed_length
percentage_of_range = min(percentage_of_range, 1.0)
# Apply linear penalty scaling from 1.0 down to 0.0
scores["scores"].append(1.0 - percentage_of_range)
# allow training even when all scores are identical
# if all([scores["scores"][0] == score for score in scores["scores"]]):
# return None # If all the same, we return None
print(f"DEBUG: Returning scores with {len(scores['tokens'])} samples, scores: {scores['scores']}")
return scores
else:
# If the gold solution is not parseable, we return None
print("DEBUG: Gold solution not parseable, returning None")
return None

async def get_next_item(self) -> GSM8kRow:
next_item = self.train[self.iter % len(self.train)]
self.iter += 1
return next_item


if __name__ == "__main__":
GSM8kEnv.cli()
Loading