Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docker/Dockerfile.rocm_base
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ ARG PYTORCH_AUDIO_BRANCH="v2.9.0"
ARG PYTORCH_AUDIO_REPO="https://github.com/pytorch/audio.git"
ARG FA_BRANCH="0e60e394"
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
ARG AITER_BRANCH="v0.1.13"
ARG AITER_BRANCH="v0.1.14"
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is release cadence for ROCm major lib version bumps.

cc @micah-wil @Rohan138 @dllehr-amd

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Fangzhou-Ai on upstream we do not bump other dependencies version. So this PR will only be continued after aiter is upgraded.

ARG AITER_REPO="https://github.com/ROCm/aiter.git"
ARG MORI_BRANCH="v1.1.0"
ARG MORI_REPO="https://github.com/ROCm/mori.git"
Expand Down
67 changes: 36 additions & 31 deletions vllm/model_executor/layers/mhc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,20 @@
# this import will also register the custom ops
# import vllm.model_executor.kernels.mhc # noqa: F401
import vllm.model_executor.kernels.mhc as mhc_kernels
from vllm._aiter_ops import is_aiter_found_and_supported
from vllm.model_executor.custom_op import CustomOp
from vllm.utils.import_utils import has_tilelang

HAS_TILELANG = has_tilelang()

# mHC dispatch order on ROCm: aiter pre/post kernels (preferred, fastest) ->
# tilelang fused post+pre -> torch/triton reference. The aiter mHC kernels are
# the default whenever aiter is available on a supported ROCm device; they
# require aiter >= 0.1.14 (sqrsum race-condition fix in
# ``mhc_pre_gemm_sqrsum_kernel``, commit b639cb6) and a hidden size that is a
# multiple of 256, otherwise we fall back to the tilelang/reference paths.
HAS_AITER_MHC = is_aiter_found_and_supported()


# --8<-- [start:mhc_pre]
@CustomOp.register("mhc_pre")
Expand Down Expand Up @@ -71,24 +80,22 @@ def forward_hip(
norm_weight: torch.Tensor | None = None,
norm_eps: float = 0.0,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# TODO: Reenable aiter after we are at the aiter
# version that has this bugfix
# https://github.com/ROCm/aiter/commit/b639cb63bcac4672dce33a731fad042a65cb3649
# It has accuracy problem at large number of tokens.
# hidden_size = residual.shape[-1]
# if hidden_size % 256 == 0:
# return torch.ops.vllm.mhc_pre_aiter(
# residual,
# fn,
# hc_scale,
# hc_base,
# rms_eps,
# hc_pre_eps,
# hc_sinkhorn_eps,
# hc_post_mult_value,
# sinkhorn_repeat,
# )
# else:
# The aiter mhc_pre kernel only supports hidden sizes that are a
# multiple of 256. Requires aiter >= 0.1.14 for correct results at
# large token counts (sqrsum race-condition fix, commit b639cb6).
hidden_size = residual.shape[-1]
if HAS_AITER_MHC and hidden_size % 256 == 0:
return torch.ops.vllm.mhc_pre_aiter(
residual,
fn,
hc_scale,
hc_base,
rms_eps,
hc_pre_eps,
hc_sinkhorn_eps,
hc_post_mult_value,
sinkhorn_repeat,
)
if HAS_TILELANG:
return torch.ops.vllm.mhc_pre_tilelang(
residual,
Expand Down Expand Up @@ -181,19 +188,17 @@ def forward_hip(
post_layer_mix: torch.Tensor,
comb_res_mix: torch.Tensor,
) -> torch.Tensor:
# TODO: Reenable aiter after we are at the aiter
# version that has this bugfix
# https://github.com/ROCm/aiter/commit/b639cb63bcac4672dce33a731fad042a65cb3649
# It has accuracy problem at large number of tokens.
# hidden_size = residual.shape[-1]
# if hidden_size % 256 == 0:
# return torch.ops.vllm.mhc_post_aiter(
# x,
# residual,
# post_layer_mix,
# comb_res_mix,
# )
# else:
# The aiter mhc_post kernel only supports hidden sizes that are a
# multiple of 256. Requires aiter >= 0.1.14 for correct results at
# large token counts (sqrsum race-condition fix, commit b639cb6).
hidden_size = residual.shape[-1]
if HAS_AITER_MHC and hidden_size % 256 == 0:
return torch.ops.vllm.mhc_post_aiter(
x,
residual,
post_layer_mix,
comb_res_mix,
)
if HAS_TILELANG:
return torch.ops.vllm.mhc_post_tilelang(
x, residual, post_layer_mix, comb_res_mix
Expand Down
14 changes: 12 additions & 2 deletions vllm/models/deepseek_v4/amd/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mhc import (
HAS_AITER_MHC,
HCHeadOp,
MHCFusedPostPreOp,
MHCPostOp,
Expand Down Expand Up @@ -595,7 +596,13 @@ def forward(
) -> tuple[
torch.Tensor, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None
]:
if not self.has_tilelang:
# Select the mHC path:
# - aiter unfused pre/post (preferred ROCm path when aiter is available),
# or the torch/triton fallback when tilelang is unavailable ->
# _forward_unfused_post_pre
# - tilelang fused post+pre (CUDA, or ROCm without aiter) ->
# _forward_fused_post_pre
if not self.has_tilelang or HAS_AITER_MHC:
return self._forward_unfused_post_pre(
x, positions, input_ids, post_mix, res_mix, residual
)
Expand Down Expand Up @@ -750,7 +757,10 @@ def forward(
res_mix,
residual,
)
if layer is not None and self.has_tilelang:
# The fused post+pre path (tilelang) defers the final hc_post and
# returns the residual streams; the unfused path (aiter / torch on
# ROCm) applies hc_post inline and returns None, so skip it here.
if layer is not None and residual is not None:
hidden_states = layer.hc_post(hidden_states, residual, post_mix, res_mix)

if not get_pp_group().is_last_rank:
Expand Down
5 changes: 4 additions & 1 deletion vllm/models/deepseek_v4/amd/mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,10 @@ def forward(
hidden_states, residual, post_mix, res_mix = self.mtp_block(
positions=positions, x=hidden_states, input_ids=None
)
if self.has_tilelang:
# The fused post+pre path (tilelang, on CUDA or ROCm) defers the final
# hc_post and returns the residual streams; the unfused path (aiter /
# torch on ROCm) applies hc_post inline and returns None.
if residual is not None:
hidden_states = self.mtp_block.hc_post(
hidden_states, residual, post_mix, res_mix
)
Expand Down
Loading