Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix grid size in Triton decoding kernel #2134

Merged
merged 2 commits into from
Nov 23, 2024

Conversation

ispobock
Copy link
Collaborator

Motivation

Fix issue mentioned in #1935.

python -m sglang.bench_one_batch --batch-size 128 --input 128 --output 128 --model meta-llama/Llama-3.1-8B-Instruct --attention-backend triton

Prefill. latency: 0.37300 s, throughput:  43925.00 token/s
Decode.  latency: 0.01022 s, throughput:  12529.66 token/s
Decode.  latency: 0.01028 s, throughput:  12455.53 token/s
Decode.  latency: 0.01024 s, throughput:  12494.67 token/s
Decode.  latency: 0.01015 s, throughput:  12611.19 token/s
Decode.  latency: 0.01016 s, throughput:  12596.69 token/s
Decode.  median latency: 0.01000 s, median throughput:  12796.35 token/s
python -m sglang.bench_offline_throughput --model meta-llama/Llama-3.1-8B-Instruct --disable-radix --num-prompt 3000 --attention-backend triton

====== Offline Throughput Benchmark Result =======
Backend:                                 engine    
Successful requests:                     3000      
Benchmark duration (s):                  60.98     
Total input tokens:                      673672    
Total generated tokens:                  581627    
Request throughput (req/s):              49.20     
Input token throughput (tok/s):          11047.56  
Output token throughput (tok/s):         9538.11   
Total token throughput (tok/s):          20585.67  
==================================================

@@ -189,11 +186,12 @@ def _decode_att_m_fwd(
logit_cap,
):
BLOCK = 32
SPLIT_K = 8
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this parameter applicable to various cases?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested it on ShareGPT, 8 is an optimal selection.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That may need some tuning for different situations.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ref https://fireworks.ai/blog/why-gpus-on-demand

Prompt Lengths(Tokens) Fireworks Latency vLLM Latency
Long prompt (4000 input, 200 output) 2117 ms (at 7.5 QPS) 2877 ms (at 0.348 QPS)
Medium prompt (2000 input, 100 output) 740 ms (at 1.33 QPS) 1509 ms (at 0.663 QPS)
Short prompt (128 input, 4 output) 43.3 ms (at 22.51 QPS) 247 ms (at 4.056 QPS)

May we also tune the Medium prompt and Long prompt cases

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, I think 4k long prompt have nothing to do with "long," even though the blog defines them as such. In reality, some cases are around 30k-50k.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tested the throughput (req/s) for these cases, split=8 is also good.

python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache --trust-remote-code --tp 1 --attention-backend triton
python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 2000 --random-output 100 --random-range-ratio 1 --num-prompts 1000
python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 4000 --random-output 200 --random-range-ratio 1 --num-prompts 1000
Prompt Lengths(Tokens) Split = 4 Split = 8 Split = 16 Split = 32
Long prompt (4000 input, 200 output) 5.32 5.34 5.35 5.32
Medium prompt (2000 input, 100 output) 14.14 14.15 14.06 13.94

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The throughput looks good. How about the latency

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Median decode latency (ms):

python3 -m sglang.bench_one_batch --batch-size 1 --input 
2000 --output 100 --model meta-llama/Llama-3.1-8B-Instruct --attention-backend triton
python3 -m sglang.bench_one_batch --batch-size 1 --input 
4000 --output 200 --model meta-llama/Llama-3.1-8B-Instruct --attention-backend triton
Prompt Lengths(Tokens) Split = 4 Split = 8 Split = 16 Split = 32
Long prompt (4000 input, 200 output) 9.78 9.37 9.02 8.91
Medium prompt (2000 input, 100 output) 8.33 8.06 7.94 7.94

@zhyncs
Copy link
Member

zhyncs commented Nov 23, 2024

I think the failure in Unit Test 3 is introduced by introduced by #2081 (comment)

Copy link
Member

@zhyncs zhyncs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@zhyncs zhyncs merged commit c5f8650 into sgl-project:main Nov 23, 2024
12 of 13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants