Skip to content

Conversation

zyongye
Copy link
Member

@zyongye zyongye commented Sep 29, 2025

Rebased dsv32, based on #25869

Run command

vllm serve deepseek-ai/DeepSeek-V3.2-Exp  --max_model_len=20000 --gpu_memory_utilization=0.9 -tp 8 --max_num_seqs=256

gsm8k

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9568|±  |0.0056|
|     |       |strict-match    |     5|exact_match|↑  |0.9575|±  |0.0056|

gsm8k, 20-shot

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|    20|exact_match|↑  |0.9507|±  | 0.006|
|     |       |strict-match    |    20|exact_match|↑  |0.9507|±  | 0.006|

heheda12345 and others added 30 commits September 20, 2025 18:24
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>

fix smoke tests

Signed-off-by: Lucas Wilkinson <[email protected]>

moved to FlashMLA repo

Signed-off-by: Lucas Wilkinson <[email protected]>

removed pytorch shim

Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
…ild-sparse-flash-mla

Build and bind sparse-FlashMLA kernels
…integration

[Feature] DeepGEMM integration
* and env and MQA path for both prefill and decode

Signed-off-by: Lucas Wilkinson <[email protected]>

* fix shapes

Signed-off-by: Lucas Wilkinson <[email protected]>

---------

Signed-off-by: Lucas Wilkinson <[email protected]>
* code from ds

Signed-off-by: youkaichao <[email protected]>

* doc from ds

Signed-off-by: youkaichao <[email protected]>

* Fixes for support_materials/2-tilelang/

Signed-off-by: mgoin <[email protected]>

* Fix example 1

Signed-off-by: mgoin <[email protected]>

* Fix Einsum in deepgemm

* Fix `libc10.so` unimported error

* fix reference code

Signed-off-by: youkaichao <[email protected]>

* adding missing indexer args

* passing index args into the module

* init

Signed-off-by: Chen Zhang <[email protected]>

* build indexer k cache medadata

* prefill indexer, but weight_proj will output -inf

* unqiantized paged indexer, still have -inf issue

* remove support material

* adding topk_indices mask

* add weight scale

* unittest infrastructure and fix weight_proj, numeric error due to quantization

* varlen prefill passed

* paged prefill

* add indices mask

---------

Signed-off-by: youkaichao <[email protected]>
Signed-off-by: mgoin <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Co-authored-by: youkaichao <[email protected]>
Co-authored-by: mgoin <[email protected]>
Co-authored-by: Wentao Ye <[email protected]>
Co-authored-by: Chen Zhang <[email protected]>
* prefill mla

Signed-off-by: Chen Zhang <[email protected]>

* can run now

Signed-off-by: Chen Zhang <[email protected]>

* tmp

Signed-off-by: Chen Zhang <[email protected]>

* can output the first token

Signed-off-by: Chen Zhang <[email protected]>

* fix bug

Signed-off-by: Chen Zhang <[email protected]>

* remove some debug

Signed-off-by: Chen Zhang <[email protected]>

* update

Signed-off-by: Chen Zhang <[email protected]>

* hack through cu_seqlen_ks exploding issue

* update basic.py

Signed-off-by: Chen Zhang <[email protected]>

* remove some unnecessary changes

Signed-off-by: Chen Zhang <[email protected]>

* clean up

Signed-off-by: Chen Zhang <[email protected]>

---------

Signed-off-by: Chen Zhang <[email protected]>
Co-authored-by: Yongye Zhu <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: NickLucche <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: youkaichao <[email protected]>
@youkaichao
Copy link
Member

locally verified this PR has correct results:

local-completions (model=deepseek-ai/DeepSeek-V3.2-Exp,base_url=http://127.0.0.1:8000/v1/completions,num_concurrent=100,max_retries=3,tokenized_requests=False), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match||0.9613|±  |0.0053|
|     |       |strict-match    |     5|exact_match||0.9613|±  |0.0053|

Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
@heheda12345
Copy link
Collaborator

heheda12345 commented Sep 30, 2025

@youkaichao Can you help to try DeepSeek-R1? I got the following errors:

VLLM_USE_DEEP_GEMM=0 vllm serve deepseek-ai/DeepSeek-R1 -tp 8 --max-num-seqs 256

(Worker_TP0 pid=3647913) ERROR 09-30 00:32:14 [multiproc_executor.py:671]   File "/data/zhang-chen/vllm/.venv/lib/python3.12/site-packages/torch/fx/graph_module.py", line 411, in __call__
(Worker_TP0 pid=3647913) ERROR 09-30 00:32:14 [multiproc_executor.py:671]     return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
(Worker_TP0 pid=3647913) ERROR 09-30 00:32:14 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=3647913) ERROR 09-30 00:32:14 [multiproc_executor.py:671]   File "/data/zhang-chen/vllm/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
(Worker_TP0 pid=3647913) ERROR 09-30 00:32:14 [multiproc_executor.py:671]     return self._call_impl(*args, **kwargs)
(Worker_TP0 pid=3647913) ERROR 09-30 00:32:14 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=3647913) ERROR 09-30 00:32:14 [multiproc_executor.py:671]   File "/data/zhang-chen/vllm/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
(Worker_TP0 pid=3647913) ERROR 09-30 00:32:14 [multiproc_executor.py:671]     return forward_call(*args, **kwargs)
(Worker_TP0 pid=3647913) ERROR 09-30 00:32:14 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=3647913) ERROR 09-30 00:32:14 [multiproc_executor.py:671]   File "<eval_with_key>.124", line 696, in forward
(Worker_TP0 pid=3647913) ERROR 09-30 00:32:14 [multiproc_executor.py:671]     submod_1 = self.submod_1(getitem, s72, getitem_1, getitem_2, getitem_3);  getitem = getitem_1 = getitem_2 = submod_1 = None
(Worker_TP0 pid=3647913) ERROR 09-30 00:32:14 [multiproc_executor.py:671]                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=3647913) ERROR 09-30 00:32:14 [multiproc_executor.py:671]   File "/data/zhang-chen/vllm/.venv/lib/python3.12/site-packages/torch/fx/graph_module.py", line 848, in call_wrapped
(Worker_TP0 pid=3647913) ERROR 09-30 00:32:14 [multiproc_executor.py:671]     return self._wrapped_call(self, *args, **kwargs)
(Worker_TP0 pid=3647913) ERROR 09-30 00:32:14 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=3647913) ERROR 09-30 00:32:14 [multiproc_executor.py:671]   File "/data/zhang-chen/vllm/.venv/lib/python3.12/site-packages/torch/fx/graph_module.py", line 424, in __call__
(Worker_TP0 pid=3647913) ERROR 09-30 00:32:14 [multiproc_executor.py:671]     raise e
(Worker_TP0 pid=3647913) ERROR 09-30 00:32:14 [multiproc_executor.py:671]   File "/data/zhang-chen/vllm/.venv/lib/python3.12/site-packages/torch/fx/graph_module.py", line 411, in __call__
(Worker_TP0 pid=3647913) ERROR 09-30 00:32:14 [multiproc_executor.py:671]     return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
(Worker_TP0 pid=3647913) ERROR 09-30 00:32:14 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=3647913) ERROR 09-30 00:32:14 [multiproc_executor.py:671]   File "/data/zhang-chen/vllm/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
(Worker_TP0 pid=3647913) ERROR 09-30 00:32:14 [multiproc_executor.py:671]     return self._call_impl(*args, **kwargs)
(Worker_TP0 pid=3647913) ERROR 09-30 00:32:14 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=3647913) ERROR 09-30 00:32:14 [multiproc_executor.py:671]   File "/data/zhang-chen/vllm/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
(Worker_TP0 pid=3647913) ERROR 09-30 00:32:14 [multiproc_executor.py:671]     return forward_call(*args, **kwargs)
(Worker_TP0 pid=3647913) ERROR 09-30 00:32:14 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=3647913) ERROR 09-30 00:32:14 [multiproc_executor.py:671]   File "<eval_with_key>.2", line 5, in forward
(Worker_TP0 pid=3647913) ERROR 09-30 00:32:14 [multiproc_executor.py:671]     unified_attention_with_output = torch.ops.vllm.unified_attention_with_output(q, x_11, key_rot_1, output_1, 'model.layers.0.self_attn.attn');  q = x_11 = key_rot_1 = output_1 = unified_attention_with_output = None
(Worker_TP0 pid=3647913) ERROR 09-30 00:32:14 [multiproc_executor.py:671]                                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=3647913) ERROR 09-30 00:32:14 [multiproc_executor.py:671]   File "/data/zhang-chen/vllm/.venv/lib/python3.12/site-packages/torch/_ops.py", line 1243, in __call__
(Worker_TP0 pid=3647913) ERROR 09-30 00:32:14 [multiproc_executor.py:671]     return self._op(*args, **kwargs)
(Worker_TP0 pid=3647913) ERROR 09-30 00:32:14 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=3647913) ERROR 09-30 00:32:14 [multiproc_executor.py:671]   File "/data/zhang-chen/vllm/vllm/attention/layer.py", line 614, in unified_attention_with_output
(Worker_TP0 pid=3647913) ERROR 09-30 00:32:14 [multiproc_executor.py:671]     self.impl.forward(self,
(Worker_TP0 pid=3647913) ERROR 09-30 00:32:14 [multiproc_executor.py:671]   File "/data/zhang-chen/vllm/vllm/v1/attention/backends/mla/common.py", line 1767, in forward
(Worker_TP0 pid=3647913) ERROR 09-30 00:32:14 [multiproc_executor.py:671]     attn_out, lse = self._forward_decode(decode_q, kv_cache,
(Worker_TP0 pid=3647913) ERROR 09-30 00:32:14 [multiproc_executor.py:671]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=3647913) ERROR 09-30 00:32:14 [multiproc_executor.py:671]   File "/data/zhang-chen/vllm/vllm/v1/attention/backends/mla/flashmla.py", line 188, in _forward_decode
(Worker_TP0 pid=3647913) ERROR 09-30 00:32:14 [multiproc_executor.py:671]     o, lse = flash_mla_with_kvcache(
(Worker_TP0 pid=3647913) ERROR 09-30 00:32:14 [multiproc_executor.py:671]              ^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=3647913) ERROR 09-30 00:32:14 [multiproc_executor.py:671]   File "/data/zhang-chen/vllm/vllm/attention/ops/flashmla.py", line 140, in flash_mla_with_kvcache
(Worker_TP0 pid=3647913) ERROR 09-30 00:32:14 [multiproc_executor.py:671]     out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8(
(Worker_TP0 pid=3647913) ERROR 09-30 00:32:14 [multiproc_executor.py:671]                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=3647913) ERROR 09-30 00:32:14 [multiproc_executor.py:671]   File "/data/zhang-chen/vllm/.venv/lib/python3.12/site-packages/torch/_ops.py", line 1243, in __call__
(Worker_TP0 pid=3647913) ERROR 09-30 00:32:14 [multiproc_executor.py:671]     return self._op(*args, **kwargs)
(Worker_TP0 pid=3647913) ERROR 09-30 00:32:14 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=3647913) ERROR 09-30 00:32:14 [multiproc_executor.py:671] RuntimeError: Expected q.dtype() == torch::kFloat8_e4m3fn to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)
(Worker_TP0 pid=3647913) ERROR 09-30 00:32:14 [multiproc_executor.py:671] 

descale_k is None
), "descale_q and descale_k should be both None or both not None"

if (descale_q is not None) and (descale_k is not None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (descale_q is not None) and (descale_k is not None):
if indices is not None:

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LucasWilkinson does this make sense? @heheda12345 's error seems to indicate that deepseek r1 goes into this branch and calls torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8

Comment on lines -109 to -110
# Note(hc): need revisit when we support DCP with decode query_len > 1.
return out.squeeze(1), softmax_lse.squeeze(-1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LucasWilkinson do we need this as well for dcp?

@youkaichao
Copy link
Member

the issue reported by @heheda12345 seems to be kvcache dtype issue.

merging first to unblock further optimizations. @LucasWilkinson might help investigate further.

@youkaichao youkaichao merged commit fa7e254 into vllm-project:main Sep 30, 2025
63 of 80 checks passed
@cjackal
Copy link
Contributor

cjackal commented Sep 30, 2025

Seems like DSV3 AWQ quantized checkpoints are broken after this PR; the error message is like the following, let me write an issue for it:

RuntimeError: Expected q.dtype() == torch::kFloat8_e4m3fn to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.)

@njhill
Copy link
Member

njhill commented Sep 30, 2025

A CI test was reportedly broken by this (now failing on main):

[2025-09-30T15:41:37Z] FAILED v1/spec_decode/test_eagle.py::test_load_model[True-1-FLASH_ATTN-eagle] - RuntimeError: generator raised StopIteration

https://buildkite.com/vllm/ci/builds/32959#01999b1b-bec0-44a2-bca0-2523a6209558

Edit: I have opened a fix here: #25978

@youkaichao
Copy link
Member

@heheda12345 @cjackal can you help check if #25956 solves the problem?

@cjackal
Copy link
Contributor

cjackal commented Oct 1, 2025

@heheda12345 @cjackal can you help check if #25956 solves the problem?

Confirmed that it works normal after your PR, thank you for the prompt bugfix!

simon-mo pushed a commit that referenced this pull request Oct 1, 2025
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: mgoin <[email protected]>
Signed-off-by: NickLucche <[email protected]>
Signed-off-by: Yongye Zhu <[email protected]>
Signed-off-by: Barry Kang <[email protected]>
Signed-off-by: Lucia Fang <[email protected]>
Co-authored-by: Chen Zhang <[email protected]>
Co-authored-by: youkaichao <[email protected]>
Co-authored-by: Lucas Wilkinson <[email protected]>
Co-authored-by: Robert Shaw <[email protected]>
Co-authored-by: Lucas Wilkinson <[email protected]>
Co-authored-by: yewentao256 <[email protected]>
Co-authored-by: Wentao Ye <[email protected]>
Co-authored-by: mgoin <[email protected]>
Co-authored-by: Lucia Fang <[email protected]>
Co-authored-by: Lucia Fang <[email protected]>
Co-authored-by: NickLucche <[email protected]>
Co-authored-by: Siyuan Fu <[email protected]>
Co-authored-by: Matthew Bonanni <[email protected]>
Co-authored-by: Xiaozhu Meng <[email protected]>
Co-authored-by: Barry Kang <[email protected]>
Signed-off-by: simon-mo <[email protected]>
iboiko-habana pushed a commit to iboiko-habana/vllm-gaudi that referenced this pull request Oct 2, 2025
pdasigi pushed a commit to pdasigi/vllm that referenced this pull request Oct 2, 2025
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: mgoin <[email protected]>
Signed-off-by: NickLucche <[email protected]>
Signed-off-by: Yongye Zhu <[email protected]>
Signed-off-by: Barry Kang <[email protected]>
Signed-off-by: Lucia Fang <[email protected]>
Co-authored-by: Chen Zhang <[email protected]>
Co-authored-by: youkaichao <[email protected]>
Co-authored-by: Lucas Wilkinson <[email protected]>
Co-authored-by: Robert Shaw <[email protected]>
Co-authored-by: Lucas Wilkinson <[email protected]>
Co-authored-by: yewentao256 <[email protected]>
Co-authored-by: Wentao Ye <[email protected]>
Co-authored-by: mgoin <[email protected]>
Co-authored-by: Lucia Fang <[email protected]>
Co-authored-by: Lucia Fang <[email protected]>
Co-authored-by: NickLucche <[email protected]>
Co-authored-by: Siyuan Fu <[email protected]>
Co-authored-by: Matthew Bonanni <[email protected]>
Co-authored-by: Xiaozhu Meng <[email protected]>
Co-authored-by: Barry Kang <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build deepseek Related to DeepSeek models documentation Improvements or additions to documentation new-model Requests to new models ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm speculative-decoding tpu Related to Google TPUs v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.