Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 46 additions & 13 deletions python/sglang/srt/layers/attention/aiter_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
mla_prefill_ps_asm_fwd,
mla_reduce_v1,
paged_attention_ragged,
paged_attention_ragged_nhd,
)
from aiter.mla import mla_decode_fwd, mla_prefill_fwd
except ImportError:
Expand Down Expand Up @@ -96,8 +97,7 @@ class ForwardMetadata:

global_workspace_buffer = None

_AITER_PARTITION_SIZE_ROCM = 256

_AITER_PARTITION_SIZE_ROCM = 256 # 256 512 1024 for paged_attention_ragged_nhd

class AiterAttnBackend(AttentionBackend):
def __init__(
Expand Down Expand Up @@ -187,17 +187,28 @@ def __init__(
self.max_num_partitions = (
self.max_context_len + _AITER_PARTITION_SIZE_ROCM - 1
) // _AITER_PARTITION_SIZE_ROCM

nbyes_per_qo_elem = torch.finfo(torch.float32).bits // 8

if not self.use_mla:
self.workspace_buffer = torch.empty(
(max_bs * self.num_head * self.max_num_partitions * self.head_dim)
* nbyes_per_qo_elem
+ 2 * (max_bs * self.num_head * self.max_num_partitions) * 4,
dtype=torch.uint8,
device=self.device,
)
# self.workspace_buffer = torch.empty(
# (max_bs * self.num_head * self.max_num_partitions * self.head_dim)
# * nbyes_per_qo_elem
# + 2 * (max_bs * self.num_head * self.max_num_partitions) * 4,
# dtype=torch.uint8,
# device=self.device,
# ).contiguous()
# aiter pa_ragged_nhd: three tensors (no single uint8 blob + host pointer math).
_nhp = max_bs * self.num_head * self.max_num_partitions
self.pa_nhd_exp_sums = torch.empty(
(_nhp,), dtype=torch.float32, device=self.device
).contiguous()
self.pa_nhd_max_logits = torch.empty(
(_nhp,), dtype=torch.float32, device=self.device
).contiguous()
self.pa_nhd_tmp_out = torch.empty(
(_nhp * self.head_dim,), dtype=torch.bfloat16, device=self.device
).contiguous()


self.scale = float(1.0 / (self.head_dim**0.5))
self.k_scale = self.v_scale = torch.tensor([1.0], dtype=torch.float32).to(
Expand Down Expand Up @@ -1770,10 +1781,11 @@ def forward_decode(

k_cache = k_cache.to(dtype)
v_cache = v_cache.to(dtype)

paged_attention_ragged(
paged_attention_ragged_nhd(
o.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
self.workspace_buffer,
self.pa_nhd_exp_sums,
self.pa_nhd_max_logits,
self.pa_nhd_tmp_out,
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
k_cache.view(-1, 1, layer.tp_k_head_num, layer.qk_head_dim),
v_cache.view(-1, 1, layer.tp_v_head_num, layer.v_head_dim),
Expand All @@ -1792,6 +1804,27 @@ def forward_decode(
None,
_AITER_PARTITION_SIZE_ROCM,
)
# paged_attention_ragged(
# o.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
# self.workspace_buffer,
# q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
# k_cache.view(-1, 1, layer.tp_k_head_num, layer.qk_head_dim),
# v_cache.view(-1, 1, layer.tp_v_head_num, layer.v_head_dim),
# self.scale,
# self.forward_metadata.kv_indptr,
# self.forward_metadata.kv_indices,
# self.kv_last_page_len,
# 1,
# self.max_num_partitions,
# None,
# "auto",
# "NHD",
# self.logits_soft_cap,
# self.k_scale,
# self.v_scale,
# None,
# _AITER_PARTITION_SIZE_ROCM,
# )

return o

Expand Down
Loading