Skip to content

[examples] Add NPU-adapted GSM8K on-policy distillation launcher#5837

Open
duesdues wants to merge 1 commit intoverl-project:mainfrom
duesdues:opd-gsm8k-npu-on-main
Open

[examples] Add NPU-adapted GSM8K on-policy distillation launcher#5837
duesdues wants to merge 1 commit intoverl-project:mainfrom
duesdues:opd-gsm8k-npu-on-main

Conversation

@duesdues
Copy link
Copy Markdown

@duesdues duesdues commented Apr 1, 2026

What does this PR do?

Adds examples/on_policy_distillation_trainer/run_qwen_gsm8k_npu.sh to run the same on-policy distillation GSM8K example on Ascend (CANN / ATB + vLLM). Training hyperparameters and algorithm flags match the existing example; changes are limited to environment setup and paths appropriate for a typical local NPU deployment.

follow up #5592

Files

  • run_qwen_gsm8k_npu.sh (new): Ascend-focused launcher.

Local paths

  • DATA_PATH: defaults to $PWD/../verlData — set it if your tree is elsewhere.
  • Weights: student/teacher use $DATA_PATH/weights/$FAMILY/<model> (on-disk layout), not bare Hub ids like the GPU script.
  • GSM8K data: $DATA_PATH/gsm8k/{train,test}.parquet — you provide the files (script only ensures dirs exist).
  • Caches: HF_HOME and VLLM_CACHE_DIR live under DATA_PATH; vLLM cache dir is created if missing.
    Ascend set_env.sh paths are the usual /usr/local/Ascend/... — fix if your install differs.

Other script-level behavior (NPU)

  • Devices: ASCEND_RT_VISIBLE_DEVICES / NPU_VISIBLE_DEVICES (defaults can be overridden in the environment).
  • vLLM on NPU: VLLM_USE_V1 default 1 where required by the image.
  • Stability defaults: enforce_eager=True for rollout and teacher; actor_rollout_ref.actor.use_torch_compile=False.
  • Logging: trainer.logger is console-only to avoid a hard wandb dependency on minimal images.

Related

Intended to stack on the on-policy distillation example line (see #5592); happy to rebase or retarget if maintainers prefer main or another base.

Test

  • Run bash run_qwen_gsm8k_npu.sh on NPU, weights under $DATA_PATH/weights/..., and GSM8K parquet files present under $DATA_PATH/gsm8k/.

Experiment

LLM on-policy distillation on GSM8K, following the PR #5592 example recipe and default script settings: k1 KL estimator with policy gradient updates (loss_mode=k1, use_policy_gradient=True). Remaining hyperparameters match the shipped example (e.g. top‑k / clamps / fused kernels as in the launcher).

Student Qwen Qwen2.5-0.5B
Teacher Qwen Qwen2.5-3B-Instruct
Dataset GSM8K — train.parquet (train), test.parquet (val), paths as in the example ($DATA_PATH/gsm8k/… for the NPU launcher; same layout as the GPU script).

Metrics

Plots / tables for the following (attach screenshots below):

  • actor/distillation/loss
image
  • val-core/openai/gsm8k/acc/mean@1
image
  • critic/score/mean
image
  • response_length/mean
image

Supersedes #5830

- Source CANN/ATB set_env; device visibility; VLLM_USE_V1
- Local weights under $DATA_PATH/weights; prepare gsm8k and vLLM cache dirs
- Eager rollout/teacher inference; actor without torch.compile; console logger
@CLAassistant
Copy link
Copy Markdown

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.
You have signed the CLA already but the status is still pending? Let us recheck it.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new bash script for on-policy distillation training of Qwen models on NPU hardware. The review feedback identifies a critical scheduling deadlock in Ray caused by incorrect resource pool settings on a 6-GPU configuration. Additionally, the reviewer pointed out performance inefficiencies in the vLLM and rollout configurations where the maximum number of batched tokens was set too low, preventing parallel processing of sequences.


STUDENT_WORLD_SIZE=4

TEACHER_RESOURCE_POOL=False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Setting TEACHER_RESOURCE_POOL=False while STUDENT_WORLD_SIZE=4 and TEACHER_WORLD_SIZE=2 on a 6-GPU setup will cause the script to hang. When enable_resource_pool is False, the teacher model is mapped to the global_pool, which is initialized with a size of trainer.n_gpus_per_node (set to 4 here). Since the student already utilizes all 4 GPUs in the global_pool, there are no resources left for the teacher manager actors to request, leading to a scheduling deadlock in Ray. Set this to True to allow the teacher to use the remaining 2 GPUs in a separate pool.

Suggested change
TEACHER_RESOURCE_POOL=False
TEACHER_RESOURCE_POOL=True

distillation.teacher_model.inference.gpu_memory_utilization=0.3
distillation.teacher_model.inference.enforce_eager=$ENFORCE_EAGER
distillation.teacher_model.inference.max_model_len=$MAX_NUM_TOKENS
distillation.teacher_model.inference.max_num_batched_tokens=$MAX_NUM_TOKENS
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Setting max_num_batched_tokens to MAX_NUM_TOKENS (769) is extremely inefficient for vLLM. This parameter controls the maximum number of tokens processed in a single forward pass. With a training batch size of 128, setting this to the length of a single sequence (769) forces vLLM to process sequences one by one, significantly degrading throughput. It should be set to a much larger value (e.g., 8192 or TRAIN_PROMPT_BSZ * MAX_NUM_TOKENS) to allow parallel processing of multiple sequences. Additionally, ensure the Hydra override uses the + prefix as per repository standards.

Suggested change
distillation.teacher_model.inference.max_num_batched_tokens=$MAX_NUM_TOKENS
+distillation.teacher_model.inference.max_num_batched_tokens=8192
References
  1. Use + instead of ++ as the prefix for overriding configuration values in Hydra.

actor_rollout_ref.rollout.gpu_memory_utilization=0.3
actor_rollout_ref.rollout.calculate_log_probs=False
actor_rollout_ref.rollout.max_model_len=$MAX_NUM_TOKENS
actor_rollout_ref.rollout.max_num_batched_tokens=$MAX_NUM_TOKENS
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the teacher model configuration, setting actor_rollout_ref.rollout.max_num_batched_tokens to 769 will severely limit rollout performance by preventing batching across sequences. Increase this to a standard value like 8192 to ensure efficient GPU/NPU utilization during generation. Ensure the Hydra override uses the + prefix.

Suggested change
actor_rollout_ref.rollout.max_num_batched_tokens=$MAX_NUM_TOKENS
+actor_rollout_ref.rollout.max_num_batched_tokens=8192
References
  1. Use + instead of ++ as the prefix for overriding configuration values in Hydra.

@wuxibin89
Copy link
Copy Markdown
Collaborator

@wucong25 现有的训练脚本应该能很平衡在GPU/NPU之间切换,我们后续是不是尽量不再维护两套脚本?

#!/usr/bin/env bash

source /usr/local/Ascend/ascend-toolkit/set_env.sh
source /usr/local/Ascend/nnal/atb/set_env.sh
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

精度曲线有验证过么,和GPU的对比呢

TEACHER_MODEL=Qwen2.5-3B-Instruct

# USE_POLICY_GRADIENT=False
# DISTILLATION_LOSS_MODE="k3"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

无效注释可以删除

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants