一个从零实现的 GRPO (Group Relative Policy Optimization) 强化学习训练框架,用于训练语言模型。本项目采用 Actor-Critic 架构,将采样和训练分离到不同的 GPU 上,实现高效的分布式强化学习训练。
本项目实现了 GRPO (Group Relative Policy Optimization) 算法,这是一种用于强化学习从人类反馈 (RLHF) 训练语言模型的方法。与传统的 PPO 不同,GRPO 使用组内相对优势归一化,能够更稳定地训练语言模型。
主要应用场景:数学推理任务(GSM8K 数据集)
- 分离式架构:采样器(Actor)和训练器(Learner)独立运行,可充分利用多 GPU 资源
- 热加载机制:采样器自动检测并加载最新的 LoRA 权重,保证策略实时更新
- DeepSpeed ZeRO-2:使用 DeepSpeed 进行分布式训练,支持大模型训练
- LoRA 微调:采用参数高效微调(PEFT),只训练约 3% 的参数
- ZeroMQ 通信:高效的进程间通信,支持大规模并行采样
- 双算法支持:同时实现 GRPO 和 GSPO 两种损失函数
grpo_fromscrach/
├── sampling_worker.py # 采样工作器(Actor)- 负责生成轨迹
├── training_worker.py # 训练工作器(Learner)- 负责模型更新
├── data.py # 数据处理模块 - GSM8K 数据集加载
├── utils.py # 工具函数 - GRPO/GSPO 损失函数、奖励函数
├── run.sh # 启动脚本
├── ds_config.json # DeepSpeed 配置文件
├── requirements.txt # Python 依赖包列表
└── README.md # 本文件
- GPU: 至少 3 张 GPU(推荐 2 张用于训练,1 张用于采样)
- 显存: 每张 GPU 至少 8GB(取决于模型大小)
方式 1:使用 requirements.txt(推荐)
pip install -r requirements.txt方式 2:手动安装
# Python 3.8+
pip install torch torchvision torchaudio
pip install transformers
pip install peft
pip install deepspeed
pip install datasets
pip install pyzmq
pip install accelerate本项目默认使用 Qwen/Qwen2.5-1.5B-Instruct 模型。你可以:
- 从 HuggingFace 自动下载(首次运行时会自动下载)
- 使用本地模型:将模型放在
./Qwen2.5-1.5B-Instruct目录下
git clone https://github.com/Sunnysoonn/grpo_fromscrach.git
cd grpo_fromscrachpip install -r requirements.txt# 终端 1: 启动采样器
export CUDA_VISIBLE_DEVICES=2
python sampling_worker.py \
--model_path "./Qwen2.5-1.5B-Instruct" \
--lora_path "./checkpoints" \
--batch_size 8 \
--port 5555
# 终端 2: 启动训练器
export CUDA_VISIBLE_DEVICES=0,1
deepspeed --num_gpus=2 training_worker.py \
--model_path "./Qwen2.5-1.5B-Instruct" \
--save_path "./checkpoints" \
--ds_config "ds_config.json" \
--group_size 4| 参数 | 默认值 | 说明 |
|---|---|---|
--model_path |
Qwen/Qwen2.5-1.5B-Instruct |
基础模型路径 |
--lora_path |
./checkpoints_lora |
LoRA 权重保存路径 |
--batch_size |
8 | 采样批次大小 |
--max_new_tokens |
256 | 最大生成长度 |
--device |
cuda:0 |
使用的 GPU 设备 |
--port |
5555 | ZeroMQ 通信端口 |
| 参数 | 默认值 | 说明 |
|---|---|---|
--model_path |
Qwen/Qwen2.5-1.5B-Instruct |
基础模型路径 |
--save_path |
./checkpoints |
模型保存路径 |
--ds_config |
ds_config.json |
DeepSpeed 配置文件 |
--port |
5555 | ZeroMQ 通信端口(需与采样器一致) |
--save_steps |
10 | 每 N 步保存一次检查点 |
--group_size |
4 | GRPO 组大小 |
主要配置项:
- ZeRO Stage 2: 优化器状态分片
- BF16 混合精度: 节省显存并加速训练
- 梯度裁剪: 防止梯度爆炸
- 学习率: 1e-6(可根据需要调整)
GRPO 是一种改进的强化学习算法,核心思想是:
-
组内归一化:将批次内的样本分组,在组内进行优势归一化
normalized_advantage = (reward - group_mean) / (group_std + epsilon) -
Token 级重要性采样:在生成序列的每个 token 级别计算重要性采样比率
-
PPO Clipping:使用 PPO 的裁剪机制防止策略更新过大
GSPO 是 GRPO 的变体,使用序列级几何平均重要性采样,更适合长序列生成。
本项目使用混合奖励函数:
- 格式奖励 (+1.25): 检查输出是否符合
<think>...</think><answer>...</answer>格式 - 答案奖励 (+1.0): 检查答案是否正确
┌─────────────────┐ ZeroMQ ┌─────────────────┐
│ Sampling Worker │ ──────────────────────> │ Training Worker │
│ (GPU 2) │ (Port 5555) │ (GPU 0,1) │
│ │ │ │
│ - 生成轨迹 │ │ - 接收数据 │
│ - 计算 LogProb │ │ - 计算 Loss │
│ - 计算 Reward │ │ - 更新模型 │
│ - 热加载 LoRA │<─────────────────────────│ - 保存 LoRA │
└─────────────────┘ (文件系统) └─────────────────┘
-
采样阶段(Sampling Worker):
- 从数据集加载 prompt
- 使用 Old Policy 生成响应
- 计算 Old Log Prob 和 Ref Log Prob
- 计算奖励
- 通过 ZeroMQ 发送轨迹数据
-
训练阶段(Training Worker):
- Rank 0 接收数据并广播给所有 rank
- 使用 DeepSpeed 进行分布式训练
- 计算 New Policy Log Prob
- 计算 GRPO/GSPO Loss
- 反向传播并更新模型
- 定期保存 LoRA 权重
-
热加载机制:
- 采样器定期检查 LoRA 权重文件
- 检测到更新后自动重新加载
- 保证采样使用的是最新策略
# 查看采样器日志
tail -f sampler.log
# 查看训练器输出
# 训练器会在控制台输出 Loss 和检查点信息编辑 utils.py 中的 reward_function 函数:
def reward_function(response: str, answer: str) -> Dict[str, float]:
# 实现你的奖励逻辑
format_reward = check_format(response)
answer_reward = check_answer(response, answer)
return {"reward": format_reward + answer_reward}在 training_worker.py 中修改损失函数调用:
# 使用 GRPO
loss = grpo_loss(...)
# 或使用 GSPO
loss = gspo_loss(...)A: 确保:
- 采样器先启动(或同时启动)
- 端口号一致(默认 5555)
- 防火墙未阻止端口
A: 尝试:
- 减小
batch_size - 减小
max_new_tokens - 启用 DeepSpeed ZeRO-3(修改
ds_config.json) - 使用更小的模型
A: 检查:
--save_path路径是否正确--lora_path路径是否匹配- 文件系统权限
A: 可能原因:
- 学习率过大/过小(调整
ds_config.json中的lr) - 奖励函数设计不合理
- 组大小设置不当(调整
--group_size)
sampling_worker.py: 采样逻辑,包括模型加载、生成、LogProb 计算training_worker.py: 训练逻辑,包括数据接收、分布式训练、模型保存data.py: 数据集加载和预处理utils.py: 核心算法实现(GRPO/GSPO Loss、奖励函数)
- 支持更多数据集:修改
data.py添加新的数据集类 - 实现其他算法:在
utils.py中添加新的损失函数 - 优化通信:使用更高效的序列化方法(如 MessagePack)
- 添加评估:定期在验证集上评估模型性能
- DeepSeek - GRPO 算法原始论文
- HuggingFace - Transformers 和 PEFT 库
- Microsoft DeepSpeed - 分布式训练框架
注意:本项目为学习和研究目的实现,生产环境使用请谨慎评估。