[examples] Add NPU-adapted GSM8K on-policy distillation launcher#5837
[examples] Add NPU-adapted GSM8K on-policy distillation launcher#5837duesdues wants to merge 1 commit intoverl-project:mainfrom
Conversation
- Source CANN/ATB set_env; device visibility; VLLM_USE_V1 - Local weights under $DATA_PATH/weights; prepare gsm8k and vLLM cache dirs - Eager rollout/teacher inference; actor without torch.compile; console logger
|
|
There was a problem hiding this comment.
Code Review
This pull request introduces a new bash script for on-policy distillation training of Qwen models on NPU hardware. The review feedback identifies a critical scheduling deadlock in Ray caused by incorrect resource pool settings on a 6-GPU configuration. Additionally, the reviewer pointed out performance inefficiencies in the vLLM and rollout configurations where the maximum number of batched tokens was set too low, preventing parallel processing of sequences.
|
|
||
| STUDENT_WORLD_SIZE=4 | ||
|
|
||
| TEACHER_RESOURCE_POOL=False |
There was a problem hiding this comment.
Setting TEACHER_RESOURCE_POOL=False while STUDENT_WORLD_SIZE=4 and TEACHER_WORLD_SIZE=2 on a 6-GPU setup will cause the script to hang. When enable_resource_pool is False, the teacher model is mapped to the global_pool, which is initialized with a size of trainer.n_gpus_per_node (set to 4 here). Since the student already utilizes all 4 GPUs in the global_pool, there are no resources left for the teacher manager actors to request, leading to a scheduling deadlock in Ray. Set this to True to allow the teacher to use the remaining 2 GPUs in a separate pool.
| TEACHER_RESOURCE_POOL=False | |
| TEACHER_RESOURCE_POOL=True |
| distillation.teacher_model.inference.gpu_memory_utilization=0.3 | ||
| distillation.teacher_model.inference.enforce_eager=$ENFORCE_EAGER | ||
| distillation.teacher_model.inference.max_model_len=$MAX_NUM_TOKENS | ||
| distillation.teacher_model.inference.max_num_batched_tokens=$MAX_NUM_TOKENS |
There was a problem hiding this comment.
Setting max_num_batched_tokens to MAX_NUM_TOKENS (769) is extremely inefficient for vLLM. This parameter controls the maximum number of tokens processed in a single forward pass. With a training batch size of 128, setting this to the length of a single sequence (769) forces vLLM to process sequences one by one, significantly degrading throughput. It should be set to a much larger value (e.g., 8192 or TRAIN_PROMPT_BSZ * MAX_NUM_TOKENS) to allow parallel processing of multiple sequences. Additionally, ensure the Hydra override uses the + prefix as per repository standards.
| distillation.teacher_model.inference.max_num_batched_tokens=$MAX_NUM_TOKENS | |
| +distillation.teacher_model.inference.max_num_batched_tokens=8192 |
References
- Use
+instead of++as the prefix for overriding configuration values in Hydra.
| actor_rollout_ref.rollout.gpu_memory_utilization=0.3 | ||
| actor_rollout_ref.rollout.calculate_log_probs=False | ||
| actor_rollout_ref.rollout.max_model_len=$MAX_NUM_TOKENS | ||
| actor_rollout_ref.rollout.max_num_batched_tokens=$MAX_NUM_TOKENS |
There was a problem hiding this comment.
Similar to the teacher model configuration, setting actor_rollout_ref.rollout.max_num_batched_tokens to 769 will severely limit rollout performance by preventing batching across sequences. Increase this to a standard value like 8192 to ensure efficient GPU/NPU utilization during generation. Ensure the Hydra override uses the + prefix.
| actor_rollout_ref.rollout.max_num_batched_tokens=$MAX_NUM_TOKENS | |
| +actor_rollout_ref.rollout.max_num_batched_tokens=8192 |
References
- Use
+instead of++as the prefix for overriding configuration values in Hydra.
|
@wucong25 现有的训练脚本应该能很平衡在GPU/NPU之间切换,我们后续是不是尽量不再维护两套脚本? |
| #!/usr/bin/env bash | ||
|
|
||
| source /usr/local/Ascend/ascend-toolkit/set_env.sh | ||
| source /usr/local/Ascend/nnal/atb/set_env.sh |
| TEACHER_MODEL=Qwen2.5-3B-Instruct | ||
|
|
||
| # USE_POLICY_GRADIENT=False | ||
| # DISTILLATION_LOSS_MODE="k3" |
What does this PR do?
Adds
examples/on_policy_distillation_trainer/run_qwen_gsm8k_npu.shto run the same on-policy distillation GSM8K example on Ascend (CANN / ATB + vLLM). Training hyperparameters and algorithm flags match the existing example; changes are limited to environment setup and paths appropriate for a typical local NPU deployment.follow up #5592
Files
run_qwen_gsm8k_npu.sh(new): Ascend-focused launcher.Local paths
DATA_PATH: defaults to$PWD/../verlData— set it if your tree is elsewhere.$DATA_PATH/weights/$FAMILY/<model>(on-disk layout), not bare Hub ids like the GPU script.$DATA_PATH/gsm8k/{train,test}.parquet— you provide the files (script only ensures dirs exist).HF_HOMEandVLLM_CACHE_DIRlive underDATA_PATH; vLLM cache dir is created if missing.Ascend
set_env.shpaths are the usual/usr/local/Ascend/...— fix if your install differs.Other script-level behavior (NPU)
ASCEND_RT_VISIBLE_DEVICES/NPU_VISIBLE_DEVICES(defaults can be overridden in the environment).VLLM_USE_V1default1where required by the image.enforce_eager=Truefor rollout and teacher;actor_rollout_ref.actor.use_torch_compile=False.trainer.loggeris console-only to avoid a hard wandb dependency on minimal images.Related
Intended to stack on the on-policy distillation example line (see #5592); happy to rebase or retarget if maintainers prefer
mainor another base.Test
bash run_qwen_gsm8k_npu.shon NPU, weights under$DATA_PATH/weights/..., and GSM8K parquet files present under$DATA_PATH/gsm8k/.Experiment
LLM on-policy distillation on GSM8K, following the PR #5592 example recipe and default script settings: k1 KL estimator with policy gradient updates (
loss_mode=k1,use_policy_gradient=True). Remaining hyperparameters match the shipped example (e.g. top‑k / clamps / fused kernels as in the launcher).Qwen2.5-0.5BQwen2.5-3B-Instructtrain.parquet(train),test.parquet(val), paths as in the example ($DATA_PATH/gsm8k/…for the NPU launcher; same layout as the GPU script).Metrics
Plots / tables for the following (attach screenshots below):
actor/distillation/lossval-core/openai/gsm8k/acc/mean@1critic/score/meanresponse_length/meanSupersedes #5830