Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions areal/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
"virl39k",
"hh-rlhf",
"torl_data",
"MATH-500",
"amc12"
]

logger = logging.getLogger("Dataset")
Expand Down Expand Up @@ -123,6 +125,24 @@ def _get_custom_dataset(
max_length=max_length,
**kwargs,
)
elif "MATH-500" in path and type == 'rl':
from .math500 import get_math500_rl_dataset
return get_math500_rl_dataset(
path=path,
split=split,
tokenizer=tokenizer,
max_length=max_length,
**kwargs,
)
elif "amc12" in path and type == 'rl':
from .amc12 import get_amc12_rl_dataset
return get_amc12_rl_dataset(
path=path,
split=split,
tokenizer=tokenizer,
max_length=max_length,
**kwargs,
)
else:
# Fallback: try loading as a generic HuggingFace dataset from disk.
# This supports arbitrary datasets saved via dataset.save_to_disk().
Expand Down
41 changes: 41 additions & 0 deletions areal/dataset/amc12.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from datasets import load_dataset
import os

def get_amc12_rl_dataset(
path: str,
split: str,
tokenizer,
max_length: int | None = None,
):
# ---- Detect raw JSON / JSONL dataset ----
data_file = os.path.join(path, f"{split}.jsonl")
dataset = load_dataset(
"json",
data_files={split: data_file},
split=split,
)

def process(sample):
messages = [{
"role": "user",
"content": (
sample["question"]
+ "\n\nReturn ONLY the final answer in LaTeX as: \\boxed{...}."
+ "\nDo not include any other text."
+ "\nAfter the box, output the token <END> and then stop."
),
}]
return {"messages": messages, "answer": sample["answer"]}

dataset = dataset.map(process).remove_columns(["question"])

if max_length is not None:
def filter_length(sample):
return (
len(tokenizer.encode(sample["messages"][0]["content"]))
<= max_length
)

dataset = dataset.filter(filter_length)

return dataset
36 changes: 36 additions & 0 deletions areal/dataset/math500.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from datasets import load_dataset
import os

def get_math500_rl_dataset(
path: str,
split: str,
tokenizer,
max_length: int | None = None,
):
# ---- Detect raw JSON / JSONL dataset ----
data_file = os.path.join(path, f"{split}.jsonl")
dataset = load_dataset(
"json",
data_files={split: data_file},
split=split,
)

def process(sample):
messages = [{
"role": "user",
"content": sample["problem"] + "\nPlease put your final answer within \\boxed{}.",
}]
return {"messages": messages, "answer": sample["answer"]}

dataset = dataset.map(process).remove_columns(["problem"])

if max_length is not None:
def filter_length(sample):
return (
len(tokenizer.encode(sample["messages"][0]["content"]))
<= max_length
)

dataset = dataset.filter(filter_length)

return dataset
1 change: 1 addition & 0 deletions areal/experimental/openai/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class InteractionWithTokenLogpReward:
# Common
model_response: ModelResponse | None = None
reward: float | None = None
norm_group: str | None = None # [MARL] add norm_group for shared backend case
parent: InteractionWithTokenLogpReward | None = None
chat_template_type: str = "hf"
_cache: dict[str, torch.Tensor] | None = None
Expand Down
20 changes: 17 additions & 3 deletions areal/infra/remote_inf_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,24 @@ async def arun_episode(
isinstance(v, InteractionWithTokenLogpReward) for v in first.values()
)
):
# Merge dicts - each result is {completion_id: InteractionWithTokenLogpReward}
merged: dict[str, InteractionWithTokenLogpReward] = {}
# [MARL] separate by norm_group
agent_interactions = {} # {norm_group:{key:interaction}}
for result in valid_results:
merged.update(result)
for key, interaction in result.items():
if hasattr(interaction, 'norm_group') and interaction.norm_group is not None:
group_id = interaction.norm_group
else:
group_id = "group_1"

if group_id not in agent_interactions:
agent_interactions[group_id] = {}
agent_interactions[group_id][key] = interaction

# merge in sorted order by group_id for deterministic ordering
# this ensure that agent1's all interactions, then agent2's then agent3's etc.
merged: dict[str, InteractionWithTokenLogpReward] = {}
for group_id in sorted(agent_interactions.keys()):
merged.update(agent_interactions[group_id])
return merged if merged else None

# Otherwise, tensor dicts - concatenate
Expand Down
180 changes: 174 additions & 6 deletions areal/reward/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,33 @@

from math_verify.metric import math_metric
from math_verify.parser import ExprExtractionConfig, LatexExtractionConfig

from math_verify.errors import TimeoutException
import re
from areal.utils import logging

logger = logging.getLogger("RewardUtils")

VALID_REWARD_FN = ["clevr_count_70k", "geometry3k"]

_BOX = re.compile(r"\\boxed\s*\{")

def _canon_gold(gt: str) -> str:
gt = (gt or "").strip()
if not gt:
return gt
# If already boxed, keep it
if _BOX.search(gt):
return gt
# Put gold in a strong LaTeX “answer target”
return f"\\boxed{{{gt}}}"

def _canon_pred(resp: str) -> str:
resp = (resp or "").strip()
# drop thinking
if "</think>" in resp:
resp = resp.split("</think>")[-1].strip()
return resp


def get_custom_reward_fn(path: str, **kwargs):
if "clevr_count_70k" in path:
Expand All @@ -25,7 +45,6 @@ def get_custom_reward_fn(path: str, **kwargs):
f"Supported reward functions are: {VALID_REWARD_FN}. "
)


class MathVerifyWorker:
"""Thin wrapper over math_verify with configurable extraction/precision.

Expand Down Expand Up @@ -56,20 +75,162 @@ def __init__(self, try_extract_without_anchor=True, precision: int = 6):
precision=precision,
)

def verify(self, response: str, ground_truth: str) -> float:
# ground_truth_parsable = "\\boxed{" + ground_truth + "}"
def verify_for_math500(self, response: str, ground_truth: str) -> float:
try:
# _canon_gold will make sure ground truth is of format: "\\boxed{" + ground_truth + "}"
gt = _canon_gold(ground_truth)
resp = _canon_pred(response)
ret_score, _ = self.verify_func([gt], [resp])
return float(ret_score)
except (Exception, TimeoutException) as e: # TimeoutException is inherited from BaseException, instead of Exception
logger.warning(
f"Exception {e} in MathVerifyWorker.verify for response={response} and ground_truth={ground_truth}",
exc_info=True,
)
return 0.0

def verify(self, response: str, ground_truth: str) -> float:
# for gsm8k
# assume: ground_truth_parsable = "\\boxed{" + ground_truth + "}"
try:
ret_score, _ = self.verify_func([ground_truth], [response])
return float(ret_score)
except Exception:
except (Exception, TimeoutException) as e: # TimeoutException inherits from BaseException
logger.warning(
f"Exception in MathVerifyWorker.verify for response={response} and ground_truth={ground_truth}",
f"Exception {e} in MathVerifyWorker.verify for response={response} and ground_truth={ground_truth}",
exc_info=True,
)
return 0.0


# Math MC Verifier
def _canon_gold_mc(gt: str) -> str:
gt = (gt or "").strip()
if not gt:
return gt
# For MC, gold is usually A/B/C/D/E (or 5-choice variants). Keep it simple.
gt = gt.upper()
# If the dataset sometimes stores "(E)" or "E)" etc.
m = re.search(r"\b([A-E])\b", gt)
if m:
gt = m.group(1)
# Put in a strong LaTeX “answer target” to help extraction consistently
if not _BOX.search(gt):
gt = f"\\boxed{{{gt}}}"
return gt

def _canon_pred_mc(resp: str) -> str:
resp = (resp or "").strip()
if not resp:
return resp

# If already boxed, keep as-is
if _BOX.search(resp):
return resp

# Common patterns: "Answer: E", "(E)", "E.", "The answer is (E)", "choice E"
# Prefer an explicit letter if present.
m = re.search(r"(?i)\b(?:answer|final|choice)\b.*?\b([A-E])\b", resp)
if m:
return f"\\boxed{{{m.group(1).upper()}}}"

# Otherwise look for a standalone MC letter token.
# (Guard against matching 'A' in words by requiring boundaries.)
m = re.search(r"(?i)(?:^|[\s\(\[\{])([A-E])(?:$|[\s\)\]\}\.\,\;\:\!])", resp)
if m:
return f"\\boxed{{{m.group(1).upper()}}}"

# As a last resort, leave response unchanged (math_verify may still extract something)
return resp


class MathMultipleChoiceVerifyWorker:
"""Verifier for multiple-choice math datasets (A/B/C/D/E style).

This mirrors MathVerifyWorker but uses MC-specific canonicalization/extraction.
It still leverages math_verify for robust parsing; we just normalize outputs
so the extractor reliably finds the selected option.

Args:
try_extract_without_anchor: If False, requires answer anchors. For MC,
leaving True is usually best because model outputs vary a lot.
precision: kept for API compatibility; not important for letter matching.
choices: string of valid choice letters.
"""

def __init__(
self,
try_extract_without_anchor: bool = True,
precision: int = 6,
choices: str = "ABCDE",
):
self.choices = "".join(sorted(set(choices.upper())))
# We still use math_verify, but our canon_* functions steer extraction to boxed letters.
self.verify_func = math_metric(
gold_extraction_target=(
ExprExtractionConfig(
try_extract_without_anchor=try_extract_without_anchor
),
LatexExtractionConfig(),
),
pred_extraction_target=(
ExprExtractionConfig(
try_extract_without_anchor=try_extract_without_anchor
),
LatexExtractionConfig(),
),
precision=precision,
)

def _normalize_gold(self, ground_truth: str) -> str:
gt = _canon_gold_mc(ground_truth)

choice_set = "".join(self.choices) if self.choices else "A-E"
pattern = rf"\b([{choice_set}])\b"

m = re.search(pattern, gt.upper())
if m and m.group(1) in self.choices:
return f"\\boxed{{{m.group(1)}}}"
return gt

def _normalize_pred(self, response: str) -> str:
resp = _canon_pred_mc(response)

choice_set = "".join(self.choices) if self.choices else "A-E"
pattern = rf"\\boxed\{{([{choice_set}])\}}"

m = re.search(pattern, resp.upper())
if m and m.group(1) in self.choices:
return f"\\boxed{{{m.group(1)}}}"
return resp

def verify(self, response: str, ground_truth: str) -> float:
try:
gt = self._normalize_gold(ground_truth)
resp = self._normalize_pred(response)

# Primary path: use math_verify scoring
ret_score, _ = self.verify_func([gt], [resp])
score = float(ret_score)

# Hard fallback: direct letter compare (useful if parsing fails)
if score == 0.0:
gt_letter = re.search(r"\\BOXED\{([A-E])\}", gt.upper())
pr_letter = re.search(r"\\BOXED\{([A-E])\}", resp.upper())
if gt_letter and pr_letter:
return 1.0 if gt_letter.group(1) == pr_letter.group(1) else 0.0

return score

except (Exception, TimeoutException):
logger.warning(
f"Exception in MathMultipleChoiceVerifyWorker.verify for response={response} and ground_truth={ground_truth}",
exc_info=True,
)
return 0.0

_MATH_VERIFY_WORKER: MathVerifyWorker | None = None
_MATH_MC_VERIFY_WORKER: MathMultipleChoiceVerifyWorker | None = None


def get_math_verify_worker() -> MathVerifyWorker:
Expand All @@ -78,12 +239,19 @@ def get_math_verify_worker() -> MathVerifyWorker:
_MATH_VERIFY_WORKER = MathVerifyWorker()
return _MATH_VERIFY_WORKER

def get_math_mc_verify_worker() -> MathMultipleChoiceVerifyWorker:
global _MATH_MC_VERIFY_WORKER
if _MATH_MC_VERIFY_WORKER is None:
_MATH_MC_VERIFY_WORKER = MathMultipleChoiceVerifyWorker()
return _MATH_MC_VERIFY_WORKER

__all__ = [
"VALID_REWARD_FN",
"get_custom_reward_fn",
"MathVerifyWorker",
"get_math_verify_worker",
"MathMultipleChoiceVerifyWorker",
"get_math_mc_verify_worker",
"gsm8k_reward_fn",
"geometry3k_reward_fn",
"clevr_count_70k_reward_fn",
Expand Down
Loading