Skip to content

Conversation

am17an
Copy link
Collaborator

@am17an am17an commented Sep 20, 2025

This kernel does the following:

  1. softmax over the logits per token [n_experts, n_tokens]
  2. argmax reduce over the top-k (n_experts_used) logits
  3. write weights + ids to global memory

It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models

Should be more useful in TG than PP

Model Test t/s master t/s patch Speedup
qwen3moe 30B.A3B Q4_K_M pp512 6502.24 6561.50 1.01
qwen3moe 30B.A3B Q4_K_M tg128 194.63 207.75 1.07

@github-actions github-actions bot added testing Everything test related Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Sep 20, 2025
@am17an
Copy link
Collaborator Author

am17an commented Sep 21, 2025

Some more performance numbers

Model Microbatch size Test t/s master t/s patch Speedup
qwen3moe 30B.A3B Q4_K_M 1 pp4096 194.25 205.88 1.06
qwen3moe 30B.A3B Q4_K_M 2 pp4096 206.51 213.80 1.04
qwen3moe 30B.A3B Q4_K_M 4 pp4096 355.94 366.72 1.03
qwen3moe 30B.A3B Q4_K_M 8 pp4096 585.58 600.32 1.03
qwen3moe 30B.A3B Q4_K_M 16 pp4096 922.33 940.61 1.02
qwen3moe 30B.A3B Q4_K_M 32 pp4096 1533.54 1564.10 1.02
qwen3moe 30B.A3B Q4_K_M 256 pp4096 3722.88 3751.81 1.01
qwen3moe 30B.A3B Q4_K_M 512 pp4096 4559.34 4588.63 1.01

Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would recommend you cache the logits in registers from the start instead of reading the same data from VRAM twice.

@am17an
Copy link
Collaborator Author

am17an commented Sep 22, 2025

Did not see perf improvements after the changes, TG improves by 6-7% still

Comment on lines 2829 to 2836
//special case for topk-moe
if (ops.size() == 5 && ops.begin()[0] == GGML_OP_SOFT_MAX && ops.begin()[1] == GGML_OP_RESHAPE && ops.begin()[2] == GGML_OP_ARGSORT
&& ops.begin()[3] == GGML_OP_VIEW && ops.begin()[4] == GGML_OP_GET_ROWS) {

for (int i = 0; i < 5; i++) {
if (cgraph->nodes[node_idx + i]->op != ops.begin()[i]) return false;
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't all this redundant since it is performed in ggml_can_fuse?

Copy link
Member

@ggerganov ggerganov Sep 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the logic of this function should be that you always apply ggml_can_fuse first and only then do special-cases.

Edit: got it, the RESHAPE and the VIEW are problematic in this case.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, as I mentioned in #16102 (comment), if I have to remove the empty ops before passing to ggml_can_fuse then it still be a special case

@am17an
Copy link
Collaborator Author

am17an commented Sep 23, 2025

Hmm, looks like this causes a illegal memory access when running Qwen3-30B-A3B-Q4_0.gguf locally, tests in test-backend-ops don't capture this

@JohannesGaessler
Copy link
Collaborator

Compile the CUDA code with -lineinfo, then use the compute-sanitizer to get the exact line causing the issue.

@am17an
Copy link
Collaborator Author

am17an commented Sep 23, 2025

Yeah I did that, however I don't understand it yet. It seems like there is some tensor not getting propagated downstream properly

========= Invalid __global__ read of size 16 bytes
=========     at void quantize_mmq_q8_1<(mmq_q8_1_ds_layout)1>(const float *, const int *, void *, long, long, long, long, long, int, int)+0x530
=========     by thread (96,0,0) in block (2,0,0)
=========     Access to 0xb1f1a36600 is out of bounds
=========     and is 652591982081 bytes after the nearest allocation at 0x1a00000000 of size 2097152 bytes
=========     Saved host backtrace up to driver entry point at kernel launch time
=========         Host Frame: quantize_mmq_q8_1_cuda(float const*, int const*, void*, ggml_type, long, long, long, long, long, long, long, long, CUstream_st*) in quantize.cu:181 [0x289726] in libggml-cuda.so
=========         Host Frame: ggml_cuda_mul_mat_q(ggml_backend_cuda_context&, ggml_tensor const*, ggml_tensor const*, ggml_tensor const*, ggml_tensor*) in mmq.cu:341 [0x1a5f82] in libggml-cuda.so
=========         Host Frame: ggml_cuda_mul_mat_id(ggml_backend_cuda_context&, ggml_tensor*) in ggml-cuda.cu:2110 [0x177ea7] in libggml-cuda.so
=========         Host Frame: ggml_cuda_compute_forward(ggml_backend_cuda_context&, ggml_tensor*) in ggml-cuda.cu:2410 [0x17947d] in libggml-cuda.so
=========         Host Frame: evaluate_and_capture_cuda_graph(ggml_backend_cuda_context*, ggml_cgraph*, bool&, bool&, bool&) in ggml-cuda.cu:3008 [0x17bf39] in libggml-cuda.so
=========         Host Frame: ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) in ggml-cuda.cu:3126 [0x17c787] in libggml-cuda.so
=========         Host Frame: ggml_backend_graph_compute_async in ggml-backend.cpp:359 [0x6e896] in libggml-base.so
=========         Host Frame: ggml_backend_sched_compute_splits(ggml_backend_sched*) in ggml-backend.cpp:1553 [0x738e5] in libggml-base.so
=========         Host Frame: ggml_backend_sched_graph_compute_async in ggml-backend.cpp:1753 [0x74719] in libggml-base.so
=========         Host Frame: llama_context::graph_compute(ggml_cgraph*, bool) in llama-context.cpp:1460 [0x40cea1] in libllama.so
=========         Host Frame: llama_context::process_ubatch(llama_ubatch const&, llm_graph_type, llama_memory_context_i*, ggml_status&) in llama-context.cpp:784 [0x409b57] in libllama.so
=========         Host Frame: llama_context::decode(llama_batch const&) in llama-context.cpp:1088 [0x40b0d2] in libllama.so
=========         Host Frame: llama_decode in llama-context.cpp:2726 [0x411ba6] in libllama.so
=========         Host Frame: common_init_from_params(common_params&) in common.cpp:1066 [0x256b3d] in llama-cli
=========         Host Frame: main in main.cpp:140 [0x833e2] in llama-cli

@am17an
Copy link
Collaborator Author

am17an commented Sep 23, 2025

It looks like the dimensions are not right when using llama-cli for a model (but are correct when using test-backend-ops) I get n_expert = 128 and n_expert_used = 128 using llama-cli (where I expect n_expert_used to be 8), not sure if what happens to the RESHAPE/VIEW operations. @slaren could you please check if what I'm doing in llama-graph is correct?

@slaren
Copy link
Member

slaren commented Sep 23, 2025

@slaren could you please check if what I'm doing in llama-graph is correct?

Yes, there is no problem with calling ggml_build_forward_expand to force the nodes to be in a certain order, it is already done in other places.

@am17an
Copy link
Collaborator Author

am17an commented Sep 24, 2025

It does not crash when doing --no-warmup, somehow setting n_expert_used = n_expert during warmup triggers an illegal memory access, however doing this in test-backend-ops does not, and compute-sanitizer is also clean. I don't see anything that changes in warmup except for setting n_expert_used = n_expert, any help would be appreciated!

Changing this line to use n_expert_used makes everything work

n_expert_used (cparams.warmup ? hparams.n_expert : hparams.n_expert_used),

@am17an
Copy link
Collaborator Author

am17an commented Sep 24, 2025

Interestingly, if I move ggml_build_forward_expand after norm_w (in build_moe_ffn) and also fuse the norm, the warmup also works. So my suspicion is that it's something to do with how we make the graph during warmup, but I'm not sure

@ggerganov
Copy link
Member

I think the special case path added in this PR does not have bounds check for the number of nodes in the graph - could this be causing the illegal memory access?

@am17an
Copy link
Collaborator Author

am17an commented Sep 24, 2025

I think the special case path added in this PR does not have bounds check for the number of nodes in the graph - could this be causing the illegal memory access?

Added the bounds check and the problem is still there. Since it only happens on warmup and test-backend-ops cannot replicate, my suspicion is that it somehow this is messing up the graph. Is there another way to debug this?

@ggerganov
Copy link
Member

Does it still happen if you keep the build forward expand and remove the new fusing logic?

@am17an
Copy link
Collaborator Author

am17an commented Sep 24, 2025

Does it still happen if you keep the build forward expand and remove the new fusing logic?

No doesn't happen if I remove the fusing logic and keep the build forward expand

@am17an
Copy link
Collaborator Author

am17an commented Sep 24, 2025

Ok the bug was not handling ties properly in the kernel, after that it all works. I'm not exactly sure why though

@am17an
Copy link
Collaborator Author

am17an commented Sep 24, 2025

If this has no side-effects on the scheduler/allocator logic, would be the best option. I don't think the backends would ever need to see the empty nodes of the graph - they should always need only the nodes that do actual reads and writes.

This would complicate fusion at least. Right now you can look at llama-graph and call build forward to get the exact sequence you see in llama-graph. I don't know how it would look like once these ops go

@jeffbolznv
Copy link
Collaborator

I'm in favor of removing the empty nodes from the graph. I think it will simplify fusion and graph_optimize.

@am17an
Copy link
Collaborator Author

am17an commented Sep 24, 2025

Added optional norm + TODO about changes to do once we figure out how to handle empty ops. Performance results for Qwen3-30B-A3B-Q4_0.gguf on a RTX 4090

Model Microbatch size Test t/s master t/s patch Speedup
qwen3moe 30B.A3B Q4_K_M 1 pp512 203.32 219.91 1.08
qwen3moe 30B.A3B Q4_K_M 2 pp512 212.57 223.57 1.05
qwen3moe 30B.A3B Q4_K_M 4 pp512 368.25 383.43 1.04
qwen3moe 30B.A3B Q4_K_M 8 pp512 602.85 626.88 1.04
qwen3moe 30B.A3B Q4_K_M 16 pp512 966.81 986.07 1.02
qwen3moe 30B.A3B Q4_K_M 32 pp512 1644.49 1684.47 1.02
qwen3moe 30B.A3B Q4_K_M 512 pp512 6324.69 6394.87 1.01

This kernel does the following:
1. softmax over the logits per token [n_experts, n_tokens]
2. argmax reduce over the top-k (n_experts_used) logits
3. write weights + ids to global memory

It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models
@am17an am17an force-pushed the cuda_topk_moe branch 2 times, most recently from 4b2d2b9 to 639e954 Compare September 25, 2025 02:28
@am17an
Copy link
Collaborator Author

am17an commented Sep 25, 2025

@JohannesGaessler will you merge once CI passes?

@JohannesGaessler
Copy link
Collaborator

As described in CONTRIBUTING.md: "Let other maintainers merge their own PRs". So I won't merge this PR unless you specifically ask me to.

@am17an
Copy link
Collaborator Author

am17an commented Sep 25, 2025

As described in CONTRIBUTING.md: "Let other maintainers merge their own PRs". So I won't merge this PR unless you specifically ask me to.

I don't have write access. Let me ping you once the CI passes to merge

@am17an
Copy link
Collaborator Author

am17an commented Sep 25, 2025

The CI failures don't seem related, so this should be good to merge. @JohannesGaessler

@JohannesGaessler JohannesGaessler merged commit 077c94d into ggml-org:master Sep 25, 2025
58 of 64 checks passed
pwilkin pushed a commit to pwilkin/llama.cpp that referenced this pull request Sep 25, 2025
* CUDA: add a fused top-K MoE kernel

This kernel does the following:
1. softmax over the logits per token [n_experts, n_tokens]
2. argmax reduce over the top-k (n_experts_used) logits
3. write weights + ids to global memory

It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models

* Refactor into ggml_cuda_should_use_topk_moe

* Review: Use better coalescing pattern, use WARP_SIZE, store logits into registers before

* Review: format + micro-optimizations

* Fix bug: fix tie breakers

* Add optional norm + clean-up code

* Use smem for final write

* Add bounds check

* Use better memory pattern for writeback
@am17an am17an deleted the cuda_topk_moe branch September 26, 2025 08:43
struct pushed a commit to struct/llama.cpp that referenced this pull request Sep 26, 2025
* CUDA: add a fused top-K MoE kernel

This kernel does the following:
1. softmax over the logits per token [n_experts, n_tokens]
2. argmax reduce over the top-k (n_experts_used) logits
3. write weights + ids to global memory

It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models

* Refactor into ggml_cuda_should_use_topk_moe

* Review: Use better coalescing pattern, use WARP_SIZE, store logits into registers before

* Review: format + micro-optimizations

* Fix bug: fix tie breakers

* Add optional norm + clean-up code

* Use smem for final write

* Add bounds check

* Use better memory pattern for writeback
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs testing Everything test related

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants