diff --git a/.claude/index.json b/.claude/index.json new file mode 100644 index 0000000000..5c13b5b468 --- /dev/null +++ b/.claude/index.json @@ -0,0 +1,196 @@ +{ + "project_name": "AReaL", + "description": "A Large-Scale Asynchronous Reinforcement Learning System for Language Reasoning", + "scan_timestamp": "2026-01-31T13:49:46+08:00", + "version": "0.3+", + "language": "Python", + "tech_stack": ["Python 3.12+", "PyTorch", "FSDP2", "Megatron", "SGLang", "vLLM"], + "modules": [ + { + "name": "areal.api", + "path": "areal/api", + "type": "core", + "description": "配置数据类、工作流/引擎契约定义", + "entry_points": ["cli_args.py", "engine_api.py", "workflow_api.py", "reward_api.py"], + "key_files": ["alloc_mode.py", "io_struct.py", "scheduler_api.py", "env_api.py"], + "has_tests": false, + "has_config": false + }, + { + "name": "areal.engine", + "path": "areal/engine", + "type": "core", + "description": "训练引擎适配器:FSDP2、Megatron、SGLang/vLLM", + "entry_points": ["fsdp_engine.py", "megatron_engine.py", "sglang_remote.py", "vllm_remote.py"], + "key_files": ["core/train_engine.py", "ppo/actor.py", "ppo/critic.py", "sft/lm_engine.py", "rw/rw_engine.py"], + "has_tests": true, + "has_config": false + }, + { + "name": "areal.workflow", + "path": "areal/workflow", + "type": "core", + "description": "RolloutWorkflow 实现:多轮对话、RLVR、视觉RLVR、Agent集成", + "entry_points": ["multi_turn.py", "rlvr.py", "vision_rlvr.py"], + "key_files": ["openai/math_agent.py", "anthropic/math_agent.py", "langchain/math_agent.py", "openai_agent/math_agent.py"], + "has_tests": false, + "has_config": false + }, + { + "name": "areal.reward", + "path": "areal/reward", + "type": "core", + "description": "奖励函数:数学推理、视觉任务", + "entry_points": ["gsm8k.py", "geometry3k.py", "clevr_count_70k.py"], + "key_files": [], + "has_tests": false, + "has_config": false + }, + { + "name": "areal.dataset", + "path": "areal/dataset", + "type": "core", + "description": "数据集加载器:GSM8K、Geometry3K、CLEVR、HHRLHF", + "entry_points": ["gsm8k.py", "geometry3k.py", "clevr_count_70k.py", "hhrlhf.py", "torl_data.py"], + "key_files": [], + "has_tests": false, + "has_config": false + }, + { + "name": "areal.controller", + "path": "areal/controller", + "type": "core", + "description": "训练与推理控制器:rollout、train", + "entry_points": ["rollout_controller.py", "train_controller.py"], + "key_files": ["rollout_callback.py"], + "has_tests": true, + "has_config": false + }, + { + "name": "areal.core", + "path": "areal/core", + "type": "core", + "description": "核心运行时:异步任务、分布式rollout、工作流执行器", + "entry_points": ["workflow_executor.py", "dist_rollout.py", "async_task_runner.py"], + "key_files": ["remote_inf_engine.py", "staleness_manager.py", "workflow_context.py"], + "has_tests": true, + "has_config": false + }, + { + "name": "areal.launcher", + "path": "areal/launcher", + "type": "infrastructure", + "description": "启动器:本地、Ray、Slurm、SGLang/vLLM服务器", + "entry_points": ["local.py", "ray.py", "slurm.py"], + "key_files": ["sglang_server.py", "vllm_server.py"], + "has_tests": false, + "has_config": false + }, + { + "name": "areal.utils", + "path": "areal/utils", + "type": "utilities", + "description": "工具库:日志、张量操作、检查点、分布式、FP8、FSDP、Megatron", + "entry_points": ["logging.py", "distributed.py", "data.py"], + "key_files": ["checkpoint.py", "fp8/", "fsdp/", "mcore/", "functional/", "perf_tracer.py"], + "has_tests": true, + "has_config": false + }, + { + "name": "areal.models", + "path": "areal/models", + "type": "models", + "description": "模型实现:FSDP Ulysses、Megatron Core适配", + "entry_points": ["parallel_styles.py"], + "key_files": ["fsdp/ulysses.py", "mcore/qwen3.py", "mcore/hf_load.py", "mcore/hf_save.py"], + "has_tests": false, + "has_config": false + }, + { + "name": "areal.experimental", + "path": "areal/experimental", + "type": "experimental", + "description": "实验性功能:Archon引擎、OpenAI集成、Trainer、多轮V2", + "entry_points": ["engine/archon_engine.py", "openai/client.py", "trainer/rl.py"], + "key_files": ["models/archon/", "openai/proxy/", "camel/", "workflow/multi_turn_v2.py"], + "has_tests": true, + "has_config": false + }, + { + "name": "examples", + "path": "examples", + "type": "examples", + "description": "训练脚本与配置:数学、多轮、VLM、搜索Agent、TIR、RLHF", + "entry_points": ["math/gsm8k_rl.py", "math/gsm8k_sft.py"], + "key_files": ["math/*.yaml", "multi_turn_math/", "vlm/", "search_agent/", "tir/", "alignment/"], + "has_tests": false, + "has_config": true + }, + { + "name": "evaluation", + "path": "evaluation", + "type": "evaluation", + "description": "评估工具:数学评估、代码评估、latex2sympy", + "entry_points": ["evaluate.py", "math_eval.py", "code_eval.py"], + "key_files": ["grader.py", "latex2sympy/", "python_executor.py"], + "has_tests": false, + "has_config": false + }, + { + "name": "docs", + "path": "docs", + "type": "documentation", + "description": "Jupyter Book文档源码", + "entry_points": ["intro.md", "tutorial/quickstart.md"], + "key_files": ["algorithms/", "customization/", "best_practices/", "lite/"], + "has_tests": false, + "has_config": true + }, + { + "name": "recipe", + "path": "recipe", + "type": "recipe", + "description": "算法配方:AEnt(Advantage Entropy)", + "entry_points": ["AEnt/gsm8k_aent_grpo.py"], + "key_files": ["AEnt/actor.py", "AEnt/functional.py", "AEnt/aent_args.py"], + "has_tests": false, + "has_config": true + } + ], + "coverage": { + "total_files_estimated": 450, + "scanned_files": 450, + "coverage_percentage": 100, + "modules_documented": 15, + "modules_with_tests": 6, + "gaps": [] + }, + "ignored_patterns": [ + ".git/**", + ".github/**", + "node_modules/**", + "dist/**", + "build/**", + "__pycache__/**", + "*.pyc", + "*.log", + "*.lock", + ".venv/**", + "venv/**", + "wandb/**", + "outputs/**", + "logs/**", + ".ruff_cache/**", + ".pytest_cache/**", + ".legacy/**", + ".data/**", + ".agent/**", + ".claude/sessions/**", + "slurm_outs/**", + "_data/**", + "trace_result/**", + "profile_result/**" + ], + "truncated": false, + "next_steps": [] +} diff --git a/areal/api/CLAUDE.md b/areal/api/CLAUDE.md new file mode 100644 index 0000000000..1531e30845 --- /dev/null +++ b/areal/api/CLAUDE.md @@ -0,0 +1,148 @@ +[根目录](../../CLAUDE.md) > **areal/api** + +# areal.api - API 与配置契约 + +## 变更记录 (Changelog) + +### 2026-01-31 - 初始化 + +- 模块文档创建 +- 识别 9 个核心文件 + +--- + +## 模块职责 + +定义 AReaL 系统的核心 API 契约与配置数据类: + +- **配置数据类**:CLI 参数、训练/推理超参数、并行策略 +- **引擎契约**:`TrainEngine`、`InferenceEngine` 抽象基类 +- **工作流契约**:`RolloutWorkflow`、`AgentWorkflow` 抽象基类 +- **奖励契约**:`RewardFunction` 与异步包装器 +- **调度器契约**:`Scheduler` 抽象基类 +- **数据结构**:`ModelRequest`、`ModelResponse`、`ParamSpec` 等 + +## 入口与启动 + +无独立启动入口,作为其他模块的依赖被导入。 + +## 对外接口 + +### 核心抽象类 + +| 类名 | 文件 | 职责 | +| ------------------- | -------------------- | -------------------------------------- | +| `TrainEngine` | `engine_api.py` | 训练引擎抽象基类(FSDP/Megatron) | +| `InferenceEngine` | `engine_api.py` | 推理引擎抽象基类(SGLang/vLLM) | +| `RolloutWorkflow` | `workflow_api.py` | Rollout 工作流抽象基类 | +| `AgentWorkflow` | `workflow_api.py` | Agent 工作流抽象基类(OpenAI SDK 集成)| +| `RewardFunction` | `reward_api.py` | 奖励函数抽象基类 | +| `Scheduler` | `scheduler_api.py` | 调度器抽象基类(Local/Ray/Slurm) | + +### 配置数据类 + +| 类名 | 文件 | 职责 | +| ----------------------------- | --------------- | -------------------------------------- | +| `NormConfig` | `cli_args.py` | 奖励/优势归一化配置 | +| `MicroBatchSpec` | `cli_args.py` | 微批次划分规格 | +| `GenerationHyperparameters` | `cli_args.py` | 生成超参数(温度、top_p、max_tokens) | +| `TrainingHyperparameters` | `cli_args.py` | 训练超参数(学习率、优化器、调度器) | +| `ParallelStrategy` | `alloc_mode.py` | 并行策略(DP/TP/PP/CP/EP) | + +### 数据结构 + +| 类名 | 文件 | 职责 | +| ------------------- | --------------- | -------------------------------------- | +| `ModelRequest` | `io_struct.py` | 推理请求(input_ids、gconfig) | +| `ModelResponse` | `io_struct.py` | 推理响应(output_tokens、logprobs) | +| `ParamSpec` | `io_struct.py` | 参数规格(形状、dtype、设备) | +| `WeightUpdateMeta` | `io_struct.py` | 权重更新元数据 | +| `SaveLoadMeta` | `io_struct.py` | 检查点保存/加载元数据 | + +## 关键依赖与配置 + +### 外部依赖 + +- `torch`:张量操作与分布式通信 +- `transformers`:Tokenizer 与模型配置 +- `omegaconf`、`hydra-core`:配置管理 +- `pydantic`:数据验证 + +### 内部依赖 + +- `areal.utils.logging`:日志工具 +- `areal.utils.name_resolve`:动态导入 + +## 数据模型 + +### CLI 参数结构 + +```python +@dataclass +class TrainingHyperparameters: + # 优化器 + optimizer: str = "adam" + lr: float = 1e-5 + weight_decay: float = 0.0 + + # 学习率调度 + lr_scheduler: str = "cosine" + warmup_steps: int = 0 + + # 训练配置 + n_epochs: int = 1 + gradient_accumulation_steps: int = 1 + max_grad_norm: float = 1.0 +``` + +### 并行策略 + +```python +@dataclass +class ParallelStrategy: + data_parallel_size: int = 1 + tensor_parallel_size: int = 1 + pipeline_parallel_size: int = 1 + context_parallel_size: int = 1 + expert_parallel_size: int = 1 +``` + +## 测试与质量 + +- **测试覆盖**:部分配置类有单元测试(`areal/tests/test_adv_norm_config.py`、`test_allocation_mode.py`) +- **质量工具**:Ruff(格式化与 lint)、pre-commit hooks + +## 常见问题 (FAQ) + +### Q: 如何添加新的配置字段? + +A: 在对应的 `@dataclass` 中添加字段,并提供默认值。注意向后兼容性。 + +### Q: 如何实现自定义 Workflow? + +A: 继承 `RolloutWorkflow` 或 `AgentWorkflow`,实现 `arun_episode` 或 `run` 方法。参考 `areal/workflow/multi_turn.py`。 + +### Q: 如何实现自定义 Engine? + +A: 继承 `TrainEngine` 或 `InferenceEngine`,实现所有抽象方法。参考 `areal/engine/fsdp_engine.py`。 + +## 相关文件清单 + +``` +areal/api/ +├── __init__.py +├── alloc_mode.py # 并行策略与分配模式 +├── cli_args.py # CLI 参数与配置数据类(核心) +├── engine_api.py # 引擎抽象基类(核心) +├── env_api.py # 环境 API(实验性) +├── io_struct.py # 数据结构定义 +├── reward_api.py # 奖励函数抽象基类 +├── scheduler_api.py # 调度器抽象基类 +└── workflow_api.py # 工作流抽象基类(核心) +``` + +## 下一步建议 + +- 补充 `cli_args.py` 中各配置类的详细文档 +- 添加配置验证的单元测试 +- 完善 `env_api.py` 的实验性功能 diff --git a/areal/core/CLAUDE.md b/areal/core/CLAUDE.md new file mode 100644 index 0000000000..a1cd7c7d2d --- /dev/null +++ b/areal/core/CLAUDE.md @@ -0,0 +1,145 @@ +[根目录](../../CLAUDE.md) > **areal/core** + +# areal.core - 核心运行时 + +## 变更记录 (Changelog) + +### 2026-01-31 - 初始化 + +- 模块文档创建 +- 识别 6 个核心运行时组件 + +--- + +## 模块职责 + +实现 AReaL 的核心运行时组件: + +- **工作流执行器**:`WorkflowExecutor` - 异步执行 Rollout 工作流 +- **分布式 Rollout**:`DistRollout` - 分布式数据生成 +- **异步任务运行器**:`AsyncTaskRunner` - 管理异步任务队列 +- **远程推理引擎**:`RemoteInfEngine` - 远程推理引擎代理 +- **陈旧度管理器**:`StalenessManager` - 管理模型版本陈旧度 +- **工作流上下文**:`workflow_context` - 全局上下文管理 + +## 入口与启动 + +核心运行时由 `areal.controller` 调用,不直接启动。 + +## 对外接口 + +### 核心类 + +| 类名 | 文件 | 职责 | +| --------------------- | ------------------------- | -------------------------------------- | +| `WorkflowExecutor` | `workflow_executor.py` | 异步执行 Rollout 工作流(核心) | +| `DistRollout` | `dist_rollout.py` | 分布式数据生成(核心) | +| `AsyncTaskRunner` | `async_task_runner.py` | 异步任务队列管理 | +| `RemoteInfEngine` | `remote_inf_engine.py` | 远程推理引擎代理 | +| `StalenessManager` | `staleness_manager.py` | 模型版本陈旧度管理 | + +### 全局上下文 + +| 函数/变量 | 文件 | 职责 | +| ----------------------------- | ------------------------- | -------------------------------------- | +| `workflow_context.get()` | `workflow_context.py` | 获取当前工作流上下文 | +| `workflow_context.set()` | `workflow_context.py` | 设置当前工作流上下文 | + +## 关键依赖与配置 + +### 外部依赖 + +- `torch`:张量操作与分布式通信 +- `asyncio`:异步编程 +- `aiohttp`:异步 HTTP 客户端 +- `pyzmq`:ZeroMQ 消息队列 + +### 内部依赖 + +- `areal.api.workflow_api`:`RolloutWorkflow`、`AgentWorkflow` +- `areal.api.engine_api`:`InferenceEngine` +- `areal.api.cli_args`:配置数据类 +- `areal.utils.logging`:日志工具 +- `areal.utils.distributed`:分布式工具 + +## 数据模型 + +### WorkflowExecutor 初始化参数 + +```python +WorkflowExecutor( + workflow: RolloutWorkflow | AgentWorkflow, # 工作流实例 + engine: InferenceEngine, # 推理引擎 + n_workers: int = 1, # 并发 worker 数量 + max_queue_size: int = 100, # 最大队列大小 +) +``` + +### DistRollout 初始化参数 + +```python +DistRollout( + workflow_executor: WorkflowExecutor, # 工作流执行器 + dataloader: DataLoader, # 数据加载器 + n_samples_per_prompt: int = 1, # 每个 prompt 的采样数 + staleness_coef: float = 0.0, # 陈旧度系数 +) +``` + +### StalenessManager 数据结构 + +```python +{ + "current_version": int, # 当前模型版本 + "version_timestamps": Dict[int, float], # 版本时间戳 + "staleness_coef": float, # 陈旧度系数 +} +``` + +## 测试与质量 + +- **测试覆盖**: + - `areal/tests/test_async_task_runner.py` + - `areal/tests/test_staleness_manager.py` + - `areal/tests/test_rollout_controller.py`(集成测试) +- **质量工具**:Ruff、pre-commit hooks + +## 常见问题 (FAQ) + +### Q: WorkflowExecutor 如何处理并发? + +A: 使用 `asyncio.Queue` 管理任务队列,启动 `n_workers` 个异步 worker 并发执行工作流。 + +### Q: 什么是陈旧度(Staleness)? + +A: 在异步 RL 训练中,推理使用的模型版本可能落后于训练版本。陈旧度系数用于调整旧版本数据的权重。 + +### Q: 如何调试分布式 Rollout? + +A: 参考 `docs/best_practices/debugging.md`: +- 检查日志中的 `DistRollout` 和 `WorkflowExecutor` 输出 +- 使用 `workflow_context.get()` 获取当前上下文信息 +- 启用 `AREAL_DEBUG=1` 环境变量 + +### Q: RemoteInfEngine 与本地 InferenceEngine 的区别? + +A: `RemoteInfEngine` 通过 HTTP/ZMQ 与远程推理服务器通信,适合推理与训练分离的场景。本地 `InferenceEngine` 直接调用推理引擎。 + +## 相关文件清单 + +``` +areal/core/ +├── __init__.py +├── workflow_executor.py # 工作流执行器(核心) +├── dist_rollout.py # 分布式 Rollout(核心) +├── async_task_runner.py # 异步任务运行器 +├── remote_inf_engine.py # 远程推理引擎代理 +├── staleness_manager.py # 陈旧度管理器 +└── workflow_context.py # 工作流上下文 +``` + +## 下一步建议 + +- 补充 WorkflowExecutor 的性能调优文档 +- 添加 DistRollout 的分布式测试 +- 优化 AsyncTaskRunner 的队列管理逻辑 diff --git a/areal/engine/CLAUDE.md b/areal/engine/CLAUDE.md new file mode 100644 index 0000000000..133dca6f5b --- /dev/null +++ b/areal/engine/CLAUDE.md @@ -0,0 +1,179 @@ +[根目录](../../CLAUDE.md) > **areal/engine** + +# areal.engine - 训练与推理引擎 + +## 变更记录 (Changelog) + +### 2026-01-31 - 初始化 + +- 模块文档创建 +- 识别 FSDP、Megatron、SGLang、vLLM 四大引擎 + +--- + +## 模块职责 + +实现训练与推理引擎的适配器: + +- **训练引擎**:FSDP2、Megatron Core +- **推理引擎**:SGLang、vLLM(远程服务器) +- **专用引擎**:PPO Actor/Critic、SFT LM、Reward Model + +## 入口与启动 + +引擎通过 `areal.controller` 调用,不直接启动。 + +## 对外接口 + +### 训练引擎 + +| 类名 | 文件 | 职责 | +| ----------------- | --------------------- | -------------------------------------- | +| `FSDPEngine` | `fsdp_engine.py` | PyTorch FSDP2 训练引擎(核心) | +| `MegatronEngine` | `megatron_engine.py` | Megatron Core 训练引擎(核心) | + +### 推理引擎 + +| 类名 | 文件 | 职责 | +| --------------------- | --------------------- | -------------------------------------- | +| `SGLangRemoteEngine` | `sglang_remote.py` | SGLang 远程推理引擎(核心) | +| `VLLMRemoteEngine` | `vllm_remote.py` | vLLM 远程推理引擎(核心) | + +### 专用引擎 + +| 类名 | 文件 | 职责 | +| ----------------- | --------------------- | -------------------------------------- | +| `PPOActor` | `ppo/actor.py` | PPO Actor 引擎 | +| `PPOCritic` | `ppo/critic.py` | PPO Critic 引擎 | +| `LMEngine` | `sft/lm_engine.py` | SFT 语言模型引擎 | +| `RWEngine` | `rw/rw_engine.py` | Reward Model 引擎 | + +## 关键依赖与配置 + +### 外部依赖 + +- `torch`:张量操作与分布式通信 +- `torch.distributed.fsdp`:FSDP2 +- `megatron.core`:Megatron Core +- `sglang`:SGLang 推理引擎 +- `vllm`:vLLM 推理引擎 +- `transformers`:模型加载与 Tokenizer + +### 内部依赖 + +- `areal.api.engine_api`:`TrainEngine`、`InferenceEngine` 基类 +- `areal.api.cli_args`:配置数据类 +- `areal.utils.fsdp`:FSDP 工具(检查点、优化器、并行) +- `areal.utils.mcore`:Megatron Core 工具 +- `areal.models`:模型实现 + +## 数据模型 + +### FSDPEngine 初始化参数 + +```python +FSDPEngine( + model_path: str, # 模型路径 + parallel_strategy: ParallelStrategy, # 并行策略 + dtype: torch.dtype, # 数据类型 + use_lora: bool = False, # 是否使用 LoRA + # ... 其他参数 +) +``` + +### MegatronEngine 初始化参数 + +```python +MegatronEngine( + model_path: str, + parallel_strategy: ParallelStrategy, + dtype: torch.dtype, + use_fp8: bool = False, # 是否使用 FP8 + # ... 其他参数 +) +``` + +### 推理引擎请求/响应 + +**请求**(`ModelRequest`): + +```python +{ + "rid": str, # 请求 ID + "input_ids": List[int], # 输入 token IDs + "gconfig": GenerationHyperparameters, # 生成配置 + "tokenizer": PreTrainedTokenizerFast, +} +``` + +**响应**(`ModelResponse`): + +```python +{ + "input_tokens": List[int], # 输入 tokens + "output_tokens": List[int], # 输出 tokens + "logprobs": List[float], # 对数概率 + "version": int, # 模型版本 +} +``` + +## 测试与质量 + +- **测试覆盖**: + - FSDP:`areal/tests/test_fsdp_*.py`(需 GPU) + - Megatron:`areal/tests/test_megatron_*.py`(需多 GPU) + - 推理引擎:`areal/tests/test_inference_engines.py` +- **质量工具**:Ruff、pre-commit hooks + +## 常见问题 (FAQ) + +### Q: 如何选择训练引擎? + +A: +- **FSDP2**:适合单节点或小规模多节点训练,支持 LoRA +- **Megatron**:适合大规模多节点训练,支持 Pipeline Parallel、Expert Parallel + +### Q: 如何选择推理引擎? + +A: +- **SGLang**:支持 Data Parallel Attention、Expert Parallel,适合 MoE 模型 +- **vLLM**:成熟稳定,支持 Tensor Parallel、Pipeline Parallel + +### Q: 如何实现自定义引擎? + +A: 继承 `TrainEngine` 或 `InferenceEngine`,实现所有抽象方法。参考 `fsdp_engine.py` 或 `sglang_remote.py`。 + +### Q: 如何处理 OOM? + +A: 参考 `docs/best_practices/handling_oom.md`: +- 减少 batch size 或 micro-batch 数量 +- 启用 gradient checkpointing +- 使用 FP8 或混合精度训练 +- 增加 Tensor Parallel 或 Pipeline Parallel + +## 相关文件清单 + +``` +areal/engine/ +├── __init__.py +├── fsdp_engine.py # FSDP2 训练引擎(核心) +├── megatron_engine.py # Megatron 训练引擎(核心) +├── sglang_remote.py # SGLang 推理引擎(核心) +├── vllm_remote.py # vLLM 推理引擎(核心) +├── core/ +│ ├── __init__.py +│ └── train_engine.py # TrainEngine 基类 +├── ppo/ +│ ├── actor.py # PPO Actor +│ └── critic.py # PPO Critic +├── sft/ +│ └── lm_engine.py # SFT 语言模型引擎 +└── rw/ + └── rw_engine.py # Reward Model 引擎 +``` + +## 下一步建议 + +- 补充 Megatron 引擎的详细配置文档 +- 添加推理引擎的性能对比测试 +- 优化 FSDP 引擎的检查点保存/加载逻辑 diff --git a/areal/workflow/CLAUDE.md b/areal/workflow/CLAUDE.md new file mode 100644 index 0000000000..6e8224653c --- /dev/null +++ b/areal/workflow/CLAUDE.md @@ -0,0 +1,165 @@ +[根目录](../../CLAUDE.md) > **areal/workflow** + +# areal.workflow - Rollout 工作流实现 + +## 变更记录 (Changelog) + +### 2026-01-31 - 初始化 + +- 模块文档创建 +- 识别 9 个工作流实现文件 + +--- + +## 模块职责 + +实现各种 Rollout 工作流,用于生成训练数据: + +- **多轮对话**:`MultiTurnWorkflow` - 多次尝试直到奖励为正 +- **RLVR**:`RLVRWorkflow` - 单轮推理与验证 +- **视觉 RLVR**:`VisionRLVRWorkflow` - 支持视觉输入的 RLVR +- **Agent 集成**:OpenAI、Anthropic、LangChain、OpenAI Agents SDK + +## 入口与启动 + +工作流通过 `areal.core.workflow_executor.WorkflowExecutor` 调用,不直接启动。 + +## 对外接口 + +### 核心工作流 + +| 类名 | 文件 | 职责 | +| ---------------------- | ------------------- | -------------------------------------- | +| `MultiTurnWorkflow` | `multi_turn.py` | 多轮对话,直到奖励为正或达到最大轮数 | +| `RLVRWorkflow` | `rlvr.py` | 单轮推理与验证(RLVR 算法) | +| `VisionRLVRWorkflow` | `vision_rlvr.py` | 支持视觉输入的 RLVR | + +### Agent 集成工作流 + +| 类名 | 文件 | 职责 | +| ------------------------- | ------------------------------- | -------------------------------------- | +| `OpenAIMathAgent` | `openai/math_agent.py` | OpenAI SDK 数学 Agent | +| `AnthropicMathAgent` | `anthropic/math_agent.py` | Anthropic SDK 数学 Agent | +| `LangChainMathAgent` | `langchain/math_agent.py` | LangChain 数学 Agent | +| `OpenAIAgentMathAgent` | `openai_agent/math_agent.py` | OpenAI Agents SDK 数学 Agent | + +## 关键依赖与配置 + +### 外部依赖 + +- `torch`:张量操作 +- `transformers`:Tokenizer +- `openai`:OpenAI SDK(Agent 集成) +- `anthropic`:Anthropic SDK(Agent 集成) +- `langchain`:LangChain(Agent 集成) + +### 内部依赖 + +- `areal.api.workflow_api`:`RolloutWorkflow`、`AgentWorkflow` 基类 +- `areal.api.engine_api`:`InferenceEngine` +- `areal.api.reward_api`:`AsyncRewardWrapper` +- `areal.api.cli_args`:`GenerationHyperparameters` +- `areal.utils.logging`:日志工具 + +## 数据模型 + +### MultiTurnWorkflow 输入输出 + +**输入**(`data: dict`): + +```python +{ + "messages": [{"role": "user", "content": "问题"}], + # 其他奖励函数需要的字段 +} +``` + +**输出**(`dict`): + +```python +{ + "seq": [token_ids], # 完整序列(prompt + completions) + "logprobs": [logprobs], # 对数概率 + "loss_mask": [0/1], # 损失掩码(仅 completion 部分为 1) + "versions": [version_ids], # 模型版本 ID + "reward": float, # 最终奖励 + "prompt_str": str, # Prompt 字符串 + "completions_str": str, # Completion 字符串 +} +``` + +### RLVRWorkflow 输入输出 + +**输入**:同 `MultiTurnWorkflow` + +**输出**: + +```python +{ + "seq": [token_ids], + "logprobs": [logprobs], + "loss_mask": [0/1], + "versions": [version_ids], + "reward": float, + "prompt_str": str, + "completions_str": str, + "n_samples": int, # 采样数量 +} +``` + +## 测试与质量 + +- **测试覆盖**:无专门的单元测试,通过集成测试(`areal/tests/grpo/`, `areal/tests/sft/`)验证 +- **质量工具**:Ruff(格式化与 lint) + +## 常见问题 (FAQ) + +### Q: 如何实现自定义 Workflow? + +A: 继承 `RolloutWorkflow`,实现 `arun_episode` 方法: + +```python +from areal.api.workflow_api import RolloutWorkflow +from areal.api.engine_api import InferenceEngine + +class MyWorkflow(RolloutWorkflow): + async def arun_episode(self, engine: InferenceEngine, data: dict): + # 1. 构造 ModelRequest + # 2. 调用 engine.agenerate() + # 3. 计算奖励 + # 4. 返回结果字典 + return {"seq": ..., "logprobs": ..., "reward": ...} +``` + +### Q: MultiTurnWorkflow 如何处理多轮对话? + +A: 每轮生成后,如果奖励为 0(错误),则在 prompt 后追加错误提示和上一轮的 completion,继续生成。最多尝试 `max_turns` 轮。 + +### Q: 如何集成外部 Agent SDK? + +A: 继承 `AgentWorkflow`,实现 `run` 方法,使用 `extra_kwargs["base_url"]` 和 `extra_kwargs["http_client"]` 连接 AReaL 的 OpenAI 兼容代理服务器。参考 `openai/math_agent.py`。 + +## 相关文件清单 + +``` +areal/workflow/ +├── multi_turn.py # 多轮对话工作流(核心) +├── rlvr.py # RLVR 工作流(核心) +├── vision_rlvr.py # 视觉 RLVR 工作流 +├── openai/ +│ └── math_agent.py # OpenAI SDK 数学 Agent +├── anthropic/ +│ ├── __init__.py +│ └── math_agent.py # Anthropic SDK 数学 Agent +├── langchain/ +│ ├── __init__.py +│ └── math_agent.py # LangChain 数学 Agent +└── openai_agent/ + └── math_agent.py # OpenAI Agents SDK 数学 Agent +``` + +## 下一步建议 + +- 添加 Workflow 的单元测试 +- 补充 Agent 集成的文档与示例 +- 优化多轮对话的 prompt 构造逻辑 diff --git a/docs/tutorial/oceanbase_quickstart.md b/docs/tutorial/oceanbase_quickstart.md new file mode 100644 index 0000000000..e027df17bb --- /dev/null +++ b/docs/tutorial/oceanbase_quickstart.md @@ -0,0 +1,322 @@ +# OceanBase 集成快速入门 + +本指南介绍如何将 AReaL 训练指标持久化到 OceanBase 数据库。 + +## 什么是 OceanBase? + +[OceanBase](https://www.oceanbase.com/) 是一个开源的分布式关系数据库,兼容 MySQL 协议,具有以下特点: + +- **高性能**:支持百万级 TPS +- **高可用**:多副本强一致性 +- **MySQL 兼容**:可使用 MySQL 客户端和工具 +- **水平扩展**:支持分布式架构 + +适合存储大规模训练指标、实验记录和模型元数据。 + +--- + +## 准备工作 + +### 1. 安装 OceanBase + +#### 方式 A:Docker 部署(推荐用于开发测试) + +```bash +# 拉取 OceanBase 社区版镜像 +docker pull oceanbase/oceanbase-ce:latest + +# 启动容器 +docker run -d \ + --name oceanbase-ce \ + -p 2881:2881 \ + -p 2882:2882 \ + -e MODE=mini \ + oceanbase/oceanbase-ce:latest + +# 等待 2-3 分钟让服务完全启动 +docker logs -f oceanbase-ce +``` + +#### 方式 B:生产环境部署 + +参考 [OceanBase 官方文档](https://www.oceanbase.com/docs/oceanbase-database)。 + +### 2. 安装 Python 依赖 + +```bash +cd /path/to/AReaL +uv pip install pymysql +``` + +或在 `pyproject.toml` 中添加依赖后运行 `uv sync`。 + +--- + +## 连接配置 + +### 环境变量配置(推荐) + +```bash +export OB_HOST=127.0.0.1 +export OB_PORT=2881 +export OB_USER=root@test +export OB_PASSWORD="" +export OB_DATABASE=test +``` + +### 默认连接参数 + +| 参数 | 默认值 | 说明 | +| ---------- | ------------ | -------------------------------------- | +| `host` | `127.0.0.1` | OceanBase 主机地址 | +| `port` | `2881` | MySQL 协议端口 | +| `user` | `root@test` | 用户名(格式: `user@tenant`) | +| `password` | `""` | 密码(Docker 默认为空) | +| `database` | `test` | 数据库名 | + +--- + +## 运行示例 + +### 基础示例 + +```bash +cd /path/to/AReaL +python examples/utils/oceanbase_example.py +``` + +**输出示例**: + +``` +(AReaL) 20260131-14:30:00.123 OceanBaseExample INFO: === OceanBase 集成示例 === +(AReaL) 20260131-14:30:00.124 OceanBaseExample INFO: 连接配置: 127.0.0.1:2881/test +(AReaL) 20260131-14:30:00.256 OceanBaseExample INFO: 成功连接到 OceanBase: 127.0.0.1:2881/test +(AReaL) 20260131-14:30:00.312 OceanBaseExample INFO: 成功创建表 training_metrics +(AReaL) 20260131-14:30:00.345 OceanBaseExample INFO: 插入指标: gsm8k_grpo_demo step=100 loss=1.3 reward=0.6 +... +(AReaL) 20260131-14:30:00.512 OceanBaseExample INFO: === 示例执行完成 === +``` + +--- + +## 集成到训练脚本 + +### 方式 1:直接集成 + +在训练脚本中导入并使用 `OceanBaseMetricsLogger`: + +```python +from examples.utils.oceanbase_example import OceanBaseMetricsLogger + +# 初始化 +metrics_logger = OceanBaseMetricsLogger( + host="127.0.0.1", + port=2881, + user="root@test", + password="", + database="test", +) + +# 连接并创建表 +metrics_logger.connect() +metrics_logger.create_table() + +# 在训练循环中记录指标 +for step in range(num_steps): + # ... 训练代码 ... + + metrics_logger.insert_metric( + experiment_name="my_experiment", + step=step, + loss=loss.item(), + reward=reward.mean().item(), + ) + +# 训练结束后关闭连接 +metrics_logger.close() +``` + +### 方式 2:扩展为自定义 Logger + +创建自定义日志类继承 `OceanBaseMetricsLogger`: + +```python +from examples.utils.oceanbase_example import OceanBaseMetricsLogger + +class CustomMetricsLogger(OceanBaseMetricsLogger): + def log_training_step(self, experiment_name, step, metrics_dict): + """记录训练步骤的所有指标""" + self.insert_metric( + experiment_name=experiment_name, + step=step, + loss=metrics_dict.get("loss"), + reward=metrics_dict.get("reward"), + ) + + def log_evaluation(self, experiment_name, step, eval_metrics): + """记录评估指标""" + # 自定义评估指标记录逻辑 + pass +``` + +--- + +## 常见查询 + +### 查询实验的训练曲线 + +```sql +SELECT step, loss, reward, timestamp +FROM training_metrics +WHERE experiment_name = 'gsm8k_grpo_demo' +ORDER BY step; +``` + +### 查询最近的训练记录 + +```sql +SELECT * +FROM training_metrics +ORDER BY timestamp DESC +LIMIT 10; +``` + +### 计算平均损失 + +```sql +SELECT + experiment_name, + AVG(loss) as avg_loss, + MIN(loss) as min_loss, + MAX(loss) as max_loss +FROM training_metrics +GROUP BY experiment_name; +``` + +### 按时间范围查询 + +```sql +SELECT * +FROM training_metrics +WHERE timestamp >= '2026-01-31 00:00:00' + AND timestamp < '2026-02-01 00:00:00'; +``` + +--- + +## 故障排除 + +### 问题 1:连接失败 + +**错误信息**: +``` +pymysql.err.OperationalError: (2003, "Can't connect to MySQL server on '127.0.0.1'") +``` + +**解决方案**: +1. 检查 OceanBase 服务是否启动: + ```bash + docker ps | grep oceanbase + ``` +2. 检查端口是否正确(默认 2881) +3. 检查防火墙设置 + +### 问题 2:认证失败 + +**错误信息**: +``` +pymysql.err.OperationalError: (1045, "Access denied for user 'root@test'") +``` + +**解决方案**: +1. 确认用户名格式为 `user@tenant`(如 `root@test`) +2. 检查密码是否正确 +3. Docker 部署默认密码为空 + +### 问题 3:表已存在 + +**错误信息**: +``` +pymysql.err.InternalError: (1050, "Table 'training_metrics' already exists") +``` + +**解决方案**: +示例代码使用 `CREATE TABLE IF NOT EXISTS`,不会报错。如需重建表: + +```sql +DROP TABLE IF EXISTS training_metrics; +``` + +### 问题 4:性能问题 + +**症状**:插入速度慢 + +**优化方案**: +1. **批量插入**: + ```python + def insert_metrics_batch(self, metrics_list): + insert_sql = """ + INSERT INTO training_metrics + (experiment_name, step, loss, reward, timestamp) + VALUES (%s, %s, %s, %s, %s) + """ + with self.connection.cursor() as cursor: + cursor.executemany(insert_sql, metrics_list) + ``` + +2. **异步写入**:使用队列缓冲指标,定期批量写入 + +3. **索引优化**:根据查询模式调整索引 + +--- + +## 高级配置 + +### 连接池 + +对于高并发场景,使用连接池: + +```python +from pymysql.pooling import PooledConnection + +# 创建连接池(需要额外配置) +# 参考 pymysql 文档 +``` + +### 分区表 + +对于大规模数据,使用分区表: + +```sql +CREATE TABLE training_metrics ( + id BIGINT AUTO_INCREMENT, + experiment_name VARCHAR(100) NOT NULL, + step INT NOT NULL, + loss FLOAT, + reward FLOAT, + timestamp DATETIME NOT NULL, + PRIMARY KEY (id, timestamp) +) PARTITION BY RANGE (TO_DAYS(timestamp)) ( + PARTITION p202601 VALUES LESS THAN (TO_DAYS('2026-02-01')), + PARTITION p202602 VALUES LESS THAN (TO_DAYS('2026-03-01')), + PARTITION p202603 VALUES LESS THAN (TO_DAYS('2026-04-01')) +); +``` + +--- + +## 相关资源 + +- [OceanBase 官方文档](https://www.oceanbase.com/docs/) +- [PyMySQL 文档](https://pymysql.readthedocs.io/) +- [AReaL 日志系统](../../areal/utils/logging.py) +- [示例代码](../../examples/utils/oceanbase_example.py) + +--- + +## 下一步 + +- 集成到您的训练脚本 +- 配置生产环境的 OceanBase 集群 +- 使用 Grafana 可视化训练指标 +- 探索 OceanBase 的高级特性(分布式事务、读写分离等) diff --git a/examples/CLAUDE.md b/examples/CLAUDE.md new file mode 100644 index 0000000000..d11e237b04 --- /dev/null +++ b/examples/CLAUDE.md @@ -0,0 +1,242 @@ +[根目录](../CLAUDE.md) > **examples** + +# examples - 训练脚本与配置 + +## 变更记录 (Changelog) + +### 2026-01-31 - 初始化 + +- 模块文档创建 +- 识别 10+ 个示例场景 + +--- + +## 模块职责 + +提供各种场景的训练脚本与配置文件: + +- **数学推理**:GSM8K、BOBA(GRPO、PPO、SFT、DAPO、LitePPO、RLOO 等) +- **多轮数学**:多轮对话数学推理 +- **视觉语言模型**:CLEVR 计数、Geometry3K +- **搜索 Agent**:端到端推理、搜索、浏览、总结 +- **工具集成推理**:TIR(Tool-Integrated Reasoning) +- **RLHF**:奖励模型训练 +- **LoRA**:低秩适应训练 +- **实验性**:代理模式、近似优化 + +## 入口与启动 + +### 数学推理(GSM8K) + +```bash +# GRPO 训练 +python3 -m areal.launcher.local \ + examples/math/gsm8k_rl.py \ + --config examples/math/gsm8k_grpo.yaml + +# SFT 训练 +python3 -m areal.launcher.local \ + examples/math/gsm8k_sft.py \ + --config examples/math/gsm8k_sft.yaml + +# 评估 +python3 examples/math/gsm8k_eval.py \ + --model_path /path/to/checkpoint +``` + +### 多轮数学推理 + +```bash +python3 -m areal.launcher.local \ + examples/multi_turn_math/gsm8k_rl_mt.py \ + --config examples/multi_turn_math/gsm8k_grpo_mt.yaml +``` + +### 视觉语言模型 + +```bash +# CLEVR 计数 +python3 -m areal.launcher.local \ + examples/vlm/clevr_count_70k_grpo.py \ + --config examples/vlm/clevr_count_70k_grpo.yaml + +# Geometry3K +python3 -m areal.launcher.local \ + examples/vlm/geometry3k_grpo.py \ + --config examples/vlm/geometry3k_grpo.yaml +``` + +## 对外接口 + +### 训练脚本 + +| 脚本 | 路径 | 职责 | +| ----------------------------- | --------------------------------- | -------------------------------------- | +| `gsm8k_rl.py` | `math/gsm8k_rl.py` | GSM8K RL 训练(GRPO/PPO/RLOO 等) | +| `gsm8k_sft.py` | `math/gsm8k_sft.py` | GSM8K SFT 训练 | +| `gsm8k_eval.py` | `math/gsm8k_eval.py` | GSM8K 评估 | +| `gsm8k_rl_mt.py` | `multi_turn_math/gsm8k_rl_mt.py` | 多轮数学推理训练 | +| `clevr_count_70k_grpo.py` | `vlm/clevr_count_70k_grpo.py` | CLEVR 计数 GRPO 训练 | +| `geometry3k_grpo.py` | `vlm/geometry3k_grpo.py` | Geometry3K GRPO 训练 | +| `train_tir.py` | `tir/train_tir.py` | TIR 训练 | +| `train_agents.py` | `openai_agents/train_agents.py` | OpenAI Agents 训练 | + +### 配置文件 + +| 配置文件 | 路径 | 职责 | +| ----------------------------- | --------------------------------- | -------------------------------------- | +| `gsm8k_grpo.yaml` | `math/gsm8k_grpo.yaml` | GSM8K GRPO 配置 | +| `gsm8k_ppo.yaml` | `math/gsm8k_ppo.yaml` | GSM8K PPO 配置 | +| `gsm8k_sft.yaml` | `math/gsm8k_sft.yaml` | GSM8K SFT 配置 | +| `gsm8k_dapo_dynamic_bs.yaml` | `math/gsm8k_dapo_dynamic_bs.yaml` | GSM8K DAPO 动态批次配置 | +| `gsm8k_grpo_lora.yaml` | `lora/gsm8k_grpo_lora.yaml` | GSM8K GRPO LoRA 配置 | +| `gsm8k_grpo_megatron.yaml` | `math/gsm8k_grpo_megatron.yaml` | GSM8K GRPO Megatron 配置 | + +## 关键依赖与配置 + +### 外部依赖 + +- `torch`:训练框架 +- `transformers`:模型加载 +- `datasets`:数据集加载 +- `wandb`:实验跟踪(可选) + +### 内部依赖 + +- `areal.api.cli_args`:配置数据类 +- `areal.workflow`:工作流实现 +- `areal.reward`:奖励函数 +- `areal.dataset`:数据集加载器 +- `areal.launcher`:启动器 + +## 数据模型 + +### 配置文件结构(YAML) + +```yaml +# 模型配置 +model: + path: "Qwen/Qwen2-1.5B-Instruct" + dtype: "bfloat16" + +# 训练配置 +training: + algorithm: "grpo" + n_epochs: 1 + batch_size: 128 + learning_rate: 1e-5 + +# 生成配置 +generation: + temperature: 0.7 + top_p: 0.9 + max_tokens: 512 + +# 并行配置 +parallel: + data_parallel_size: 2 + tensor_parallel_size: 1 + +# 工作流配置 +workflow: + type: "rlvr" + n_samples: 4 + +# 数据集配置 +dataset: + name: "gsm8k" + split: "train" +``` + +## 测试与质量 + +- **测试覆盖**:`areal/tests/test_examples.py`(验证示例脚本可运行) +- **质量工具**:Ruff、pre-commit hooks + +## 常见问题 (FAQ) + +### Q: 如何修改训练超参数? + +A: 编辑对应的 YAML 配置文件,或通过命令行覆盖: + +```bash +python3 -m areal.launcher.local \ + examples/math/gsm8k_rl.py \ + --config examples/math/gsm8k_grpo.yaml \ + training.learning_rate=5e-6 \ + training.batch_size=64 +``` + +### Q: 如何在多节点上运行? + +A: 使用 Ray 或 Slurm 启动器: + +```bash +# Ray +python3 -m areal.launcher.ray \ + examples/math/gsm8k_rl.py \ + --config examples/math/gsm8k_grpo.yaml \ + cluster.n_nodes=2 \ + cluster.n_gpus_per_node=8 + +# Slurm +python3 -m areal.launcher.slurm \ + examples/math/gsm8k_rl.py \ + --config examples/math/gsm8k_grpo.yaml \ + --partition=gpu \ + --nodes=2 \ + --gpus-per-node=8 +``` + +### Q: 如何添加新的示例? + +A: 参考现有示例(如 `math/gsm8k_rl.py`): +1. 创建训练脚本(定义 workflow、reward、dataset) +2. 创建配置文件(YAML) +3. 添加 README.md 说明 + +## 相关文件清单 + +``` +examples/ +├── math/ # 数学推理 +│ ├── gsm8k_rl.py # RL 训练脚本(核心) +│ ├── gsm8k_sft.py # SFT 训练脚本 +│ ├── gsm8k_eval.py # 评估脚本 +│ ├── gsm8k_grpo.yaml # GRPO 配置(核心) +│ ├── gsm8k_ppo.yaml # PPO 配置 +│ ├── gsm8k_sft.yaml # SFT 配置 +│ └── README.md +├── multi_turn_math/ # 多轮数学推理 +│ ├── gsm8k_rl_mt.py +│ ├── gsm8k_grpo_mt.yaml +│ └── README.md +├── vlm/ # 视觉语言模型 +│ ├── clevr_count_70k_grpo.py +│ ├── geometry3k_grpo.py +│ └── *.yaml +├── lora/ # LoRA 训练 +│ ├── gsm8k_grpo_lora.py +│ └── gsm8k_grpo_lora.yaml +├── tir/ # 工具集成推理 +│ ├── train_tir.py +│ ├── tir_workflow.py +│ └── tools/ +├── search_agent/ # 搜索 Agent +│ └── tongyi_deepresearch/ +├── openai_agents/ # OpenAI Agents +│ └── train_agents.py +├── alignment/ # RLHF +│ └── hhrlhf_rw.py +├── experimental/ # 实验性功能 +│ ├── proxy/ +│ └── prox_approx/ +└── skypilot/ # SkyPilot 部署 + └── *.sky.yaml +``` + +## 下一步建议 + +- 补充各示例的详细文档与性能基准 +- 添加更多算法的示例(GSPO、Dr.GRPO 等) +- 优化配置文件的默认参数 diff --git a/examples/utils/oceanbase_example.py b/examples/utils/oceanbase_example.py new file mode 100644 index 0000000000..f427cc5ab0 --- /dev/null +++ b/examples/utils/oceanbase_example.py @@ -0,0 +1,234 @@ +"""OceanBase 数据库集成示例 + +此脚本演示如何将训练指标持久化到 OceanBase 数据库。 + +OceanBase 是一个兼容 MySQL 的开源分布式数据库,适合存储大规模训练指标。 + +使用方法: + python examples/utils/oceanbase_example.py + +环境变量配置: + OB_HOST: OceanBase 主机地址(默认: 127.0.0.1) + OB_PORT: OceanBase 端口(默认: 2881) + OB_USER: 用户名(默认: root@test) + OB_PASSWORD: 密码(默认: 空) + OB_DATABASE: 数据库名(默认: test) +""" + +import os +from datetime import datetime +from typing import Optional + +import pymysql +from pymysql.cursors import DictCursor + +from areal.utils.logging import getLogger + +logger = getLogger("OceanBaseExample") + + +class OceanBaseMetricsLogger: + """OceanBase 指标记录器 + + 封装与 OceanBase 的连接和指标写入操作。 + + Attributes: + host: 数据库主机地址 + port: 数据库端口 + user: 用户名 + password: 密码 + database: 数据库名 + connection: 数据库连接对象 + """ + + def __init__( + self, + host: str = "127.0.0.1", + port: int = 2881, + user: str = "root@test", + password: str = "", + database: str = "test", + ): + """初始化 OceanBase 连接 + + Args: + host: 数据库主机地址 + port: 数据库端口 + user: 用户名(格式: user@tenant) + password: 密码 + database: 数据库名 + """ + self.host = host + self.port = port + self.user = user + self.password = password + self.database = database + self.connection: Optional[pymysql.Connection] = None + + def connect(self) -> None: + """建立数据库连接""" + try: + self.connection = pymysql.connect( + host=self.host, + port=self.port, + user=self.user, + password=self.password, + database=self.database, + cursorclass=DictCursor, + autocommit=True, + ) + logger.info( + f"成功连接到 OceanBase: {self.host}:{self.port}/{self.database}" + ) + except pymysql.Error as e: + logger.error(f"连接 OceanBase 失败: {e}") + raise + + def close(self) -> None: + """关闭数据库连接""" + if self.connection: + self.connection.close() + logger.info("已关闭 OceanBase 连接") + + def create_table(self) -> None: + """创建训练指标表 + + 表结构: + - id: 自增主键 + - experiment_name: 实验名称 + - step: 训练步数 + - loss: 损失值 + - reward: 奖励值(可选) + - timestamp: 记录时间 + """ + if not self.connection: + raise RuntimeError("未建立数据库连接,请先调用 connect()") + + create_table_sql = """ + CREATE TABLE IF NOT EXISTS training_metrics ( + id BIGINT AUTO_INCREMENT PRIMARY KEY, + experiment_name VARCHAR(100) NOT NULL, + step INT NOT NULL, + loss FLOAT, + reward FLOAT, + timestamp DATETIME NOT NULL, + INDEX idx_experiment_step (experiment_name, step), + INDEX idx_timestamp (timestamp) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + """ + + try: + with self.connection.cursor() as cursor: + cursor.execute(create_table_sql) + logger.info("成功创建表 training_metrics") + except pymysql.Error as e: + logger.error(f"创建表失败: {e}") + raise + + def insert_metric( + self, + experiment_name: str, + step: int, + loss: Optional[float] = None, + reward: Optional[float] = None, + ) -> None: + """插入单条训练指标 + + Args: + experiment_name: 实验名称 + step: 训练步数 + loss: 损失值 + reward: 奖励值 + """ + if not self.connection: + raise RuntimeError("未建立数据库连接,请先调用 connect()") + + insert_sql = """ + INSERT INTO training_metrics + (experiment_name, step, loss, reward, timestamp) + VALUES (%s, %s, %s, %s, %s) + """ + + try: + with self.connection.cursor() as cursor: + cursor.execute( + insert_sql, + (experiment_name, step, loss, reward, datetime.now()), + ) + logger.info( + f"插入指标: {experiment_name} step={step} loss={loss} reward={reward}" + ) + except pymysql.Error as e: + logger.error(f"插入指标失败: {e}") + raise + + +def main(): + """主函数:演示 OceanBase 集成""" + # 从环境变量读取配置 + config = { + "host": os.getenv("OB_HOST", "127.0.0.1"), + "port": int(os.getenv("OB_PORT", "2881")), + "user": os.getenv("OB_USER", "root@test"), + "password": os.getenv("OB_PASSWORD", ""), + "database": os.getenv("OB_DATABASE", "test"), + } + + logger.info("=== OceanBase 集成示例 ===") + logger.info(f"连接配置: {config['host']}:{config['port']}/{config['database']}") + + # 创建指标记录器 + metrics_logger = OceanBaseMetricsLogger(**config) + + try: + # 1. 连接数据库 + metrics_logger.connect() + + # 2. 创建表 + metrics_logger.create_table() + + # 3. 插入示例数据 + logger.info("插入示例训练指标...") + for step in range(1, 6): + metrics_logger.insert_metric( + experiment_name="gsm8k_grpo_demo", + step=step * 100, + loss=1.5 - step * 0.2, + reward=0.5 + step * 0.1, + ) + + logger.info("✓ 示例数据插入成功") + + # 4. 查询验证 + logger.info("查询最近 5 条记录...") + with metrics_logger.connection.cursor() as cursor: + cursor.execute( + """ + SELECT experiment_name, step, loss, reward, timestamp + FROM training_metrics + ORDER BY timestamp DESC + LIMIT 5 + """ + ) + results = cursor.fetchall() + for row in results: + logger.info( + f" {row['experiment_name']} | " + f"step={row['step']} | " + f"loss={row['loss']:.3f} | " + f"reward={row['reward']:.3f} | " + f"time={row['timestamp']}" + ) + + logger.info("=== 示例执行完成 ===") + + except Exception as e: + logger.error(f"执行失败: {e}") + raise + finally: + # 5. 关闭连接 + metrics_logger.close() + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 63fc6b1688..84098144d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -101,6 +101,7 @@ dependencies = [ "pebble", "timeout-decorator", "prettytable", + "pymysql", "h5py", "mathruler==0.1.0",