Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def update_weights(self) -> None:
if (
not skip_base_sync
and self.quantization_config
and self.quantization_config["quant_method"] in ["compressed-tensors"]
and self.quantization_config["quant_method"] in ["compressed-tensors", "fp8"]
):
post_process_weights(
rollout_engines=self.rollout_engines,
Expand Down
37 changes: 36 additions & 1 deletion scripts/run_qwen3_30b_a3b.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class ScriptArgs(U.ExecuteTrainConfig):
model_name: str = "Qwen3-30B-A3B"
megatron_model_type: str = "qwen3-30B-A3B"
num_gpus_per_node: int | None = None
hardware: Literal["H100", "B200", "B300", "GB200", "GB300"] = "H100"
hardware: Literal["H100", "B200", "B300", "GB200", "GB300", "MI350X", "MI355X"] = "H100"
enable_eval: bool = True
extra_args: str = ""
data_dir: str = "/root/datasets"
Expand Down Expand Up @@ -205,6 +205,23 @@ def execute(args: ScriptArgs):
misc_env_vars |= {
"NVTE_FP8_BLOCK_SCALING_FP32_SCALES": "1",
}
case "MI350X" | "MI355X":
# ROCm gfx950: blockwise FP8 via ported Triton kernels.
# ROCm has no wgrad fusion yet, so turn off gradient-accumulation-fusion.
misc_args += (
"--transformer-impl transformer_engine "
"--bf16 "
"--fp8-format e4m3 "
"--fp8-recipe blockwise "
"--no-gradient-accumulation-fusion "
)
misc_env_vars |= {
"NVTE_FP8_BLOCK_SCALING_FP32_SCALES": "1",
"NVTE_ROCM_ENABLE_FP8_BLOCK_SCALING": "1",
# keep Ray from blanking HIP/CUDA visibility for the job entrypoint
"RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES": "1",
"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1",
}

if args.enable_megatron_bridge:
misc_args += "--megatron-to-hf-mode bridge "
Expand Down Expand Up @@ -268,6 +285,24 @@ def execute(args: ScriptArgs):
)
else:
sglang_args += "--rollout-num-gpus-per-engine 4 " "--sglang-cuda-graph-max-bs 512 "
case ("MI350X" | "MI355X", 1):
perf_args += (
"--tensor-model-parallel-size 1 "
"--sequence-parallel "
"--pipeline-model-parallel-size 2 "
"--context-parallel-size 2 "
"--expert-model-parallel-size 4 "
"--expert-tensor-parallel-size 1 "
"--max-tokens-per-gpu 16384 "
)
sglang_args = (
"--rollout-num-gpus-per-engine 2 "
"--sglang-mem-fraction-static 0.7 "
"--sglang-max-running-requests 512 "
)
optimizer_args += (
"--optimizer-cpu-offload " "--overlap-cpu-optimizer-d2h-h2d " "--use-precision-aware-optimizer "
)
case _:
raise NotImplementedError

Expand Down
Loading