Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify Scheduler.update_running_batch #2154

Merged
merged 10 commits into from
Nov 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ SGLang is a fast serving framework for large language models and vision language
It makes your interaction with models faster and more controllable by co-designing the backend runtime and frontend language.
The core features include:

- **Fast Backend Runtime**: Provides efficient serving with RadixAttention for prefix caching, jump-forward constrained decoding, continuous batching, token attention (paged attention), tensor parallelism, FlashInfer kernels, chunked prefill, and quantization (INT4/FP8/AWQ/GPTQ).
- **Fast Backend Runtime**: Provides efficient serving with RadixAttention for prefix caching, jump-forward constrained decoding, overhead-free CPU scheduler, continuous batching, token attention (paged attention), tensor parallelism, FlashInfer kernels, chunked prefill, and quantization (FP8/INT4/AWQ/GPTQ).
- **Flexible Frontend Language**: Offers an intuitive interface for programming LLM applications, including chained generation calls, advanced prompting, control flow, multi-modal inputs, parallelism, and external interactions.
- **Extensive Model Support**: Supports a wide range of generative models (Llama, Gemma, Mistral, QWen, DeepSeek, LLaVA, etc.), embedding models (e5-mistral, gte, mcdse) and reward models (Skywork), with easy extensibility for integrating new models.
- **Active Community**: SGLang is open-source and backed by an active community with industry adoption.
Expand Down
1 change: 0 additions & 1 deletion docs/backend/backend.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
```
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --chunked-prefill-size 4096
```
- To enable the experimental overlapped scheduler, add `--enable-overlap-schedule`. It overlaps CPU scheduler with GPU computation and can accelerate almost all workloads. This does not work for constrained decoding currently.
- To enable torch.compile acceleration, add `--enable-torch-compile`. It accelerates small models on small batch sizes. This does not work for FP8 currently.
- To enable torchao quantization, add `--torchao-config int4wo-128`. It supports various quantization strategies.
- To enable fp8 weight quantization, add `--quantization fp8` on a fp16 checkpoint or directly load a fp8 checkpoint without specifying any arguments.
Expand Down
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ SGLang is a fast serving framework for large language models and vision language
It makes your interaction with models faster and more controllable by co-designing the backend runtime and frontend language.
The core features include:

- **Fast Backend Runtime**: Provides efficient serving with RadixAttention for prefix caching, jump-forward constrained decoding, continuous batching, token attention (paged attention), tensor parallelism, FlashInfer kernels, chunked prefill, and quantization (INT4/FP8/AWQ/GPTQ).
- **Fast Backend Runtime**: Provides efficient serving with RadixAttention for prefix caching, jump-forward constrained decoding, overhead-free CPU scheduler, continuous batching, token attention (paged attention), tensor parallelism, FlashInfer kernels, chunked prefill, and quantization (FP8/INT4/AWQ/GPTQ).
- **Flexible Frontend Language**: Offers an intuitive interface for programming LLM applications, including chained generation calls, advanced prompting, control flow, multi-modal inputs, parallelism, and external interactions.
- **Extensive Model Support**: Supports a wide range of generative models (Llama, Gemma, Mistral, QWen, DeepSeek, LLaVA, etc.), embedding models (e5-mistral, gte) and reward models (Skywork), with easy extensibility for integrating new models.
- **Active Community**: SGLang is open-source and backed by an active community with industry adoption.
Expand Down
1 change: 0 additions & 1 deletion docs/references/hyperparameter_tuning.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ If you see out of memory (OOM) errors, you can try to tune the following paramet
- You can also try to decrease `--mem-fraction-static`, which reduces the memory usage of the KV cache memory pool and helps both prefill and decoding.

### Try Advanced Options
- To enable the experimental overlapped scheduler, add `--enable-overlap-schedule`. It overlaps CPU scheduler with GPU computation and can accelerate almost all workloads. This does not work for constrained decoding currently.
- To enable torch.compile acceleration, add `--enable-torch-compile`. It accelerates small models on small batch sizes. This does not work for FP8 currently.

### Tune `--schedule-policy`
Expand Down
4 changes: 3 additions & 1 deletion python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,7 @@ class ScheduleBatch:
extend_lens: List[int] = None
extend_num_tokens: int = None
decoding_reqs: List[Req] = None
extend_logprob_start_lens: List[int] = None

# For encoder-decoder
encoder_cached: Optional[List[bool]] = None
Expand Down Expand Up @@ -722,7 +723,6 @@ def mix_with_running(self, running_batch: "ScheduleBatch"):
self.merge_batch(running_batch)
self.input_ids = input_ids
self.out_cache_loc = out_cache_loc
self.extend_num_tokens += running_bs

# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
self.prefix_lens.extend(
Expand All @@ -732,6 +732,8 @@ def mix_with_running(self, running_batch: "ScheduleBatch"):
]
)
self.extend_lens.extend([1] * running_bs)
self.extend_num_tokens += running_bs
# TODO (lianmin): Revisit this. It should be seq_len - 1
self.extend_logprob_start_lens.extend([0] * running_bs)

def check_decode_mem(self):
Expand Down
81 changes: 41 additions & 40 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# ==============================================================================
"""A scheduler that manages a tensor parallel GPU worker."""

import dataclasses
import logging
import os
import threading
Expand All @@ -28,7 +27,7 @@
import zmq

from sglang.global_config import global_config
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import (
Expand Down Expand Up @@ -302,6 +301,9 @@ def __init__(
) / global_config.default_new_token_ratio_decay_steps
self.new_token_ratio = self.init_new_token_ratio

# Tells whether the current running batch is full so that we can skip
# the check of whether to prefill new requests.
# This is an optimization to reduce the overhead of the prefill check.
self.batch_is_full = False

# Init watchdog thread
Expand Down Expand Up @@ -721,40 +723,30 @@ def check_memory(self):

def get_next_batch_to_run(self):
# Merge the prefill batch into the running batch
if (
self.last_batch
and not self.last_batch.forward_mode.is_decode()
and not self.last_batch.is_empty()
):
if self.last_batch and self.last_batch.forward_mode.is_extend():
if self.being_chunked_req:
# Move the chunked request out of the batch
self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req)
self.tree_cache.cache_unfinished_req(self.being_chunked_req)
# Inflight request keeps its rid but will get a new req_pool_idx.
# Inflight request keeps its rid but will get a new req_pool_idx
self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
self.batch_is_full = False

if not self.last_batch.is_empty():
if self.running_batch is None:
self.running_batch = self.last_batch
else:
self.running_batch.merge_batch(self.last_batch)

# Prefill first
# Run prefill first if possible
new_batch = self.get_new_batch_prefill()
if new_batch is not None:
return new_batch

# Check memory
if self.running_batch is None:
return

# Run decode
before_bs = self.running_batch.batch_size()
self.update_running_batch()
if not self.running_batch:
self.batch_is_full = False
if self.running_batch is None:
return None
if before_bs != self.running_batch.batch_size():
self.batch_is_full = False
self.running_batch = self.update_running_batch(self.running_batch)
return self.running_batch

def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
Expand Down Expand Up @@ -866,15 +858,16 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:

return new_batch

def update_running_batch(self):
def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
"""Update the current running decoding batch."""
global test_retract
batch = self.running_batch

initial_bs = batch.batch_size()

batch.filter_batch()
if batch.is_empty():
self.running_batch = None
return
self.batch_is_full = False
return None

# Check if decode out of memory
if not batch.check_decode_mem() or (test_retract and batch.batch_size() > 10):
Expand All @@ -900,11 +893,15 @@ def update_running_batch(self):
jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
self.waiting_queue.extend(jump_forward_reqs)
if batch.is_empty():
self.running_batch = None
return
self.batch_is_full = False
return None

if batch.batch_size() < initial_bs:
self.batch_is_full = False

# Update batch tensors
batch.prepare_for_decode(self.enable_overlap)
return batch

def run_batch(self, batch: ScheduleBatch):
"""Run a batch."""
Expand Down Expand Up @@ -979,8 +976,10 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result):
if req.is_retracted:
continue

if self.is_mixed_chunk and self.enable_overlap and req.finished():
raise ValueError("Unhandled error!")

if req.is_being_chunked <= 0:
# Inflight reqs' prefill is not finished
req.completion_tokens_wo_jump_forward += 1
req.output_ids.append(next_token_id)
req.check_finished()
Expand All @@ -990,14 +989,16 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result):
elif not batch.decoding_reqs or req not in batch.decoding_reqs:
self.tree_cache.cache_unfinished_req(req)

if req.grammar is not None:
req.grammar.accept_token(next_token_id)

if req.return_logprob:
# TODO (lianmin): need to think the case w/ mixed chunked prefill
logprob_pt += self.add_logprob_return_values(
i, req, logprob_pt, next_token_ids, logits_output
)

if req.grammar is not None:
req.grammar.accept_token(next_token_id)
else:
# Inflight reqs' prefill is not finished
req.is_being_chunked -= 1

if batch.next_batch_sampling_info:
Expand All @@ -1015,18 +1016,18 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result):
continue

req.embedding = embeddings[i]
if req.is_being_chunked > 0:
req.is_being_chunked -= 1
else:
# Inflight reqs' prefill is not finished
# dummy output token for embedding models
if req.is_being_chunked <= 0:
# Dummy output token for embedding models
req.output_ids.append(0)
req.check_finished()

if req.finished():
self.tree_cache.cache_finished_req(req)
if req.finished():
self.tree_cache.cache_finished_req(req)
else:
self.tree_cache.cache_unfinished_req(req)
else:
self.tree_cache.cache_unfinished_req(req)
# Inflight reqs' prefill is not finished
req.is_being_chunked -= 1

self.stream_output(batch.reqs)

Expand Down Expand Up @@ -1061,9 +1062,6 @@ def process_batch_result_decode(self, batch: ScheduleBatch, result):
req.output_ids.append(next_token_id)
req.check_finished()

if req.grammar is not None:
req.grammar.accept_token(next_token_id)

if req.finished():
self.tree_cache.cache_finished_req(req)

Expand All @@ -1074,6 +1072,9 @@ def process_batch_result_decode(self, batch: ScheduleBatch, result):
if req.top_logprobs_num > 0:
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])

if req.grammar is not None:
req.grammar.accept_token(next_token_id)

if batch.next_batch_sampling_info:
batch.next_batch_sampling_info.update_regex_vocab_mask()
torch.cuda.current_stream().synchronize()
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/test/few_shot_gsm8k.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def few_shot_gsm8k(s, question):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num-shots", type=int, default=5)
parser.add_argument("--data-path", type=str, default="test.jsonl")
parser.add_argument("--data-path", type=str)
parser.add_argument("--num-questions", type=int, default=200)
parser.add_argument("--max-new-tokens", type=int, default=512)
parser.add_argument("--parallel", type=int, default=128)
Expand Down
Loading