-
Notifications
You must be signed in to change notification settings - Fork 104
牛马挣劳务队伍提交 patch 和 README #106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Hongzunhei
wants to merge
4
commits into
mindspore-lab:dev
Choose a base branch
from
Hongzunhei:feature/牛马挣劳务
base: dev
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
The head ref may contain hidden characters: "feature/\u725B\u9A6C\u6323\u52B3\u52A1"
Open
Changes from 3 commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
77 changes: 77 additions & 0 deletions
77
2025-Ascend-Innovation-Contest/S1/MultiModal/牛马挣劳务/README.md
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,77 @@ | ||
| # 队伍:牛马挣劳务 | ||
|
|
||
| ## 一、Qwen2-VL / Janus Pro 多模态推理优化说明 | ||
|
|
||
| 本仓库主要针对 **MindNLP + MindSpore** 下的多模态大模型(以 **Qwen2-VL / Janus Pro** 为主)进行推理侧优化,目标是: | ||
|
|
||
| - 在 **不改变模型行为与精度** 的前提下, | ||
| - **降低端到端推理延迟**、**提升吞吐**, | ||
| - 并尽量 **降低显存占用、清理冗余实现**,为后续维护和扩展打基础。 | ||
|
|
||
| 所有改动集中在: | ||
|
|
||
| - `mindnlp/transformers/models/qwen2_vl/modeling_qwen2_vl.py` | ||
| - `mindnlp/transformers/models/llama/modeling_llama.py` | ||
| - `llm/inference/janus_pro/janus/models/*` 及相关预处理逻辑 | ||
|
|
||
| 等文件中。 | ||
|
|
||
| --- | ||
|
|
||
| ## 二、整体优化思路概览 | ||
|
|
||
| 主要针对 **Prefill 时延、Decode 时延、峰值显存占用** 三方面的优化展开: | ||
|
|
||
| 1. **Prefill 时延** | ||
Hongzunhei marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| Prefill 时延主要是输入到产生第一个 token 的时间,可以看作「预处理 + 图像前向 + 文本编码」整体耗时。 | ||
|
|
||
| deepseek-ai/Janus-Pro-7B: | ||
Hongzunhei marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| - 在deepseek-ai/Janus-Pro-7B当中,Prefill时延瓶颈主要来源于预处理部分。详细参考llm/inference/janus_pro/janus/models/processing_vlm.py。 | ||
| 本人采用了一种一种笨笨的办法,通过[JIT装饰器](https://www.mindspore.cn/docs/zh-CN/r2.6.0/api_python/mindspore/mindspore.jit.html)加速操作。JIT装饰器主要作用是可以针对相同shape下的多个算子一起下发,节约时间。一定要是相同shape,否则将会重新编译,编译时长可能大于JIT能节约的时间。通过填充无用数据构造成相同shape后,再将无用数据剔除,还原原本shape的操作,可以将编译时间丢到warm-up里,后面正式进行推理时则节约了时间。具体收益不记得了,但是收益较大。 | ||
|
|
||
| Qwen/Qwen2-VL-2B-Instruct: | ||
Hongzunhei marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| - 在Qwen/Qwen2-VL-2B-Instruct当中,Prefill时延瓶颈主要来源于视觉部分。详细参考mindnlp/transformers/models/qwen2_vl/modeling_qwen2_vl.py。 | ||
| 通过Profiler工具可以观察到,有一个Conv3D的算子相当的耗时,将nn.Conv3d替换为mindspore.mint.nn.Conv3d能大大减少耗时,收益较大。 | ||
| - 另外一部分瓶颈来源于**VisionAttention**的计算。通过**flash_attention_score**融合算子可以大大减少Prefill时延,收益较大。 | ||
| ```python | ||
Hongzunhei marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| output = mops.flash_attention_score(query, key, value, head_num=head_num, input_layout='BSND', | ||
| real_shift=None, padding_mask=None, attn_mask=attn_mask, | ||
| scalar_value=scale_factor, keep_prob=1 - dropout_p, pre_tokens=2147483647, | ||
| next_tokens=2147483647, inner_precise=0, | ||
| drop_mask=None, prefix=None, actual_seq_qlen=None, actual_seq_kvlen=None, | ||
| sparse_mode=sparse_mode) | ||
| ``` | ||
| - 对于VisionAttention的attn_mask加一个判断,如果要构造一个没有掩码的mask,则直接将mask设为None。收益较小。 | ||
|
|
||
| 2. **Decode 时延** | ||
|
|
||
| Decode 时延主要指生成阶段「每步新增一个 token」的平均耗时,瓶颈通常在注意力计算、KV Cache 管理及 mask 构造。 | ||
|
|
||
| deepseek-ai/Janus-Pro-7B: | ||
| - 在deepseek-ai/Janus-Pro-7B当中,Decode时延瓶颈主要来源于apply_rotary_pos_emb。但可惜在这里用rotary_position_embedding融合算子会导致mismatch,所以deepseek-ai/Janus-Pro-7B的decode优化较少。详细参考mindnlp/transformers/models/llama/modeling_llama.py。 | ||
| - 针对apply_rotary_pos_emb,将q和k一起进行rotary_pos_emb有一丢丢收益,不大。 | ||
| ```python | ||
| qk = mindspore.mint.cat((q, k), dim=0) | ||
| qk_embed = mindspore.mint.mul(qk, cos) + mindspore.mint.mul(rotate_half(qk), sin) | ||
| q_embed, k_embed = mindspore.mint.split(qk_embed, 1, 0) | ||
| return q_embed, k_embed | ||
| ``` | ||
| - 将repeat_kv通过repeat_interleave融合算子替代。 | ||
| - 有很多可以使用mint替换ops,有一些有收益。很难估计,但收益都比较小,积少成多。 | ||
|
|
||
| Qwen/Qwen2-VL-2B-Instruct: | ||
| - 在Qwen/Qwen2-VL-2B-Instruct当中,Decode时延瓶颈主要来源于Qwen2RMSNorm、apply_rotary_pos_emb_vision等部分。详细参考mindnlp/transformers/models/qwen2_vl/modeling_qwen2_vl.py。 | ||
| - 用融合算子rotary_position_embedding替代原有rotary_pos_emb逻辑。 | ||
| - 将repeat_kv通过repeat_interleave融合算子替代。 | ||
| - 将apply_mrope_mask操作从Qwen2VLAttention.forward丢到Qwen2VLModel.forward进行,避免做N层重复运算。 | ||
| - 在Qwen2VLModel里增加对Prefill和decode阶段的判断,避免在decode阶段进行_prepare_4d_causal_attention_mask_with_cache_position操作。 | ||
| - 有很多可以使用mint替换ops,有一些有收益。很难估计,但收益都比较小,积少成多。 | ||
| 3. **峰值显存占用** | ||
|
|
||
| 峰值显存主要由 KV Cache、注意力中间结果以及多模态 embedding 组成。 | ||
| 为降低峰值显存,本次优化仅在Qwen/Qwen2-VL-2B-Instruct优化成功: | ||
| - **将VisionAttention当中的softmax去掉提升精度**:通过**flash_attention_score**融合算子也可以。 | ||
| ## 三、总结 | ||
| 改动较多,主要列举了收益较大的改动,其余改动见文件所示。未一一列举请多多包涵。 | ||
|
|
||
Hongzunhei marked this conversation as resolved.
Show resolved
Hide resolved
Hongzunhei marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Binary file not shown.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.