diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index 00be33f2521..2e90a7fab71 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -7,15 +7,19 @@ from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl from vllm.config import VllmConfig, get_current_vllm_config from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group +from vllm.forward_context import get_forward_context +from vllm.logger import logger from vllm.model_executor.layers.linear import (LinearBase, ReplicatedLinear, UnquantizedLinearMethod) from vllm.triton_utils import HAS_TRITON from vllm.v1.attention.backends.utils import AttentionCGSupport +from vllm_ascend import envs from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, + trans_rope_weight, transdata, wait_for_kv_layer_from_connector) from vllm_ascend.ops.shared_weight_layer import ( is_hidden_layer, post_process_after_loading_for_shared_weight_series, @@ -23,6 +27,7 @@ register_layer_to_shared_weight_series) from vllm_ascend.ops.triton.rope import rope_forward_triton from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch +from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, _round_up, dispose_layer, enable_sp, is_enable_nz, replace_layer) @@ -341,12 +346,13 @@ def __init__( self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp self.enable_prefetch = ascend_config.weight_prefetch_config.enabled self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz - self.vllm_config = get_current_vllm_config() + self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO + assert self.indexer is not None, "Indexer is required for DSA." self.enable_sfa_cp = enable_sp() self.local_num_heads = self.num_heads - + self.vllm_config = get_current_vllm_config() if self.enable_sfa_cp: self.local_num_heads = self.num_heads * self.tp_size @@ -454,6 +460,29 @@ def get_and_maybe_dequant_weights(layer: LinearBase): post_process_after_loading_for_shared_weight_series( self.o_proj) + if self.enable_mlapo: + quant_method = getattr( + getattr(self.fused_qkv_a_proj, "quant_method", None), + "quant_method", + None, + ) + reasons = [] + if self.fused_qkv_a_proj is None or not isinstance( + quant_method, AscendW8A8LinearMethod): + reasons.append( + "Currently mlapo only supports W8A8 quantization in MLA scenario." + "Some layers in your model are not quantized with W8A8," + "thus mlapo is disabled for these layers.") + if self.enable_sfa_cp: + reasons.append("Currently mlapo does not support SFA with CP," + "thus mlapo is disabled for these layers.") + if reasons: + self.enable_mlapo = False + for msg in reasons: + logger.warning_once(msg) + else: + self._process_weights_for_fused_mlapo(act_dtype) + def _v_up_proj(self, x): if self.W_UV.shape[0] * self.W_UV.shape[1] < 65536: x = x.view(-1, self.local_num_heads, self.kv_lora_rank) @@ -555,6 +584,161 @@ def rope_single( x = torch_npu.npu_interleave_rope(x, cos, sin) return x.view(B, N, D) + # Processing the input parameters for MLAPO by reordering and transposing + # QKV(and part of Q) weight, applying RoPE-related dimension transformations, + # and handling quantization parameters. + def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype): + assert self.kv_a_proj_with_mqa is None + assert self.fused_qkv_a_proj is not None + + kv_a_proj_wt = self.fused_qkv_a_proj.weight.data[ + ..., self.q_lora_rank:].contiguous() + q_a_proj_wt = self.fused_qkv_a_proj.weight.data[ + ..., :self.q_lora_rank].contiguous() + + self.fused_qkv_a_proj.weight = None + + kv_a_proj_wt = kv_a_proj_wt.t().contiguous() + kv_a_proj_wt = trans_rope_weight(kv_a_proj_wt, self.qk_rope_head_dim) + kv_a_proj_wt = kv_a_proj_wt.t().contiguous() + wd_qkv = torch.cat((kv_a_proj_wt, q_a_proj_wt), dim=-1) + wd_qkv = wd_qkv.t().contiguous() + wd_qkv = transdata(wd_qkv, + block_size=(16, 32)).unsqueeze(0).contiguous() + self.wd_qkv = torch_npu.npu_format_cast(wd_qkv, 29) + + kv_a_proj_deq_scl = self.fused_qkv_a_proj.deq_scale[ + self.q_lora_rank:].contiguous() + q_a_proj_deq_scl = self.fused_qkv_a_proj.deq_scale[:self. + q_lora_rank].contiguous( + ) + kv_a_proj_deq_scl = kv_a_proj_deq_scl.reshape( + self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous() + kv_a_proj_deq_scl = trans_rope_weight(kv_a_proj_deq_scl, + self.qk_rope_head_dim) + kv_a_proj_deq_scl = kv_a_proj_deq_scl.view( + self.kv_lora_rank + self.qk_rope_head_dim).contiguous() + self.deq_scale_qkv = torch.cat((kv_a_proj_deq_scl, q_a_proj_deq_scl), + dim=-1).contiguous() + + kv_a_proj_qt_bias = self.fused_qkv_a_proj.quant_bias[ + self.q_lora_rank:].contiguous() + q_a_proj_qt_bias = self.fused_qkv_a_proj.quant_bias[:self. + q_lora_rank].contiguous( + ) + + kv_a_proj_qt_bias = kv_a_proj_qt_bias.reshape( + self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous() + kv_a_proj_qt_bias = trans_rope_weight(kv_a_proj_qt_bias, + self.qk_rope_head_dim) + kv_a_proj_qt_bias = kv_a_proj_qt_bias.view( + self.kv_lora_rank + self.qk_rope_head_dim).contiguous() + self.quant_bias_qkv = torch.cat((kv_a_proj_qt_bias, q_a_proj_qt_bias), + dim=-1).contiguous() + + wu_q = self.q_proj.weight.data + wu_q = wu_q.t().reshape(self.num_heads, + self.qk_nope_head_dim + self.qk_rope_head_dim, + -1) + wu_q = trans_rope_weight(wu_q, self.qk_rope_head_dim) + wu_q = wu_q.reshape( + self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim), + -1) + wu_q = transdata(wu_q, block_size=(16, 32)).unsqueeze(0).contiguous() + self.wu_q = torch_npu.npu_format_cast(wu_q, 29) + + qb_deq_scl = self.q_proj.deq_scale.data + qb_deq_scl = qb_deq_scl.reshape( + self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1) + qb_deq_scl = trans_rope_weight(qb_deq_scl, self.qk_rope_head_dim) + self.qb_deq_scl = qb_deq_scl.reshape( + self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim)) + + qb_qt_bias = self.q_proj.quant_bias.data + qb_qt_bias = qb_qt_bias.reshape( + self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1) + qb_qt_bias = trans_rope_weight(qb_qt_bias, self.qk_rope_head_dim) + self.qb_qt_bias = qb_qt_bias.reshape( + self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim)) + + device = self.q_proj.weight.device + self.gamma1 = self.q_a_layernorm.weight.data + self.beta1 = self.q_a_layernorm.bias.data + self.gamma2 = self.kv_a_layernorm.weight.data + self.quant_scale0 = self.fused_qkv_a_proj.input_scale.data + self.quant_offset0 = self.fused_qkv_a_proj.input_offset.data + self.quant_scale1 = self.q_proj.input_scale.data + self.quant_offset1 = self.q_proj.input_offset.data + self.ctkv_scale = torch.tensor([1], dtype=act_dtype, device=device) + self.q_nope_scale = torch.tensor([1], dtype=act_dtype, device=device) + + if self.vllm_config.kv_transfer_config is not None: + self.fused_qkv_a_proj.deq_scale = None + self.fused_qkv_a_proj.quant_bias = None + self.q_proj.deq_scale = None + self.q_proj.quant_bias = None + torch.npu.empty_cache() + + def _sfa_preprocessc_decode( + self, + hidden_states: torch.Tensor, + kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + attn_metadata: M, + need_gather_q_kv: bool, + num_actual_tokens: int, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + hidden_states.contiguous(), need_gather_q_kv) + k_nope, k_pe = kv_cache[0], kv_cache[1] + ql_nope = torch.empty( + (num_actual_tokens, self.W_UK_T.shape[0], k_nope.shape[-1]), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + q_pe = torch.empty( + (num_actual_tokens, self.W_UK_T.shape[0], k_pe.shape[-1]), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + q_c = torch.empty( + (num_actual_tokens, self.q_lora_rank), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + torch.ops._C_ascend.mla_preprocess( + hidden_states, + self.wd_qkv, + self.deq_scale_qkv, + self.gamma1, + self.beta1, + self.wu_q, + self.qb_deq_scl, + self.gamma2, + attn_metadata.cos, + attn_metadata.sin, + self.W_UK_T, + k_nope, + k_pe, + attn_metadata.slot_mapping[:num_actual_tokens].flatten(), + quant_scale0=self.quant_scale0, + quant_offset0=self.quant_offset0, + bias0=self.quant_bias_qkv, + quant_scale1=self.quant_scale1, + quant_offset1=self.quant_offset1, + bias1=self.qb_qt_bias, + ctkv_scale=self.ctkv_scale, + q_nope_scale=self.q_nope_scale, + cache_mode="krope_ctkv", + quant_mode="per_tensor_quant_asymm", + enable_inner_out=True, + q_out0=ql_nope, + kv_cache_out0=k_nope, + q_out1=q_pe, + kv_cache_out1=k_pe, + inner_out=q_c, + ) + return hidden_states, ql_nope, q_pe, q_c + def forward( self, layer_name, @@ -565,69 +749,76 @@ def forward( output: Optional[torch.Tensor] = None, ) -> torch.Tensor: assert output is not None, "Output tensor must be provided." + forward_context = get_forward_context() if attn_metadata is None: # Profiling run. - if self.enable_sfa_cp: - from vllm.forward_context import get_forward_context - if not get_forward_context().in_profile_run: - if is_hidden_layer(self.vllm_config, self.q_proj): - reach_layer_for_shared_weight_series(self.q_proj) - if is_hidden_layer(self.vllm_config, self.o_proj): - reach_layer_for_shared_weight_series(self.o_proj) - + if self.enable_sfa_cp and not forward_context.in_profile_run: + if is_hidden_layer(self.vllm_config, self.q_proj): + reach_layer_for_shared_weight_series(self.q_proj) + if is_hidden_layer(self.vllm_config, self.o_proj): + reach_layer_for_shared_weight_series(self.o_proj) return output.fill_(0) has_prefill = attn_metadata.has_prefill num_actual_tokens = attn_metadata.num_actual_tokens + cos = attn_metadata.cos + sin = attn_metadata.sin + actual_seq_lengths_query = attn_metadata.cum_query_lens + actual_seq_lengths_key = attn_metadata.seq_lens hidden_states = hidden_states[:num_actual_tokens] if self.enable_sfa_cp: need_gather_q_kv = False # Inputs and outputs may be padded for CUDA graphs output_padded = output output = output[:num_actual_tokens] - assert self.fused_qkv_a_proj is not None, "q lora is required for DSA." - maybe_npu_prefetch(inputs=self.fused_qkv_a_proj.weight, - dependency=hidden_states, - enabled=self.enable_prefetch) - qkv_lora = self.fused_qkv_a_proj(hidden_states)[0] - q_c, kv_no_split = qkv_lora.split( - [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], - dim=-1, - ) - q_c = self.q_a_layernorm(q_c) - - # Process for Flash Comm V1 - if need_gather_q_kv: - q_c = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( - q_c.contiguous(), need_gather_q_kv) - kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( - kv_no_split.contiguous(), need_gather_q_kv) - - if has_prefill: - wait_for_kv_layer_from_connector(layer_name) - cos = attn_metadata.cos - sin = attn_metadata.sin - slot_mapping = attn_metadata.slot_mapping[:num_actual_tokens] - slot_mapping_cp = None - actual_seq_lengths_query = attn_metadata.cum_query_lens - actual_seq_lengths_key = attn_metadata.seq_lens - if self.enable_sfa_cp: - assert attn_metadata.sfa_cp_context is not None - slot_mapping_cp = attn_metadata.sfa_cp_context.slot_mapping_cp - actual_seq_lengths_query = attn_metadata.sfa_cp_context.actual_seq_lengths_query - actual_seq_lengths_key = attn_metadata.sfa_cp_context.actual_seq_lengths_key - - self.exec_kv(kv_no_split, cos, sin, kv_cache, slot_mapping, - slot_mapping_cp) - - if self.enable_sfa_cp and attn_metadata.sfa_cp_context is not None: - if is_hidden_layer(self.vllm_config, self.q_proj): - reach_layer_for_shared_weight_series(self.q_proj) - if is_hidden_layer(self.vllm_config, self.o_proj): - reach_layer_for_shared_weight_series(self.o_proj) - - ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c) - q_pe = self.rope_single(q_pe, cos, sin) + if self.enable_mlapo and not forward_context.with_prefill: + hidden_states, ql_nope, q_pe, q_c = self._sfa_preprocessc_decode( + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + need_gather_q_kv=need_gather_q_kv, + num_actual_tokens=num_actual_tokens, + ) + else: + assert self.fused_qkv_a_proj is not None, "q lora is required for DSA." + maybe_npu_prefetch(inputs=self.fused_qkv_a_proj.weight, + dependency=hidden_states, + enabled=self.enable_prefetch) + qkv_lora = self.fused_qkv_a_proj(hidden_states)[0] + q_c, kv_no_split = qkv_lora.split( + [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], + dim=-1, + ) + q_c = self.q_a_layernorm(q_c) + # Process for Flash Comm V1 + if need_gather_q_kv: + q_c = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + q_c.contiguous(), need_gather_q_kv) + kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + kv_no_split.contiguous(), need_gather_q_kv) + + if has_prefill: + wait_for_kv_layer_from_connector(layer_name) + + slot_mapping = attn_metadata.slot_mapping[:num_actual_tokens] + slot_mapping_cp = None + if self.enable_sfa_cp: + assert attn_metadata.sfa_cp_context is not None + slot_mapping_cp = attn_metadata.sfa_cp_context.slot_mapping_cp + actual_seq_lengths_query = attn_metadata.sfa_cp_context.actual_seq_lengths_query + actual_seq_lengths_key = attn_metadata.sfa_cp_context.actual_seq_lengths_key + + self.exec_kv(kv_no_split, cos, sin, kv_cache, slot_mapping, + slot_mapping_cp) + + if self.enable_sfa_cp and attn_metadata.sfa_cp_context is not None: + if is_hidden_layer(self.vllm_config, self.q_proj): + reach_layer_for_shared_weight_series(self.q_proj) + if is_hidden_layer(self.vllm_config, self.o_proj): + reach_layer_for_shared_weight_series(self.o_proj) + + ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c) + q_pe = self.rope_single(q_pe, cos, sin) topk_indices = self.indexer_select( x=hidden_states,