diff --git a/areal/dataset/__init__.py b/areal/dataset/__init__.py index d7161a7287..ace1bf7e5b 100644 --- a/areal/dataset/__init__.py +++ b/areal/dataset/__init__.py @@ -19,6 +19,8 @@ "virl39k", "hh-rlhf", "torl_data", + "MATH-500", + "amc12" ] logger = logging.getLogger("Dataset") @@ -123,6 +125,24 @@ def _get_custom_dataset( max_length=max_length, **kwargs, ) + elif "MATH-500" in path and type == 'rl': + from .math500 import get_math500_rl_dataset + return get_math500_rl_dataset( + path=path, + split=split, + tokenizer=tokenizer, + max_length=max_length, + **kwargs, + ) + elif "amc12" in path and type == 'rl': + from .amc12 import get_amc12_rl_dataset + return get_amc12_rl_dataset( + path=path, + split=split, + tokenizer=tokenizer, + max_length=max_length, + **kwargs, + ) else: # Fallback: try loading as a generic HuggingFace dataset from disk. # This supports arbitrary datasets saved via dataset.save_to_disk(). diff --git a/areal/dataset/amc12.py b/areal/dataset/amc12.py new file mode 100644 index 0000000000..8285b6ba1e --- /dev/null +++ b/areal/dataset/amc12.py @@ -0,0 +1,41 @@ +from datasets import load_dataset +import os + +def get_amc12_rl_dataset( + path: str, + split: str, + tokenizer, + max_length: int | None = None, +): + # ---- Detect raw JSON / JSONL dataset ---- + data_file = os.path.join(path, f"{split}.jsonl") + dataset = load_dataset( + "json", + data_files={split: data_file}, + split=split, + ) + + def process(sample): + messages = [{ + "role": "user", + "content": ( + sample["question"] + + "\n\nReturn ONLY the final answer in LaTeX as: \\boxed{...}." + + "\nDo not include any other text." + + "\nAfter the box, output the token and then stop." + ), + }] + return {"messages": messages, "answer": sample["answer"]} + + dataset = dataset.map(process).remove_columns(["question"]) + + if max_length is not None: + def filter_length(sample): + return ( + len(tokenizer.encode(sample["messages"][0]["content"])) + <= max_length + ) + + dataset = dataset.filter(filter_length) + + return dataset \ No newline at end of file diff --git a/areal/dataset/math500.py b/areal/dataset/math500.py new file mode 100644 index 0000000000..33ae7eccc0 --- /dev/null +++ b/areal/dataset/math500.py @@ -0,0 +1,36 @@ +from datasets import load_dataset +import os + +def get_math500_rl_dataset( + path: str, + split: str, + tokenizer, + max_length: int | None = None, +): + # ---- Detect raw JSON / JSONL dataset ---- + data_file = os.path.join(path, f"{split}.jsonl") + dataset = load_dataset( + "json", + data_files={split: data_file}, + split=split, + ) + + def process(sample): + messages = [{ + "role": "user", + "content": sample["problem"] + "\nPlease put your final answer within \\boxed{}.", + }] + return {"messages": messages, "answer": sample["answer"]} + + dataset = dataset.map(process).remove_columns(["problem"]) + + if max_length is not None: + def filter_length(sample): + return ( + len(tokenizer.encode(sample["messages"][0]["content"])) + <= max_length + ) + + dataset = dataset.filter(filter_length) + + return dataset diff --git a/areal/experimental/openai/types.py b/areal/experimental/openai/types.py index 8af3d01a26..c999f7aae5 100644 --- a/areal/experimental/openai/types.py +++ b/areal/experimental/openai/types.py @@ -39,6 +39,7 @@ class InteractionWithTokenLogpReward: # Common model_response: ModelResponse | None = None reward: float | None = None + norm_group: str | None = None # [MARL] add norm_group for shared backend case parent: InteractionWithTokenLogpReward | None = None chat_template_type: str = "hf" _cache: dict[str, torch.Tensor] | None = None diff --git a/areal/infra/remote_inf_engine.py b/areal/infra/remote_inf_engine.py index 63946842e0..cacddf2610 100644 --- a/areal/infra/remote_inf_engine.py +++ b/areal/infra/remote_inf_engine.py @@ -111,10 +111,24 @@ async def arun_episode( isinstance(v, InteractionWithTokenLogpReward) for v in first.values() ) ): - # Merge dicts - each result is {completion_id: InteractionWithTokenLogpReward} - merged: dict[str, InteractionWithTokenLogpReward] = {} + # [MARL] separate by norm_group + agent_interactions = {} # {norm_group:{key:interaction}} for result in valid_results: - merged.update(result) + for key, interaction in result.items(): + if hasattr(interaction, 'norm_group') and interaction.norm_group is not None: + group_id = interaction.norm_group + else: + group_id = "group_1" + + if group_id not in agent_interactions: + agent_interactions[group_id] = {} + agent_interactions[group_id][key] = interaction + + # merge in sorted order by group_id for deterministic ordering + # this ensure that agent1's all interactions, then agent2's then agent3's etc. + merged: dict[str, InteractionWithTokenLogpReward] = {} + for group_id in sorted(agent_interactions.keys()): + merged.update(agent_interactions[group_id]) return merged if merged else None # Otherwise, tensor dicts - concatenate diff --git a/areal/reward/__init__.py b/areal/reward/__init__.py index 7396261bdd..346e285fe1 100644 --- a/areal/reward/__init__.py +++ b/areal/reward/__init__.py @@ -2,13 +2,33 @@ from math_verify.metric import math_metric from math_verify.parser import ExprExtractionConfig, LatexExtractionConfig - +from math_verify.errors import TimeoutException +import re from areal.utils import logging logger = logging.getLogger("RewardUtils") VALID_REWARD_FN = ["clevr_count_70k", "geometry3k"] +_BOX = re.compile(r"\\boxed\s*\{") + +def _canon_gold(gt: str) -> str: + gt = (gt or "").strip() + if not gt: + return gt + # If already boxed, keep it + if _BOX.search(gt): + return gt + # Put gold in a strong LaTeX “answer target” + return f"\\boxed{{{gt}}}" + +def _canon_pred(resp: str) -> str: + resp = (resp or "").strip() + # drop thinking + if "" in resp: + resp = resp.split("")[-1].strip() + return resp + def get_custom_reward_fn(path: str, **kwargs): if "clevr_count_70k" in path: @@ -25,7 +45,6 @@ def get_custom_reward_fn(path: str, **kwargs): f"Supported reward functions are: {VALID_REWARD_FN}. " ) - class MathVerifyWorker: """Thin wrapper over math_verify with configurable extraction/precision. @@ -56,20 +75,162 @@ def __init__(self, try_extract_without_anchor=True, precision: int = 6): precision=precision, ) - def verify(self, response: str, ground_truth: str) -> float: - # ground_truth_parsable = "\\boxed{" + ground_truth + "}" + def verify_for_math500(self, response: str, ground_truth: str) -> float: + try: + # _canon_gold will make sure ground truth is of format: "\\boxed{" + ground_truth + "}" + gt = _canon_gold(ground_truth) + resp = _canon_pred(response) + ret_score, _ = self.verify_func([gt], [resp]) + return float(ret_score) + except (Exception, TimeoutException) as e: # TimeoutException is inherited from BaseException, instead of Exception + logger.warning( + f"Exception {e} in MathVerifyWorker.verify for response={response} and ground_truth={ground_truth}", + exc_info=True, + ) + return 0.0 + + def verify(self, response: str, ground_truth: str) -> float: + # for gsm8k + # assume: ground_truth_parsable = "\\boxed{" + ground_truth + "}" try: ret_score, _ = self.verify_func([ground_truth], [response]) return float(ret_score) - except Exception: + except (Exception, TimeoutException) as e: # TimeoutException inherits from BaseException logger.warning( - f"Exception in MathVerifyWorker.verify for response={response} and ground_truth={ground_truth}", + f"Exception {e} in MathVerifyWorker.verify for response={response} and ground_truth={ground_truth}", exc_info=True, ) return 0.0 + +# Math MC Verifier +def _canon_gold_mc(gt: str) -> str: + gt = (gt or "").strip() + if not gt: + return gt + # For MC, gold is usually A/B/C/D/E (or 5-choice variants). Keep it simple. + gt = gt.upper() + # If the dataset sometimes stores "(E)" or "E)" etc. + m = re.search(r"\b([A-E])\b", gt) + if m: + gt = m.group(1) + # Put in a strong LaTeX “answer target” to help extraction consistently + if not _BOX.search(gt): + gt = f"\\boxed{{{gt}}}" + return gt + +def _canon_pred_mc(resp: str) -> str: + resp = (resp or "").strip() + if not resp: + return resp + + # If already boxed, keep as-is + if _BOX.search(resp): + return resp + + # Common patterns: "Answer: E", "(E)", "E.", "The answer is (E)", "choice E" + # Prefer an explicit letter if present. + m = re.search(r"(?i)\b(?:answer|final|choice)\b.*?\b([A-E])\b", resp) + if m: + return f"\\boxed{{{m.group(1).upper()}}}" + + # Otherwise look for a standalone MC letter token. + # (Guard against matching 'A' in words by requiring boundaries.) + m = re.search(r"(?i)(?:^|[\s\(\[\{])([A-E])(?:$|[\s\)\]\}\.\,\;\:\!])", resp) + if m: + return f"\\boxed{{{m.group(1).upper()}}}" + + # As a last resort, leave response unchanged (math_verify may still extract something) + return resp + + +class MathMultipleChoiceVerifyWorker: + """Verifier for multiple-choice math datasets (A/B/C/D/E style). + + This mirrors MathVerifyWorker but uses MC-specific canonicalization/extraction. + It still leverages math_verify for robust parsing; we just normalize outputs + so the extractor reliably finds the selected option. + + Args: + try_extract_without_anchor: If False, requires answer anchors. For MC, + leaving True is usually best because model outputs vary a lot. + precision: kept for API compatibility; not important for letter matching. + choices: string of valid choice letters. + """ + + def __init__( + self, + try_extract_without_anchor: bool = True, + precision: int = 6, + choices: str = "ABCDE", + ): + self.choices = "".join(sorted(set(choices.upper()))) + # We still use math_verify, but our canon_* functions steer extraction to boxed letters. + self.verify_func = math_metric( + gold_extraction_target=( + ExprExtractionConfig( + try_extract_without_anchor=try_extract_without_anchor + ), + LatexExtractionConfig(), + ), + pred_extraction_target=( + ExprExtractionConfig( + try_extract_without_anchor=try_extract_without_anchor + ), + LatexExtractionConfig(), + ), + precision=precision, + ) + + def _normalize_gold(self, ground_truth: str) -> str: + gt = _canon_gold_mc(ground_truth) + + choice_set = "".join(self.choices) if self.choices else "A-E" + pattern = rf"\b([{choice_set}])\b" + + m = re.search(pattern, gt.upper()) + if m and m.group(1) in self.choices: + return f"\\boxed{{{m.group(1)}}}" + return gt + + def _normalize_pred(self, response: str) -> str: + resp = _canon_pred_mc(response) + + choice_set = "".join(self.choices) if self.choices else "A-E" + pattern = rf"\\boxed\{{([{choice_set}])\}}" + + m = re.search(pattern, resp.upper()) + if m and m.group(1) in self.choices: + return f"\\boxed{{{m.group(1)}}}" + return resp + + def verify(self, response: str, ground_truth: str) -> float: + try: + gt = self._normalize_gold(ground_truth) + resp = self._normalize_pred(response) + + # Primary path: use math_verify scoring + ret_score, _ = self.verify_func([gt], [resp]) + score = float(ret_score) + + # Hard fallback: direct letter compare (useful if parsing fails) + if score == 0.0: + gt_letter = re.search(r"\\BOXED\{([A-E])\}", gt.upper()) + pr_letter = re.search(r"\\BOXED\{([A-E])\}", resp.upper()) + if gt_letter and pr_letter: + return 1.0 if gt_letter.group(1) == pr_letter.group(1) else 0.0 + + return score + + except (Exception, TimeoutException): + logger.warning( + f"Exception in MathMultipleChoiceVerifyWorker.verify for response={response} and ground_truth={ground_truth}", + exc_info=True, + ) + return 0.0 _MATH_VERIFY_WORKER: MathVerifyWorker | None = None +_MATH_MC_VERIFY_WORKER: MathMultipleChoiceVerifyWorker | None = None def get_math_verify_worker() -> MathVerifyWorker: @@ -78,12 +239,19 @@ def get_math_verify_worker() -> MathVerifyWorker: _MATH_VERIFY_WORKER = MathVerifyWorker() return _MATH_VERIFY_WORKER +def get_math_mc_verify_worker() -> MathMultipleChoiceVerifyWorker: + global _MATH_MC_VERIFY_WORKER + if _MATH_MC_VERIFY_WORKER is None: + _MATH_MC_VERIFY_WORKER = MathMultipleChoiceVerifyWorker() + return _MATH_MC_VERIFY_WORKER __all__ = [ "VALID_REWARD_FN", "get_custom_reward_fn", "MathVerifyWorker", "get_math_verify_worker", + "MathMultipleChoiceVerifyWorker", + "get_math_mc_verify_worker", "gsm8k_reward_fn", "geometry3k_reward_fn", "clevr_count_70k_reward_fn", diff --git a/areal/reward/amc12.py b/areal/reward/amc12.py new file mode 100644 index 0000000000..105a2e75b5 --- /dev/null +++ b/areal/reward/amc12.py @@ -0,0 +1,56 @@ +import re + +from areal.reward import get_math_mc_verify_worker +from areal.utils import logging + +logger = logging.getLogger("AMC12Reward") + + +def extract_mc_answer(text: str) -> str: + """ + Extract a multiple-choice answer (A/B/C/D/E) from text. + Prefers explicit answer-style patterns, then falls back to + standalone letter detection. + """ + if not text: + return "" + + text = str(text).upper() + + # Pattern 1: "Answer: E", "Final answer is D", etc. + m = re.search(r"\b(?:ANSWER|FINAL|CHOICE)\b\s*[:=]?\s*\(?\s*([A-E])\s*\)?", text) + if m: + return m.group(1) + + # Pattern 2: boxed format \boxed{E} + m = re.search(r"\\BOXED\s*\{\s*([A-E])\s*\}", text) + if m: + return m.group(1) + + # Pattern 3: standalone letter token + m = re.search(r"(?:^|[\s\(\[\{])([A-E])(?:$|[\s\)\]\}\.\,\;\:\!])", text) + if m: + return m.group(1) + + return "" + + +def amc12_mc_reward_fn( + prompt, completions, prompt_ids, completion_ids, answer, **kwargs +) -> float: + """ + Simple 0/1 reward for AMC12 multiple-choice questions. + """ + try: + pred_letter = extract_mc_answer(str(completions)) + gold_letter = extract_mc_answer(str(answer)) or str(answer).strip().upper() + + if not pred_letter or not gold_letter: + return 0.0 + + worker = get_math_mc_verify_worker() + return worker.verify(pred_letter, gold_letter) + + except Exception: + logger.warning("Exception in amc12_mc_reward_fn", exc_info=True) + return 0.0 diff --git a/areal/reward/math500.py b/areal/reward/math500.py new file mode 100644 index 0000000000..de8a52dcfc --- /dev/null +++ b/areal/reward/math500.py @@ -0,0 +1,15 @@ +from areal.reward import get_math_verify_worker +from areal.utils import logging + +logger = logging.getLogger("MATH500_Reward") + + +def math500_reward_fn( + prompt, completions, prompt_ids, completion_ids, answer, **kwargs +) -> float: + try: + worker = get_math_verify_worker() + return worker.verify_for_math500(str(completions), str(answer)) + except Exception: + logger.warning("Exception in MATH500_Reward", exc_info=True) + return 0.0 diff --git a/areal/utils/episode_data.py b/areal/utils/episode_data.py new file mode 100644 index 0000000000..61cd5e5762 --- /dev/null +++ b/areal/utils/episode_data.py @@ -0,0 +1,18 @@ +import copy +from typing import Any, Mapping + + +def clone_episode_data(data: Mapping[str, Any]) -> dict[str, Any]: + """ + Clone per-episode mutable inputs (especially chat `messages`) to avoid cross-episode + mutation when rollouts run concurrently. + + This intentionally deep-copies only `messages` (if present) and shallow-copies the + rest of the mapping to avoid heavy/unsafe deep copies of tensors or other objects. + """ + + cloned = dict(data) + if "messages" in data: + cloned["messages"] = copy.deepcopy(data["messages"]) + return cloned + diff --git a/examples/openai_agents/config_marti_grpo-amc.yaml b/examples/openai_agents/config_marti_grpo-amc.yaml new file mode 100644 index 0000000000..cfa626f24f --- /dev/null +++ b/examples/openai_agents/config_marti_grpo-amc.yaml @@ -0,0 +1,194 @@ +experiment_name: multi-agents-amc12-grpo-marti-COA +trial_name: trial0 + +seed: 1 +enable_offload: false +total_train_epochs: 100 +tokenizer_path: ${actor.path} + +reward_fn_path: "areal.reward.amc12.amc12_mc_reward_fn" + +cluster: + n_nodes: 1 + n_gpus_per_node: 16 + fileroot: /path/to/experiments # Please update this path to your local experiment directory + name_resolve: + type: nfs + nfs_record_root: /path/to/name_resolve # Please update this path to your local NFS record directory + +allocation_mode: vllm:d8p1t1+d8p1t1 + +scheduler: + type: local + +rollout: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + max_concurrent_rollouts: 256 + queue_size: null + consumer_batch_size: ${train_dataset.batch_size} + max_head_offpolicyness: 2 + enable_rollout_tracing: false + scheduling_spec: ${actor.scheduling_spec} + fileroot: ${cluster.fileroot} + tokenizer_path: ${tokenizer_path} + dump_to_file: true + openai: + engine_max_tokens: 16384 + +gconfig: + n_samples: 4 + min_new_tokens: 0 + max_tokens: 1600 + greedy: false + temperature: 1.0 + top_k: 100000000 + top_p: 1.0 + +actor: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: /path/to/model/Qwen2.5-3B-Instruct # Please update this path to your local model weights directory + dtype: bfloat16 + init_from_scratch: false + disable_dropout: true + gradient_checkpointing: true + mb_spec: + max_tokens_per_mb: 16384 + optimizer: + type: adam + lr: 1.0e-5 + weight_decay: 0.017 + beta1: 0.9 + beta2: 0.999 + eps: 1e-8 + lr_scheduler_type: constant + gradient_clipping: 1.0 + warmup_steps_proportion: 0.001 + eps_clip: 0.2 + temperature: ${gconfig.temperature} + reward_scaling: 10.0 + reward_bias: -0.5 + kl_ctl: 0.0 + ppo_n_minibatches: 1 + recompute_logprob: true + use_decoupled_loss: true + behave_imp_weight_cap: 5.0 + reward_norm: + mean_level: group + std_level: group + group_size: ${gconfig.n_samples} + adv_norm: + mean_level: batch + std_level: batch + weight_update_mode: xccl + max_new_tokens: ${gconfig.max_new_tokens} + scheduling_spec: + - task_type: worker + port_count: 2 + gpu: 1 + mem: 32 + cmd: python3 -m areal.infra.rpc.rpc_server + env_vars: + VLLM_USE_V1: "1" + VLLM_ALLOW_LONG_MAX_MODEL_LEN: "1" + HCCL_EXEC_TIMEOUT: "14400" + HCCL_OP_EXPANSION_MODE: "HOST" + ACL_DEVICE_SYNC_TIMEOUT: "14400" + HCCL_EVENT_TIMEOUT: "14500" + HCCL_ASYNC_ERROR_HANDLING: "0" + ACL_STREAM_TIMEOUT: "14500000" + HCCL_CONNECT_TIMEOUT: "7200" + PYTORCH_NPU_ALLOC_CONF: "expandable_segments:True" + HCCL_IF_BASE_PORT: "60000" + HCCL_HOST_SOCKET_PORT_RANGE: "60000-60099" + HCCL_NPU_SOCKET_PORT_RANGE: "60000-60099" + + +ref: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: ${actor.path} + init_from_scratch: false + disable_dropout: true + dtype: ${actor.dtype} + mb_spec: + max_tokens_per_mb: 10240 + optimizer: null + scheduling_strategy: + type: colocation + target: actor + scheduling_spec: ${actor.scheduling_spec} + +vllm: + model: ${actor.path} + seed: ${seed} + skip_tokenizer_init: false + dtype: ${actor.dtype} + max_model_len: 32768 + gpu_memory_utilization: 0.8 + enforce_eager: false + +# datasets +train_dataset: + batch_size: 256 + shuffle: true + pin_memory: true + num_workers: 4 + path: /path/to/data/amc12 # Please update this path to your local training dataset directory + type: rl + max_length: 1024 + +valid_dataset: + batch_size: 256 + pin_memory: true + num_workers: 4 + path: /path/to/data/amc12 # Please update this path to your local validation dataset directory + type: rl + +# Utilities +saver: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +recover: + mode: disabled + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: 3600 + +evaluator: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +stats_logger: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + wandb: + mode: disabled + # SwanLab integration is optional and can be configured here + # swanlab: + # project: areal-marl-project # Please update by your need + # name: marti_COA_grpo_shared_math500 + # mode: cloud + # api_key: PLS-FILL-UR-KEY + +perf_tracer: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + enabled: false + session_tracer: + enabled: false \ No newline at end of file diff --git a/examples/openai_agents/config_marti_grpo-gsm8k.yaml b/examples/openai_agents/config_marti_grpo-gsm8k.yaml new file mode 100644 index 0000000000..9b1ba26dd6 --- /dev/null +++ b/examples/openai_agents/config_marti_grpo-gsm8k.yaml @@ -0,0 +1,195 @@ +experiment_name: multi-agents-gsm8k-grpo-marti-COA +trial_name: trial0 + +seed: 1 +enable_offload: false +total_train_epochs: 100 +tokenizer_path: ${actor.path} + +reward_fn_path: "areal.reward.gsm8k.gsm8k_reward_fn" + +cluster: + n_nodes: 1 + n_gpus_per_node: 16 + fileroot: /path/to/experiments # Please update this path to your local experiment directory + name_resolve: + type: nfs + nfs_record_root: /path/to/name_resolve # Please update this path to your local NFS record directory + +allocation_mode: vllm:d8p1t1+d8p1t1 + +scheduler: + type: local + +rollout: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + max_concurrent_rollouts: 256 + queue_size: null + consumer_batch_size: ${train_dataset.batch_size} + max_head_offpolicyness: 2 + enable_rollout_tracing: false + scheduling_spec: ${actor.scheduling_spec} + fileroot: ${cluster.fileroot} + tokenizer_path: ${tokenizer_path} + dump_to_file: true + openai: + # this should set the token limit for whole session (prompt and all turns of responses) + engine_max_tokens: 16384 + +gconfig: + n_samples: 4 + min_new_tokens: 0 + max_tokens: 1600 + greedy: false + temperature: 1.0 + top_k: 100000000 + top_p: 1.0 + +actor: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: /path/to/model/Qwen2.5-3B-Instruct # Please update this path to your local model weights directory + dtype: bfloat16 + init_from_scratch: false + disable_dropout: true + gradient_checkpointing: true + mb_spec: + max_tokens_per_mb: 16384 + optimizer: + type: adam + lr: 1.0e-5 + weight_decay: 0.017 + beta1: 0.9 + beta2: 0.999 + eps: 1e-8 + lr_scheduler_type: constant + gradient_clipping: 1.0 + warmup_steps_proportion: 0.001 + eps_clip: 0.2 + temperature: ${gconfig.temperature} + reward_scaling: 10.0 + reward_bias: -0.5 + kl_ctl: 0.0 + ppo_n_minibatches: 1 + recompute_logprob: true + use_decoupled_loss: true + behave_imp_weight_cap: 5.0 + reward_norm: + mean_level: group + std_level: group + group_size: ${gconfig.n_samples} + adv_norm: + mean_level: batch + std_level: batch + weight_update_mode: xccl + max_new_tokens: ${gconfig.max_new_tokens} + scheduling_spec: + - task_type: worker + port_count: 2 + gpu: 1 + mem: 32 + cmd: python3 -m areal.infra.rpc.rpc_server + env_vars: + VLLM_USE_V1: "1" + VLLM_ALLOW_LONG_MAX_MODEL_LEN: "1" + HCCL_EXEC_TIMEOUT: "14400" + HCCL_OP_EXPANSION_MODE: "HOST" + ACL_DEVICE_SYNC_TIMEOUT: "14400" + HCCL_EVENT_TIMEOUT: "14500" + HCCL_ASYNC_ERROR_HANDLING: "0" + ACL_STREAM_TIMEOUT: "14500000" + HCCL_CONNECT_TIMEOUT: "7200" + PYTORCH_NPU_ALLOC_CONF: "expandable_segments:True" + HCCL_IF_BASE_PORT: "60000" + HCCL_HOST_SOCKET_PORT_RANGE: "60000-60099" + HCCL_NPU_SOCKET_PORT_RANGE: "60000-60099" + + +ref: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: ${actor.path} + init_from_scratch: false + disable_dropout: true + dtype: ${actor.dtype} + mb_spec: + max_tokens_per_mb: 10240 + optimizer: null + scheduling_strategy: + type: colocation + target: actor + scheduling_spec: ${actor.scheduling_spec} + +vllm: + model: ${actor.path} + seed: ${seed} + skip_tokenizer_init: false + dtype: ${actor.dtype} + max_model_len: 32768 + gpu_memory_utilization: 0.8 + enforce_eager: false + +# datasets +train_dataset: + batch_size: 256 + shuffle: true + pin_memory: true + num_workers: 4 + path: /path/to/data/gsm8k # Please update this path to your local training dataset directory + type: rl + max_length: 1024 + +valid_dataset: + batch_size: 256 + pin_memory: true + num_workers: 4 + path: /path/to/data/gsm8k # Please update this path to your local validation dataset directory + type: rl + +# Utilities +saver: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +recover: + mode: disabled + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: 3600 + +evaluator: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +stats_logger: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + wandb: + mode: disabled + # SwanLab integration is optional and can be configured here + # swanlab: + # project: areal-marl-project # Please update by your need + # name: marti_COA_grpo_shared_math500 + # mode: cloud + # api_key: PLS-FILL-UR-KEY + +perf_tracer: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + enabled: false + session_tracer: + enabled: false \ No newline at end of file diff --git a/examples/openai_agents/config_marti_grpo-math500.yaml b/examples/openai_agents/config_marti_grpo-math500.yaml new file mode 100644 index 0000000000..4a56924608 --- /dev/null +++ b/examples/openai_agents/config_marti_grpo-math500.yaml @@ -0,0 +1,195 @@ +experiment_name: multi-agents-math500-grpo-marti-COA +trial_name: trial0 + +seed: 1 +enable_offload: false +total_train_epochs: 100 +tokenizer_path: ${actor.path} + +reward_fn_path: "areal.reward.math500.math500_reward_fn" + +cluster: + n_nodes: 1 + n_gpus_per_node: 16 + fileroot: /path/to/experiments # Please update this path to your local experiment directory + name_resolve: + type: nfs + nfs_record_root: /path/to/name_resolve # Please update this path to your local NFS record directory + +allocation_mode: vllm:d8p1t1+d8p1t1 + +scheduler: + type: local + +rollout: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + max_concurrent_rollouts: 256 + queue_size: null + consumer_batch_size: ${train_dataset.batch_size} + max_head_offpolicyness: 2 + enable_rollout_tracing: false + scheduling_spec: ${actor.scheduling_spec} + fileroot: ${cluster.fileroot} + tokenizer_path: ${tokenizer_path} + dump_to_file: true + openai: + # this should set the token limit for whole session (prompt and all turns of responses) + engine_max_tokens: 16384 + +gconfig: + n_samples: 4 + min_new_tokens: 0 + max_tokens: 1600 + greedy: false + temperature: 1.0 + top_k: 100000000 + top_p: 1.0 + +actor: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: /path/to/model/Qwen2.5-3B-Instruct # Please update this path to your local model weights directory + dtype: bfloat16 + init_from_scratch: false + disable_dropout: true + gradient_checkpointing: true + mb_spec: + max_tokens_per_mb: 16384 + optimizer: + type: adam + lr: 1.0e-5 + weight_decay: 0.017 + beta1: 0.9 + beta2: 0.999 + eps: 1e-8 + lr_scheduler_type: constant + gradient_clipping: 1.0 + warmup_steps_proportion: 0.001 + eps_clip: 0.2 + temperature: ${gconfig.temperature} + reward_scaling: 10.0 + reward_bias: -0.5 + kl_ctl: 0.0 + ppo_n_minibatches: 1 + recompute_logprob: true + use_decoupled_loss: true + behave_imp_weight_cap: 5.0 + reward_norm: + mean_level: group + std_level: group + group_size: ${gconfig.n_samples} + adv_norm: + mean_level: batch + std_level: batch + weight_update_mode: xccl + max_new_tokens: ${gconfig.max_new_tokens} + scheduling_spec: + - task_type: worker + port_count: 2 + gpu: 1 + mem: 32 + cmd: python3 -m areal.infra.rpc.rpc_server + env_vars: + VLLM_USE_V1: "1" + VLLM_ALLOW_LONG_MAX_MODEL_LEN: "1" + HCCL_EXEC_TIMEOUT: "14400" + HCCL_OP_EXPANSION_MODE: "HOST" + ACL_DEVICE_SYNC_TIMEOUT: "14400" + HCCL_EVENT_TIMEOUT: "14500" + HCCL_ASYNC_ERROR_HANDLING: "0" + ACL_STREAM_TIMEOUT: "14500000" + HCCL_CONNECT_TIMEOUT: "7200" + PYTORCH_NPU_ALLOC_CONF: "expandable_segments:True" + HCCL_IF_BASE_PORT: "60000" + HCCL_HOST_SOCKET_PORT_RANGE: "60000-60099" + HCCL_NPU_SOCKET_PORT_RANGE: "60000-60099" + + +ref: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: ${actor.path} + init_from_scratch: false + disable_dropout: true + dtype: ${actor.dtype} + mb_spec: + max_tokens_per_mb: 10240 + optimizer: null + scheduling_strategy: + type: colocation + target: actor + scheduling_spec: ${actor.scheduling_spec} + +vllm: + model: ${actor.path} + seed: ${seed} + skip_tokenizer_init: false + dtype: ${actor.dtype} + max_model_len: 32768 + gpu_memory_utilization: 0.8 + enforce_eager: false + +# datasets +train_dataset: + batch_size: 256 + shuffle: true + pin_memory: true + num_workers: 4 + path: /path/to/data/MATH-500 # Please update this path to your local training dataset directory + type: rl + max_length: 1024 + +valid_dataset: + batch_size: 256 + pin_memory: true + num_workers: 4 + path: /path/to/data/MATH-500 # Please update this path to your local validation dataset directory + type: rl + +# Utilities +saver: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +recover: + mode: disabled + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: 3600 + +evaluator: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +stats_logger: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + wandb: + mode: disabled + # SwanLab integration is optional and can be configured here + # swanlab: + # project: areal-marl-project # Please update by your need + # name: marti_COA_grpo_shared_math500 + # mode: cloud + # api_key: PLS-FILL-UR-KEY + +perf_tracer: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + enabled: false + session_tracer: + enabled: false \ No newline at end of file diff --git a/examples/openai_agents/train_math_marti_shared.py b/examples/openai_agents/train_math_marti_shared.py new file mode 100644 index 0000000000..37882f8a77 --- /dev/null +++ b/examples/openai_agents/train_math_marti_shared.py @@ -0,0 +1,284 @@ +from dataclasses import dataclass, field +from string import Template +from typing import Any + +from agents import Agent as OpenAIAgent +from agents import ModelSettings, OpenAIProvider, RunConfig +from agents import Runner as OpenAIRunner +from transformers import PreTrainedTokenizerFast + +from areal import PPOTrainer, workflow_context +from areal.api import AsyncRewardWrapper, RolloutWorkflow +from areal.api.cli_args import GenerationHyperparameters, GRPOConfig, load_expr_config +from areal.dataset import get_custom_dataset +from areal.experimental.openai import ArealOpenAI +from areal.utils import logging, stats_tracker +from areal.utils.dynamic_import import import_from_string +from areal.utils.episode_data import clone_episode_data +from areal.utils.hf_utils import load_hf_tokenizer + + +logger = logging.getLogger("OpenAIAgentWorkflow") + + +GENERATOR_TEMPLATE = ( + "$question\n\nPlease reason step by step, and put your final answer within \\boxed{}." +) +VERIFIER_TEMPLATE = ( + "You are tasked with analyzing an answer to a problem and providing constructive feedback.\n\n" + "Problem: $question\n\n" + "Solution: $generator\n\n" + "Do NOT provide direct solutions." +) +REFINER_TEMPLATE = ( + "You are tasked with revising a draft solution to a problem based on the critique provided. " + "Please provide a revised solution that addresses the feedback and improves the overall quality of the solution.\n\n" + "Problem: $question\n\n" + "Solution: $generator\n\n" + "Critique: $verifier\n\n" + "Please reason step by step, and put your final answer within \\boxed{}." +) + + +def render_template(template: str, **kwargs: str) -> str: + return Template(template).safe_substitute(**kwargs) + + +def apply_template_with_tokenizer( + tokenizer: PreTrainedTokenizerFast, + prompt: str, + tools: list[dict[str, Any]] | None = None, + enable_thinking: bool = False, +) -> str: + kwargs: dict[str, Any] = { + "tokenize": False, + "add_generation_prompt": True, + "enable_thinking": enable_thinking, + } + if tools is not None: + kwargs["tools"] = tools + return tokenizer.apply_chat_template([{"role": "user", "content": prompt}], **kwargs) + + +class OpenAIAgentWrapper: + def __init__( + self, + agent_name: str, + reward_fn_path: str, + is_last: bool = False, + temperature: float = 1.0, + max_tokens: int = 512, + ): + self.agent = OpenAIAgent(name=agent_name) + self.is_last = is_last + + self.async_reward_fn: AsyncRewardWrapper | None = None + if self.is_last and reward_fn_path: + self.async_reward_fn = AsyncRewardWrapper(import_from_string(reward_fn_path)) + self.temperature = temperature + self.max_tokens = max_tokens + + async def run_agent(self, data: dict[str, Any], client: ArealOpenAI) -> tuple[str, float]: + run_config = RunConfig( + model_provider=OpenAIProvider(openai_client=client), + tracing_disabled=True, + model_settings=ModelSettings( + temperature=self.temperature, + max_tokens=self.max_tokens, + ), + ) + try: + result = await OpenAIRunner.run( + self.agent, input=data["messages"][-1]["content"], run_config=run_config + ) + except Exception as e: + logger.error(f"!! Agent {self.agent.name} inference failed: {e}") + return "Error: Inference failed", 0.0 + + reward = 0.0 + if self.is_last and self.async_reward_fn: + reward = await self.async_reward_fn( + completions=result.final_output, + answer=data["answer"], + prompt=data.get("prompt"), + prompt_ids=data.get("prompt_ids"), + completion_ids=data.get("completion_ids"), + **{ + k: v + for k, v in data.items() + if k + not in ["messages", "answer", "prompt", "prompt_ids", "completion_ids"] + }, + ) + client.set_last_reward(reward) + + return result.final_output, reward + + +class OpenAIAgentWorkflow(RolloutWorkflow): + def __init__( + self, + reward_fn_path: str, + gconfig: GenerationHyperparameters, + tokenizer: PreTrainedTokenizerFast | str, + ): + if isinstance(tokenizer, str): + tokenizer = load_hf_tokenizer(tokenizer) + self.gconfig = gconfig.new_with_stop_and_pad_token_ids(tokenizer) + self.tokenizer = tokenizer + + # generator agent + self.generator = OpenAIAgentWrapper( + temperature=gconfig.temperature, + max_tokens=gconfig.max_tokens, + agent_name="generator", + reward_fn_path=reward_fn_path, + is_last=False, + ) + # verifier agent + self.verifier = OpenAIAgentWrapper( + temperature=gconfig.temperature, + max_tokens=gconfig.max_tokens, + agent_name="verifier", + reward_fn_path=reward_fn_path, + is_last=False, + ) + # refiner agent + self.refiner = OpenAIAgentWrapper( + temperature=gconfig.temperature, + max_tokens=gconfig.max_tokens, + agent_name="refiner", + reward_fn_path=reward_fn_path, + is_last=True, + ) + + def _set_norm_group_if_missing(self, client: ArealOpenAI, norm_group: str) -> None: + for interaction in client._cache.values(): + if interaction.norm_group is None: + interaction.norm_group = norm_group + + async def _run_stage( + self, + *, + wrapper: OpenAIAgentWrapper, + client: ArealOpenAI, + data: dict[str, Any], + prompt: str, + norm_group: str, + ) -> tuple[str, float]: + input_prompt = apply_template_with_tokenizer(client.tokenizer, prompt) + data["messages"].append({"role": "user", "content": input_prompt}) + output, reward = await wrapper.run_agent(data=data, client=client) + output = output.strip() + data["messages"].append({"role": "assistant", "content": output}) + self._set_norm_group_if_missing(client, norm_group) + return output, reward + + async def arun_episode(self, engine, data: dict[str, Any]): + data = clone_episode_data(data) + client = ArealOpenAI( + engine=engine, tokenizer=self.tokenizer, tool_call_parser="qwen25" + ) + + if not data.get("messages"): + raise ValueError("Expected episode data to contain non-empty 'messages'.") + + question = data["messages"][-1]["content"] + + generator_prompt = render_template(GENERATOR_TEMPLATE, question=question) + generator_output, _ = await self._run_stage( + wrapper=self.generator, + client=client, + data=data, + prompt=generator_prompt, + norm_group="group_1", + ) + + verifier_prompt = render_template( + VERIFIER_TEMPLATE, question=question, generator=generator_output + ) + verifier_output, _ = await self._run_stage( + wrapper=self.verifier, + client=client, + data=data, + prompt=verifier_prompt, + norm_group="group_2", + ) + + refiner_prompt = render_template( + REFINER_TEMPLATE, + question=question, + generator=generator_output, + verifier=verifier_output, + ) + refiner_output, final_reward = await self._run_stage( + wrapper=self.refiner, + client=client, + data=data, + prompt=refiner_prompt, + norm_group="group_3", + ) + logger.debug("Refiner output: %s", refiner_output) + + # equally distribute rewards + for interaction in client._cache.values(): + interaction.reward = final_reward + + stats_tracker.get(workflow_context.stat_scope()).scalar(reward=final_reward) + + # client.apply_reward_discount(turn_discount=0.9) + interactions_with_reward = client.export_interactions(style="individual") + return interactions_with_reward + + +@dataclass +class AgentRLConfig(GRPOConfig): + reward_fn_path: str = field( + default="areal.reward.gsm8k.gsm8k_reward_fn", + metadata={ + "help": "The path to the reward function. Should follow the API in `areal/api/reward_api.py`." + }, + ) + + +def main(args): + config, _ = load_expr_config(args, AgentRLConfig) + + tokenizer = load_hf_tokenizer(config.tokenizer_path) + + train_dataset = get_custom_dataset( + split="train", + dataset_config=config.train_dataset, + tokenizer=tokenizer, + ) + + valid_dataset = get_custom_dataset( + split="test", + dataset_config=config.valid_dataset, + tokenizer=tokenizer, + ) + + workflow_kwargs = dict( + reward_fn_path=config.reward_fn_path, + gconfig=config.gconfig, + tokenizer=config.tokenizer_path, + ) + eval_workflow_kwargs = workflow_kwargs.copy() + + with PPOTrainer( + config, + train_dataset=train_dataset, + valid_dataset=valid_dataset, + ) as trainer: + trainer.train( + workflow="examples.openai_agents.train_math_marti_shared.OpenAIAgentWorkflow", + workflow_kwargs=workflow_kwargs, + eval_workflow="examples.openai_agents.train_math_marti_shared.OpenAIAgentWorkflow", + eval_workflow_kwargs=eval_workflow_kwargs, + ) + + +if __name__ == "__main__": + import sys + + main(sys.argv[1:])