diff --git a/docker/Dockerfile.rocm_base b/docker/Dockerfile.rocm_base index 195067b51a2a..1c31af83dbd2 100644 --- a/docker/Dockerfile.rocm_base +++ b/docker/Dockerfile.rocm_base @@ -9,7 +9,7 @@ ARG PYTORCH_AUDIO_BRANCH="v2.9.0" ARG PYTORCH_AUDIO_REPO="https://github.com/pytorch/audio.git" ARG FA_BRANCH="0e60e394" ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git" -ARG AITER_BRANCH="v0.1.13" +ARG AITER_BRANCH="v0.1.14" ARG AITER_REPO="https://github.com/ROCm/aiter.git" ARG MORI_BRANCH="v1.1.0" ARG MORI_REPO="https://github.com/ROCm/mori.git" diff --git a/vllm/model_executor/layers/mhc.py b/vllm/model_executor/layers/mhc.py index b720fa1f6fe2..61d05a7cba11 100644 --- a/vllm/model_executor/layers/mhc.py +++ b/vllm/model_executor/layers/mhc.py @@ -5,11 +5,20 @@ # this import will also register the custom ops # import vllm.model_executor.kernels.mhc # noqa: F401 import vllm.model_executor.kernels.mhc as mhc_kernels +from vllm._aiter_ops import is_aiter_found_and_supported from vllm.model_executor.custom_op import CustomOp from vllm.utils.import_utils import has_tilelang HAS_TILELANG = has_tilelang() +# mHC dispatch order on ROCm: aiter pre/post kernels (preferred, fastest) -> +# tilelang fused post+pre -> torch/triton reference. The aiter mHC kernels are +# the default whenever aiter is available on a supported ROCm device; they +# require aiter >= 0.1.14 (sqrsum race-condition fix in +# ``mhc_pre_gemm_sqrsum_kernel``, commit b639cb6) and a hidden size that is a +# multiple of 256, otherwise we fall back to the tilelang/reference paths. +HAS_AITER_MHC = is_aiter_found_and_supported() + # --8<-- [start:mhc_pre] @CustomOp.register("mhc_pre") @@ -71,24 +80,22 @@ def forward_hip( norm_weight: torch.Tensor | None = None, norm_eps: float = 0.0, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - # TODO: Reenable aiter after we are at the aiter - # version that has this bugfix - # https://github.com/ROCm/aiter/commit/b639cb63bcac4672dce33a731fad042a65cb3649 - # It has accuracy problem at large number of tokens. - # hidden_size = residual.shape[-1] - # if hidden_size % 256 == 0: - # return torch.ops.vllm.mhc_pre_aiter( - # residual, - # fn, - # hc_scale, - # hc_base, - # rms_eps, - # hc_pre_eps, - # hc_sinkhorn_eps, - # hc_post_mult_value, - # sinkhorn_repeat, - # ) - # else: + # The aiter mhc_pre kernel only supports hidden sizes that are a + # multiple of 256. Requires aiter >= 0.1.14 for correct results at + # large token counts (sqrsum race-condition fix, commit b639cb6). + hidden_size = residual.shape[-1] + if HAS_AITER_MHC and hidden_size % 256 == 0: + return torch.ops.vllm.mhc_pre_aiter( + residual, + fn, + hc_scale, + hc_base, + rms_eps, + hc_pre_eps, + hc_sinkhorn_eps, + hc_post_mult_value, + sinkhorn_repeat, + ) if HAS_TILELANG: return torch.ops.vllm.mhc_pre_tilelang( residual, @@ -181,19 +188,17 @@ def forward_hip( post_layer_mix: torch.Tensor, comb_res_mix: torch.Tensor, ) -> torch.Tensor: - # TODO: Reenable aiter after we are at the aiter - # version that has this bugfix - # https://github.com/ROCm/aiter/commit/b639cb63bcac4672dce33a731fad042a65cb3649 - # It has accuracy problem at large number of tokens. - # hidden_size = residual.shape[-1] - # if hidden_size % 256 == 0: - # return torch.ops.vllm.mhc_post_aiter( - # x, - # residual, - # post_layer_mix, - # comb_res_mix, - # ) - # else: + # The aiter mhc_post kernel only supports hidden sizes that are a + # multiple of 256. Requires aiter >= 0.1.14 for correct results at + # large token counts (sqrsum race-condition fix, commit b639cb6). + hidden_size = residual.shape[-1] + if HAS_AITER_MHC and hidden_size % 256 == 0: + return torch.ops.vllm.mhc_post_aiter( + x, + residual, + post_layer_mix, + comb_res_mix, + ) if HAS_TILELANG: return torch.ops.vllm.mhc_post_tilelang( x, residual, post_layer_mix, comb_res_mix diff --git a/vllm/models/deepseek_v4/amd/model.py b/vllm/models/deepseek_v4/amd/model.py index 28836a2b1432..5d937c692dce 100644 --- a/vllm/models/deepseek_v4/amd/model.py +++ b/vllm/models/deepseek_v4/amd/model.py @@ -24,6 +24,7 @@ ) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mhc import ( + HAS_AITER_MHC, HCHeadOp, MHCFusedPostPreOp, MHCPostOp, @@ -595,7 +596,13 @@ def forward( ) -> tuple[ torch.Tensor, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None ]: - if not self.has_tilelang: + # Select the mHC path: + # - aiter unfused pre/post (preferred ROCm path when aiter is available), + # or the torch/triton fallback when tilelang is unavailable -> + # _forward_unfused_post_pre + # - tilelang fused post+pre (CUDA, or ROCm without aiter) -> + # _forward_fused_post_pre + if not self.has_tilelang or HAS_AITER_MHC: return self._forward_unfused_post_pre( x, positions, input_ids, post_mix, res_mix, residual ) @@ -750,7 +757,10 @@ def forward( res_mix, residual, ) - if layer is not None and self.has_tilelang: + # The fused post+pre path (tilelang) defers the final hc_post and + # returns the residual streams; the unfused path (aiter / torch on + # ROCm) applies hc_post inline and returns None, so skip it here. + if layer is not None and residual is not None: hidden_states = layer.hc_post(hidden_states, residual, post_mix, res_mix) if not get_pp_group().is_last_rank: diff --git a/vllm/models/deepseek_v4/amd/mtp.py b/vllm/models/deepseek_v4/amd/mtp.py index 5938cde6959c..a4035b9c5969 100644 --- a/vllm/models/deepseek_v4/amd/mtp.py +++ b/vllm/models/deepseek_v4/amd/mtp.py @@ -155,7 +155,10 @@ def forward( hidden_states, residual, post_mix, res_mix = self.mtp_block( positions=positions, x=hidden_states, input_ids=None ) - if self.has_tilelang: + # The fused post+pre path (tilelang, on CUDA or ROCm) defers the final + # hc_post and returns the residual streams; the unfused path (aiter / + # torch on ROCm) applies hc_post inline and returns None. + if residual is not None: hidden_states = self.mtp_block.hc_post( hidden_states, residual, post_mix, res_mix )