diff --git a/torchtitan/grpo/test/gsm8k_server.py b/torchtitan/grpo/test/gsm8k_server.py new file mode 100644 index 0000000000..c9b6c17042 --- /dev/null +++ b/torchtitan/grpo/test/gsm8k_server.py @@ -0,0 +1,396 @@ +import random +import time +from typing import Dict, List, Optional, Tuple, TypedDict, Union + +from datasets import load_dataset +from latex2sympy2_extended import NormalizationConfig +from math_verify import LatexExtractionConfig, parse, verify +from tqdm.asyncio import tqdm_asyncio + +from atroposlib.envs.base import ( + APIServerConfig, + BaseEnv, + BaseEnvConfig, + ScoredDataGroup, +) +from atroposlib.type_definitions import Item + +system_prompt = ( + "You are a deep thinking AI, you may use extremely long chains of thought " + "to deeply consider the problem and deliberate with yourself via systematic " + "reasoning processes to help come to a correct solution prior to answering. " + "You should enclose your thoughts and internal monologue inside " + "tags, and then provide your solution or response to the problem.\n\n" +) + +system_prompt += """You are allocated a maximum of 2048 tokens, please strive to use less. + +You will then provide your answer like this: \\boxed{your answer here} +It is important that you provide your answer in the correct format. +If you do not, you will not receive credit for your answer. +So please end your answer with \\boxed{your answer here}""" + + +class GSM8kRow(TypedDict): + question: str + answer: str + + +class GSM8kEnv(BaseEnv): + + name = "gsm8k" + + def __init__( + self, + config: BaseEnvConfig, + server_configs: List[APIServerConfig], + slurm=True, + testing=False, + ): + super().__init__(config, server_configs, slurm, testing) + print(f"DEBUG: GSM8kEnv initialized with {len(self.server.servers)} servers") + for i, server in enumerate(self.server.servers): + if hasattr(server, 'config'): + print(f"DEBUG: Server {i}: {server.config.base_url}") + self.percent_correct_buffer = list() + self.eval_metrics = list() + # Add tracking for wandb visualizations + self.rollouts_for_wandb = [] + self.completion_lengths = [] + + @classmethod + def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]: + env_config = BaseEnvConfig( + tokenizer_name="Qwen/Qwen2.5-7B", + group_size=8, + use_wandb=True, + rollout_server_url="http://localhost:8000", + total_steps=1000, + batch_size=12, + steps_per_eval=100, + max_token_length=2048, + wandb_name="gsm8k_qwen3_test", + ) + server_configs = [ + APIServerConfig( + model_name="Qwen/Qwen2.5-7B", + base_url="http://localhost:9001/v1", + api_key="x", + num_requests_for_eval=256, + ), + ] + + return env_config, server_configs + + async def wandb_log(self, wandb_metrics: Optional[Dict] = None): + if wandb_metrics is None: + wandb_metrics = {} + + # Try to calculate percent_correct, pass if there's a division by zero + try: + wandb_metrics["train/percent_correct"] = sum( + self.percent_correct_buffer + ) / len(self.percent_correct_buffer) + except ZeroDivisionError: + # Skip if buffer is empty + pass + + self.percent_correct_buffer = list() + for item in self.eval_metrics: + wandb_metrics[item[0]] = item[1] + self.eval_metrics = list() + # Call the parent method to handle the server metrics + await super().wandb_log(wandb_metrics) + + async def setup(self): + self.train = load_dataset("gsm8k", "main", split="train").shuffle(seed=42) + test_data = load_dataset("gsm8k", "main", split="test").shuffle(seed=42) + self.test = list() + for item in test_data: + self.test.append( + { + "question": item["question"], + "gold_answer": item["answer"] + .split("#")[-1] + .strip() + .replace(",", ""), + } + ) + self.iter = 0 + + def save_checkpoint(self, step, data=None): + if data is None: + data = {} + data["iter"] = self.iter + super().save_checkpoint(step, data) + + async def rollout_and_score_eval(self, question: str, answer: str) -> dict: + """Rollout and score evaluation with detailed sample data collection.""" + + async with self.server.managed_server(tokenizer=self.tokenizer) as managed: + completion = await managed.chat_completion( + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": question}, + ], + n=1, + max_tokens=self.config.max_token_length, + temperature=0.6, + ) + + response_content = completion.choices[0].message.content + + # Parse gold answer + gold_parsed = parse( + "\\boxed{" + answer + "}", + extraction_mode="first_match", + extraction_config=[LatexExtractionConfig()], + ) + + # Parse model answer + answer_parsed = parse( + response_content.split("")[-1], + extraction_config=[ + LatexExtractionConfig( + normalization_config=NormalizationConfig( + nits=False, + malformed_operators=False, + basic_latex=True, + equations=True, + boxed="all", + units=True, + ), + boxed_match_priority=0, + try_extract_without_anchor=False, + ) + ], + extraction_mode="first_match", + ) + + score = 1 if verify(answer_parsed, gold_parsed) else 0 + + sample = { + "messages": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": question}, + {"role": "assistant", "content": response_content}, + ], + "question": question, + "gold_answer": answer, + "gold_parsed": str(gold_parsed) if gold_parsed else None, + "model_parsed": str(answer_parsed) if answer_parsed else None, + "score": int(score), + "correct": bool(score), + "finish_reason": completion.choices[0].finish_reason, + "response_after_think": ( + response_content.split("")[-1] + if "" in response_content + else response_content + ), + } + + return {"score": score, "sample": sample} + + async def evaluate(self, *args, **kwargs): + start_time = time.time() + + eval_tasks = [] + for item in self.test: + eval_tasks.append( + self.rollout_and_score_eval(item["question"], item["gold_answer"]) + ) + results = await tqdm_asyncio.gather(*eval_tasks) + + # Extract scores and samples + scores = [result["score"] for result in results] + samples = [result["sample"] for result in results] + + percent_correct = sum(scores) / len(scores) + + end_time = time.time() + + # Add to existing metrics for wandb + self.eval_metrics.append(("eval/percent_correct", percent_correct)) + + # Log evaluation results + eval_metrics = { + "eval/percent_correct": percent_correct, + } + + await self.evaluate_log( + metrics=eval_metrics, + samples=samples, + start_time=start_time, + end_time=end_time, + generation_parameters={ + "temperature": 0.0, + "max_tokens": self.config.max_token_length, + }, + ) + + async def collect_trajectories( + self, item: GSM8kRow + ) -> Tuple[ScoredDataGroup, list[Item]]: + print(f"DEBUG: collect_trajectories() called for question: {item['question'][:80]}...") + try: + user_message = {"role": "user", "content": item["question"]} + gold_answer = ( + "\\boxed{" + item["answer"].split("#")[-1].strip().replace(",", "") + "}" + ) + + print(f"DEBUG: About to call managed_server...") + async with self.server.managed_server(tokenizer=self.tokenizer) as managed: + print(f"DEBUG: Inside managed_server context, about to call chat_completion...") + + chat_completions = await managed.chat_completion( + messages=[{"role": "system", "content": system_prompt}, user_message], + n=self.config.group_size, + max_tokens=self.config.max_token_length, + temperature=1.0, + ) + print(f"DEBUG: chat_completion returned, got {len(chat_completions.choices)} completions") + + state = managed.get_state() + nodes = state["nodes"] + print(f"DEBUG: Got state with {len(nodes)} nodes") + except Exception as e: + print(f"ERROR in collect_trajectories: {type(e).__name__}: {e}") + import traceback + traceback.print_exc() + return None, [] + + to_score = list() + to_backlog = list() + for i, chat_completion in enumerate(chat_completions.choices): + messages = ( + {"role": "system", "content": system_prompt}, + user_message, + {"role": "assistant", "content": chat_completion.message.content}, + ) + to_score.append( + { + "messages": messages, + "gold_answer": gold_answer, + "finish_reason": chat_completion.finish_reason, + "tokens": nodes[i].tokens, + "masks": nodes[i].masked_tokens, + "logprobs": nodes[i].logprobs, + } + ) + to_postprocess = await self.score(to_score) + return to_postprocess, to_backlog + + async def score( + self, rollout_group_data + ) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]: + print(f"DEBUG: score() called with {len(rollout_group_data)} rollouts") + scores = ScoredDataGroup() + scores["tokens"] = list() + scores["masks"] = list() + scores["scores"] = list() + scores["inference_logprobs"] = list() + gold_parsed = parse( + rollout_group_data[0]["gold_answer"], + extraction_mode="first_match", + extraction_config=[LatexExtractionConfig()], + ) + print(f"DEBUG: Gold answer parsed: {len(gold_parsed)} elements") + if len(gold_parsed) != 0: + # We require the answer to be provided in correct latex (no malformed operators) + random.shuffle(rollout_group_data) + for item in rollout_group_data: + # print(item[0][-1]["content"]) + answer_parsed = parse( + item["messages"][-1]["content"].split("")[-1], + extraction_config=[ + LatexExtractionConfig( + normalization_config=NormalizationConfig( + nits=False, + malformed_operators=False, + basic_latex=True, + equations=True, + boxed="all", + units=True, + ), + # Ensures that boxed is tried first + boxed_match_priority=0, + try_extract_without_anchor=False, + ) + ], + extraction_mode="first_match", + ) + # Reward 1 if the content is the same as the ground truth, 0 otherwise + reward = verify(answer_parsed, gold_parsed) + + tokens = item["tokens"] + masks = item["masks"] + logprobs = item["logprobs"] + + # remove obviously bad examples + num_valid_tokens = len([1 for i in masks if i != -100]) + if num_valid_tokens < 5: # Lowered from 10 to 5 to be less strict + print(f"Filtering out sample with only {num_valid_tokens} valid tokens") + continue + scores["tokens"].append(tokens) + scores["masks"].append(masks) + scores["inference_logprobs"].append(logprobs) + scores["scores"].append(1.0 if reward else -1.0) + + if len(scores["tokens"]) >= self.config.group_size: + break + + # Check if we have enough valid samples after filtering + if len(scores["tokens"]) < self.config.group_size: + print(f"Warning: Only got {len(scores['tokens'])} samples after filtering, need {self.config.group_size}") + return None + + for score in scores["scores"]: + self.percent_correct_buffer.append(max(score, 0)) + + # check if all the same + # print(scores['scores']) + if all([score == 1 for score in scores["scores"]]): + # Do length penalty :) + token_lengths = [len(token) for token in scores["tokens"]] + if max(token_lengths) == 0: + # What? But don't want to crash a run so just in case... + return None + + # Get max allowed token length from config + max_allowed_length = self.config.max_token_length + # Set threshold at 50% of max_token_length - no penalty below this + length_threshold = max_allowed_length * 0.5 + + # Apply modified length penalty with threshold + scores["scores"] = [] + for length in token_lengths: + if length <= length_threshold: + # No penalty for responses under threshold + scores["scores"].append(1.0) + else: + # Calculate how far we are between threshold and max as a percentage + percentage_of_range = (length - length_threshold) / ( + max_allowed_length - length_threshold + ) + # Cap at 1.0 in case length exceeds max_allowed_length + percentage_of_range = min(percentage_of_range, 1.0) + # Apply linear penalty scaling from 1.0 down to 0.0 + scores["scores"].append(1.0 - percentage_of_range) + # allow training even when all scores are identical + # if all([scores["scores"][0] == score for score in scores["scores"]]): + # return None # If all the same, we return None + print(f"DEBUG: Returning scores with {len(scores['tokens'])} samples, scores: {scores['scores']}") + return scores + else: + # If the gold solution is not parseable, we return None + print("DEBUG: Gold solution not parseable, returning None") + return None + + async def get_next_item(self) -> GSM8kRow: + next_item = self.train[self.iter % len(self.train)] + self.iter += 1 + return next_item + + +if __name__ == "__main__": + GSM8kEnv.cli() diff --git a/torchtitan/grpo/test/test_config.toml b/torchtitan/grpo/test/test_config.toml new file mode 100644 index 0000000000..e5ca401c41 --- /dev/null +++ b/torchtitan/grpo/test/test_config.toml @@ -0,0 +1,100 @@ +# torchtitan config.toml - GSM8k Test Configuration +# Test configuration for Qwen3-1.7B on GSM8k environment + +[job] +dump_folder = "/tmp/gsm8k_test_run" +description = "gsm8k_qwen3_1.7b_test" + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 100 + +[metrics] +log_freq = 1 +enable_tensorboard = false +save_tb_folder = "tb" +enable_wandb = true + +[model] +name = "qwen2" +flavor = "7B" +tokenizer_path = "Qwen/Qwen2.5-7B" + +[optimizer] +name = "AdamW" +lr = 1e-6 +beta1 = 0.9 +beta2 = 0.95 +weight_decay = 0.1 + +[lr_scheduler] +warmup_steps = 10 +decay_type = "linear" +decay_ratio = 0.1 + +[training] +local_batch_size = 1 +seq_len = 2048 +global_batch_size = 32 # 8 rollouts * 4 gradient accumulation steps + +max_norm = 0.25 # grad norm clipping +steps = 100 # test run - 100 steps to validate pipeline + +[compile] +enable = false +components = ["model", "loss"] + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 # Use all available GPUs for data parallel sharding +tensor_parallel_degree = 1 +context_parallel_degree = 1 +pipeline_parallel_degree = 1 + +[grpo] +sglang_tp = 1 +sglang_urls = [] +sglang_slurm_num_nodes = 1 +sglang_port = 26756 + +# GRPO hyperparameters +logit_loss_weight = 0.0 +entropy_loss_weight = 0.0000 +kl_beta = 0.000 +kl_estimator_type = "k3" +ref_model_ema = 0.999 +clip_ratio_lower_bound = 0.0003 +clip_ratio_upper_bound = 0.0004 +policy_ratio_type = "sequence" +pos_scaler = 1.00 +neg_scaler = 1.00 +grpo_by_token = true +scale_adv_by_len = false +num_microbatches = 2 +onpolicy_logp_threshold = 0.0 +rollout_is_level = "sequence" +rollout_is_mode = "truncate" +rollout_is_threshold = 4.0 + +# disabled for this test +ptx_mixin_batchsize = 0 +ptx_scale_by_tokens = false + +[checkpoint] +enable = true +folder = "checkpoints" +# Update this path to point to your Qwen3-1.7B checkpoint +initial_load_path = "/home/shared/torchtitan-conversions/qwen_2-5_7b" +initial_load_legacy = true +interval = 50 +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = 'selective' # ['none', 'selective', 'full'] +selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[quantize.linear.float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false diff --git a/torchtitan/grpo/test/test_full_rl.slurm b/torchtitan/grpo/test/test_full_rl.slurm new file mode 100644 index 0000000000..f2513da7a5 --- /dev/null +++ b/torchtitan/grpo/test/test_full_rl.slurm @@ -0,0 +1,52 @@ +#!/bin/bash +#SBATCH --job-name=grpo_full_test +#SBATCH --output=logs/%j.out +#SBATCH --error=logs/%j.err +#SBATCH --nodes=2 +#SBATCH --ntasks-per-node=1 +#SBATCH --exclusive +#SBATCH --gpus-per-task=8 +#SBATCH --cpus-per-task=64 + +# Create logs directory +mkdir -p logs/$SLURM_JOB_ID + +# Set ulimit higher +ulimit -n 32000 +export LOGDIR="$(pwd)/logs/${SLURM_JOB_ID}" + +# echo slurm nodes +echo "SLURM nodes: $SLURM_JOB_NODELIST" + +# Basic config stuff - pointing to test setup +export CONFIG_FILE="$(pwd)/torchtitan/grpo/test/test_config.toml" +export MODEL_NAME="/home/nightwing/Projects/torchtitan/tmp/qwen2.5-7b" +export PYTHON_SCRIPT="$(pwd)/torchtitan/grpo/test/gsm8k_server.py" +export PYTHON_ARGS="" +export TRAINING_ARGS="" +export NUM_TRAINING_NODES=1 +export NUM_INFERENCE_NODES=1 + +# NCCL settings +export NCCL_BUFFSIZE=33554432 +export CUDA_DEVICE_ORDER=PCI_BUS_ID +export NCCL_IB_AR_THRESHOLD=0 +export NCCL_IB_PCI_RELAXED_ORDERING=1 +export NCCL_IB_QPS_PER_CONNECTION=2 +export NCCL_IB_SPLIT_DATA_ON_QPS=0 +export NCCL_IGNORE_CPU_AFFINITY=1 + +# Define environment paths +export TRAIN_PATH="$(pwd)" +export TRAIN_ENV="/home/nightwing/miniconda3/envs/torchtitan/" +export VLLM_ENV="/home/nightwing/miniconda3/envs/vllm/" +export API_ENV="${TRAIN_ENV}" + +# Get head node info +nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) ) +nodes_array=($nodes) +head_node=${nodes_array[0]} +export head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) + +# Use the proven working vllm_launch.sh script +srun -l --export=ALL ./vllm_launch.sh diff --git a/vllm_launch.sh b/vllm_launch.sh index b84d3d3c0c..4a77950da5 100644 --- a/vllm_launch.sh +++ b/vllm_launch.sh @@ -10,7 +10,7 @@ if [[ "$SLURM_NODEID" -eq 0 ]]; then echo "Starting trajectory handler..." run-api > ${LOGDIR}/api.log 2>&1 & python $PYTHON_SCRIPT serve --slurm=True $PYTHON_ARGS > ${LOGDIR}/env_server.log 2>&1 & - deactivate + eactivate echo "Started trajectory handler..." fi echo $SLURM_NODEID ", " $NUM_TRAINING_NODES @@ -30,12 +30,13 @@ if [[ "$SLURM_NODEID" -lt "$NUM_TRAINING_NODES" ]]; then # export NCCL_P2P_DISABLE=1 # export NCCL_IB_DISABLE=1 + export NCCL_IB_DISABLE=1 + export NCCL_P2P_LEVEL=SYS + # debugging flags (optional) - export NCCL_DEBUG=WARN + export NCCL_DEBUG=INFO + export NCCL_DEBUG_SUBSYS=NET export PYTHONFAULTHANDLER=1 - # optional debug settings - # export NCCL_DEBUG=INFO - # NCCL_DEBUG_SUBSYS=INIT,GRAPH,ENV # export LD_LIBRARY_PATH=/opt/amazon/efa/lib:$LD_LIBRARY_PATH export LD_LIBRARY_PATH=/usr/local/lib/:$LD_LIBRARY_PATH @@ -44,10 +45,10 @@ if [[ "$SLURM_NODEID" -lt "$NUM_TRAINING_NODES" ]]; then # on your cluster you might need these: # set the network interface -# export NCCL_SOCKET_IFNAME="eth0,en,eth,em,bond" -# export NCCL_BUFFSIZE=2097152 -# export TORCH_DIST_INIT_BARRIER=1 -# export FI_EFA_SET_CUDA_SYNC_MEMOPS=0 + export NCCL_SOCKET_IFNAME=bond0 + export NCCL_BUFFSIZE=2097152 + export TORCH_DIST_INIT_BARRIER=1 + export FI_EFA_SET_CUDA_SYNC_MEMOPS=0 # dcgmi profile --pause # adjust sbatch --ntasks and sbatch --nodes above and --nnodes below @@ -69,6 +70,16 @@ else source ${VLLM_ENV}/bin/activate + # Set NCCL network settings for vLLM weight sync + export NCCL_SOCKET_IFNAME=bond0 + export NCCL_IB_DISABLE=1 + export NCCL_P2P_LEVEL=SYS + export NCCL_BUFFSIZE=2097152 + export TORCH_DIST_INIT_BARRIER=1 + export FI_EFA_SET_CUDA_SYNC_MEMOPS=0 + export NCCL_DEBUG=INFO + export NCCL_DEBUG_SUBSYS=NET + PORT_BASE=9000 # Start 8 vllm instances on GPUs 0-3