From c65c91e1780b7c24b319810b7a6cc915c5c58cfe Mon Sep 17 00:00:00 2001 From: apinge Date: Wed, 13 May 2026 06:34:26 +0000 Subject: [PATCH 1/2] add pa_gluon and flash_attn_varlen_func --- .../srt/layers/attention/aiter_backend.py | 587 ++++++++++++++++-- 1 file changed, 526 insertions(+), 61 deletions(-) diff --git a/python/sglang/srt/layers/attention/aiter_backend.py b/python/sglang/srt/layers/attention/aiter_backend.py index a52fbfab9961..a7767e0dd4b8 100644 --- a/python/sglang/srt/layers/attention/aiter_backend.py +++ b/python/sglang/srt/layers/attention/aiter_backend.py @@ -11,6 +11,7 @@ import torch import triton +import triton.language as tl from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton @@ -28,6 +29,7 @@ try: from aiter import ( + dtypes, flash_attn_varlen_func, get_mla_metadata_info_v1, get_mla_metadata_v1, @@ -43,6 +45,7 @@ print( "aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device." ) + dtypes = None # type: ignore from sglang.srt.configs.model_config import AttentionArch from sglang.srt.layers.attention.utils import pad_sequence_with_mask @@ -51,6 +54,15 @@ logger = logging.getLogger(__name__) + +def _pa_decode_gluon_available() -> bool: + try: + aiter_ops = getattr(torch.ops, "aiter", None) + return aiter_ops is not None and hasattr(aiter_ops, "pa_decode_gluon") + except Exception: + return False + + # Use aiter mla persist design for fp8-kv cache _use_mla_ps_kernel = get_bool_env_var("SGLANG_AITER_MLA_PERSIST", "True") @@ -68,6 +80,107 @@ intra_batch_mode = True if _use_mla_ps_kernel else False +@triton.jit +def reshape_and_cache_shuffle_kernel( + key_ptr, + value_ptr, + key_cache_ptr, + value_cache_ptr, + slot_mapping_ptr, + k_scale_ptr, + v_scale_ptr, + x, + k_stride0, + v_stride0, + block_size, + head_size, + num_kv_heads, + BLOCK_SIZE: tl.constexpr, + QUANT: tl.constexpr, +): + tid = tl.program_id(0) + head_id = tl.program_id(1) + offset = tl.arange(0, BLOCK_SIZE) + src_offset_k = tid * k_stride0 + head_id * head_size + src_offset_v = tid * v_stride0 + head_id * head_size + slot_id = tl.load(slot_mapping_ptr + tid) + if slot_id < 0: + return + block_id = slot_id // block_size + block_offset = slot_id % block_size + dst_offset = ( + block_id * num_kv_heads * head_size * block_size + + head_id * head_size * block_size + ) + dst_k_shuffle_offset = ( + dst_offset + offset // x * block_size * x + block_offset * x + offset % x + ) + dst_v_shuffle_offset = ( + dst_offset + block_offset // x * head_size * x + offset * x + block_offset % x + ) + k_val = tl.load(key_ptr + src_offset_k + offset) + v_val = tl.load(value_ptr + src_offset_v + offset) + if QUANT: + k_scale = tl.load(k_scale_ptr) + v_scale = tl.load(v_scale_ptr) + k_dtype = key_cache_ptr.type.element_ty + v_dtype = value_cache_ptr.type.element_ty + k_val = (k_val.to(tl.float32) / k_scale).to(k_dtype) + v_val = (v_val.to(tl.float32) / v_scale).to(v_dtype) + tl.store(key_cache_ptr + dst_k_shuffle_offset, k_val) + tl.store(value_cache_ptr + dst_v_shuffle_offset, v_val) + + +def reshape_and_cache_shuffle_triton( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scales: torch.Tensor, + v_scales: torch.Tensor, +): + num_tokens = slot_mapping.shape[0] + _, num_kv_heads, head_size = key.shape + num_blocks, block_size, _, _ = key_cache.shape + x = 16 // key_cache.element_size() + k_cache_template = torch.empty( + [num_blocks, num_kv_heads, head_size // x, block_size, x], + dtype=key_cache.dtype, + device="meta", + ) + v_cache_template = torch.empty( + [num_blocks, num_kv_heads, block_size // x, head_size, x], + dtype=value_cache.dtype, + device="meta", + ) + new_key_cache = key_cache.view_as(k_cache_template) + new_value_cache = value_cache.view_as(v_cache_template) + QUANT = kv_cache_dtype.startswith("fp8") + grid = ( + num_tokens, + num_kv_heads, + ) + reshape_and_cache_shuffle_kernel[grid]( + key, + value, + new_key_cache, + new_value_cache, + slot_mapping, + k_scales, + v_scales, + x, + key.stride(0), + value.stride(0), + block_size, + head_size, + num_kv_heads, + BLOCK_SIZE=head_size, + QUANT=QUANT, + ) + + class WrapperDispatch(Enum): SLIDING_WINDOW = auto() CROSS_ATTENTION = auto() @@ -92,6 +205,9 @@ class ForwardMetadata: custom_mask: Optional[torch.Tensor] = None mask_indptr: Optional[torch.Tensor] = None max_extend_len: Optional[int] = None + # Non-MLA page_size==1024 + pa_decode_gluon: physical block table per seq + page_table: Optional[torch.Tensor] = None + kv_lens: Optional[torch.Tensor] = None global_workspace_buffer = None @@ -173,6 +289,66 @@ def __init__( (max_bs + 1,), dtype=torch.int64, device=model_runner.device ) + self.decode_using_pa_gluon = False + if ( + not self.use_mla + and (self.page_size % 64 == 0 or 64 % self.page_size == 0) + and _pa_decode_gluon_available() + ): + self.decode_using_pa_gluon = True + self.seq_lens_for_page_table = torch.zeros( + (max_bs,), dtype=torch.int32, device=model_runner.device + ) + self.page_table = torch.zeros( + (max_bs, self.max_context_len // self.page_size), + dtype=torch.int32, + device=model_runner.device, + ) + self.strided_indices = torch.arange( + 0, self.max_context_len, self.page_size, device=model_runner.device + ) + # Workspace for torch.ops.aiter.pa_decode_gluon + self.gluon_use_ps = False + self.gluon_context_partition_size = 256 + query_group_size = max(1, self.num_head // self.num_kv_head) + props = torch.cuda.get_device_properties(self.device) + num_sm = props.multi_processor_count * 2 + max_part_ps = min( + 16, + triton.cdiv(num_sm, max(1, self.num_kv_head)), + ) + max_part_linear = triton.cdiv( + self.max_context_len, self.gluon_context_partition_size + ) + self.gluon_max_context_partition_num = max(max_part_ps, max_part_linear) + gshape = ( + max_bs, + self.num_kv_head, + self.gluon_max_context_partition_num, + query_group_size, + ) + self.gluon_exp_sums = torch.empty( + *gshape, dtype=torch.float32, device=self.device + ) + self.gluon_max_logits = torch.empty( + *gshape, dtype=torch.float32, device=self.device + ) + self.gluon_temporary_output = torch.empty( + *gshape, + self.head_dim, + dtype=self.input_dtype, + device=self.device, + ) + + logger.info( + "AiterAttnBackend: decode_using_pa_gluon=%s (use_mla=%s, page_size=%s, " + "pa_decode_gluon_available=%s)", + self.decode_using_pa_gluon, + self.use_mla, + self.page_size, + _pa_decode_gluon_available(), + ) + # Create prefill indices updater if not skip_prefill: self.indices_updater_prefill = AiterIndicesUpdaterPrefill( @@ -430,6 +606,40 @@ def make_mla_prefill_ps_meta_data( is_causal=is_causal, ) + def set_kv_buffer_with_layout_shuffle( + self, + cache_loc, + k, + v, + k_buffer, + v_buffer, + k_scale, + v_scale, + block_size, + ): + num_slots, num_kv_heads, head_dim = k_buffer.shape + num_blocks = num_slots // block_size + num_slots_with_block = num_blocks * block_size + k_buffer = k_buffer[:num_slots_with_block].view( + num_blocks, block_size, num_kv_heads, head_dim + ) + v_buffer = v_buffer[:num_slots_with_block].view( + num_blocks, block_size, num_kv_heads, head_dim + ) + kv_cache_dtype = "auto" + if k_buffer.dtype in (torch.float8_e4m3fn, torch.float8_e4m3fnuz): + kv_cache_dtype = "fp8" + reshape_and_cache_shuffle_triton( + k, + v, + k_buffer, + v_buffer, + cache_loc, + kv_cache_dtype, + k_scale, + v_scale, + ) + def init_forward_metadata(self, forward_batch: ForwardBatch): """Init auxiliary variables for triton attention backend.""" @@ -504,22 +714,66 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): intra_batch_mode=intra_batch_mode, ) - self.forward_metadata = ForwardMetadata( - kv_indptr, - kv_indices, - qo_indptr, - kv_last_page_len, - max_q_len, - None, - work_metadata=work_metadata, - work_info_set=work_info_set, - work_indptr=work_indptr, - reduce_indptr=reduce_indptr, - reduce_final_map=reduce_final_map, - reduce_partial_map=reduce_partial_map, - num_kv_splits=num_kv_splits, - run_graph=False, - ) + if ( + self.decode_using_pa_gluon + and spec_info is None + and not self.use_mla + ): + seq_lens_cpu = forward_batch.seq_lens_cpu + if seq_lens_cpu is None: + seq_lens_cpu = forward_batch.seq_lens.cpu() + page_table_persistent = self.page_table + seq_lens_persistent = self.seq_lens_for_page_table + seq_lens_persistent.fill_(0) + page_table_persistent.fill_(0) + seq_lens_persistent[:bs].copy_( + forward_batch.seq_lens, non_blocking=True + ) + max_seq_pages = ( + seq_lens_cpu.max().item() + self.page_size - 1 + ) // self.page_size + page_table_src = self.req_to_token[ + forward_batch.req_pool_indices[:, None], + self.strided_indices[:max_seq_pages][None, :], + ] + page_table_persistent[:bs, :max_seq_pages].copy_( + page_table_src // self.page_size, non_blocking=True + ) + self.forward_metadata = ForwardMetadata( + kv_indptr, + kv_indices, + qo_indptr, + kv_last_page_len, + max_q_len, + None, + work_metadata=work_metadata, + work_info_set=work_info_set, + work_indptr=work_indptr, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + num_kv_splits=num_kv_splits, + run_graph=False, + page_table=page_table_persistent[:bs, :].contiguous(), + kv_lens=seq_lens_persistent[:bs], + ) + else: + self.forward_metadata = ForwardMetadata( + kv_indptr, + kv_indices, + qo_indptr, + kv_last_page_len, + max_q_len, + None, + work_metadata=work_metadata, + work_info_set=work_info_set, + work_indptr=work_indptr, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + num_kv_splits=num_kv_splits, + run_graph=False, + ) elif forward_batch.forward_mode.is_draft_extend(): if self.use_mla: @@ -941,22 +1195,44 @@ def init_forward_metadata_capture_cuda_graph( reduce_final_map = self.reduce_final_map reduce_partial_map = self.reduce_partial_map - self.forward_metadata = ForwardMetadata( - kv_indptr, - kv_indices, - qo_indptr, - kv_last_page_len, - max_q_len, - kv_indptr[-1].item(), - work_metadata=work_metadata, - work_info_set=work_info_set, - work_indptr=work_indptr, - reduce_indptr=reduce_indptr, - reduce_final_map=reduce_final_map, - reduce_partial_map=reduce_partial_map, - num_kv_splits=num_kv_splits, - # num_kv_splits_indptr=num_kv_splits_indptr, - ) + if self.decode_using_pa_gluon and spec_info is None and not self.use_mla: + page_table_slice = self.page_table[:bs, :] + self.seq_lens_for_page_table[:bs].copy_(seq_lens, non_blocking=True) + seq_lens_persistent = self.seq_lens_for_page_table[:bs] + self.forward_metadata = ForwardMetadata( + kv_indptr, + kv_indices, + qo_indptr, + kv_last_page_len, + max_q_len, + kv_indptr[-1].item(), + work_metadata=work_metadata, + work_info_set=work_info_set, + work_indptr=work_indptr, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + num_kv_splits=num_kv_splits, + page_table=page_table_slice.contiguous(), + kv_lens=seq_lens_persistent, + ) + else: + self.forward_metadata = ForwardMetadata( + kv_indptr, + kv_indices, + qo_indptr, + kv_last_page_len, + max_q_len, + kv_indptr[-1].item(), + work_metadata=work_metadata, + work_info_set=work_info_set, + work_indptr=work_indptr, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + num_kv_splits=num_kv_splits, + # num_kv_splits_indptr=num_kv_splits_indptr, + ) elif forward_mode.is_target_verify(): if self.use_mla: @@ -1189,6 +1465,39 @@ def init_forward_metadata_replay_cuda_graph( kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices + if ( + self.decode_using_pa_gluon + and spec_info is None + and not self.use_mla + and seq_lens_cpu is not None + ): + page_table_persistent = self.page_table + seq_lens_persistent = self.seq_lens_for_page_table + seq_lens_persistent.fill_(0) + page_table_persistent.fill_(0) + seq_lens_persistent[:bs].copy_(seq_lens[:bs], non_blocking=True) + max_seq_pages = ( + int(seq_lens_cpu[:bs].max().item()) + self.page_size - 1 + ) // self.page_size + 1 + page_table_src = self.req_to_token[ + req_pool_indices[:bs, None], + self.strided_indices[:max_seq_pages][None, :], + ] + page_table_persistent[:bs, :max_seq_pages].copy_( + page_table_src // self.page_size, non_blocking=True + ) + self.forward_metadata = ForwardMetadata( + kv_indptr, + kv_indices, + None, + None, + 1, + None, + run_graph=True, + page_table=page_table_persistent[:bs, :].contiguous(), + kv_lens=seq_lens_persistent[:bs], + ) + elif forward_mode.is_target_verify(): bs = len(req_pool_indices) qo_indptr = self.qo_indptr[: bs + 1] @@ -1271,6 +1580,20 @@ def forward_extend( if save_kv_cache: if self.use_mla: forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) + elif self.decode_using_pa_gluon: + k_buf, v_buf = forward_batch.token_to_kv_pool.get_kv_buffer( + layer.layer_id + ) + self.set_kv_buffer_with_layout_shuffle( + cache_loc, + k, + v, + k_buf, + v_buf, + layer.k_scale, + layer.v_scale, + self.page_size, + ) else: forward_batch.token_to_kv_pool.set_kv_buffer( layer, cache_loc, k, v, layer.k_scale, layer.v_scale @@ -1617,6 +1940,50 @@ def forward_extend( f"Invalid forward mode for MLA prefill: {forward_batch.forward_mode=}" ) else: + if self.decode_using_pa_gluon: + # Dense flash_attn_varlen only matches "no radix prefix": q packs + # sum(extend_lens) tokens and indices_updater sets qo_indptr from extend_lens. + # With prefix cache, extend_lens = seq_lens - prefix_lens, so using seq_lens + # for cu_seqlens makes cu_seqlens[-1] > q.shape[0] → kernel OOB / core dump. + if self.is_multimodal: + extend_no_prefix = False + else: + extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu) + if extend_no_prefix: + extend_lens = ( + forward_batch.seq_lens - forward_batch.extend_prefix_lens + ) + cu_seqlens_q = torch.nn.functional.pad( + torch.cumsum(extend_lens, dim=0, dtype=torch.int32), (1, 0) + ) + else: + cu_seqlens_q = self.qo_indptr[: forward_batch.batch_size + 1] + if ( + dtypes is not None + and q.dtype != k.dtype + and k.dtype == dtypes.fp8 + ): + q = q.to(dtypes.fp8) + max_ql = self.forward_metadata.max_q_len + o = flash_attn_varlen_func( + q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + k.contiguous().view(-1, layer.tp_k_head_num, layer.head_dim), + v.contiguous().view(-1, layer.tp_v_head_num, layer.head_dim), + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_q, + max_seqlen_q=max_ql, + # K/V packed like Q (extend tokens only): per-seq K len is extend_len. + # max_kv_len is for paged mha (prefix+extend); must not use it here. + max_seqlen_k=max_ql, + min_seqlen_q=0, + dropout_p=0.0, + softmax_scale=self.scale, + causal=True, + window_size=(-1, -1, 0), + sink_ptr=None, + ) + return o.view(-1, layer.tp_q_head_num * layer.head_dim) + if ( forward_batch.forward_mode.is_target_verify() or forward_batch.forward_mode.is_draft_extend() @@ -1701,9 +2068,24 @@ def forward_decode( o = torch.empty_like(q, dtype=self.input_dtype) if save_kv_cache: - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, forward_batch.out_cache_loc, k, v - ) + if self.decode_using_pa_gluon: + k_buf, v_buf = forward_batch.token_to_kv_pool.get_kv_buffer( + layer.layer_id + ) + self.set_kv_buffer_with_layout_shuffle( + forward_batch.out_cache_loc, + k, + v, + k_buf, + v_buf, + layer.k_scale, + layer.v_scale, + self.page_size, + ) + else: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, forward_batch.out_cache_loc, k, v + ) if self.use_mla: k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) @@ -1764,34 +2146,117 @@ def forward_decode( layer.layer_id ) - # TODO kkhuang-amd need to remove it when paged_attention_ragged support fp8-kv - if self.kv_cache_dtype == fp8_dtype: - dtype = q.dtype + if self.decode_using_pa_gluon: + batch_size = q.shape[0] + block_size = self.page_size + num_slots, num_kv_heads, head_size = k_cache.shape + num_blocks = num_slots // block_size + k_cache = k_cache[: num_blocks * block_size].view( + num_blocks, block_size, num_kv_heads, head_size + ) + v_cache = v_cache[: num_blocks * block_size].view( + num_blocks, block_size, num_kv_heads, head_size + ) + x = 16 // k_cache.element_size() + new_key_cache = k_cache.view( + num_blocks, num_kv_heads, head_size // x, block_size, x + ) + new_value_cache = v_cache.view( + num_blocks, num_kv_heads, block_size // x, head_size, x + ) + md = self.forward_metadata + q_flat = q.reshape(-1, layer.tp_q_head_num, layer.qk_head_dim) + out_flat = o.reshape(-1, layer.tp_q_head_num, layer.v_head_dim) + block_tables = md.page_table + context_lengths = md.kv_lens + query_scale = None + key_scale = None + value_scale = None + if self.kv_cache_dtype == fp8_dtype: + if dtypes is None: + raise RuntimeError( + "aiter dtypes required for FP8 KV with pa_decode_gluon" + ) + compute_type = dtypes.fp8 + key_scale = self.k_scale + value_scale = self.v_scale + if q_flat.dtype == dtypes.fp8: + query_scale = ( + layer.k_scale + if isinstance(layer.k_scale, torch.Tensor) + else torch.ones( + 1, dtype=torch.float32, device=q.device + ) + ) + elif dtypes is not None and q_flat.dtype == dtypes.fp8: + compute_type = dtypes.fp8 + query_scale = ( + layer.k_scale + if isinstance(layer.k_scale, torch.Tensor) + else torch.ones(1, dtype=torch.float32, device=q.device) + ) + elif q_flat.dtype in (torch.float16, torch.bfloat16): + compute_type = q_flat.dtype + else: + compute_type = self.input_dtype - k_cache = k_cache.to(dtype) - v_cache = v_cache.to(dtype) + sw = getattr(layer, "sliding_window_size", None) + sliding_window = ( + int(sw) if sw is not None and int(sw) > 0 else 0 + ) - paged_attention_ragged( - o.view(-1, layer.tp_q_head_num, layer.qk_head_dim), - self.workspace_buffer, - q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), - k_cache.view(-1, 1, layer.tp_k_head_num, layer.qk_head_dim), - v_cache.view(-1, 1, layer.tp_v_head_num, layer.v_head_dim), - self.scale, - self.forward_metadata.kv_indptr, - self.forward_metadata.kv_indices, - self.kv_last_page_len, - 1, - self.max_num_partitions, - None, - "auto", - "NHD", - self.logits_soft_cap, - self.k_scale, - self.v_scale, - None, - _AITER_PARTITION_SIZE_ROCM, - ) + torch.ops.aiter.pa_decode_gluon( + out_flat, + q_flat, + new_key_cache, + new_value_cache, + context_lengths, + block_tables, + layer.scaling, + 1, + self.gluon_max_context_partition_num, + self.gluon_context_partition_size, + compute_type, + query_scale, + key_scale, + value_scale, + exp_sums=self.gluon_exp_sums[:batch_size], + max_logits=self.gluon_max_logits[:batch_size], + temporary_output=self.gluon_temporary_output[:batch_size], + alibi_slopes=None, + sinks=None, + sliding_window=sliding_window, + ps=False, + ) + else: + # TODO kkhuang-amd need to remove it when paged_attention_ragged support fp8-kv + if self.kv_cache_dtype == fp8_dtype: + dtype = q.dtype + + k_cache = k_cache.to(dtype) + v_cache = v_cache.to(dtype) + + paged_attention_ragged( + o.view(-1, layer.tp_q_head_num, layer.qk_head_dim), + self.workspace_buffer, + q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), + k_cache.view(-1, 1, layer.tp_k_head_num, layer.qk_head_dim), + v_cache.view(-1, 1, layer.tp_v_head_num, layer.v_head_dim), + self.scale, + self.forward_metadata.kv_indptr, + self.forward_metadata.kv_indices, + self.kv_last_page_len, + 1, + self.max_num_partitions, + None, + "auto", + "NHD", + self.logits_soft_cap, + self.k_scale, + self.v_scale, + None, + _AITER_PARTITION_SIZE_ROCM, + ) return o From d15151794e6f90422340612e3d5ef70b665d5a9d Mon Sep 17 00:00:00 2001 From: apinge Date: Thu, 14 May 2026 13:06:57 +0000 Subject: [PATCH 2/2] add vllm function --- .../srt/layers/attention/aiter_backend.py | 289 +++++++++++++++++- 1 file changed, 281 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/layers/attention/aiter_backend.py b/python/sglang/srt/layers/attention/aiter_backend.py index a7767e0dd4b8..239ee1a840f2 100644 --- a/python/sglang/srt/layers/attention/aiter_backend.py +++ b/python/sglang/srt/layers/attention/aiter_backend.py @@ -14,6 +14,7 @@ import triton.language as tl from sglang.srt.layers.attention.base_attn_backend import AttentionBackend +from sglang.srt.layers.attention.merge_state import merge_state from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.dp_attention import ( get_attention_tp_size, @@ -181,6 +182,156 @@ def reshape_and_cache_shuffle_triton( ) +@triton.jit +def _cp_mha_gather_cache_kernel( + key_cache_ptr, + value_cache_ptr, + key_ptr, + value_ptr, + block_table_ptr, + cu_seqlens_kv_ptr, + token_to_batch_ptr, + seq_start_ptr, + k_scale_ptr, + v_scale_ptr, + num_heads, + head_size, + x, + max_block_num, + DEQUANT: tl.constexpr, + PAGE_SIZE: tl.constexpr, + CACHE_FORMAT: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + token_id = tl.program_id(0) + head_id = tl.program_id(1) + col_offsets = tl.arange(0, BLOCK_SIZE) + + key_ptr_offset = ( + key_ptr + token_id * head_size * num_heads + head_id * head_size + ) + value_ptr_offset = ( + value_ptr + token_id * head_size * num_heads + head_id * head_size + ) + batch_idx = tl.load(token_to_batch_ptr + token_id) + batch_start = tl.load(seq_start_ptr + batch_idx) + token_start = tl.load(cu_seqlens_kv_ptr + batch_idx) + batch_offset = token_id - token_start + batch_start + block_offset = batch_offset // PAGE_SIZE + block_id = tl.load( + block_table_ptr + max_block_num * batch_idx + block_offset + ).to(tl.int64) + slot_id = batch_offset % PAGE_SIZE + + if CACHE_FORMAT == "NHD": + key_cache_ptr_offset = ( + key_cache_ptr + + block_id * num_heads * head_size * PAGE_SIZE + + slot_id * num_heads * head_size + + head_id * head_size + ) + value_cache_ptr_offset = ( + value_cache_ptr + + block_id * num_heads * head_size * PAGE_SIZE + + slot_id * num_heads * head_size + + head_id * head_size + ) + k_reg = tl.load(key_cache_ptr_offset + col_offsets) + v_reg = tl.load(value_cache_ptr_offset + col_offsets) + if DEQUANT: + k_scale = tl.load(k_scale_ptr) + v_scale = tl.load(v_scale_ptr) + k_dtype = k_reg.dtype + v_dtype = v_reg.dtype + k_reg = (k_reg.to(tl.float32) * k_scale).to(k_dtype) + v_reg = (v_reg.to(tl.float32) * v_scale).to(v_dtype) + tl.store(key_ptr_offset + col_offsets, k_reg) + tl.store(value_ptr_offset + col_offsets, v_reg) + + elif CACHE_FORMAT == "SHUFFLE": + key_cache_ptr_offset = ( + key_cache_ptr + + block_id * num_heads * head_size * PAGE_SIZE + + head_id * head_size * PAGE_SIZE + + slot_id * x + ) + value_cache_ptr_offset = ( + value_cache_ptr + + block_id * num_heads * head_size * PAGE_SIZE + + head_id * head_size * PAGE_SIZE + + (slot_id // x) * head_size * x + + slot_id % x + ) + k_reg_offset = col_offsets // x * PAGE_SIZE * x + col_offsets % x + v_reg_offset = col_offsets * x + k_reg = tl.load(key_cache_ptr_offset + k_reg_offset) + v_reg = tl.load(value_cache_ptr_offset + v_reg_offset) + if DEQUANT: + k_scale = 1.0 + v_scale = 1.0 + k_reg = k_reg.to(tl.float32) * k_scale + v_reg = v_reg.to(tl.float32) * v_scale + tl.store(key_ptr_offset + col_offsets, k_reg) + tl.store(value_ptr_offset + col_offsets, v_reg) + + +def cp_mha_gather_cache( + key_cache: torch.Tensor, + value_cache: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + block_tables: torch.Tensor, + k_scales: torch.Tensor, + v_scales: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + token_to_batch: torch.Tensor, + seq_starts: torch.Tensor, + dequant: bool, + kv_cache_layout: str, + total_tokens: int, +): + assert kv_cache_layout in ("NHD", "SHUFFLE") + head_dim = key.shape[2] + x = 16 // key_cache.element_size() + assert head_dim == key_cache.shape[3], ( + "cp_mha_gather_cache expects key_cache [num_blocks, page_size, num_heads, head_dim]" + ) + page_size = key_cache.shape[1] + num_heads = key_cache.shape[2] + + grid = (total_tokens, num_heads) + _cp_mha_gather_cache_kernel[grid]( + key_cache, + value_cache, + key, + value, + block_tables, + cu_seqlens_kv, + token_to_batch, + seq_starts, + k_scales, + v_scales, + num_heads, + head_dim, + x, + block_tables.size(1), + DEQUANT=dequant, + PAGE_SIZE=page_size, + CACHE_FORMAT=kv_cache_layout, + BLOCK_SIZE=head_dim, + ) + + +def _flash_varlen_lse_to_merge_layout( + lse: torch.Tensor, num_query_heads: int +) -> torch.Tensor: + if lse is None: + return lse + if lse.dim() == 2 and lse.shape[0] == num_query_heads: + return lse.transpose(0, 1).contiguous() + return lse + + class WrapperDispatch(Enum): SLIDING_WINDOW = auto() CROSS_ATTENTION = auto() @@ -1948,7 +2099,10 @@ def forward_extend( if self.is_multimodal: extend_no_prefix = False else: - extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu) + extend_no_prefix = ( + forward_batch.extend_prefix_lens_cpu is None + or not any(forward_batch.extend_prefix_lens_cpu) + ) if extend_no_prefix: extend_lens = ( forward_batch.seq_lens - forward_batch.extend_prefix_lens @@ -1965,15 +2119,42 @@ def forward_extend( ): q = q.to(dtypes.fp8) max_ql = self.forward_metadata.max_q_len - o = flash_attn_varlen_func( - q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), - k.contiguous().view(-1, layer.tp_k_head_num, layer.head_dim), - v.contiguous().view(-1, layer.tp_v_head_num, layer.head_dim), + q3 = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) + k3 = k.contiguous().view(-1, layer.tp_k_head_num, layer.head_dim) + v3 = v.contiguous().view(-1, layer.tp_v_head_num, layer.head_dim) + + has_prefix = ( + forward_batch.extend_prefix_lens_cpu is not None + and any(forward_batch.extend_prefix_lens_cpu) + ) + if not has_prefix: + o = flash_attn_varlen_func( + q3, + k3, + v3, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_q, + max_seqlen_q=max_ql, + max_seqlen_k=max_ql, + min_seqlen_q=0, + dropout_p=0.0, + softmax_scale=self.scale, + causal=True, + window_size=(-1, -1, 0), + sink_ptr=None, + ) + return o.view(-1, layer.tp_q_head_num * layer.head_dim) + + # Radix prefix + pa_decode_gluon shuffle cache: same pattern as vLLM + # rocm_aiter_fa.extend_forward — extend self-attn, gather prefix K/V, + # second varlen (non-causal), merge with logsumexp. + suf = flash_attn_varlen_func( + q3, + k3, + v3, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_q, max_seqlen_q=max_ql, - # K/V packed like Q (extend tokens only): per-seq K len is extend_len. - # max_kv_len is for paged mha (prefix+extend); must not use it here. max_seqlen_k=max_ql, min_seqlen_q=0, dropout_p=0.0, @@ -1981,8 +2162,100 @@ def forward_extend( causal=True, window_size=(-1, -1, 0), sink_ptr=None, + return_lse=True, + ) + out_suf, lse_suf = suf if isinstance(suf, tuple) else (suf, None) + + bs = forward_batch.batch_size + pl = forward_batch.extend_prefix_lens + total_pt = int(pl.sum().item()) + if total_pt == 0: + return out_suf.view(-1, layer.tp_q_head_num * layer.head_dim) + + key_fetched = torch.empty( + (total_pt, layer.tp_k_head_num, layer.head_dim), + dtype=q3.dtype, + device=q3.device, + ) + value_fetched = torch.empty_like(key_fetched) + + cu_seqlens_prefix = torch.zeros( + bs + 1, dtype=torch.int32, device=q3.device + ) + cu_seqlens_prefix[1:] = torch.cumsum(pl, dim=0) + token_to_batch = torch.repeat_interleave( + torch.arange(bs, device=q3.device, dtype=torch.int32), + pl.to(dtype=torch.long, device=q3.device), ) - return o.view(-1, layer.tp_q_head_num * layer.head_dim) + seq_starts = torch.zeros(bs, dtype=torch.int32, device=q3.device) + + page_size = self.page_size + max_pp = (int(pl.max().item()) + page_size - 1) // page_size + strided_pos = self.strided_indices[:max_pp].to(torch.int64) + pos = strided_pos.unsqueeze(0).expand(bs, -1).clone() + pos_clamped = torch.minimum( + pos, + (forward_batch.seq_lens - 1).clamp(min=0)[:, None], + ) + req_idx = forward_batch.req_pool_indices[:, None] + slots = self.req_to_token[req_idx, pos_clamped] + block_table = (slots // page_size).to(torch.int32) + + k_buf, v_buf = forward_batch.token_to_kv_pool.get_kv_buffer( + layer.layer_id + ) + num_slots = k_buf.shape[0] + num_blocks = num_slots // page_size + k_nhd = k_buf[: num_blocks * page_size].view( + num_blocks, page_size, layer.tp_k_head_num, layer.head_dim + ) + v_nhd = v_buf[: num_blocks * page_size].view( + num_blocks, page_size, layer.tp_k_head_num, layer.head_dim + ) + dequant = self.kv_cache_dtype == fp8_dtype + k_scales = torch.ones(1, dtype=torch.float32, device=q3.device) + v_scales = torch.ones(1, dtype=torch.float32, device=q3.device) + + cp_mha_gather_cache( + k_nhd, + v_nhd, + key_fetched, + value_fetched, + block_table, + k_scales, + v_scales, + cu_seqlens_prefix, + token_to_batch, + seq_starts, + dequant, + "SHUFFLE", + total_pt, + ) + + max_pk = int(pl.max().item()) + pre = flash_attn_varlen_func( + q3, + key_fetched, + value_fetched, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_prefix, + max_seqlen_q=max_ql, + max_seqlen_k=max_pk, + min_seqlen_q=0, + dropout_p=0.0, + softmax_scale=self.scale, + causal=False, + window_size=(-1, -1, 0), + sink_ptr=None, + return_lse=True, + ) + out_pre, lse_pre = pre if isinstance(pre, tuple) else (pre, None) + + nhq = layer.tp_q_head_num + lse_suf_m = _flash_varlen_lse_to_merge_layout(lse_suf, nhq) + lse_pre_m = _flash_varlen_lse_to_merge_layout(lse_pre, nhq) + merged, _ = merge_state(out_pre, lse_pre_m, out_suf, lse_suf_m) + return merged.view(-1, layer.tp_q_head_num * layer.head_dim) if ( forward_batch.forward_mode.is_target_verify()