Skip to content

Sunnysoonn/grpo_fromscrach

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

GRPO from Scratch

一个从零实现的 GRPO (Group Relative Policy Optimization) 强化学习训练框架,用于训练语言模型。本项目采用 Actor-Critic 架构,将采样和训练分离到不同的 GPU 上,实现高效的分布式强化学习训练。

📋 目录

🎯 项目简介

本项目实现了 GRPO (Group Relative Policy Optimization) 算法,这是一种用于强化学习从人类反馈 (RLHF) 训练语言模型的方法。与传统的 PPO 不同,GRPO 使用组内相对优势归一化,能够更稳定地训练语言模型。

主要应用场景:数学推理任务(GSM8K 数据集)

✨ 核心特性

  • 分离式架构:采样器(Actor)和训练器(Learner)独立运行,可充分利用多 GPU 资源
  • 热加载机制:采样器自动检测并加载最新的 LoRA 权重,保证策略实时更新
  • DeepSpeed ZeRO-2:使用 DeepSpeed 进行分布式训练,支持大模型训练
  • LoRA 微调:采用参数高效微调(PEFT),只训练约 3% 的参数
  • ZeroMQ 通信:高效的进程间通信,支持大规模并行采样
  • 双算法支持:同时实现 GRPO 和 GSPO 两种损失函数

📁 项目结构

grpo_fromscrach/
├── sampling_worker.py    # 采样工作器(Actor)- 负责生成轨迹
├── training_worker.py    # 训练工作器(Learner)- 负责模型更新
├── data.py               # 数据处理模块 - GSM8K 数据集加载
├── utils.py              # 工具函数 - GRPO/GSPO 损失函数、奖励函数
├── run.sh                # 启动脚本
├── ds_config.json        # DeepSpeed 配置文件
├── requirements.txt      # Python 依赖包列表
└── README.md             # 本文件

🔧 环境要求

硬件要求

  • GPU: 至少 3 张 GPU(推荐 2 张用于训练,1 张用于采样)
  • 显存: 每张 GPU 至少 8GB(取决于模型大小)

软件依赖

方式 1:使用 requirements.txt(推荐)

pip install -r requirements.txt

方式 2:手动安装

# Python 3.8+
pip install torch torchvision torchaudio
pip install transformers
pip install peft
pip install deepspeed
pip install datasets
pip install pyzmq
pip install accelerate

模型准备

本项目默认使用 Qwen/Qwen2.5-1.5B-Instruct 模型。你可以:

  1. 从 HuggingFace 自动下载(首次运行时会自动下载)
  2. 使用本地模型:将模型放在 ./Qwen2.5-1.5B-Instruct 目录下

🚀 快速开始

1. 克隆项目

git clone https://github.com/Sunnysoonn/grpo_fromscrach.git
cd grpo_fromscrach

2. 安装依赖

pip install -r requirements.txt

3. 运行训练

# 终端 1: 启动采样器
export CUDA_VISIBLE_DEVICES=2
python sampling_worker.py \
    --model_path "./Qwen2.5-1.5B-Instruct" \
    --lora_path "./checkpoints" \
    --batch_size 8 \
    --port 5555

# 终端 2: 启动训练器
export CUDA_VISIBLE_DEVICES=0,1
deepspeed --num_gpus=2 training_worker.py \
    --model_path "./Qwen2.5-1.5B-Instruct" \
    --save_path "./checkpoints" \
    --ds_config "ds_config.json" \
    --group_size 4

⚙️ 配置说明

采样器参数 (sampling_worker.py)

参数 默认值 说明
--model_path Qwen/Qwen2.5-1.5B-Instruct 基础模型路径
--lora_path ./checkpoints_lora LoRA 权重保存路径
--batch_size 8 采样批次大小
--max_new_tokens 256 最大生成长度
--device cuda:0 使用的 GPU 设备
--port 5555 ZeroMQ 通信端口

训练器参数 (training_worker.py)

参数 默认值 说明
--model_path Qwen/Qwen2.5-1.5B-Instruct 基础模型路径
--save_path ./checkpoints 模型保存路径
--ds_config ds_config.json DeepSpeed 配置文件
--port 5555 ZeroMQ 通信端口(需与采样器一致)
--save_steps 10 每 N 步保存一次检查点
--group_size 4 GRPO 组大小

DeepSpeed 配置 (ds_config.json)

主要配置项:

  • ZeRO Stage 2: 优化器状态分片
  • BF16 混合精度: 节省显存并加速训练
  • 梯度裁剪: 防止梯度爆炸
  • 学习率: 1e-6(可根据需要调整)

🧠 核心算法

GRPO (Group Relative Policy Optimization)

GRPO 是一种改进的强化学习算法,核心思想是:

  1. 组内归一化:将批次内的样本分组,在组内进行优势归一化

    normalized_advantage = (reward - group_mean) / (group_std + epsilon)
    
  2. Token 级重要性采样:在生成序列的每个 token 级别计算重要性采样比率

  3. PPO Clipping:使用 PPO 的裁剪机制防止策略更新过大

GSPO (Group Sequential Policy Optimization)

GSPO 是 GRPO 的变体,使用序列级几何平均重要性采样,更适合长序列生成。

奖励函数

本项目使用混合奖励函数:

  • 格式奖励 (+1.25): 检查输出是否符合 <think>...</think><answer>...</answer> 格式
  • 答案奖励 (+1.0): 检查答案是否正确

🏗️ 架构设计

Actor-Critic 分离架构

┌─────────────────┐         ZeroMQ         ┌─────────────────┐
│ Sampling Worker │ ──────────────────────> │ Training Worker │
│   (GPU 2)       │      (Port 5555)        │   (GPU 0,1)     │
│                 │                          │                 │
│ - 生成轨迹       │                          │ - 接收数据      │
│ - 计算 LogProb  │                          │ - 计算 Loss     │
│ - 计算 Reward    │                          │ - 更新模型      │
│ - 热加载 LoRA   │<─────────────────────────│ - 保存 LoRA     │
└─────────────────┘      (文件系统)          └─────────────────┘

工作流程

  1. 采样阶段(Sampling Worker):

    • 从数据集加载 prompt
    • 使用 Old Policy 生成响应
    • 计算 Old Log Prob 和 Ref Log Prob
    • 计算奖励
    • 通过 ZeroMQ 发送轨迹数据
  2. 训练阶段(Training Worker):

    • Rank 0 接收数据并广播给所有 rank
    • 使用 DeepSpeed 进行分布式训练
    • 计算 New Policy Log Prob
    • 计算 GRPO/GSPO Loss
    • 反向传播并更新模型
    • 定期保存 LoRA 权重
  3. 热加载机制

    • 采样器定期检查 LoRA 权重文件
    • 检测到更新后自动重新加载
    • 保证采样使用的是最新策略

📊 使用示例

监控训练进度

# 查看采样器日志
tail -f sampler.log

# 查看训练器输出
# 训练器会在控制台输出 Loss 和检查点信息

自定义奖励函数

编辑 utils.py 中的 reward_function 函数:

def reward_function(response: str, answer: str) -> Dict[str, float]:
    # 实现你的奖励逻辑
    format_reward = check_format(response)
    answer_reward = check_answer(response, answer)
    return {"reward": format_reward + answer_reward}

切换算法

training_worker.py 中修改损失函数调用:

# 使用 GRPO
loss = grpo_loss(...)

# 或使用 GSPO
loss = gspo_loss(...)

❓ 常见问题

Q1: 采样器和训练器无法通信?

A: 确保:

  • 采样器先启动(或同时启动)
  • 端口号一致(默认 5555)
  • 防火墙未阻止端口

Q2: 显存不足?

A: 尝试:

  • 减小 batch_size
  • 减小 max_new_tokens
  • 启用 DeepSpeed ZeRO-3(修改 ds_config.json
  • 使用更小的模型

Q3: LoRA 权重未更新?

A: 检查:

  • --save_path 路径是否正确
  • --lora_path 路径是否匹配
  • 文件系统权限

Q4: 训练 Loss 不下降?

A: 可能原因:

  • 学习率过大/过小(调整 ds_config.json 中的 lr
  • 奖励函数设计不合理
  • 组大小设置不当(调整 --group_size

📝 开发说明

代码结构

  • sampling_worker.py: 采样逻辑,包括模型加载、生成、LogProb 计算
  • training_worker.py: 训练逻辑,包括数据接收、分布式训练、模型保存
  • data.py: 数据集加载和预处理
  • utils.py: 核心算法实现(GRPO/GSPO Loss、奖励函数)

扩展建议

  1. 支持更多数据集:修改 data.py 添加新的数据集类
  2. 实现其他算法:在 utils.py 中添加新的损失函数
  3. 优化通信:使用更高效的序列化方法(如 MessagePack)
  4. 添加评估:定期在验证集上评估模型性能

🙏 致谢


注意:本项目为学习和研究目的实现,生产环境使用请谨慎评估。

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors