Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 1 addition & 11 deletions docs/source/Instruction/命令行参数.md
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,7 @@ App参数继承于[部署参数](#部署参数), [Web-UI参数](#Web-UI参数)

- prm_model: 过程奖励模型的类型,可以是模型id(以pt方式拉起),或者plugin中定义的prm key(自定义推理过程)。
- orm_model: 结果奖励模型的类型,通常是通配符或测试用例等,一般定义在plugin中。
- sampler_type:采样类型,目前支持 sample, mcts,未来会支持 dvts。
- sampler_type:采样类型,目前支持 sample, distill
- sampler_engine:支持`pt`, `lmdeploy`, `vllm`, `client`, `no`,默认为`pt`,采样模型的推理引擎。
- output_dir:输出目录,默认为`sample_output`。
- output_file:输出文件名称,默认为`None`使用时间戳作为文件名。传入时不需要传入目录,仅支持jsonl格式。
Expand All @@ -695,16 +695,6 @@ App参数继承于[部署参数](#部署参数), [Web-UI参数](#Web-UI参数)
- cache_files:为避免同时加载prm和generator造成显存OOM,可以分两步进行采样,第一步将prm和orm置为`None`,则所有结果都会输出到文件中,第二次运行采样将sampler_engine置为`no`并传入`--cache_files`为上次采样的输出文件,则会使用上次输出的结果进行prm和orm评估并输出最终结果。
- 注意:使用cache_files时,`--dataset`仍然需要传入,这是因为cache_files的id是由原始数据计算的md5,需要把两部分信息结合使用。

#### MCTS
- rollout_depth:rollout 时的最大深度,默认为 `5`。
- rollout_start_depth:开始 rollout 时的深度,低于此深度的节点只会进行 expand 操作,默认为 `3`。
- max_iterations:mcts 的最大迭代次数,默认为 `100`。
- process_reward_rate:select 中计算 value 时 process reward 占的比例,默认为 `0.0`,即不使用 PRM。
- exploration_rate:UCT 算法中的探索参数,值越大越照顾探索次数较小的节点,默认为 `0.5`。
- api_key:使用 client 作为推理引擎时需要,默认为 `EMPTY`。
- base_url:使用 client 作为推理引擎时需要,默认为 'https://dashscope.aliyuncs.com/compatible-mode/v1'


## 特定模型参数
除了以上参数外,有些模型还支持额外的具体模型参数。这些参数含义通常可以在对应模型官方repo或者其推理代码中找到相应含义。**ms-swift引入这些参数以确保训练的模型与官方推理代码效果对齐**。
- 特定模型参数可以通过`--model_kwargs`或者环境变量进行设置,例如: `--model_kwargs '{"fps_max_frames": 12}'`或者`FPS_MAX_FRAMES=12`。
Expand Down
12 changes: 1 addition & 11 deletions docs/source/Instruction/强化微调.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,7 @@ DeepSeek-R1使用了GRPO算法从零使base模型涌现CoT能力,该方法需

SWIFT支持sample命令,该命令就是用于模型采样。目前支持的采样方式有:

- do_sample:sample方式对模型进行采样,该方式支持对开源模型进行采样,后续会支持模型蒸馏
- sample方式后续会支持URL采样,用于大模型蒸馏

- mcts:蒙特卡洛采样,该方式在PR中,后续会支持
- dvts:调研中
- sample:以generate方式对模型进行采样

目前我们给出了一个较为通用的[RFT脚本](https://github.com/modelscope/ms-swift/tree/main/examples/train/rft/rft.py)。该脚本适用于自我提升方式的训练,且支持动态调整采样温度值、PRM阈值等超参数,并且训练方式灵活可变(微调、DPO等;或者每次迭代重新训练原模型或继续训练上个迭代的模型,甚至加载上个迭代的所有训练状态等)。开发者可以在该脚本中增加其他数据过滤(生成的数据集中,id相同的行来自同一个query),例如多样性判断、语种判断等。

Expand All @@ -95,9 +91,3 @@ SWIFT支持sample命令,该命令就是用于模型采样。目前支持的采
| Qwen2.5_math_7b_instruct | 92.8 | 91.6 |

可以看到,RFT训练后gsm8k指标变化不大,并没有出现前述的掉点现象。

## 未来计划

1. 更多的采样方式,如MCTS
2. 超大模型蒸馏训练
3. 以PPO为主的on-policy训练
11 changes: 1 addition & 10 deletions docs/source_en/Instruction/Command-line-parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,7 @@ Export Arguments include the [basic arguments](#base-arguments) and [merge argum

- prm_model: The type of process reward model. It can be a model ID (triggered using `pt`) or a `prm` key defined in a plugin (for custom inference processes).
- orm_model: The type of outcome reward model, typically a wildcard or test case, usually defined in a plugin.
- sampler_type: The type of sampling. Currently supports `sample` (using `do_sample` method). Future support will include `mcts` and `dvts`.
- sampler_type: The type of sampling. Currently supports `sample` and `distill`.
- sampler_engine: Supports `pt`, `lmdeploy`, `vllm`, `no`. Defaults to `pt`. Specifies the inference engine for the sampling model.
- output_dir: The output directory. Defaults to `sample_output`.
- output_file: The name of the output file. Defaults to `None`, which uses a timestamp as the filename. When provided, only the filename should be passed without the directory, and only JSONL format is supported.
Expand All @@ -714,15 +714,6 @@ Export Arguments include the [basic arguments](#base-arguments) and [merge argum
- cache_files: To avoid loading both `prm` and `generator` simultaneously and causing GPU memory OOM, sampling can be done in two steps. In the first step, set `prm` and `orm` to `None`, and all results will be output to a file. In the second run, set `sampler_engine` to `no` and pass `--cache_files` with the output file from the first sampling. This will use the results from the first run for `prm` and `orm` evaluation and output the final results.
- Note: When using `cache_files`, the `--dataset` still needs to be provided because the ID for `cache_files` is calculated using the MD5 of the original data. Both pieces of information need to be used together.

#### MCTS
- rollout_depth: The maximum depth during rollouts, default is `5`.
- rollout_start_depth: The depth at which rollouts begin; nodes below this depth will only undergo expand operations, default is `3`.
- max_iterations: The maximum number of iterations for MCTS, default is `100`.
- process_reward_rate: The proportion of process reward used in calculating value during selection, default is `0.0`, meaning PRM is not used.
- exploration_rate: A parameter in the UCT algorithm that balances exploration; a higher value gives more weight to nodes with fewer explorations, default is `0.5`.
- api_key: Required when using the client as an inference engine, default is `EMPTY`.
- base_url: Required when using the client as an inference engine, default is 'https://dashscope.aliyuncs.com/compatible-mode/v1'.

## Specific Model Arguments

In addition to the parameters listed above, some models support additional model-specific arguments. The meanings of these parameters can usually be found in the corresponding model's official repository or its inference code. **MS-Swift includes these parameters to ensure that the trained model aligns with the behavior of the official inference implementation**.
Expand Down
11 changes: 1 addition & 10 deletions docs/source_en/Instruction/Reinforced-Fine-tuning.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,7 @@ Reinforced fine-tuning heavily depends on the accuracy of reward evaluations. If

SWIFT supports the `sample` command, which is used for model sampling. Currently supported sampling methods include:

- **do_sample**: A sampling method for open-source models; future updates will include support for model distillation.
- URL sampling will also be supported in the future for large-model distillation.

- **mcts**: Monte Carlo sampling, currently under review, with future support planned.
- **dvts**: Currently under investigation.
- **sample**: Use `generate` do rollout.

We have provided a general [RFT script](https://github.com/modelscope/ms-swift/tree/main/examples/train/rft/rft.py). This script supports self-improvement training and allows dynamic adjustments of sampling temperature, PRM thresholds, and other hyperparameters. The training method is flexible (e.g., fine-tuning, DPO) and supports iterative retraining of the original model or continued training from the previous iteration, even loading all training states from the previous iteration. Developers can incorporate additional data filtering (e.g., ensuring rows with the same ID come from the same query), including diversity checks, language filtering, etc.

Expand All @@ -96,8 +92,3 @@ Specifically, we tested the GSM8K metric for `Qwen2.5_math_7b_instruct`:

As shown, RFT training did not significantly change the GSM8K score, avoiding the previously mentioned performance degradation phenomenon.

## Future Roadmap

1. More sampling methods,MCTS for example
2. Distill from super huge model
3. On policy RFT like PPO
116 changes: 0 additions & 116 deletions examples/sampler/mcts/mcts.py

This file was deleted.

35 changes: 0 additions & 35 deletions examples/sampler/mcts/mcts.sh

This file was deleted.

7 changes: 0 additions & 7 deletions examples/sampler/mcts/system_prompt.txt

This file was deleted.

1 change: 1 addition & 0 deletions examples/sampler/ray/sample.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
swift sample --config sampling.yaml
35 changes: 35 additions & 0 deletions examples/sampler/ray/sampling.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
ray_exp_name: sampling

use_ray: true

model: Qwen/Qwen2.5-VL-3B-Instruct
dataset: tastelikefeet/competition_math#16
num_return_sequences: 2
max_length: 2048
system: "You are a math model, you should **think step by step** carefully, and always consider the basic math principles to avoid making calculating mistakes. Give the final answer wrapped with \\boxed{{}}"
load_args: false
sampler_engine: vllm
max_new_tokens: 768
orm_model: math
prm_model: Qwen/Qwen2.5-Math-PRM-7B
override_exist_file: true
num_sampling_per_gpu_batch_size: 4
top_p: 1.0
temperature: 1.0
prm_threshold: 0.8
output_file: sampling.jsonl
engine_kwargs: "{\"mm_processor_cache_gb\":0.0}"

device_groups:
nproc_per_node: 4
sample_group:
device: GPU
ranks: list(range(0, 2))
workers:
- sampler
rm_group:
device: GPU
ranks: list(range(2, 4))
workers:
- prm
- orm
1 change: 1 addition & 0 deletions requirements/framework.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ requests
rouge
safetensors
scipy
omegaconf
sentencepiece
simplejson>=3.3.0
sortedcontainers>=1.5.9
Expand Down
1 change: 1 addition & 0 deletions requirements/ray.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ray
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,11 @@ def gen_packages_items():
all_requires = []
extra_requires['eval'], _ = parse_requirements('requirements/eval.txt')
extra_requires['swanlab'], _ = parse_requirements('requirements/swanlab.txt')
extra_requires['ray'], _ = parse_requirements('requirements/ray.txt')
all_requires.extend(install_requires)
all_requires.extend(extra_requires['eval'])
all_requires.extend(extra_requires['swanlab'])
all_requires.extend(extra_requires['ray'])
extra_requires['all'] = all_requires

setup(
Expand Down
44 changes: 43 additions & 1 deletion swift/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import os
import subprocess
import sys
from typing import Dict, List, Optional
import json
from typing import Dict, List, Optional, Any

from swift.utils import get_logger

Expand Down Expand Up @@ -45,6 +46,44 @@ def get_torchrun_args() -> Optional[List[str]]:
return torchrun_args


def prepare_config_args(argv):
for i in range(0, len(argv[1:]), 2):
arg_name = argv[i]
arg_value = argv[i + 1]
if arg_name == '--config':
Comment on lines +50 to +53
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This loop iterates through argv[1:], but accesses elements using argv[i] and argv[i + 1]. This will cause an IndexError when i is the last index in the loop, as argv[i + 1] will be out of bounds. The loop should iterate through the indices of argv directly.

To fix this, iterate through range(1, len(argv), 2) and adjust the indexing accordingly.

Suggested change
for i in range(0, len(argv[1:]), 2):
arg_name = argv[i]
arg_value = argv[i + 1]
if arg_name == '--config':
for i in range(1, len(argv), 2):
arg_name = argv[i]
if i + 1 < len(argv):
arg_value = argv[i + 1]
else:
break # Handle the case where there is no value for the last argument

from omegaconf import OmegaConf, DictConfig
from swift.ray import RayHelper
config = OmegaConf.load(arg_value)

def parse_dict_config(cfg: DictConfig) -> Dict[str, Any]:
result = {}
def _traverse(config: Any, parent_key: str = ""):
if isinstance(config, DictConfig):
for key, value in config.items():
if key == 'device_groups':
result[key] = json.dumps(OmegaConf.to_container(value))
else:
current_path = f"{parent_key}.{key}" if parent_key else key
_traverse(value, current_path)
else:
last_key = parent_key.split('.')[-1] if parent_key else ""
result[last_key] = config

_traverse(cfg)
return result

cfg = parse_dict_config(config)
for key, value in cfg.items():
argv.append(f'--{key}')
if not isinstance(value, str):
value = str(value)
argv.append(value)

argv.pop(i)
argv.pop(i)
break


def _compat_web_ui(argv):
# [compat]
method_name = argv[0]
Expand All @@ -56,11 +95,14 @@ def _compat_web_ui(argv):
def cli_main(route_mapping: Optional[Dict[str, str]] = None) -> None:
route_mapping = route_mapping or ROUTE_MAPPING
argv = sys.argv[1:]
if 'local-rank' in argv[0]:
argv = argv[1:]
_compat_web_ui(argv)
method_name = argv[0].replace('_', '-')
argv = argv[1:]
file_path = importlib.util.find_spec(route_mapping[method_name]).origin
torchrun_args = get_torchrun_args()
prepare_config_args(argv)
python_cmd = sys.executable
if torchrun_args is None or method_name not in {'pt', 'sft', 'rlhf', 'infer'}:
args = [python_cmd, file_path, *argv]
Expand Down
Loading
Loading