Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
287 changes: 287 additions & 0 deletions 2025-Ascend-Innovation-Contest/S1/MultiModal/咚嚓嚓/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,287 @@
# 模型优化说明

本文档记录了针对两种多模态大模型(deepseek-ai/Janus-Pro-7B、Qwen/Qwen2-VL-2B-Instruct)的性能优化说明,并附带了核心代码实现。总体优化思路是用mindspore高性能算子([mindspore.ops.flash_attention_score](https://gitee.com/mindspore/mindspore/blob/v2.7.0-rc1/mindspore/python/mindspore/ops/operations/manually_defined/ops_def.py)),替换mindnlp仓库对应实现。

## 1. Janus-Pro-7B 模型

改进方案: 推理Janus-Pro-7B 模型,会调用LlamaAttention类的forward函数。用mindspore.ops.flash_attention_score算子替换原本mindnlp的flash attention计算。为了确保鲁棒性,保留标准注意力计算实现,即_standard_attention函数,作为回退方案。同时,经过实验发现需要额外为 Janus-Pro-7B 模型的prompt增加一个回车符,否则输出内容总会少俩词。

**源码实现** (`mindnlp/mindnlp/transformers/models/llama/modeling_llama.py`):

```python
import mindspore.ops as mindsporeops

class LlamaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
#省略部分形参及不重要代码(表示为...),便于阅读
def forward(...):
...
# 重塑为多头格式 [bsz, q_len, num_heads, head_dim]
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim)

# 转置为 [bsz, num_heads, q_len, head_dim] 格式以匹配原始实现
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)

if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin)."
)
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings

# 使用原始的apply_rotary_pos_emb函数,使用unsqueeze_dim=1
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=1)

if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

# ========== 使用 Flash Attention 替换标准注意力计算 ==========
if not output_attentions: # 只有在不需要输出注意力权重时使用flash attention
try:
# 现在query_states, key_states, value_states已经是 [bsz, num_heads, seq_len, head_dim] 格式
# 直接用于Flash Attention
query_states_fa = query_states # [bsz, num_heads, q_len, head_dim]
key_states_fa = key_states # [bsz, num_heads, kv_len, head_dim]
value_states_fa = value_states # [bsz, num_heads, kv_len, head_dim]

# 准备attention mask
attn_mask = None
if attention_mask is not None:
# attention_mask形状: [bsz, 1, q_len, kv_len]
# 需要扩展为 [bsz, num_heads, q_len, kv_len]
causal_mask = ops.narrow(attention_mask, 3, 0, key_states.shape[-2])
attn_mask = causal_mask.repeat(1, self.num_heads, 1, 1)
# 转换为bool类型
attn_mask = attn_mask.astype(mindspore.bool_)

attn_output_fa = mindsporeops.flash_attention_score(
query=query_states_fa,
key=key_states_fa,
value=value_states_fa,
head_num=self.num_heads,
attn_mask=attn_mask,
keep_prob=1.0, # 推理时确保无dropout
scalar_value=1.0 / math.sqrt(self.head_dim),
input_layout='BNSD',
sparse_mode=0
)

# Flash Attention的输出已经是 [bsz, num_heads, q_len, head_dim] 格式
attn_output = attn_output_fa
attn_weights = None

except Exception as e:
print(f"Flash attention failed: {e}")
print("Falling back to standard attention...")
# 回退到标准注意力计算
attn_output, attn_weights = self._standard_attention(
query_states, key_states, value_states, attention_mask
)
else:
# 如果需要输出注意力权重,使用标准实现
attn_output, attn_weights = self._standard_attention(
query_states, key_states, value_states, attention_mask
)
# ========== Flash Attention 替换结束 ==========

# 验证输出形状
if attn_output.shape != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.shape}"
)

# 转置回 [bsz, q_len, num_heads, head_dim] 格式
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, -1)

if self.config.pretraining_tp > 1:
attn_output = ops.split(attn_output, self.hidden_size // self.config.pretraining_tp, dim=2)
o_proj_slices = ops.split(self.o_proj.weight, self.hidden_size // self.config.pretraining_tp, dim=1)
attn_output = sum(F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp))
else:
attn_output = self.o_proj(attn_output)

return attn_output, attn_weights, past_key_value

def _standard_attention(self, query_states, key_states, value_states, attention_mask):
"""标准注意力计算实现(回退方案)"""
bsz, num_heads, q_len, head_dim = query_states.shape

# 计算注意力分数 [bsz, num_heads, q_len, kv_len]
attn_weights = ops.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

# 应用注意力掩码
if attention_mask is not None:
causal_mask = ops.narrow(attention_mask, 3, 0, key_states.shape[-2])
causal_mask = causal_mask.repeat(1, self.num_heads, 1, 1)
attn_weights = attn_weights + causal_mask

# Softmax和dropout
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=mindspore.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)

# 应用注意力权重 [bsz, num_heads, q_len, head_dim]
attn_output = ops.matmul(attn_weights, value_states)

return attn_output, attn_weights

```

**源码实现** (`mindnlp/llm/inference/janus_pro/janus/utils/io.py`):

````python
def load_pil_images(...):
...
#prompt增加一个\n
if message["content"].startswith("<image_placeholder>"):
if "\n" in message["content"]:
_, existing_prompt = message["content"].split("\n", 1)
message["content"] = f"<image_placeholder>\n\n{existing_prompt}"

return pil_images

````



## 2. Qwen2-VL-2B-Instruct 模型

改进方案: 推理Qwen2-VL-2B-Instruct 模型,会调用VisionAttention类的forward函数。用mindspore.ops.flash_attention_score算子替换原本mindnlp的flash attention计算。为了确保鲁棒性,保留标准注意力计算实现,即_standard_attention函数,作为回退方案。同时,经过实验发现需要额外为 Qwen2-VL-2B-Instruct 模型的prompt增加一个空格,否则输出内容总会少俩词。

**源码实现** (`mindnlp/mindnlp/transformers/models/qwen2_vl/modeling_qwen2_vl.py`):

```python
import mindspore.ops as mindsporeops

class VisionAttention(nn.Module):
def forward(
self, hidden_states: mindspore.Tensor, cu_seqlens: mindspore.Tensor, rotary_pos_emb: mindspore.Tensor = None
) -> mindspore.Tensor:
seq_length = hidden_states.shape[0]
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)

# 准备注意力掩码
attention_mask = ops.full(
[1, seq_length, seq_length], float(ops.finfo(q.dtype).min), dtype=q.dtype
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1]: cu_seqlens[i], cu_seqlens[i - 1]: cu_seqlens[i]] = 0

# ========== 使用 Flash Attention 替换标准注意力计算 ==========
try:
# 调整张量形状以匹配 Flash Attention 的输入格式 [batch_size, seq_length, num_heads, head_dim]
# 当前形状: [seq_length, num_heads, head_dim] -> [1, num_heads, seq_length, head_dim]
q_fa = q.transpose(0, 1).unsqueeze(0) # [1, num_heads, seq_length, head_dim]
k_fa = k.transpose(0, 1).unsqueeze(0) # [1, num_heads, seq_length, head_dim]
v_fa = v.transpose(0, 1).unsqueeze(0) # [1, num_heads, seq_length, head_dim]

# 准备 Flash Attention 的注意力掩码
# attention_mask 当前形状: [1, seq_length, seq_length]
# 需要扩展为 [1, num_heads, seq_length, seq_length]
attn_mask_fa = attention_mask.repeat(1, self.num_heads, 1, 1)
# 转换为 bool 类型
attn_mask_fa = attn_mask_fa.astype(mindspore.bool_)

# 使用 Flash Attention - 使用正确的参数名
attn_output_fa = mindsporeops.flash_attention_score(
query=q_fa,
key=k_fa,
value=v_fa,
head_num=self.num_heads,
attn_mask=attn_mask_fa,
keep_prob=1.0, # 推理时确保无 dropout
scalar_value=1.0 / math.sqrt(self.head_dim), # 修正参数名
input_layout='BNSD',
sparse_mode=0
)

# Flash Attention 输出形状: [1, num_heads, seq_length, head_dim]
# 转置回 [1, seq_length, num_heads, head_dim] 然后重塑
attn_output = attn_output_fa.transpose(1, 2) # [1, seq_length, num_heads, head_dim]
attn_output = attn_output.reshape(1, seq_length, -1).squeeze(0) # [seq_length, num_heads * head_dim]

except Exception as e:
print(f"Flash attention failed: {e}")
print("Falling back to standard attention...")
# 回退到标准注意力计算
attn_output = self._standard_attention(q, k, v, attention_mask, seq_length)

# ========== Flash Attention 替换结束 ==========

attn_output = self.proj(attn_output)
return attn_output

def _standard_attention(self, q: mindspore.Tensor, k: mindspore.Tensor, v: mindspore.Tensor,
attention_mask: mindspore.Tensor, seq_length: int) -> mindspore.Tensor:
"""标准注意力计算实现(回退方案)"""
# 转置以匹配标准注意力计算格式
q_std = q.swapaxes(0, 1) # [num_heads, seq_length, head_dim]
k_std = k.swapaxes(0, 1) # [num_heads, seq_length, head_dim]
v_std = v.swapaxes(0, 1) # [num_heads, seq_length, head_dim]

# 计算注意力权重
attn_weights = ops.matmul(q_std, k_std.swapaxes(1, 2)) / math.sqrt(self.head_dim)
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=mindspore.float32).to(q.dtype)

# 应用注意力权重
attn_output = ops.matmul(attn_weights, v_std)
attn_output = attn_output.swapaxes(0, 1) # [seq_length, num_heads, head_dim]
attn_output = attn_output.reshape(seq_length, -1) # [seq_length, num_heads * head_dim]

return attn_output

```

**源码实现** (`mindnlp/mindnlp/transformers/processing_utils.py`):

````python
class ProcessorMixin:
def apply_chat_template(...):
...
#给prompt加个" "
for message in conversation:
if "content" in message and isinstance(message["content"], list):
for content_item in message["content"]:
if isinstance(content_item, dict) and "text" in content_item:
content_item["text"] = " " + content_item["text"]

return self.tokenizer.apply_chat_template(
conversation, chat_template=chat_template, tokenize=tokenize, **kwargs
)

````



## 3. 最终收益

| model_name | memory_reserved | memory_allocated | avg_prefill_latency | avg_decode_latency |
| :--- | :--- | :--- | :--- | :--- |
| Janus-Pro-7B | 17.179869184 | 17.179869184 | 0.6178838014602661 | 0.04868452310562134 |
| Qwen2-VL-2B-Instruct | 6.442450944 | 5.466100736 | 0.6571066379547119 | 0.12304831266403198 |



## 4. 评测结果

| 评测指标 | 平均得分 |
|---------|---------|
| 峰值显存得分 | 116.6667 |
| Prefill时延得分 | 111.9669 |
| Decode时延得分 | 103.5013 |
| **总分** | **110.7116** |
Binary file not shown.