Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 26 additions & 12 deletions atom/model_ops/attention_mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,14 @@ def _can_attempt_prefill_sink_asm(self, fwd_ctx: ForwardContext) -> bool:
return False
if getattr(attn_metadata, "dropout_p", 0.0) != 0.0:
return False
if getattr(attn_metadata, "has_cached", False):
return False
# Prefix-cache hit (has_cached) is supported: prefill_attention gathers
# the cached+new KV into a dense packed [total_kv, ...] tensor and the
# gfx1250 sink varlen ASM kernel handles bottom-right causal for
# sq != sk (chunked-prefill). cu_seqlens_q / cu_seqlens_k carry the
# per-request new-token vs cached+new lengths, so we no longer require
# max_seqlen_q == max_seqlen_k.
if attn_metadata.cu_seqlens_q is None or attn_metadata.cu_seqlens_k is None:
return False
if attn_metadata.max_seqlen_q != attn_metadata.max_seqlen_k:
return False
return True

def _can_use_prefill_sink_asm(
Expand Down Expand Up @@ -366,14 +368,11 @@ def rope_cache(self, q, k, v, qkv, position, fwd_ctx: ForwardContext):
)
self._cache_format = "SHUFFLE" if asm_layout else "NHD"

# Prefix cache hit: gather cached KV from paged cache and concat with new tokens
if attn_metadata.has_cached:
q, k, v, k_cache, v_cache, k_scale, v_scale = (
self._gather_prefix_and_concat_kv(
q, k, v, k_cache, v_cache, k_scale, v_scale, attn_metadata
)
)

# NOTE: on a prefix-cache hit the cached+new KV is gathered into a dense
# packed tensor inside prefill_attention (the ASM varlen path that needs
# it). The Triton path reads the paged KV cache directly, so it never
# gathers. Keeping the gather out of here also means dispatch_backend
# sees q/k with matching token counts (sq == sk).
return q, k, v, k_cache, v_cache, k_scale, v_scale

def _gather_prefix_and_concat_kv(
Expand Down Expand Up @@ -734,6 +733,18 @@ def prefill_attention(

# variable lenth attention use key value as input
attn_metadata = fwd_ctx.attn_metadata
# Prefix-cache hit: gather cached+new KV from the paged cache into a
# dense packed [total_kv, ...] tensor (new tokens were already written
# during rope_cache). flash_attn_varlen_func then attends over the full
# sequence; cu_seqlens_q / cu_seqlens_k carry the new vs cached+new
# lengths (sq < sk), which the varlen kernel handles via bottom-right
# causal.
if attn_metadata.has_cached:
q, k, v, k_cache, v_cache, k_scale, v_scale = (
self._gather_prefix_and_concat_kv(
q, k, v, k_cache, v_cache, k_scale, v_scale, attn_metadata
)
)
sliding_window = (
(self.sliding_window, 0, 0) if self.sliding_window > 0 else (-1, -1, 0)
)
Expand Down Expand Up @@ -861,6 +872,9 @@ def dispatch_backend(
v: torch.Tensor,
):
if fwd_ctx.context.is_prefill:
# q/k/v here still hold only the new tokens (the prefix gather happens
# inside prefill_attention), so the q.shape[0] == k.shape[0] check in
# _can_use_prefill_sink_asm is valid.
if self._can_use_prefill_sink_asm(q, k, v, fwd_ctx):
return self.prefill_attention
if envs.ATOM_USE_UNIFIED_ATTN or self.use_flash_layout:
Expand Down
Loading