diff --git a/atom/model_ops/attention_mha.py b/atom/model_ops/attention_mha.py index 051980ff9..8953c40e4 100644 --- a/atom/model_ops/attention_mha.py +++ b/atom/model_ops/attention_mha.py @@ -124,12 +124,14 @@ def _can_attempt_prefill_sink_asm(self, fwd_ctx: ForwardContext) -> bool: return False if getattr(attn_metadata, "dropout_p", 0.0) != 0.0: return False - if getattr(attn_metadata, "has_cached", False): - return False + # Prefix-cache hit (has_cached) is supported: prefill_attention gathers + # the cached+new KV into a dense packed [total_kv, ...] tensor and the + # gfx1250 sink varlen ASM kernel handles bottom-right causal for + # sq != sk (chunked-prefill). cu_seqlens_q / cu_seqlens_k carry the + # per-request new-token vs cached+new lengths, so we no longer require + # max_seqlen_q == max_seqlen_k. if attn_metadata.cu_seqlens_q is None or attn_metadata.cu_seqlens_k is None: return False - if attn_metadata.max_seqlen_q != attn_metadata.max_seqlen_k: - return False return True def _can_use_prefill_sink_asm( @@ -366,14 +368,11 @@ def rope_cache(self, q, k, v, qkv, position, fwd_ctx: ForwardContext): ) self._cache_format = "SHUFFLE" if asm_layout else "NHD" - # Prefix cache hit: gather cached KV from paged cache and concat with new tokens - if attn_metadata.has_cached: - q, k, v, k_cache, v_cache, k_scale, v_scale = ( - self._gather_prefix_and_concat_kv( - q, k, v, k_cache, v_cache, k_scale, v_scale, attn_metadata - ) - ) - + # NOTE: on a prefix-cache hit the cached+new KV is gathered into a dense + # packed tensor inside prefill_attention (the ASM varlen path that needs + # it). The Triton path reads the paged KV cache directly, so it never + # gathers. Keeping the gather out of here also means dispatch_backend + # sees q/k with matching token counts (sq == sk). return q, k, v, k_cache, v_cache, k_scale, v_scale def _gather_prefix_and_concat_kv( @@ -734,6 +733,18 @@ def prefill_attention( # variable lenth attention use key value as input attn_metadata = fwd_ctx.attn_metadata + # Prefix-cache hit: gather cached+new KV from the paged cache into a + # dense packed [total_kv, ...] tensor (new tokens were already written + # during rope_cache). flash_attn_varlen_func then attends over the full + # sequence; cu_seqlens_q / cu_seqlens_k carry the new vs cached+new + # lengths (sq < sk), which the varlen kernel handles via bottom-right + # causal. + if attn_metadata.has_cached: + q, k, v, k_cache, v_cache, k_scale, v_scale = ( + self._gather_prefix_and_concat_kv( + q, k, v, k_cache, v_cache, k_scale, v_scale, attn_metadata + ) + ) sliding_window = ( (self.sliding_window, 0, 0) if self.sliding_window > 0 else (-1, -1, 0) ) @@ -861,6 +872,9 @@ def dispatch_backend( v: torch.Tensor, ): if fwd_ctx.context.is_prefill: + # q/k/v here still hold only the new tokens (the prefix gather happens + # inside prefill_attention), so the q.shape[0] == k.shape[0] check in + # _can_use_prefill_sink_asm is valid. if self._can_use_prefill_sink_asm(q, k, v, fwd_ctx): return self.prefill_attention if envs.ATOM_USE_UNIFIED_ATTN or self.use_flash_layout: