Skip to content

ShareableXue/mini-trl

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

11 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Mini-TRL: 简化版后训练框架

一个教育性质的后训练(Post-Training)框架,旨在用最简洁的代码帮助初学者理解TRL的核心思想。

📚 项目目标

本项目从 HuggingFace TRL 中提取核心逻辑,保留最本质的算法实现,剔除生产级别的复杂性,帮助你理解:

  1. SFT (Supervised Fine-Tuning): 监督微调的本质
  2. Reward Model: 如何训练奖励模型
  3. DPO (Direct Preference Optimization): 直接偏好优化的数学原理
  4. GRPO (Group Relative Policy Optimization): 群组相对策略优化

🏗️ 项目结构

mini-trl/
├── mini_trl/
│   ├── __init__.py
│   ├── configs/                    # 配置类
│   │   ├── __init__.py
│   │   ├── sft_config.py
│   │   ├── dpo_config.py
│   │   ├── reward_config.py
│   │   └── grpo_config.py
│   ├── trainers/                   # 训练器
│   │   ├── __init__.py
│   │   ├── base_trainer.py         # 基础训练器
│   │   ├── sft_trainer.py          # SFT训练器
│   │   ├── reward_trainer.py       # 奖励模型训练器
│   │   ├── dpo_trainer.py          # DPO训练器
│   │   └── grpo_trainer.py         # GRPO训练器
│   ├── data/                       # 数据处理
│   │   ├── __init__.py
│   │   ├── data_utils.py           # 数据工具函数
│   │   └── collators.py            # DataCollator
│   ├── models/                     # 模型工具
│   │   ├── __init__.py
│   │   └── utils.py                # 参考模型创建等
│   └── rewards/                    # 奖励函数
│       ├── __init__.py
│       └── basic_rewards.py        # 基础奖励函数
├── examples/                       # 示例脚本
│   ├── sft_example.py
│   ├── reward_example.py
│   ├── dpo_example.py
│   └── grpo_example.py
├── notebooks/                      # Jupyter教程
│   ├── 01_sft_tutorial.ipynb
│   ├── 02_reward_model.ipynb
│   ├── 03_dpo_explained.ipynb
│   └── 04_grpo_explained.ipynb
├── pyproject.toml                  # 包配置 (pip install -e .)
├── requirements.txt
└── README.md

🏛️ 架构设计

所有训练器继承自 BaseTrainer,共享通用功能:

            transformers.Trainer
                    ↑
              BaseTrainer (共用功能)
                    ↑
    ┌───────┬───────┼───────┬────────┐
    │       │       │       │        │
  SFT     DPO   Reward   GRPO    ...

BaseTrainer 提供:

  • 模型加载: load_model(), load_tokenizer() 静态方法
  • Dropout控制: disable_dropout() 用于偏好学习
  • 指标记录: record_metric(), record_metrics() 统一的日志系统
  • 日志输出: 重写 log() 自动平均并清除指标
# 示例:使用 BaseTrainer 的静态方法
from mini_trl import BaseTrainer

model = BaseTrainer.load_model("Qwen/Qwen2.5-0.5B")
tokenizer = BaseTrainer.load_tokenizer("Qwen/Qwen2.5-0.5B")
BaseTrainer.disable_dropout(model)  # DPO/Reward 推荐

🔑 核心概念

1. SFT (Supervised Fine-Tuning)

最基础的微调方法,本质是语言模型的继续预训练:

Loss = -log P(completion | prompt)

2. Reward Model

训练一个模型来判断回答的好坏:

Loss = -log σ(r(chosen) - r(rejected))

3. DPO (Direct Preference Optimization)

直接用偏好数据优化策略,无需训练奖励模型:

Loss = -log σ(β * (log π(chosen)/π_ref(chosen) - log π(rejected)/π_ref(rejected)))

其中:

  • π: 当前策略(正在训练的模型)
  • π_ref: 参考策略(冻结的初始模型)
  • β: 控制偏离参考模型的程度

4. GRPO (Group Relative Policy Optimization)

生成多个回答,用组内相对奖励进行优化:

1. 对每个prompt生成G个completion
2. 计算每个completion的reward
3. 计算组内normalized reward: Â = (r - mean(r)) / std(r)
4. Loss = -Â * log π(completion) + β * KL(π || π_ref)

📊 算法对比

方法 需要奖励模型 需要生成 数据格式 复杂度
SFT prompt + completion
Reward chosen + rejected
DPO chosen + rejected
GRPO ✅ (或函数) prompt + reward_fn

🚀 快速开始

安装

方式一:可编辑安装(推荐)

cd mini-trl
pip install -e .

安装后可以在任何地方使用:

from mini_trl import SFTTrainer, DPOTrainer, BaseTrainer

方式二:仅安装依赖

pip install -r requirements.txt
# 需要在 mini-trl 目录下运行脚本

SFT 示例

from mini_trl import SFTTrainer, SFTConfig
from datasets import load_dataset

dataset = load_dataset("trl-lib/Capybara", split="train[:1000]")

trainer = SFTTrainer(
    model="Qwen/Qwen2.5-0.5B",
    args=SFTConfig(output_dir="./sft_output"),
    train_dataset=dataset,
)
trainer.train()

DPO 示例

from mini_trl import DPOTrainer, DPOConfig
from datasets import load_dataset

dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train[:1000]")

trainer = DPOTrainer(
    model="Qwen/Qwen2.5-0.5B-Instruct",
    args=DPOConfig(output_dir="./dpo_output", beta=0.1),
    train_dataset=dataset,
)
trainer.train()

GRPO 示例

from mini_trl import GRPOTrainer, GRPOConfig
from datasets import load_dataset

def accuracy_reward(completions, **kwargs):
    # 自定义奖励函数
    return [1.0 if "正确" in c else 0.0 for c in completions]

dataset = load_dataset("your-dataset", split="train")

trainer = GRPOTrainer(
    model="Qwen/Qwen2.5-0.5B-Instruct",
    reward_funcs=accuracy_reward,
    args=GRPOConfig(output_dir="./grpo_output"),
    train_dataset=dataset,
)
trainer.train()

📖 学习路线

  1. 理解SFT,阅读 sft_trainer.py
  2. 理解Reward Model,阅读 reward_trainer.py
  3. 深入DPO算法,阅读论文 + dpo_trainer.py
  4. 掌握GRPO/PPO,阅读 grpo_trainer.py

📄 参考论文

🙏 致谢

本项目基于 HuggingFace TRL 简化而来,感谢TRL团队的开源贡献。

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors