Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
297 changes: 244 additions & 53 deletions vllm_ascend/attention/sfa_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,27 @@
from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
from vllm.forward_context import get_forward_context
from vllm.logger import logger
from vllm.model_executor.layers.linear import (LinearBase, ReplicatedLinear,
UnquantizedLinearMethod)
from vllm.triton_utils import HAS_TRITON
from vllm.v1.attention.backends.utils import AttentionCGSupport

from vllm_ascend import envs
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
trans_rope_weight, transdata,
wait_for_kv_layer_from_connector)
from vllm_ascend.ops.shared_weight_layer import (
is_hidden_layer, post_process_after_loading_for_shared_weight_series,
reach_layer_for_shared_weight_series,
register_layer_to_shared_weight_series)
from vllm_ascend.ops.triton.rope import rope_forward_triton
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
_round_up, dispose_layer, enable_sp,
is_enable_nz, replace_layer)
Expand Down Expand Up @@ -341,12 +346,13 @@ def __init__(
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
self.enable_prefetch = ascend_config.weight_prefetch_config.enabled
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
self.vllm_config = get_current_vllm_config()
self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO

assert self.indexer is not None, "Indexer is required for DSA."

self.enable_sfa_cp = enable_sp()
self.local_num_heads = self.num_heads

self.vllm_config = get_current_vllm_config()
if self.enable_sfa_cp:
self.local_num_heads = self.num_heads * self.tp_size

Expand Down Expand Up @@ -454,6 +460,29 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
post_process_after_loading_for_shared_weight_series(
self.o_proj)

if self.enable_mlapo:
quant_method = getattr(
getattr(self.fused_qkv_a_proj, "quant_method", None),
"quant_method",
None,
)
reasons = []
if self.fused_qkv_a_proj is None or not isinstance(
quant_method, AscendW8A8LinearMethod):
reasons.append(
"Currently mlapo only supports W8A8 quantization in MLA scenario."
"Some layers in your model are not quantized with W8A8,"
"thus mlapo is disabled for these layers.")
if self.enable_sfa_cp:
reasons.append("Currently mlapo does not support SFA with CP,"
"thus mlapo is disabled for these layers.")
if reasons:
self.enable_mlapo = False
for msg in reasons:
logger.warning_once(msg)
else:
self._process_weights_for_fused_mlapo(act_dtype)

def _v_up_proj(self, x):
if self.W_UV.shape[0] * self.W_UV.shape[1] < 65536:
x = x.view(-1, self.local_num_heads, self.kv_lora_rank)
Expand Down Expand Up @@ -555,6 +584,161 @@ def rope_single(
x = torch_npu.npu_interleave_rope(x, cos, sin)
return x.view(B, N, D)

# Processing the input parameters for MLAPO by reordering and transposing
# QKV(and part of Q) weight, applying RoPE-related dimension transformations,
# and handling quantization parameters.
def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype):
assert self.kv_a_proj_with_mqa is None
assert self.fused_qkv_a_proj is not None

kv_a_proj_wt = self.fused_qkv_a_proj.weight.data[
..., self.q_lora_rank:].contiguous()
q_a_proj_wt = self.fused_qkv_a_proj.weight.data[
..., :self.q_lora_rank].contiguous()

self.fused_qkv_a_proj.weight = None

kv_a_proj_wt = kv_a_proj_wt.t().contiguous()
kv_a_proj_wt = trans_rope_weight(kv_a_proj_wt, self.qk_rope_head_dim)
kv_a_proj_wt = kv_a_proj_wt.t().contiguous()
wd_qkv = torch.cat((kv_a_proj_wt, q_a_proj_wt), dim=-1)
wd_qkv = wd_qkv.t().contiguous()
wd_qkv = transdata(wd_qkv,
block_size=(16, 32)).unsqueeze(0).contiguous()
self.wd_qkv = torch_npu.npu_format_cast(wd_qkv, 29)

kv_a_proj_deq_scl = self.fused_qkv_a_proj.deq_scale[
self.q_lora_rank:].contiguous()
q_a_proj_deq_scl = self.fused_qkv_a_proj.deq_scale[:self.
q_lora_rank].contiguous(
)
kv_a_proj_deq_scl = kv_a_proj_deq_scl.reshape(
self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous()
kv_a_proj_deq_scl = trans_rope_weight(kv_a_proj_deq_scl,
self.qk_rope_head_dim)
kv_a_proj_deq_scl = kv_a_proj_deq_scl.view(
self.kv_lora_rank + self.qk_rope_head_dim).contiguous()
self.deq_scale_qkv = torch.cat((kv_a_proj_deq_scl, q_a_proj_deq_scl),
dim=-1).contiguous()

kv_a_proj_qt_bias = self.fused_qkv_a_proj.quant_bias[
self.q_lora_rank:].contiguous()
q_a_proj_qt_bias = self.fused_qkv_a_proj.quant_bias[:self.
q_lora_rank].contiguous(
)

kv_a_proj_qt_bias = kv_a_proj_qt_bias.reshape(
self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous()
kv_a_proj_qt_bias = trans_rope_weight(kv_a_proj_qt_bias,
self.qk_rope_head_dim)
kv_a_proj_qt_bias = kv_a_proj_qt_bias.view(
self.kv_lora_rank + self.qk_rope_head_dim).contiguous()
self.quant_bias_qkv = torch.cat((kv_a_proj_qt_bias, q_a_proj_qt_bias),
dim=-1).contiguous()

wu_q = self.q_proj.weight.data
wu_q = wu_q.t().reshape(self.num_heads,
self.qk_nope_head_dim + self.qk_rope_head_dim,
-1)
wu_q = trans_rope_weight(wu_q, self.qk_rope_head_dim)
wu_q = wu_q.reshape(
self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim),
-1)
wu_q = transdata(wu_q, block_size=(16, 32)).unsqueeze(0).contiguous()
self.wu_q = torch_npu.npu_format_cast(wu_q, 29)

qb_deq_scl = self.q_proj.deq_scale.data
qb_deq_scl = qb_deq_scl.reshape(
self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1)
qb_deq_scl = trans_rope_weight(qb_deq_scl, self.qk_rope_head_dim)
self.qb_deq_scl = qb_deq_scl.reshape(
self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim))

qb_qt_bias = self.q_proj.quant_bias.data
qb_qt_bias = qb_qt_bias.reshape(
self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1)
qb_qt_bias = trans_rope_weight(qb_qt_bias, self.qk_rope_head_dim)
self.qb_qt_bias = qb_qt_bias.reshape(
self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim))

device = self.q_proj.weight.device
self.gamma1 = self.q_a_layernorm.weight.data
self.beta1 = self.q_a_layernorm.bias.data
self.gamma2 = self.kv_a_layernorm.weight.data
self.quant_scale0 = self.fused_qkv_a_proj.input_scale.data
self.quant_offset0 = self.fused_qkv_a_proj.input_offset.data
self.quant_scale1 = self.q_proj.input_scale.data
self.quant_offset1 = self.q_proj.input_offset.data
self.ctkv_scale = torch.tensor([1], dtype=act_dtype, device=device)
self.q_nope_scale = torch.tensor([1], dtype=act_dtype, device=device)

if self.vllm_config.kv_transfer_config is not None:
self.fused_qkv_a_proj.deq_scale = None
self.fused_qkv_a_proj.quant_bias = None
self.q_proj.deq_scale = None
self.q_proj.quant_bias = None
torch.npu.empty_cache()

def _sfa_preprocessc_decode(
self,
hidden_states: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
attn_metadata: M,
need_gather_q_kv: bool,
num_actual_tokens: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
hidden_states.contiguous(), need_gather_q_kv)
k_nope, k_pe = kv_cache[0], kv_cache[1]
ql_nope = torch.empty(
(num_actual_tokens, self.W_UK_T.shape[0], k_nope.shape[-1]),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
q_pe = torch.empty(
(num_actual_tokens, self.W_UK_T.shape[0], k_pe.shape[-1]),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
q_c = torch.empty(
(num_actual_tokens, self.q_lora_rank),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
torch.ops._C_ascend.mla_preprocess(
hidden_states,
self.wd_qkv,
self.deq_scale_qkv,
self.gamma1,
self.beta1,
self.wu_q,
self.qb_deq_scl,
self.gamma2,
attn_metadata.cos,
attn_metadata.sin,
self.W_UK_T,
k_nope,
k_pe,
attn_metadata.slot_mapping[:num_actual_tokens].flatten(),
quant_scale0=self.quant_scale0,
quant_offset0=self.quant_offset0,
bias0=self.quant_bias_qkv,
quant_scale1=self.quant_scale1,
quant_offset1=self.quant_offset1,
bias1=self.qb_qt_bias,
ctkv_scale=self.ctkv_scale,
q_nope_scale=self.q_nope_scale,
cache_mode="krope_ctkv",
quant_mode="per_tensor_quant_asymm",
enable_inner_out=True,
q_out0=ql_nope,
kv_cache_out0=k_nope,
q_out1=q_pe,
kv_cache_out1=k_pe,
inner_out=q_c,
)
return hidden_states, ql_nope, q_pe, q_c

def forward(
self,
layer_name,
Expand All @@ -565,69 +749,76 @@ def forward(
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert output is not None, "Output tensor must be provided."
forward_context = get_forward_context()
if attn_metadata is None:
# Profiling run.
if self.enable_sfa_cp:
from vllm.forward_context import get_forward_context
if not get_forward_context().in_profile_run:
if is_hidden_layer(self.vllm_config, self.q_proj):
reach_layer_for_shared_weight_series(self.q_proj)
if is_hidden_layer(self.vllm_config, self.o_proj):
reach_layer_for_shared_weight_series(self.o_proj)

if self.enable_sfa_cp and not forward_context.in_profile_run:
if is_hidden_layer(self.vllm_config, self.q_proj):
reach_layer_for_shared_weight_series(self.q_proj)
if is_hidden_layer(self.vllm_config, self.o_proj):
reach_layer_for_shared_weight_series(self.o_proj)
return output.fill_(0)
has_prefill = attn_metadata.has_prefill
num_actual_tokens = attn_metadata.num_actual_tokens
cos = attn_metadata.cos
sin = attn_metadata.sin
actual_seq_lengths_query = attn_metadata.cum_query_lens
actual_seq_lengths_key = attn_metadata.seq_lens
hidden_states = hidden_states[:num_actual_tokens]
if self.enable_sfa_cp:
need_gather_q_kv = False
# Inputs and outputs may be padded for CUDA graphs
output_padded = output
output = output[:num_actual_tokens]
assert self.fused_qkv_a_proj is not None, "q lora is required for DSA."
maybe_npu_prefetch(inputs=self.fused_qkv_a_proj.weight,
dependency=hidden_states,
enabled=self.enable_prefetch)
qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
q_c, kv_no_split = qkv_lora.split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
dim=-1,
)
q_c = self.q_a_layernorm(q_c)

# Process for Flash Comm V1
if need_gather_q_kv:
q_c = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
q_c.contiguous(), need_gather_q_kv)
kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
kv_no_split.contiguous(), need_gather_q_kv)

if has_prefill:
wait_for_kv_layer_from_connector(layer_name)

cos = attn_metadata.cos
sin = attn_metadata.sin
slot_mapping = attn_metadata.slot_mapping[:num_actual_tokens]
slot_mapping_cp = None
actual_seq_lengths_query = attn_metadata.cum_query_lens
actual_seq_lengths_key = attn_metadata.seq_lens
if self.enable_sfa_cp:
assert attn_metadata.sfa_cp_context is not None
slot_mapping_cp = attn_metadata.sfa_cp_context.slot_mapping_cp
actual_seq_lengths_query = attn_metadata.sfa_cp_context.actual_seq_lengths_query
actual_seq_lengths_key = attn_metadata.sfa_cp_context.actual_seq_lengths_key

self.exec_kv(kv_no_split, cos, sin, kv_cache, slot_mapping,
slot_mapping_cp)

if self.enable_sfa_cp and attn_metadata.sfa_cp_context is not None:
if is_hidden_layer(self.vllm_config, self.q_proj):
reach_layer_for_shared_weight_series(self.q_proj)
if is_hidden_layer(self.vllm_config, self.o_proj):
reach_layer_for_shared_weight_series(self.o_proj)

ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c)
q_pe = self.rope_single(q_pe, cos, sin)
if self.enable_mlapo and not forward_context.with_prefill:
hidden_states, ql_nope, q_pe, q_c = self._sfa_preprocessc_decode(
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
need_gather_q_kv=need_gather_q_kv,
num_actual_tokens=num_actual_tokens,
)
else:
assert self.fused_qkv_a_proj is not None, "q lora is required for DSA."
maybe_npu_prefetch(inputs=self.fused_qkv_a_proj.weight,
dependency=hidden_states,
enabled=self.enable_prefetch)
qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
q_c, kv_no_split = qkv_lora.split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
dim=-1,
)
q_c = self.q_a_layernorm(q_c)
# Process for Flash Comm V1
if need_gather_q_kv:
q_c = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
q_c.contiguous(), need_gather_q_kv)
kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
kv_no_split.contiguous(), need_gather_q_kv)

if has_prefill:
wait_for_kv_layer_from_connector(layer_name)

slot_mapping = attn_metadata.slot_mapping[:num_actual_tokens]
slot_mapping_cp = None
if self.enable_sfa_cp:
assert attn_metadata.sfa_cp_context is not None
slot_mapping_cp = attn_metadata.sfa_cp_context.slot_mapping_cp
actual_seq_lengths_query = attn_metadata.sfa_cp_context.actual_seq_lengths_query
actual_seq_lengths_key = attn_metadata.sfa_cp_context.actual_seq_lengths_key

self.exec_kv(kv_no_split, cos, sin, kv_cache, slot_mapping,
slot_mapping_cp)

if self.enable_sfa_cp and attn_metadata.sfa_cp_context is not None:
if is_hidden_layer(self.vllm_config, self.q_proj):
reach_layer_for_shared_weight_series(self.q_proj)
if is_hidden_layer(self.vllm_config, self.o_proj):
reach_layer_for_shared_weight_series(self.o_proj)

ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c)
q_pe = self.rope_single(q_pe, cos, sin)

topk_indices = self.indexer_select(
x=hidden_states,
Expand Down
Loading