diff --git a/python/sglang/srt/layers/attention/aiter_backend.py b/python/sglang/srt/layers/attention/aiter_backend.py index a52fbfab9961..dbcfd96730d0 100644 --- a/python/sglang/srt/layers/attention/aiter_backend.py +++ b/python/sglang/srt/layers/attention/aiter_backend.py @@ -37,6 +37,7 @@ mla_prefill_ps_asm_fwd, mla_reduce_v1, paged_attention_ragged, + paged_attention_ragged_nhd, ) from aiter.mla import mla_decode_fwd, mla_prefill_fwd except ImportError: @@ -96,8 +97,7 @@ class ForwardMetadata: global_workspace_buffer = None -_AITER_PARTITION_SIZE_ROCM = 256 - +_AITER_PARTITION_SIZE_ROCM = 256 # 256 512 1024 for paged_attention_ragged_nhd class AiterAttnBackend(AttentionBackend): def __init__( @@ -187,17 +187,28 @@ def __init__( self.max_num_partitions = ( self.max_context_len + _AITER_PARTITION_SIZE_ROCM - 1 ) // _AITER_PARTITION_SIZE_ROCM - nbyes_per_qo_elem = torch.finfo(torch.float32).bits // 8 if not self.use_mla: - self.workspace_buffer = torch.empty( - (max_bs * self.num_head * self.max_num_partitions * self.head_dim) - * nbyes_per_qo_elem - + 2 * (max_bs * self.num_head * self.max_num_partitions) * 4, - dtype=torch.uint8, - device=self.device, - ) + # self.workspace_buffer = torch.empty( + # (max_bs * self.num_head * self.max_num_partitions * self.head_dim) + # * nbyes_per_qo_elem + # + 2 * (max_bs * self.num_head * self.max_num_partitions) * 4, + # dtype=torch.uint8, + # device=self.device, + # ).contiguous() + # aiter pa_ragged_nhd: three tensors (no single uint8 blob + host pointer math). + _nhp = max_bs * self.num_head * self.max_num_partitions + self.pa_nhd_exp_sums = torch.empty( + (_nhp,), dtype=torch.float32, device=self.device + ).contiguous() + self.pa_nhd_max_logits = torch.empty( + (_nhp,), dtype=torch.float32, device=self.device + ).contiguous() + self.pa_nhd_tmp_out = torch.empty( + (_nhp * self.head_dim,), dtype=torch.bfloat16, device=self.device + ).contiguous() + self.scale = float(1.0 / (self.head_dim**0.5)) self.k_scale = self.v_scale = torch.tensor([1.0], dtype=torch.float32).to( @@ -1770,10 +1781,11 @@ def forward_decode( k_cache = k_cache.to(dtype) v_cache = v_cache.to(dtype) - - paged_attention_ragged( + paged_attention_ragged_nhd( o.view(-1, layer.tp_q_head_num, layer.qk_head_dim), - self.workspace_buffer, + self.pa_nhd_exp_sums, + self.pa_nhd_max_logits, + self.pa_nhd_tmp_out, q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), k_cache.view(-1, 1, layer.tp_k_head_num, layer.qk_head_dim), v_cache.view(-1, 1, layer.tp_v_head_num, layer.v_head_dim), @@ -1792,6 +1804,27 @@ def forward_decode( None, _AITER_PARTITION_SIZE_ROCM, ) + # paged_attention_ragged( + # o.view(-1, layer.tp_q_head_num, layer.qk_head_dim), + # self.workspace_buffer, + # q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), + # k_cache.view(-1, 1, layer.tp_k_head_num, layer.qk_head_dim), + # v_cache.view(-1, 1, layer.tp_v_head_num, layer.v_head_dim), + # self.scale, + # self.forward_metadata.kv_indptr, + # self.forward_metadata.kv_indices, + # self.kv_last_page_len, + # 1, + # self.max_num_partitions, + # None, + # "auto", + # "NHD", + # self.logits_soft_cap, + # self.k_scale, + # self.v_scale, + # None, + # _AITER_PARTITION_SIZE_ROCM, + # ) return o