Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eagle speculative decoding part 4: Add EAGLE2 worker #2150

Merged
merged 94 commits into from
Jan 2, 2025

Conversation

yukavio
Copy link
Collaborator

@yukavio yukavio commented Nov 24, 2024

Support eagle speculative decoding. The following results are obtained on a single H100.

Official eagle code: 200 token/s

see https://github.com/SafeAILab/EAGLE

Normal decoding speed (SGLang): 156 token/s

python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf

Eagle decoding speed (SGLang): 297 token/s

python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf  --speculative-algo EAGLE --speculative-draft lmzheng/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --mem-fraction 0.7

Eagle decoding speed (SGLang w/ torch.comopile): 316 token/s

python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf  --speculative-algo EAGLE --speculative-draft lmzheng/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --mem-fraction 0.7 --enable-torch-compile --cuda-graph-max-bs 2

Benchmark script

import time
import requests

tic = time.time()
response = requests.post(
    "http://localhost:30000/generate",
    json={
        "text": "[INST] Give me a simple FastAPI server. Show the python code. [/INST]",
        "sampling_params": {
            "temperature": 0,
            "max_new_tokens": 256,
        },
    },
)
latency = time.time() - tic
ret = response.json()

print(ret["text"])
speed = ret["meta_info"]["completion_tokens"]
print(f"speed: {speed / latency:.2f} token/s")

Some sub PRs:

@merrymercy merrymercy changed the title Speculative EAGLE2 Eagle speculative decoding part 4: Add EAGLE2 worker Jan 2, 2025
@merrymercy merrymercy merged commit 815dce0 into sgl-project:main Jan 2, 2025
15 checks passed
@zhyncs
Copy link
Member

zhyncs commented Jan 2, 2025

🎉🎉🎉

YAMY1234 pushed a commit to YAMY1234/sglang that referenced this pull request Jan 2, 2025
XiaotongJiang pushed a commit to XiaotongJiang/sglang that referenced this pull request Jan 3, 2025
@Xu-Chen
Copy link
Contributor

Xu-Chen commented Jan 16, 2025

When the batch size increases, the time taken for eagle_verify_retrive increases considerably. When the batch_size is increased to 10, the eagle_verify_retrive time increases to 0.15s for the 70b model on 4*A100, resulting in a slow overall throughput speed.

@yukavio
Copy link
Collaborator Author

yukavio commented Jan 17, 2025

Thanks for your report. I'll go confirm this. and look for possible solutions.

@mmdbhs
Copy link

mmdbhs commented Jan 20, 2025

can it used in awq model?

@Xu-Chen
Copy link
Contributor

Xu-Chen commented Jan 21, 2025

can it used in awq model?

You need to train the eagle model based on the awq model, and then you can use it.
Since the hidden states of transformer and sglang have certain precision differences, the acceleration effect will be discounted.

@Xu-Chen
Copy link
Contributor

Xu-Chen commented Jan 21, 2025

Thanks for your report. I'll go confirm this. and look for possible solutions.

Is there any progress? @yukavio

@yukavio
Copy link
Collaborator Author

yukavio commented Jan 21, 2025

Thanks for your report. I'll go confirm this. and look for possible solutions.

Is there any progress? @yukavio

I found only the first execution of the kernel is slow and it will not result in a slow overall throughput speed.
Test on 4*H800, batch size=10, model: LLAMA2-CHAT-70B
image

@Xu-Chen
Copy link
Contributor

Xu-Chen commented Jan 21, 2025

H800

Not sure if it's an A800(sm80) issue.
According to my test results on 4 * A800, model: Qwen2-72B-instruct.
When BS=1, the TPS of a single request is 75; when BS=10, the overall TPS = 230 - 260, and the average TPS of a single request is only 25.

Test Code

import argparse
import time
from openai import OpenAI
from multiprocessing import Pool
def infer_one(pid, openai_api_base, model_name, num):
    openai_api_key = "EMPTY"
    client = OpenAI(api_key=openai_api_key, base_url=openai_api_base)
    start_time = time.time()
    messages = [{"role": "user", "content": "以模型部署与推理加速为题,写一篇少于 2000 字的文章"}]
    llm_response = client.chat.completions.create(
        messages=messages,
        model=model_name,
        max_tokens=2048,
        temperature=0,
        stream=True
    )
    if pid == 0:
        print("\n########## start ##########")
    answer = ""
    num_token = 0
    for each in llm_response:
        if len(each.choices) == 0:
            continue
        content = each.choices[0].delta.content
        if content is not None:
            answer += content
            num_token += 1
            if pid == 0:
                print(content,end="",  flush=True)
    if pid == 0:
        print("\n########## end  ############")

def infer_main(pid, openai_api_base, model_name):
    for i in range(3):
        infer_one(pid, openai_api_base, model_name, i)

def parser_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--pid-num",
        type=int,
        help="work num",
        default=5,
    )
    parser.add_argument(
        "--host",
        type=str,
        help="host",
        default="",
    )
    parser.add_argument(
        "--model-name",
        type=str,
        help="model name",
        default="",
    )
    args = parser.parse_args()
    return args

if __name__ == '__main__':
    args = parser_args()
    print(args)
    pid_num = args.pid_num
    model_name = args.model_name
    openai_api_base = f"{args.host}/v1"
    assert openai_api_base != ""
    pool = Pool(processes=pid_num)
    process_list = []
    start_time = time.time()
    for i in range(pid_num):
        p = pool.apply_async(infer_main, (i, openai_api_base, model_name,))
        process_list.append(p)
    pool.close()
    pool.join()

@yukavio
Copy link
Collaborator Author

yukavio commented Jan 22, 2025

Sorry, I don't have the A800 for testing. Could you please give me a profile file of nsight-system of of your test? (Don't test too many requests. It will cause the file too large to open in a laptop) I could help you to find the problem that caused the server to slow. @Xu-Chen

@Xu-Chen
Copy link
Contributor

Xu-Chen commented Jan 22, 2025

Sorry, I don't have the A800 for testing. Could you please give me a profile file of nsight-system of of your test? (Don't test too many requests. It will cause the file too large to open in a laptop) I could help you to find the problem that caused the server to slow. @Xu-Chen

I generate test reports with nsight-systems-cli on 4 * A100, sglang.out-v1.nsys-rep.zip, thanks for your help. @yukavio

@yukavio
Copy link
Collaborator Author

yukavio commented Jan 23, 2025

230 - 260

H800

Not sure if it's an A800(sm80) issue. According to my test results on 4 * A800, model: Qwen2-72B-instruct. When BS=1, the TPS of a single request is 75; when BS=10, the overall TPS = 230 - 260, and the average TPS of a single request is only 25.

Test Code

import argparse
import time
from openai import OpenAI
from multiprocessing import Pool
def infer_one(pid, openai_api_base, model_name, num):
    openai_api_key = "EMPTY"
    client = OpenAI(api_key=openai_api_key, base_url=openai_api_base)
    start_time = time.time()
    messages = [{"role": "user", "content": "以模型部署与推理加速为题,写一篇少于 2000 字的文章"}]
    llm_response = client.chat.completions.create(
        messages=messages,
        model=model_name,
        max_tokens=2048,
        temperature=0,
        stream=True
    )
    if pid == 0:
        print("\n########## start ##########")
    answer = ""
    num_token = 0
    for each in llm_response:
        if len(each.choices) == 0:
            continue
        content = each.choices[0].delta.content
        if content is not None:
            answer += content
            num_token += 1
            if pid == 0:
                print(content,end="",  flush=True)
    if pid == 0:
        print("\n########## end  ############")

def infer_main(pid, openai_api_base, model_name):
    for i in range(3):
        infer_one(pid, openai_api_base, model_name, i)

def parser_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--pid-num",
        type=int,
        help="work num",
        default=5,
    )
    parser.add_argument(
        "--host",
        type=str,
        help="host",
        default="",
    )
    parser.add_argument(
        "--model-name",
        type=str,
        help="model name",
        default="",
    )
    args = parser.parse_args()
    return args

if __name__ == '__main__':
    args = parser_args()
    print(args)
    pid_num = args.pid_num
    model_name = args.model_name
    openai_api_base = f"{args.host}/v1"
    assert openai_api_base != ""
    pool = Pool(processes=pid_num)
    process_list = []
    start_time = time.time()
    for i in range(pid_num):
        p = pool.apply_async(infer_main, (i, openai_api_base, model_name,))
        process_list.append(p)
    pool.close()
    pool.join()

What command line did you use to start the service?

@Xu-Chen
Copy link
Contributor

Xu-Chen commented Jan 23, 2025

What command line did you use to start the service?

/opt/conda/bin/python3 -m sglang.launch_server --model Qwen/Qwen2-72B-Instruct --speculative-algo EAGLE --speculative-draft ./eagle-qwen2-72b-instruct --speculative-num-steps 5 --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --mem-fraction 0.7 --cuda-graph-max-bs 16 --context-length 32768 --tp 4 --port 8080 --dtype bfloat16

./eagle-qwen2-72b-instruct is EAGLE-Qwen2-72B-Instruct, change Qwen2ForCausalLM to Qwen2ForCausalLMEagle.

The model in the test is the accelerated model of qwen2.5 that we trained. You can test based on the qwen2 model and change the test input prompt to English.

We have collected open-source data in both Chinese and English, and trained an Eagle2 model based on the Qwen 2.5 model. Currently, we are facing issues with high-concurrency inference speed. Once this issue is resolved, we plan to open-source the model. We tested some samples on sglang, and under single-request scenarios, achieved a 2x acceleration for both Chinese and English.

@yukavio
Copy link
Collaborator Author

yukavio commented Jan 23, 2025

It is reasonable that the TPS of each request will decrease as the batch size increases because speculative decoding is a method that helps us to improve the computation efficiency with small batch size.
But you can try to set a smaller value of 'speculative-num-steps', 'speculative-eagle-topk' and 'speculative-num-draft-tokens' which will help you to get better performance with large batch size. @Xu-Chen

@Xu-Chen
Copy link
Contributor

Xu-Chen commented Jan 23, 2025

It is reasonable that the TPS of each request will decrease as the batch size increases because speculative decoding is a method that helps us to improve the computation efficiency with small batch size. But you can try to set a smaller value of 'speculative-num-steps', 'speculative-eagle-topk' and 'speculative-num-draft-tokens' which will help you to get better performance with large batch size. @Xu-Chen

You are so right. The following parameters can achieve up to 1.5X the speed at high concurrency. The parameter speculative-num-draft-tokens particularly affects performance at high batch-size( we test 10).

--speculative-num-steps 3
--speculative-eagle-topk 4
--speculative-num-draft-tokens 16

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants