Skip to content

Commit d2bcde2

Browse files
authored
[train][inference] qwen_gr00t supports ascend and musa platform (flagos-ai#1178)
### PR Category <!-- One of [ Train | Inference | Compress | Serve | RL | Core | Hardware | CICD | Tools | Others ] --> Train | Inference ### PR Types <!-- One of [ User Experience | New Features | Bug Fixes | Improvements | Performance | Breaking Change| Deprecations | Test Case | Docs | Others ] --> Others ### PR Description <!-- Describe what you’ve done --> 1, qwen_gr00t train/inference adapt ascend and musa platform
1 parent 006d645 commit d2bcde2

File tree

5 files changed

+19
-13
lines changed

5 files changed

+19
-13
lines changed

examples/qwen_gr00t/README.md

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ Install FlagScale and training dependencies:
2424

2525
```sh
2626
cd FlagScale/
27+
# "[cuda-train]" is for NVIDIA GPUs; replace with "[ascend-train]" on Huawei Ascend, or "[musa-train]" on Moore Threads MUSA
2728
pip install ".[cuda-train]" --verbose
2829
```
2930

@@ -116,10 +117,12 @@ vim examples/qwen_gr00t/conf/train.yaml
116117

117118
Configure the following fields:
118119

119-
- `experiment.envs.CUDA_VISIBLE_DEVICES` - GPU devices to use (default: `"0,1,2,3,4,5,6,7"` for 8 GPUs)
120-
- `experiment.envs.CUDA_DEVICE_MAX_CONNECTIONS` - Connection limit (typically `1`)
120+
- `experiment.envs.CUDA_VISIBLE_DEVICES` - GPU devices to use (e.g., `"0,1,2,3"` for 4 GPUs). Use `ASCEND_RT_VISIBLE_DEVICES` for Huawei Ascend, `MUSA_VISIBLE_DEVICES` for Moore Threads MUSA
121+
- `experiment.envs.CUDA_DEVICE_MAX_CONNECTIONS` - Connection limit (typically `1`). Use `MUSA_DEVICE_MAX_CONNECTIONS` for Moore Threads MUSA
122+
- `experiment.envs.MUSA_LAUNCH_BLOCKING` - Set to `"1"` on Moore Threads MUSA to enable synchronous kernel execution, useful for debugging
121123
- `experiment.exp_name` - Experiment name
122124
- `experiment.exp_dir` - Output directory for checkpoints and logs
125+
- `experiment.runner.nproc_per_node` - Number of processes per node for multi-GPU training (required for Huawei Ascend)
123126

124127
#### Task-Level Config
125128

@@ -199,7 +202,7 @@ model:
199202
- `data.vla_data.obs` - Observation image keys (default: `["image_0"]`)
200203
- `data.observation_delta_indices` - Observation delta indices (default: `[0]`)
201204
- `data.action_delta_indices` - Action delta indices (default: `[0,1,2,3,4,5,6,7]`)
202-
- `data.preprocessor` - Preprocessor pipeline configuration
205+
- `data.preprocessor` - Preprocessor pipeline configuration. For Moore Threads MUSA, set `device_processor.config.device` to `"musa"`, for Huawei Ascend, set to `"npu"`,
203206
- `data.postprocessor` - Postprocessor pipeline configuration
204207

205208
### Start Training
@@ -250,7 +253,7 @@ Configure the following fields:
250253
**Engine settings:**
251254
- `engine.model_variant` - Model variant (default: `"QwenGr00t"`)
252255
- `engine.model` - Path to trained checkpoint (e.g., `/workspace/outputs/qwen_gr00t_train/checkpoints/last`)
253-
- `engine.device` - Device to use (e.g., `"cuda"`)
256+
- `engine.device` - Device to use (e.g., `"cuda", "musa", "npu"`)
254257

255258
**Generate settings:**
256259
- `generate.images` - Dictionary mapping image keys to file paths:

flagscale/inference/inference_qwen_gr00t.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from flagscale.logger import logger
99
from flagscale.models.utils.constants import OBS_STATE
1010
from flagscale.models.vla import TrainablePolicy
11+
from flagscale.platforms import get_platform # noqa: F401 must be before model imports
1112
from flagscale.train.processor import PolicyProcessorPipeline
1213

1314

flagscale/models/vla/base_policy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def from_pretrained(cls, pretrained_path, device="cpu", *, config=None):
157157
missing, unexpected = load_model(
158158
model,
159159
str(weights_path),
160-
device=device,
160+
device="cpu" if str(device) == "musa" else device,
161161
strict=False,
162162
)
163163
if missing:

flagscale/models/vla/qwen_gr00t/modeling_qwen_gr00t.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from flagscale.models.vla.base_policy import TrainablePolicy
3232
from flagscale.models.vla.registry import build_action_model, build_vlm
3333
from flagscale.models.vla.utils import get_vlm_config
34+
from flagscale.platforms.platform_manager import get_platform
3435

3536

3637
class QwenGr00t(TrainablePolicy):
@@ -91,7 +92,7 @@ def forward(
9192
qwen_inputs = self.vlm.build_qwenvl_inputs(images, instructions)
9293

9394
# TODO: (yupu) Hard-coded autocast and dtype, matches starVLA
94-
with torch.autocast("cuda", dtype=torch.bfloat16):
95+
with torch.autocast(get_platform().amp_device_type(), dtype=torch.bfloat16):
9596
vlm_output = self.vlm.forward(qwen_inputs, output_attentions=False)
9697
# last_hidden_state: [B, seq_len, H]
9798
last_hidden = vlm_output["hidden_states"][-1] # [B, L, H]
@@ -122,7 +123,7 @@ def forward(
122123
padded_actions.append(final_a)
123124
action_masks.append(mask)
124125

125-
with torch.autocast("cuda", dtype=torch.float32):
126+
with torch.autocast(get_platform().amp_device_type(), dtype=torch.float32):
126127
# TODO: (yupu) Is this a bug or a feature? The action dtype would stay as bf16 under this autocast.
127128
actions = torch.stack(padded_actions).to(
128129
device=last_hidden.device, dtype=last_hidden.dtype
@@ -156,7 +157,7 @@ def forward(
156157
result = {"loss": output["loss"]}
157158

158159
if vlm_batch is not None:
159-
with torch.autocast("cuda", dtype=torch.bfloat16):
160+
with torch.autocast(get_platform().amp_device_type(), dtype=torch.bfloat16):
160161
vlm_loss = self.vlm.model(**vlm_batch, return_dict=True).loss
161162
result["vlm_loss"] = vlm_loss
162163

@@ -194,7 +195,7 @@ def predict_action(self, batch: list[dict] | dict) -> dict:
194195

195196
qwen_inputs = self.vlm.build_qwenvl_inputs(images, instructions)
196197

197-
with torch.autocast("cuda", dtype=torch.bfloat16):
198+
with torch.autocast(get_platform().amp_device_type(), dtype=torch.bfloat16):
198199
vlm_output = self.vlm.forward(qwen_inputs, output_attentions=False)
199200
# last_hidden_state: [B, seq_len, H]
200201
last_hidden = vlm_output["hidden_states"][-1] # [B, L, H]
@@ -207,7 +208,7 @@ def predict_action(self, batch: list[dict] | dict) -> dict:
207208
state = state.to(device=last_hidden.device, dtype=last_hidden.dtype)
208209

209210
# Step 4: Action Expert Forward
210-
with torch.autocast("cuda", dtype=torch.float32):
211+
with torch.autocast(get_platform().amp_device_type(), dtype=torch.float32):
211212
vlm_output_for_action = {"hidden_states": last_hidden}
212213
action_input = {"state": state}
213214
output = self.action_model.predict_action(vlm_output_for_action, action_input)

flagscale/models/vla/vlm/qwenvl_backbone.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from flagscale.logger import logger
2222
from flagscale.models.vla.registry import register_vlm
23+
from flagscale.platforms.platform_manager import get_platform
2324

2425

2526
@dataclass
@@ -142,7 +143,7 @@ def forward(self, batch: dict[str, torch.Tensor], **kwargs) -> dict[str, torch.T
142143
f"[VLM.forward] input keys={list(batch.keys())} "
143144
+ " ".join(f"{k}={v.shape}" for k, v in batch.items() if isinstance(v, torch.Tensor))
144145
)
145-
with torch.autocast("cuda", dtype=torch.bfloat16):
146+
with torch.autocast(get_platform().amp_device_type(), dtype=torch.bfloat16):
146147
outputs = self.model(
147148
**batch,
148149
output_hidden_states=True,
@@ -205,7 +206,7 @@ def build_qwenvl_inputs(
205206

206207
# Use current CUDA device instead of self.model.device, which returns
207208
# a DTensor device under FSDP2 and causes mixed Tensor/DTensor errors.
208-
return batch_input.to(f"cuda:{torch.cuda.current_device()}")
209+
return batch_input.to(get_platform().device())
209210

210211

211212
@register_vlm("qwen3-vl")
@@ -253,4 +254,4 @@ def build_qwenvl_inputs(
253254

254255
# Use current CUDA device instead of self.model.device, which returns
255256
# a DTensor device under FSDP2 and causes mixed Tensor/DTensor errors.
256-
return batch_inputs.to(f"cuda:{torch.cuda.current_device()}")
257+
return batch_inputs.to(get_platform().device())

0 commit comments

Comments
 (0)