diff --git a/atom/model_ops/base_attention.py b/atom/model_ops/base_attention.py index a086a87b6..f99e85357 100644 --- a/atom/model_ops/base_attention.py +++ b/atom/model_ops/base_attention.py @@ -324,6 +324,8 @@ def fake_( k: torch.Tensor, v: torch.Tensor, positions: torch.Tensor, + kv_cache: torch.Tensor, + kv_scale: torch.Tensor, layer_name: str, use_mla: bool, qkv: torch.Tensor, @@ -342,13 +344,19 @@ def fake_( # Dynamo will not try to inspect any of the internal operations for prefill or decode # This way, although attention operation is complicated, # we can still capture the model's computation graph as a full-graph -@mark_spliting_op(is_custom=True, gen_fake=fake_, mutates_args=[]) +@mark_spliting_op( + is_custom=True, + gen_fake=fake_, + mutates_args=["kv_cache", "kv_scale"], +) def unified_attention_with_output_base( q: torch.Tensor, q_scale: Optional[torch.Tensor], k: torch.Tensor, v: torch.Tensor, positions: torch.Tensor, + kv_cache: torch.Tensor, + kv_scale: torch.Tensor, layer_name: str, use_mla: bool, qkv: torch.Tensor, @@ -368,9 +376,11 @@ def unified_attention_with_output_base( query=q, key=k, value=v, - position=positions, + positions=positions, q_scale=q_scale, qkv=qkv, + kv_cache=kv_cache, + kv_scale=kv_scale, ) @@ -379,6 +389,7 @@ def linear_attention_with_output_base_fake( b: torch.Tensor, a: torch.Tensor, core_attn_out: torch.Tensor, + kv_cache: torch.Tensor, layer_name: str, ) -> torch.Tensor: return torch.empty_like(core_attn_out) @@ -387,19 +398,20 @@ def linear_attention_with_output_base_fake( @mark_spliting_op( is_custom=True, gen_fake=linear_attention_with_output_base_fake, - mutates_args=[], + mutates_args=["kv_cache"], ) def linear_attention_with_output_base( mixed_qkv: torch.Tensor, b: torch.Tensor, a: torch.Tensor, core_attn_out: torch.Tensor, + kv_cache: torch.Tensor, layer_name: str, ) -> torch.Tensor: atom_config = get_current_atom_config() self = atom_config.compilation_config.static_forward_context[layer_name] ret = torch.empty_like(core_attn_out) - ret = self.impl.forward(mixed_qkv, b, a, ret, layer_name) + ret = self.impl.forward(mixed_qkv, b, a, ret, kv_cache, layer_name) return ret @@ -503,6 +515,7 @@ def __init__( compilation_config = atom_config.compilation_config default_name = f"Linear_{layer_num}" self.layer_name = prefix if prefix is not None else default_name + self.kv_cache = torch.tensor([]) if self.layer_name in compilation_config.static_forward_context: raise ValueError("Duplicate layer: {}".format(self.layer_name)) compilation_config.static_forward_context[self.layer_name] = self @@ -515,6 +528,6 @@ def forward( core_attn_out: torch.Tensor, ): output = torch.ops.aiter.linear_attention_with_output_base( - mixed_qkv, b, a, core_attn_out, self.layer_name + mixed_qkv, b, a, core_attn_out, self.kv_cache, self.layer_name ) return output diff --git a/atom/model_ops/paged_attention.py b/atom/model_ops/paged_attention.py index 7937ee111..8a320fa1f 100644 --- a/atom/model_ops/paged_attention.py +++ b/atom/model_ops/paged_attention.py @@ -118,7 +118,23 @@ def forward( qkv: torch.Tensor = None, **kwargs, ): + kv_cache = getattr(self.impl, "kv_cache", query.new_empty(0)) + kv_scale_fn = getattr(self.impl, "_kv_scale_arg_for_forward", None) + kv_scale = ( + kv_scale_fn(query, kv_cache) + if kv_scale_fn is not None + else query.new_empty(0) + ) output = torch.ops.aiter.unified_attention_with_output_base( - query, q_scale, key, value, positions, self.layer_name, self.use_mla, qkv + query, + q_scale, + key, + value, + positions, + kv_cache, + kv_scale, + self.layer_name, + self.use_mla, + qkv, ) return output diff --git a/atom/plugin/vllm/attention/layer_gdn.py b/atom/plugin/vllm/attention/layer_gdn.py index 12b6b446a..73f972d31 100644 --- a/atom/plugin/vllm/attention/layer_gdn.py +++ b/atom/plugin/vllm/attention/layer_gdn.py @@ -420,6 +420,7 @@ def forward( b: torch.Tensor, a: torch.Tensor, core_attn_out: torch.Tensor, + kv_cache: torch.Tensor, layer_name: str, ): """ @@ -448,8 +449,10 @@ def forward( non_spec_state_indices_tensor = ( attn_metadata.non_spec_state_indices_tensor ) # noqa: E501 - compilation_config = forward_context.no_compile_layers - self_kv_cache = compilation_config[layer_name].kv_cache + if kv_cache is None or kv_cache.numel() == 0: + compilation_config = forward_context.no_compile_layers + kv_cache = compilation_config[layer_name].kv_cache + self_kv_cache = kv_cache conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] num_actual_tokens = attn_metadata.num_actual_tokens diff --git a/atom/plugin/vllm/attention/layer_mha.py b/atom/plugin/vllm/attention/layer_mha.py index e0f622146..db9a4687e 100644 --- a/atom/plugin/vllm/attention/layer_mha.py +++ b/atom/plugin/vllm/attention/layer_mha.py @@ -63,6 +63,14 @@ def _init_vllm_mha_layer_state( _init_kv_cache_quant(layer, quant_config, layer_name) +def _normalize_aiter_kv_cache_dtype(kv_cache_dtype: str) -> str: + if kv_cache_dtype == "bfloat16": + return "bf16" + if kv_cache_dtype == "float16": + return "fp16" + return kv_cache_dtype + + def _set_default_mha_scales(layer) -> None: from vllm.model_executor.layers.attention.attention import set_default_quant_scales @@ -129,6 +137,7 @@ def __init__( else 1.0 ) self.kv_scale = torch.tensor(self.kv_scale_float, dtype=torch.float32) + self.per_tensor_scale = self.kv_scale self.per_token_quant = True self.sinks = sinks self.sliding_window = ( @@ -150,6 +159,7 @@ def __init__( calculate_kv_scales=calculate_kv_scales, quant_config=quant_config, ) + self.kv_cache_dtype = _normalize_aiter_kv_cache_dtype(self.kv_cache_dtype) _register_vllm_static_forward_context(self) @@ -174,6 +184,75 @@ def calc_kv_scales(self, query, key, value): self._v_scale_float = self._v_scale.item() self.calculate_kv_scales = False + def _scalar_kv_scale_for_device(self, reference: torch.Tensor) -> torch.Tensor: + if self.kv_scale.dim() != 0 or self.kv_scale.device != reference.device: + self.kv_scale = torch.tensor( + self.kv_scale_float, dtype=torch.float32, device=reference.device + ) + self.per_tensor_scale = self.kv_scale + self.k_scale = self.v_scale = None + return self.kv_scale + + def _ensure_fp8_kv_scale( + self, kv_cache: torch.Tensor, reference: torch.Tensor + ) -> torch.Tensor: + if kv_cache is None or kv_cache.numel() == 0: + return self._scalar_kv_scale_for_device(reference) + + k_cache, _ = kv_cache.unbind(0) + num_blocks, block_size, num_kv_heads, _ = k_cache.shape + expected_shape = (2, num_blocks, num_kv_heads, block_size) + if ( + self.kv_scale.shape != expected_shape + or self.kv_scale.device != reference.device + ): + self.per_tensor_scale = self._scalar_kv_scale_for_device(reference) + self.kv_scale = torch.full( + expected_shape, + self.kv_scale_float, + dtype=torch.float32, + device=reference.device, + ) + self.k_scale = self.kv_scale[0] + self.v_scale = self.kv_scale[1] + elif self.k_scale is None or self.v_scale is None: + self.k_scale = self.kv_scale[0] + self.v_scale = self.kv_scale[1] + + return self.kv_scale + + def _kv_scale_arg_for_forward( + self, reference: torch.Tensor, kv_cache: torch.Tensor = None + ) -> torch.Tensor: + kv_cache = self.kv_cache if kv_cache is None else kv_cache + if self.kv_cache_dtype == "fp8": + return self._ensure_fp8_kv_scale(kv_cache, reference) + return self._scalar_kv_scale_for_device(reference) + + @property + def kv_cache(self): + return self._kv_cache + + @kv_cache.setter + def kv_cache(self, kv_cache): + self._kv_cache = kv_cache + if ( + not hasattr(self, "kv_cache_dtype") + or self.kv_cache_dtype != "fp8" + or not isinstance(kv_cache, torch.Tensor) + ): + return + + if kv_cache.numel() == 0: + # vLLM clears kv_cache on shutdown/profiling cleanup. Drop the large + # side buffer with it so the next bind allocates a fresh graph target. + self.kv_scale = torch.tensor(self.kv_scale_float, dtype=torch.float32) + self.per_tensor_scale = self.kv_scale + self.k_scale = self.v_scale = None + return + + self._ensure_fp8_kv_scale(kv_cache, kv_cache) + def forward( self, query: torch.Tensor, @@ -182,15 +261,31 @@ def forward( positions: torch.Tensor = None, q_scale: Optional[torch.Tensor] = None, qkv: torch.Tensor = None, + kv_cache: torch.Tensor = None, + kv_scale: torch.Tensor = None, **kwargs, ): if self.calculate_kv_scales and key is not None and value is not None: self.calc_kv_scales(query, key, value) + if positions is None: + from vllm.forward_context import ( + get_forward_context as get_vllm_forward_context, + is_forward_context_available, + ) + + if is_forward_context_available(): + positions = get_vllm_forward_context().additional_kwargs.get( + "atom_positions" + ) + kv_cache = self.kv_cache if kv_cache is None else kv_cache + if kv_scale is None: + kv_scale = self._kv_scale_arg_for_forward(query, kv_cache) return torch.ops.aiter.atom_vllm_mha_attention( query, key, value, - self.kv_cache, + kv_cache, + kv_scale, self.layer_name, positions, q_scale, @@ -404,7 +499,7 @@ def paged_attention_triton( query_group_size = max_qlen * (num_q_heads_total // num_kv_heads) context_partition_size = 256 - use_ps = True + use_ps = self.kv_cache_dtype != "fp8" if use_ps: max_context_partition_num = get_recommended_splits( num_decodes, num_kv_heads @@ -694,6 +789,9 @@ def extend_forward( ) def _dispatch_decode_backend(self, num_decodes): + if self.kv_cache_dtype == "fp8": + return self.paged_attention_triton + # use asm pa for models without setting gluon pa decode bs gluon_pa_decode_bs = _GLUON_PA_DECODE_BS_MAPPING.get(self.model_type, -1) if self.use_triton_attn: @@ -710,6 +808,7 @@ def forward_impl( key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, + kv_scale: torch.Tensor = None, attn_metadata: "AiterMhaMetadataForVllm" = None, position: torch.Tensor = None, q_scale: torch.Tensor = None, @@ -756,22 +855,22 @@ def forward_impl( k_cache = k_cache.view(target_dtype) v_cache = v_cache.view(target_dtype) - # create kv scale according to the num_blocks - # usually it is created when cuda graph capture for decode phase + # Keep FP8 KV scales on the explicit mutable custom-op input. FULL + # cudagraph replay must read and write the same scale storage that the + # cache-write kernel updates for newly generated tokens. + active_k_scale = self.k_scale + active_v_scale = self.v_scale if self.kv_cache_dtype == "fp8": - if self.k_scale is None or self.v_scale is None: - # origin kv_scale is per tensor scale of value one. - self.per_tensor_scale = self.kv_scale - self.kv_scale = torch.zeros( - 2, - num_blocks, - num_kv_heads, - block_size, - dtype=dtypes.fp32, - device=self.device, - ) - self.k_scale = self.kv_scale[0] - self.v_scale = self.kv_scale[1] + scale_buffer = self._ensure_fp8_kv_scale(kv_cache, query) + active_k_scale = self.k_scale + active_v_scale = self.v_scale + if ( + isinstance(kv_scale, torch.Tensor) + and kv_scale.shape == scale_buffer.shape + and kv_scale.device == scale_buffer.device + ): + active_k_scale = kv_scale[0] + active_v_scale = kv_scale[1] # as vLLM cuda graph capture padding mechanism, here split the qkvo with # the actual tokens @@ -797,8 +896,8 @@ def forward_impl( attention_metadata=attn_metadata, k_cache=k_cache, v_cache=v_cache, - k_scale=self.k_scale, - v_scale=self.v_scale, + k_scale=active_k_scale, + v_scale=active_v_scale, flash_layout=False, ) query, key, value, k_cache, v_cache, k_scale, v_scale = result diff --git a/atom/plugin/vllm/attention/ops.py b/atom/plugin/vllm/attention/ops.py index 264a70901..7df134d59 100644 --- a/atom/plugin/vllm/attention/ops.py +++ b/atom/plugin/vllm/attention/ops.py @@ -21,6 +21,7 @@ def atom_vllm_mha_attention_fake( key: Optional[torch.Tensor], value: Optional[torch.Tensor], kv_cache: torch.Tensor, + kv_scale: torch.Tensor, layer_name: str, positions: Optional[torch.Tensor] = None, q_scale: Optional[torch.Tensor] = None, @@ -32,13 +33,14 @@ def atom_vllm_mha_attention_fake( @mark_spliting_op( is_custom=True, gen_fake=atom_vllm_mha_attention_fake, - mutates_args=["kv_cache"], + mutates_args=["kv_cache", "kv_scale"], ) def atom_vllm_mha_attention( query: torch.Tensor, key: Optional[torch.Tensor], value: Optional[torch.Tensor], kv_cache: torch.Tensor, + kv_scale: torch.Tensor, layer_name: str, positions: Optional[torch.Tensor] = None, q_scale: Optional[torch.Tensor] = None, @@ -50,6 +52,7 @@ def atom_vllm_mha_attention( key, value, kv_cache, + kv_scale, attn_metadata=attn_metadata, position=positions, q_scale=q_scale,