Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 130 additions & 12 deletions vllm_ascend/attention/sfa_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
trans_rope_weight, transdata,
wait_for_kv_layer_from_connector)
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
Expand Down Expand Up @@ -131,6 +132,8 @@
common_attn_metadata: AscendCommonAttentionMetadata,
model: nn.Module,
) -> AscendSFAMetadata:
cos = common_attn_metadata.cos
sin = common_attn_metadata.sin
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
Expand Down Expand Up @@ -164,10 +167,10 @@
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs].to(
torch.int32).to(device, non_blocking=True)

cos = self.cos_cache[input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
sin = self.sin_cache[input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
cos[:num_actual_tokens,
...] = self.cos_cache[input_positions].unsqueeze(1).unsqueeze(2)

Check failure on line 171 in vllm_ascend/attention/sfa_v1.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Value of type "None" is not indexable [index]

Check failure on line 171 in vllm_ascend/attention/sfa_v1.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Value of type "None" is not indexable [index]

Check failure on line 171 in vllm_ascend/attention/sfa_v1.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Value of type "None" is not indexable [index]

Check failure on line 171 in vllm_ascend/attention/sfa_v1.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Value of type "None" is not indexable [index]

Check failure on line 171 in vllm_ascend/attention/sfa_v1.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Value of type "None" is not indexable [index]
sin[:num_actual_tokens,
...] = self.sin_cache[input_positions].unsqueeze(1).unsqueeze(2)

Check failure on line 173 in vllm_ascend/attention/sfa_v1.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Value of type "None" is not indexable [index]

Check failure on line 173 in vllm_ascend/attention/sfa_v1.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Value of type "None" is not indexable [index]

Check failure on line 173 in vllm_ascend/attention/sfa_v1.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Value of type "None" is not indexable [index]

Check failure on line 173 in vllm_ascend/attention/sfa_v1.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Value of type "None" is not indexable [index]

Check failure on line 173 in vllm_ascend/attention/sfa_v1.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Value of type "None" is not indexable [index]

return self.metadata_cls( # type: ignore
has_prefill=has_prefill,
Expand All @@ -180,8 +183,8 @@
attn_mask=common_attn_metadata.attn_mask,
attn_state=common_attn_metadata.attn_state,
block_tables=block_table,
sin=sin,
cos=cos)
sin=sin[:num_actual_tokens],
cos=cos[:num_actual_tokens])

def build_for_graph_capture(
self,
Expand Down Expand Up @@ -332,6 +335,7 @@
# Waiting for BMM NZ support
# self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29)
# self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29)
self._process_weights_for_fused_mlapo(act_dtype)

def _v_up_proj(self, x):
if self.W_UV.shape[0] * self.W_UV.shape[1] < 65536:
Expand Down Expand Up @@ -404,6 +408,84 @@
x = torch_npu.npu_interleave_rope(x, cos, sin)
return x.view(B, N, D)

def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype):
kv_a_proj_wt = self.fused_qkv_a_proj.weight.data[
..., self.q_lora_rank:].contiguous()
q_a_proj_wt = self.fused_qkv_a_proj.weight.data[
..., :self.q_lora_rank].contiguous()
kv_a_proj_wt = kv_a_proj_wt.t().contiguous()
kv_a_proj_wt = trans_rope_weight(kv_a_proj_wt, self.qk_rope_head_dim)
kv_a_proj_wt = kv_a_proj_wt.t().contiguous()
wd_qkv = torch.cat((kv_a_proj_wt, q_a_proj_wt), dim=-1)
wd_qkv = wd_qkv.t().contiguous()
wd_qkv = transdata(wd_qkv,
block_size=(16, 32)).unsqueeze(0).contiguous()
self.wd_qkv = torch_npu.npu_format_cast(wd_qkv, 29)

kv_a_proj_deq_scl = self.fused_qkv_a_proj.deq_scale[
self.q_lora_rank:].contiguous()
q_a_proj_deq_scl = self.fused_qkv_a_proj.deq_scale[:self.
q_lora_rank].contiguous(
)
kv_a_proj_deq_scl = kv_a_proj_deq_scl.reshape(
self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous()
kv_a_proj_deq_scl = trans_rope_weight(kv_a_proj_deq_scl,
self.qk_rope_head_dim)
kv_a_proj_deq_scl = kv_a_proj_deq_scl.view(
self.kv_lora_rank + self.qk_rope_head_dim).contiguous()
self.deq_scale_qkv = torch.cat((kv_a_proj_deq_scl, q_a_proj_deq_scl),
dim=-1).contiguous()

kv_a_proj_qt_bias = self.fused_qkv_a_proj.quant_bias[
self.q_lora_rank:].contiguous()
q_a_proj_qt_bias = self.fused_qkv_a_proj.quant_bias[:self.
q_lora_rank].contiguous(
)
kv_a_proj_qt_bias = kv_a_proj_qt_bias.reshape(
self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous()
kv_a_proj_qt_bias = trans_rope_weight(kv_a_proj_qt_bias,
self.qk_rope_head_dim)
kv_a_proj_qt_bias = kv_a_proj_qt_bias.view(
self.kv_lora_rank + self.qk_rope_head_dim).contiguous()
self.quant_bias_qkv = torch.cat((kv_a_proj_qt_bias, q_a_proj_qt_bias),
dim=-1).contiguous()

wu_q = self.q_proj.weight.data
wu_q = wu_q.t().reshape(self.num_heads,
self.qk_nope_head_dim + self.qk_rope_head_dim,
-1)
wu_q = trans_rope_weight(wu_q, self.qk_rope_head_dim)
wu_q = wu_q.reshape(
self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim),
-1)
wu_q = transdata(wu_q, block_size=(16, 32)).unsqueeze(0).contiguous()
self.wu_q = torch_npu.npu_format_cast(wu_q, 29)

qb_deq_scl = self.q_proj.deq_scale.data
qb_deq_scl = qb_deq_scl.reshape(
self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1)
qb_deq_scl = trans_rope_weight(qb_deq_scl, self.qk_rope_head_dim)
self.qb_deq_scl = qb_deq_scl.reshape(
self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim))

qb_qt_bias = self.q_proj.quant_bias.data
qb_qt_bias = qb_qt_bias.reshape(
self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1)
qb_qt_bias = trans_rope_weight(qb_qt_bias, self.qk_rope_head_dim)
self.qb_qt_bias = qb_qt_bias.reshape(
self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim))

device = self.q_proj.weight.device
self.gamma1 = self.q_a_layernorm.weight.data
self.beta1 = self.q_a_layernorm.bias.data
self.gamma2 = self.kv_a_layernorm.weight.data
self.quant_scale0 = self.fused_qkv_a_proj.input_scale.data
self.quant_offset0 = self.fused_qkv_a_proj.input_offset.data
self.quant_scale1 = self.q_proj.input_scale.data
self.quant_offset1 = self.q_proj.input_offset.data
self.ctkv_scale = torch.tensor([1], dtype=act_dtype, device=device)
self.q_nope_scale = torch.tensor([1], dtype=act_dtype, device=device)

def forward(
self,
layer_name,
Expand Down Expand Up @@ -443,12 +525,48 @@
if has_prefill:
wait_for_kv_layer_from_connector(layer_name)

slot_mapping = attn_metadata.slot_mapping[:num_actual_tokens]
ql_nope, q_pe = \
self._q_proj_and_k_up_proj(q_c)
q_pe = self.rope_single(q_pe, attn_metadata.cos, attn_metadata.sin)
k_pe, k_nope = self.exec_kv(kv_no_split, attn_metadata.cos,
attn_metadata.sin, kv_cache, slot_mapping)
k_nope, k_pe = kv_cache[0], kv_cache[1]
ql_nope = torch.empty(
(hidden_states.shape[0], self.W_UK_T.shape[0], k_nope.shape[-1]),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
q_pe = torch.empty(
(hidden_states.shape[0], self.W_UK_T.shape[0], k_pe.shape[-1]),
dtype=hidden_states.dtype,
device=hidden_states.device,
)

torch.ops._C_ascend.mla_preprocess(
hidden_states,
self.wd_qkv,
self.deq_scale_qkv,
self.gamma1,
self.beta1,
self.wu_q,
self.qb_deq_scl,
self.gamma2,
attn_metadata.cos,
attn_metadata.sin,
self.W_UK_T,
k_nope,
k_pe,
attn_metadata.slot_mapping[:num_actual_tokens].flatten(),
quant_scale0=self.quant_scale0,
quant_offset0=self.quant_offset0,
bias0=self.quant_bias_qkv,
quant_scale1=self.quant_scale1,
quant_offset1=self.quant_offset1,
bias1=self.qb_qt_bias,
ctkv_scale=self.ctkv_scale,
q_nope_scale=self.q_nope_scale,
cache_mode="krope_ctkv",
quant_mode="per_tensor_quant_asymm",
q_out0=ql_nope,
kv_cache_out0=k_nope,
q_out1=q_pe,
kv_cache_out1=k_pe,
)

topk_indices = self.indexer_select(x=hidden_states,
qr=q_c,
Expand Down
3 changes: 1 addition & 2 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,8 +409,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
dtype=torch.int32,
device=self.device)

if self.vllm_config.model_config.use_mla and \
self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
if self.vllm_config.model_config.use_mla:
rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
self.cos = torch.ones(self.max_num_reqs *
self.decode_token_per_req,
Expand Down
Loading