diff --git a/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py b/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py index dce9dc865c..34002d1f0e 100644 --- a/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py +++ b/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py @@ -195,7 +195,7 @@ def update_weights(self) -> None: if ( not skip_base_sync and self.quantization_config - and self.quantization_config["quant_method"] in ["compressed-tensors"] + and self.quantization_config["quant_method"] in ["compressed-tensors", "fp8"] ): post_process_weights( rollout_engines=self.rollout_engines, diff --git a/scripts/run_qwen3_30b_a3b.py b/scripts/run_qwen3_30b_a3b.py index cb9e225f7c..110ab4ee86 100644 --- a/scripts/run_qwen3_30b_a3b.py +++ b/scripts/run_qwen3_30b_a3b.py @@ -13,7 +13,7 @@ class ScriptArgs(U.ExecuteTrainConfig): model_name: str = "Qwen3-30B-A3B" megatron_model_type: str = "qwen3-30B-A3B" num_gpus_per_node: int | None = None - hardware: Literal["H100", "B200", "B300", "GB200", "GB300"] = "H100" + hardware: Literal["H100", "B200", "B300", "GB200", "GB300", "MI350X", "MI355X"] = "H100" enable_eval: bool = True extra_args: str = "" data_dir: str = "/root/datasets" @@ -205,6 +205,23 @@ def execute(args: ScriptArgs): misc_env_vars |= { "NVTE_FP8_BLOCK_SCALING_FP32_SCALES": "1", } + case "MI350X" | "MI355X": + # ROCm gfx950: blockwise FP8 via ported Triton kernels. + # ROCm has no wgrad fusion yet, so turn off gradient-accumulation-fusion. + misc_args += ( + "--transformer-impl transformer_engine " + "--bf16 " + "--fp8-format e4m3 " + "--fp8-recipe blockwise " + "--no-gradient-accumulation-fusion " + ) + misc_env_vars |= { + "NVTE_FP8_BLOCK_SCALING_FP32_SCALES": "1", + "NVTE_ROCM_ENABLE_FP8_BLOCK_SCALING": "1", + # keep Ray from blanking HIP/CUDA visibility for the job entrypoint + "RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES": "1", + "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1", + } if args.enable_megatron_bridge: misc_args += "--megatron-to-hf-mode bridge " @@ -268,6 +285,24 @@ def execute(args: ScriptArgs): ) else: sglang_args += "--rollout-num-gpus-per-engine 4 " "--sglang-cuda-graph-max-bs 512 " + case ("MI350X" | "MI355X", 1): + perf_args += ( + "--tensor-model-parallel-size 1 " + "--sequence-parallel " + "--pipeline-model-parallel-size 2 " + "--context-parallel-size 2 " + "--expert-model-parallel-size 4 " + "--expert-tensor-parallel-size 1 " + "--max-tokens-per-gpu 16384 " + ) + sglang_args = ( + "--rollout-num-gpus-per-engine 2 " + "--sglang-mem-fraction-static 0.7 " + "--sglang-max-running-requests 512 " + ) + optimizer_args += ( + "--optimizer-cpu-offload " "--overlap-cpu-optimizer-d2h-h2d " "--use-precision-aware-optimizer " + ) case _: raise NotImplementedError