diff --git a/examples/runtime/engine/offline_batch_inference.py b/examples/runtime/engine/offline_batch_inference.py index 7404c7e4e7f..8918c6ef8b6 100644 --- a/examples/runtime/engine/offline_batch_inference.py +++ b/examples/runtime/engine/offline_batch_inference.py @@ -1,21 +1,23 @@ import sglang as sgl - +import time def main(): # Sample prompts. prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", + "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: Where is the capital city of France? ASSISTANT:", + "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: 北京今天天气怎么样? ASSISTANT:" ] # Create a sampling params object. - sampling_params = {"temperature": 0.8, "top_p": 0.95} + sampling_params = {"temperature": 0, "max_new_tokens": 30} - # Create an LLM. - llm = sgl.Engine(model_path="meta-llama/Meta-Llama-3.1-8B-Instruct") + # Create an LLM. + llm = sgl.Engine(model_path="Llama-2-7b-chat-hf", draft_model_path='EAGLE-llama2-chat-7B', disable_cuda_graph=True, num_speculative_steps=5, eagle_topk=8, num_draft_tokens=64, speculative_algorithm='EAGLE', mem_fraction_static=0.60) + #llm = sgl.Engine(model_path="Llama-2-7b-chat-hf", disable_cuda_graph=False) + #outputs = llm.generate(prompts, sampling_params) + start = time.time() outputs = llm.generate(prompts, sampling_params) + print(time.time()-start) # Print the outputs. for prompt, output in zip(prompts, outputs): print("===============================") diff --git a/python/sglang/srt/layers/attention/__init__.py b/python/sglang/srt/layers/attention/__init__.py index f5d573f5f7b..0b503d8b05f 100644 --- a/python/sglang/srt/layers/attention/__init__.py +++ b/python/sglang/srt/layers/attention/__init__.py @@ -23,9 +23,12 @@ def init_cuda_graph_state(self, max_bs: int): def init_forward_metadata_capture_cuda_graph( self, bs: int, + num_token: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, - encoder_lens: Optional[torch.Tensor] = None, + encoder_lens: torch.Tensor = None, + spec_info=None, + is_draft_runner=False, ): """Init the metadata for a forward pass for capturing a cuda graph.""" raise NotImplementedError() @@ -33,10 +36,12 @@ def init_forward_metadata_capture_cuda_graph( def init_forward_metadata_replay_cuda_graph( self, bs: int, + num_token: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, - encoder_lens: Optional[torch.Tensor] = None, + encoder_lens=None, + spec_info=None, ): """Init the metadata for a forward pass for replying a cuda graph.""" raise NotImplementedError() @@ -54,7 +59,9 @@ def forward( forward_batch: ForwardBatch, ): """Run forward on an attention layer.""" - if forward_batch.forward_mode.is_decode(): + if forward_batch.forward_mode.is_verify(): + return self.forward_extend(q, k, v, layer, forward_batch) + elif forward_batch.forward_mode.is_decode(): return self.forward_decode(q, k, v, layer, forward_batch) else: return self.forward_extend(q, k, v, layer, forward_batch) diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index c6b5393ee92..a2c7ded54e0 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -16,12 +16,13 @@ from sglang.global_config import global_config from sglang.srt.layers.attention import AttentionBackend -from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.utils import is_flashinfer_available if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.model_runner import ModelRunner + from sglang.srt.speculative.speculative_utils import SpecInput if is_flashinfer_available(): from flashinfer import ( @@ -124,20 +125,40 @@ def __init__(self, model_runner: ModelRunner): def init_forward_metadata(self, forward_batch: ForwardBatch): if forward_batch.forward_mode.is_decode(): + wrappers = self.prefill_wrappers_paged if forward_batch.forward_mode.is_verify() \ + else self.decode_wrappers self.indices_updater_decode.update( forward_batch.req_pool_indices, forward_batch.seq_lens, forward_batch.seq_lens_sum, - decode_wrappers=None, + decode_wrappers=wrappers, encoder_lens=forward_batch.encoder_lens, + forward_batch=forward_batch, ) - self.forward_metadata = (self.decode_wrappers,) + if forward_batch.forward_mode.is_verify(): + self.forward_metadata = (False, False, None) + else: + self.forward_metadata = (wrappers,) + elif forward_batch.forward_mode.is_spec_extend(): + use_ragged = False + extend_no_prefix = True + prefix_lens = None + self.indices_updater_prefill.update( + forward_batch.req_pool_indices, + forward_batch.seq_lens, + prefix_lens, + use_ragged=use_ragged, + encoder_lens=forward_batch.encoder_lens, + forward_batch=forward_batch, + ) + self.forward_metadata = (use_ragged, extend_no_prefix, None) else: prefix_lens = forward_batch.extend_prefix_lens # Some heuristics to check whether to use ragged forward use_ragged = False - if forward_batch.extend_num_tokens >= 4096 and self.num_wrappers == 1: + if (forward_batch.extend_num_tokens >= 4096 and self.num_wrappers == 1 and + not forward_batch.forward_mode.is_verify()): use_ragged = True extend_no_prefix = not torch.any(forward_batch.extend_prefix_lens).item() @@ -148,9 +169,10 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): prefix_lens, use_ragged=use_ragged, encoder_lens=forward_batch.encoder_lens, + forward_batch=forward_batch, ) - self.forward_metadata = (use_ragged, extend_no_prefix) + self.forward_metadata = (use_ragged, extend_no_prefix, None) def init_cuda_graph_state(self, max_bs: int): cuda_graph_kv_indices = torch.zeros( @@ -161,53 +183,88 @@ def init_cuda_graph_state(self, max_bs: int): self.cuda_graph_kv_indices = [cuda_graph_kv_indices] + [ cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1) ] + + self.cuda_graph_custom_mask = torch.zeros( + (max_bs * (self.max_context_len+7)//8), + dtype=torch.uint8, + device="cuda", + ) + self.cuda_graph_qk_indptr = [x.clone() for x in self.kv_indptr] + self.cuda_graph_qo_indptr = [x.clone() for x in self.kv_indptr] def init_forward_metadata_capture_cuda_graph( self, bs: int, + num_token: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, encoder_lens: torch.Tensor = None, + spec_info:SpecInput=None, + is_draft_runner:bool=False, + forward_batch: ForwardBatch=None ): decode_wrappers = [] for i in range(self.num_wrappers): - decode_wrappers.append( - BatchDecodeWithPagedKVCacheWrapper( - self.workspace_buffer, - "NHD", - use_cuda_graph=True, - use_tensor_cores=self.decode_use_tensor_cores, - paged_kv_indptr_buffer=self.kv_indptr[i][: bs + 1], - paged_kv_indices_buffer=self.cuda_graph_kv_indices[i], - paged_kv_last_page_len_buffer=self.kv_last_page_len[:bs], + # speculative decodign verify stage + if spec_info is not None and not is_draft_runner: + decode_wrappers.append( + BatchPrefillWithPagedKVCacheWrapper( + self.workspace_buffer, + "NHD", + use_cuda_graph=True, + qo_indptr_buf=self.cuda_graph_qo_indptr[i][:bs+1], + paged_kv_indptr_buf=self.kv_indptr[i][: bs + 1], + paged_kv_indices_buf=self.cuda_graph_kv_indices[i], + paged_kv_last_page_len_buf=self.kv_last_page_len[:bs], + custom_mask_buf=self.cuda_graph_custom_mask, + qk_indptr_buf=self.cuda_graph_qk_indptr[i][:bs+1], + ) ) - ) + self.forward_metadata = (False, False, decode_wrappers) + + else: + decode_wrappers.append( + BatchDecodeWithPagedKVCacheWrapper( + self.workspace_buffer, + "NHD", + use_cuda_graph=True, + use_tensor_cores=self.decode_use_tensor_cores, + paged_kv_indptr_buffer=self.kv_indptr[i][: num_token + 1], + paged_kv_indices_buffer=self.cuda_graph_kv_indices[i], + paged_kv_last_page_len_buffer=self.kv_last_page_len[:num_token], + ) + ) + self.forward_metadata = (decode_wrappers,) - seq_lens_sum = seq_lens.sum().item() - self.indices_updater_decode.update( - req_pool_indices, - seq_lens, - seq_lens_sum, - decode_wrappers=decode_wrappers, - encoder_lens=encoder_lens, - ) - self.cuda_graph_metadata[bs] = decode_wrappers - self.forward_metadata = (decode_wrappers,) + seq_lens_sum = seq_lens.sum().item() + self.indices_updater_decode.update( + req_pool_indices, + seq_lens, + seq_lens_sum, + decode_wrappers=decode_wrappers, + encoder_lens=encoder_lens, + forward_batch=forward_batch, + ) + self.cuda_graph_metadata[num_token] = decode_wrappers + def init_forward_metadata_replay_cuda_graph( self, bs: int, + num_token: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, - encoder_lens: torch.Tensor = None, + encoder_lens=None, + forward_batch=None, ): self.indices_updater_decode.update( req_pool_indices[:bs], seq_lens[:bs], seq_lens_sum, - decode_wrappers=self.cuda_graph_metadata[bs], + decode_wrappers=self.cuda_graph_metadata[num_token], encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None, + forward_batch=forward_batch ) def get_cuda_graph_seq_len_fill_value(self): @@ -220,7 +277,7 @@ def forward_extend( self._get_wrapper_idx(layer) ] - use_ragged, extend_no_prefix = self.forward_metadata + use_ragged, extend_no_prefix, graph_wrapper = self.forward_metadata cache_loc = ( forward_batch.out_cache_loc if not layer.is_cross_attention @@ -231,15 +288,24 @@ def forward_extend( if k is not None: assert v is not None forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) - - o = prefill_wrapper_paged.forward( - q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), - forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), - causal=not layer.is_cross_attention, - sm_scale=layer.scaling, - window_left=layer.sliding_window_size, - logits_soft_cap=layer.logit_cap, - ) + if graph_wrapper is not None and forward_batch.forward_mode.is_verify(): + o = graph_wrapper[self._get_wrapper_idx(layer)].forward( + q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), + causal=False, + sm_scale=layer.scaling, + window_left=layer.sliding_window_size, + logits_soft_cap=layer.logit_cap, + ) + else: + o = prefill_wrapper_paged.forward( + q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), + causal=not layer.is_cross_attention, + sm_scale=layer.scaling, + window_left=layer.sliding_window_size, + logits_soft_cap=layer.logit_cap, + ) else: o1, s1 = self.prefill_wrapper_ragged.forward_return_lse( q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), @@ -335,7 +401,8 @@ def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend): self.update = self.update_single_wrapper def update( - self, req_pool_indices, seq_lens, seq_lens_sum, decode_wrappers, encoder_lens + self, req_pool_indices, seq_lens, seq_lens_sum, decode_wrappers, + encoder_lens, forward_batch ): # Keep the signature for type checking. It will be assigned during runtime. raise NotImplementedError() @@ -347,6 +414,7 @@ def update_single_wrapper( seq_lens_sum: int, decode_wrappers=None, encoder_lens=None, + forward_batch=None, ): decode_wrappers = decode_wrappers or self.decode_wrappers self.call_begin_forward( @@ -356,6 +424,7 @@ def update_single_wrapper( seq_lens_sum, self.kv_indptr[0], None, + forward_batch, ) def update_sliding_window( @@ -365,6 +434,7 @@ def update_sliding_window( seq_lens_sum: int, decode_wrappers=None, encoder_lens=None, + forward_batch=None, ): decode_wrappers = decode_wrappers or self.decode_wrappers @@ -399,6 +469,7 @@ def update_cross_attention( seq_lens_sum, decode_wrappers=None, encoder_lens=None, + forward_batch=None, ): decode_wrappers = decode_wrappers or self.decode_wrappers @@ -430,36 +501,63 @@ def call_begin_forward( paged_kernel_lens_sum, kv_indptr, kv_start_idx, + forward_batch=None, ): - bs = len(req_pool_indices) - kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) - kv_indptr = kv_indptr[: bs + 1] - kv_indices = torch.empty( - paged_kernel_lens_sum, dtype=torch.int32, device="cuda" - ) + + bs = forward_batch.input_ids.numel() + if forward_batch.spec_info is not None: + kv_indices, kv_indptr, kv_last_page_len, qo_indptr = ( + forward_batch.spec_info.generate_attn_arg( + req_pool_indices, + paged_kernel_lens, + self.req_to_token, + ) + ) + else: + kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = torch.empty( + paged_kernel_lens_sum, dtype=torch.int32, device="cuda" + ) + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + paged_kernel_lens, + kv_indptr, + kv_start_idx, + kv_indices, + self.max_context_len, + ) - create_flashinfer_kv_indices_triton[(bs,)]( - self.req_to_token, - req_pool_indices, - paged_kernel_lens, - kv_indptr, - kv_start_idx, - kv_indices, - self.max_context_len, - ) - wrapper.end_forward() - wrapper.begin_forward( - kv_indptr, - kv_indices, - self.kv_last_page_len[:bs], - self.num_qo_heads, - self.num_kv_heads, - self.head_dim, - 1, - data_type=self.data_type, - q_data_type=self.q_data_type, - ) + if forward_batch.forward_mode.is_verify(): + bs = len(req_pool_indices) + custom_mask = getattr(forward_batch.spec_info, "custom_mask", None) + wrapper.end_forward() + wrapper.begin_forward( + qo_indptr, + kv_indptr, + kv_indices, + self.kv_last_page_len[:bs], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + 1, + custom_mask=custom_mask, + ) + else: + wrapper.end_forward() + wrapper.begin_forward( + kv_indptr, + kv_indices, + self.kv_last_page_len[:bs], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + 1, + data_type=self.data_type, + q_data_type=self.q_data_type, + ) class FlashInferIndicesUpdaterPrefill: @@ -496,12 +594,14 @@ def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend): assert self.attn_backend.num_wrappers == 1 self.update = self.update_single_wrapper - def update(self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens): + def update(self, req_pool_indices, seq_lens, prefix_lens, use_ragged, + encoder_lens, forward_batch): # Keep the signature for type checking. It will be assigned during runtime. raise NotImplementedError() def update_single_wrapper( - self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens + self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens, + forward_batch ): if use_ragged: paged_kernel_lens = prefix_lens @@ -519,10 +619,12 @@ def update_single_wrapper( self.kv_indptr[0], self.qo_indptr[0], use_ragged, + forward_batch ) def update_sliding_window( - self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens + self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens, + forward_batch ): for wrapper_id in range(2): if wrapper_id == 0: @@ -547,10 +649,12 @@ def update_sliding_window( self.kv_indptr[wrapper_id], self.qo_indptr[wrapper_id], use_ragged, + forward_batch, ) def update_cross_attention( - self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens + self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens, + forward_batch ): for wrapper_id in range(2): if wrapper_id == 0: @@ -573,6 +677,7 @@ def update_cross_attention( self.kv_indptr[wrapper_id], self.qo_indptr[wrapper_id], use_ragged, + forward_batch ) def call_begin_forward( @@ -587,23 +692,35 @@ def call_begin_forward( kv_indptr, qo_indptr, use_ragged, + forward_batch, ): bs = len(req_pool_indices) - kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) - kv_indptr = kv_indptr[: bs + 1] - kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda") - create_flashinfer_kv_indices_triton[(bs,)]( - self.req_to_token, - req_pool_indices, - paged_kernel_lens, - kv_indptr, - kv_start_idx, - kv_indices, - self.max_context_len, - ) + if forward_batch.forward_mode.is_spec_extend(): + # spec extend update generate arg + kv_indices, kv_indptr, kv_last_page_len, qo_indptr = ( + forward_batch.spec_info.generate_attn_arg_spec_extend( + req_pool_indices, + paged_kernel_lens, + self.req_to_token, + ) + ) + else: + kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda") + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + paged_kernel_lens, + kv_indptr, + kv_start_idx, + kv_indices, + self.max_context_len, + ) - qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0) - qo_indptr = qo_indptr[: bs + 1] + qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0) + qo_indptr = qo_indptr[: bs + 1] + kv_last_page_len = self.kv_last_page_len[:bs] # extend part if use_ragged: @@ -622,7 +739,7 @@ def call_begin_forward( qo_indptr, kv_indptr, kv_indices, - self.kv_last_page_len[:bs], + kv_last_page_len[:bs], self.num_qo_heads, self.num_kv_heads, self.head_dim, @@ -663,4 +780,4 @@ def create_flashinfer_kv_indices_triton( data = tl.load(req_to_token_ptr + ld_offset, mask=mask) tl.store(kv_indices_ptr + st_offset, data, mask=mask) ld_offset += BLOCK_SIZE - st_offset += BLOCK_SIZE + st_offset += BLOCK_SIZE \ No newline at end of file diff --git a/python/sglang/srt/layers/attention/flashinfer_utils.py b/python/sglang/srt/layers/attention/flashinfer_utils.py new file mode 100644 index 00000000000..58964aaf65e --- /dev/null +++ b/python/sglang/srt/layers/attention/flashinfer_utils.py @@ -0,0 +1,297 @@ +from enum import Enum, auto + +import torch +import triton +import triton.language as tl + + +class WrapperDispatch(Enum): + SLIDING_WINDOW = auto() + CROSS_ATTENTION = auto() + + +@triton.jit +def create_flashinfer_kv_indices_triton( + req_to_token_ptr, # [max_batch, max_context_len] + req_pool_indices_ptr, + page_kernel_lens_ptr, + kv_indptr, + kv_start_idx, + kv_indices_ptr, + max_context_len: tl.constexpr, +): + BLOCK_SIZE: tl.constexpr = 512 + pid = tl.program_id(axis=0) + req_pool_index = tl.load(req_pool_indices_ptr + pid) + kv_indices_offset = tl.load(kv_indptr + pid) + + kv_start = 0 + kv_end = 0 + if kv_start_idx: + kv_start = tl.load(kv_start_idx + pid).to(tl.int32) + kv_end = kv_start + kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32) + + req_to_token_ptr += req_pool_index * max_context_len + kv_indices_ptr += kv_indices_offset + + ld_offset = kv_start + tl.arange(0, BLOCK_SIZE) + st_offset = tl.arange(0, BLOCK_SIZE) + num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE) + for _ in range(num_loop): + mask = ld_offset < kv_end + data = tl.load(req_to_token_ptr + ld_offset, mask=mask) + tl.store(kv_indices_ptr + st_offset, data, mask=mask) + ld_offset += BLOCK_SIZE + st_offset += BLOCK_SIZE + + +class FlashinferUpdater: + def __init__( + self, + forward_mode, + model_runner, + req_pool_indices, + seq_lens, + prefix_lens, + decode_wrappers=None, + use_ragged=False, + spec_info=None, + use_cuda_graph=False, + ): + self.forward_mode = forward_mode + self.model_runner = model_runner + self.req_pool_indices = req_pool_indices + self.seq_lens = seq_lens + self.prefix_lens = prefix_lens + self.use_ragged = use_ragged + self.spec_info = spec_info + self.use_cuda_graph = use_cuda_graph + + self.num_qo_heads = ( + model_runner.model_config.num_attention_heads // model_runner.tp_size + ) + self.num_kv_heads = model_runner.model_config.get_num_kv_heads( + model_runner.tp_size + ) + self.head_dim = model_runner.model_config.head_dim + self.batch_size = len(req_pool_indices) + + self.decode_wrappers = ( + decode_wrappers or self.model_runner.attn_backend.decode_wrappers + ) + self.prefill_wrapper_ragged = ( + self.model_runner.attn_backend.prefill_wrapper_ragged + ) + self.prefill_wrappers_paged = ( + self.model_runner.attn_backend.prefill_wrappers_paged + ) + + self.kv_last_page_len = torch.ones( + (self.batch_size,), dtype=torch.int32, device="cuda" + ) + + def _update_decode_indices(self, decode_wrapper): + assert not isinstance(decode_wrapper, list) + decode_wrapper.end_forward() + decode_wrapper.begin_forward( + self.kv_indptr, + self.kv_indices, + self.kv_last_page_len, + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + 1, + data_type=self.model_runner.kv_cache_dtype, + q_data_type=self.model_runner.dtype, + ) + + def _update_extend_indices(self, ragged_wrapper, paged_wrapper): + assert not isinstance(paged_wrapper, list) + assert not isinstance(ragged_wrapper, list) + + # extend part + qo_indptr = torch.zeros( + (self.batch_size + 1,), dtype=torch.int32, device="cuda" + ) + qo_indptr[1:] = torch.cumsum(self.seq_lens - self.prefix_lens, dim=0) + + if self.use_ragged: + ragged_wrapper.end_forward() + ragged_wrapper.begin_forward( + qo_indptr, + qo_indptr, + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + ) + + # cached part + paged_wrapper.end_forward() + paged_wrapper.begin_forward( + qo_indptr, + self.kv_indptr, + self.kv_indices, + self.kv_last_page_len, + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + 1, + ) + + def _update_verify_indices(self, paged_wrapper): + custom_mask = getattr(self.spec_info, "custom_mask", None) + paged_wrapper.end_forward() + paged_wrapper.begin_forward( + self.qo_indptr, + self.kv_indptr, + self.kv_indices, + self.kv_last_page_len, + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + 1, + custom_mask=custom_mask, + ) + + def _update_spec_extend(self, paged_wrapper): + assert not isinstance(paged_wrapper, list) + paged_wrapper.end_forward() + paged_wrapper.begin_forward( + self.qo_indptr, + self.kv_indptr, + self.kv_indices, + self.kv_last_page_len, + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + 1, + ) + + def _get_indices(self, dispatch_reason: WrapperDispatch = None, wrapper_id=0): + if dispatch_reason is None: + if self.use_ragged: + paged_kernel_lens = self.prefix_lens + else: + paged_kernel_lens = self.seq_lens + self.kv_start_idx = None + elif dispatch_reason == WrapperDispatch.SLIDING_WINDOW: + if wrapper_id == 0: + # window attention use paged only + if self.forward_mode.is_decode(): + paged_kernel_lens = torch.minimum( + self.seq_lens, + torch.tensor(self.model_runner.sliding_window_size + 1), + ) + else: + paged_kernel_lens = torch.minimum( + self.seq_lens, + torch.tensor(self.model_runner.sliding_window_size) + + self.seq_lens + - self.prefix_lens, + ) + else: + # full attention + paged_kernel_lens = self.seq_lens + self.kv_start_idx = self.seq_lens - paged_kernel_lens + + if self.spec_info is not None and self.forward_mode.is_spec_extend(): + self.kv_indices, self.kv_indptr, self.kv_last_page_len, self.qo_indptr = ( + self.spec_info.generate_attn_arg_spec_extend( + self.req_pool_indices, + paged_kernel_lens, + self.model_runner.req_to_token_pool, + ) + ) + elif self.spec_info is not None and self.forward_mode.is_decode(): + self.kv_indices, self.kv_indptr, self.kv_last_page_len, self.qo_indptr = ( + self.spec_info.generate_attn_arg( + self.req_pool_indices, + paged_kernel_lens, + self.model_runner.req_to_token_pool, + ) + ) + else: + self.kv_indptr = torch.zeros( + (self.batch_size + 1,), dtype=torch.int32, device="cuda" + ) + self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) + self.kv_indices = torch.empty( + self.kv_indptr[-1], dtype=torch.int32, device="cuda" + ) + + create_flashinfer_kv_indices_triton[(self.batch_size,)]( + self.model_runner.req_to_token_pool.req_to_token, + self.req_pool_indices, + paged_kernel_lens, + self.kv_indptr, + self.kv_start_idx, + self.kv_indices, + self.model_runner.req_to_token_pool.req_to_token.size(1), + ) + + def _update_indicess_single_wrapper(self): + self._get_indices() + if self.forward_mode.is_verify(): + if self.use_cuda_graph: + self._update_verify_indices(self.decode_wrappers[0]) + else: + self._update_verify_indices(self.prefill_wrappers_paged[0]) + elif self.forward_mode.is_spec_extend(): + self._update_spec_extend(self.prefill_wrappers_paged[0]) + elif self.forward_mode.is_decode(): + self._update_decode_indices(self.decode_wrappers[0]) + else: + self._update_extend_indices( + self.prefill_wrapper_ragged, + self.prefill_wrappers_paged[0], + ) + + def _update_indices_cross_attention(self): + pass + + def _update_indices_sliding_window(self): + assert self.use_ragged is False + for wrapper_id in range(2): + self._get_indices(WrapperDispatch.SLIDING_WINDOW, wrapper_id) + if self.forward_mode.is_decode(): + self._update_decode_indices(self.decode_wrappers[wrapper_id]) + else: + self._update_extend_indices( + None, + self.prefill_wrappers_paged[wrapper_id], + ) + + +def update_flashinfer_indices( + forward_mode, + model_runner, + req_pool_indices, + seq_lens, + prefix_lens, + decode_wrappers=None, + use_ragged=False, + spec_info=None, + use_cuda_graph=False, +): + updater = FlashinferUpdater( + forward_mode, + model_runner, + req_pool_indices, + seq_lens, + prefix_lens, + decode_wrappers, + use_ragged, + spec_info, + use_cuda_graph, + ) + + dispatch_reason = model_runner.attn_backend.dispatch_reason + + if dispatch_reason == WrapperDispatch.SLIDING_WINDOW: + updater._update_indices_sliding_window() + elif dispatch_reason == WrapperDispatch.CROSS_ATTENTION: + updater._update_indices_cross_attention() + else: + assert model_runner.attn_backend.num_wrappers == 1 + updater._update_indicess_single_wrapper() diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 47b8d3cd56d..4c1a5489881 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -84,9 +84,12 @@ def init_cuda_graph_state(self, max_bs: int): def init_forward_metadata_capture_cuda_graph( self, bs: int, + num_token: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, - encoder_lens=None, + encoder_lens: torch.Tensor = None, + spec_info=None, + is_draft_runner=False, ): # NOTE: encoder_lens expected to be zeros or None self.forward_metadata = ( @@ -99,10 +102,12 @@ def init_forward_metadata_capture_cuda_graph( def init_forward_metadata_replay_cuda_graph( self, bs: int, + num_token: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, encoder_lens=None, + spec_info=None, ): # NOTE: encoder_lens expected to be zeros or None self.cuda_graph_start_loc.zero_() diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index eda2c7738d0..d95cf4da097 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -34,6 +34,13 @@ class LogitsProcessorOutput: next_token_logits: torch.Tensor # The logprobs of the next tokens. shape: [#seq, vocab_size] next_token_logprobs: torch.Tensor = None + + # Used by speculative inference + # The output of transformer layers + hidden_states: Optional[torch.Tensor] = None + # backup of next_token_logits when use cuda graph + # id(next_token_logits_bak) == id(next_token_logits) + next_token_logits_bak: Optional[torch.Tensor] = None # The normlaized logprobs of prompts. shape: [#seq] normalized_prompt_logprobs: torch.Tensor = None @@ -59,6 +66,8 @@ class LogitsMetadata: extend_logprob_start_lens_cpu: Optional[List[int]] = None extend_logprob_pruned_lens_cpu: Optional[List[int]] = None + + is_draft_batch: bool = False @classmethod def from_forward_batch(cls, forward_batch: ForwardBatch): @@ -67,7 +76,7 @@ def from_forward_batch(cls, forward_batch: ForwardBatch): else: return_top_logprob = False - if forward_batch.forward_mode.is_extend(): + if forward_batch.forward_mode.is_extend() and not forward_batch.is_draft_batch: extend_logprob_pruned_lens_cpu = [ extend_len - start_len for extend_len, start_len in zip( @@ -75,6 +84,7 @@ def from_forward_batch(cls, forward_batch: ForwardBatch): forward_batch.extend_logprob_start_lens_cpu, ) ] + else: extend_logprob_pruned_lens_cpu = None return cls( @@ -86,6 +96,7 @@ def from_forward_batch(cls, forward_batch: ForwardBatch): extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu, extend_logprob_start_lens_cpu=forward_batch.extend_logprob_start_lens_cpu, extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu, + is_draft_batch=forward_batch.is_draft_batch ) @@ -168,7 +179,9 @@ def forward( weight, logits_metadata: Union[LogitsMetadata, ForwardBatch], ): + need_hidden_states = False if isinstance(logits_metadata, ForwardBatch): + need_hidden_states = logits_metadata.spec_algorithm == 'EAGLE' logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata) assert isinstance(logits_metadata, LogitsMetadata) @@ -191,7 +204,7 @@ def forward( last_logits.mul_(self.config.final_logit_softcapping) # Return only last_logits if logprob is not requested - if not logits_metadata.return_logprob: + if not logits_metadata.return_logprob or logits_metadata.is_draft_batch: return LogitsProcessorOutput( next_token_logits=last_logits, next_token_logprobs=None, @@ -199,6 +212,8 @@ def forward( input_token_logprobs=None, input_top_logprobs=None, output_top_logprobs=None, + hidden_states=last_hidden if need_hidden_states else None, + next_token_logits_bak=last_logits, ) else: last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 85ca560a926..0870de8d5bf 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1,3 +1,5 @@ +from __future__ import annotations + """ Copyright 2023-2024 SGLang Team Licensed under the Apache License, Version 2.0 (the "License"); @@ -31,7 +33,8 @@ import dataclasses import logging -from typing import List, Optional, Tuple, Union +from dataclasses import dataclass +from typing import TYPE_CHECKING, List, Optional, Tuple, Union import torch @@ -46,6 +49,9 @@ from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import ServerArgs +if TYPE_CHECKING: + from sglang.srt.speculative.speculative_utils import SpecInput + INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 # Put some global args for easy access @@ -474,6 +480,10 @@ class ScheduleBatch: # device device: str = "cuda" + # speculative decoding + spec_info: SpecInput = None + spec_algorithm: str = None + @classmethod def init_new( cls, @@ -482,6 +492,7 @@ def init_new( token_to_kv_pool, tree_cache, model_config, + speculative_algorithm, ): return cls( reqs=reqs, @@ -493,6 +504,7 @@ def init_new( has_stream=any(req.stream for req in reqs), has_grammar=any(req.grammar for req in reqs), device=req_to_token_pool.device, + spec_algorithm=speculative_algorithm ) def batch_size(self): @@ -705,8 +717,8 @@ def mix_with_running(self, running_batch: "ScheduleBatch"): self.extend_lens.extend([1] * running_bs) self.extend_logprob_start_lens.extend([0] * running_bs) - def check_decode_mem(self): - bs = len(self.reqs) + def check_decode_mem(self, buf_multiplier=1): + bs = len(self.reqs)*buf_multiplier if self.token_to_kv_pool.available_size() >= bs: return True @@ -854,6 +866,8 @@ def prepare_encoder_info_decode(self): def prepare_for_decode(self, enable_overlap: bool = False): self.forward_mode = ForwardMode.DECODE + if self.spec_algorithm == 'EAGLE': + return self.input_ids = self.output_ids self.output_ids = None @@ -962,6 +976,8 @@ def merge_batch(self, other: "ScheduleBatch"): self.return_logprob = self.return_logprob or other.return_logprob self.has_stream = self.has_stream or other.has_stream self.has_grammar = self.has_grammar or other.has_grammar + if self.spec_info is not None: + self.spec_info.merge_batch(other.spec_info) def get_model_worker_batch(self): if self.forward_mode.is_decode(): @@ -1003,6 +1019,8 @@ def get_model_worker_batch(self): encoder_out_cache_loc=self.encoder_out_cache_loc, lora_paths=[req.lora_path for req in self.reqs], sampling_info=self.sampling_info, + spec_algorithm=self.spec_algorithm, + spec_info=self.spec_info, mrope_positions_delta=mrope_positions_delta, ) @@ -1069,9 +1087,14 @@ class ModelWorkerBatch: # Sampling info sampling_info: SamplingBatchInfo - + # For Qwen2-VL mrope_positions_delta: List[List[int]] + + # Speclulative decoding + spec_algorithm: str = None + spec_info: SpecInput = None + def copy(self): return dataclasses.replace(self, sampling_info=self.sampling_info.copy()) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index f876847e1d3..b1e80c7c93e 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -60,6 +60,7 @@ PrefillAdder, SchedulePolicy, ) +from sglang.srt.speculative.speculative_worker import spec_worker_factory from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient from sglang.srt.mem_cache.chunk_cache import ChunkCache @@ -158,6 +159,7 @@ def __init__( self.model_config.hf_config.architectures, self.server_args.is_embedding ) + # Launch a tensor parallel worker if self.enable_overlap: TpWorkerClass = TpModelWorkerClient @@ -171,7 +173,20 @@ def __init__( dp_rank=dp_rank, nccl_port=port_args.nccl_port, ) - + + # Launch Speculative worker if need + if self.server_args.speculative_algorithm is not None: + self.draft_worker = spec_worker_factory.get(self.server_args.speculative_algorithm)( + gpu_id=gpu_id, + tp_rank=tp_rank, + server_args=server_args, + nccl_port=port_args.nccl_port, + target_worker=self.tp_worker, + dp_rank=dp_rank, + ) + else: + self.draft_worker = None + # Get token and memory info from the model worker ( self.max_total_num_tokens, @@ -318,6 +333,7 @@ def event_loop_normal(self): """A normal blocking scheduler loop.""" self.last_batch = None + # TODO: need to edit it to support batch by split prefill @kavioyu while True: recv_reqs = self.recv_requests() self.process_input_requests(recv_reqs) @@ -375,13 +391,15 @@ def event_loop_overlap(self): def recv_requests(self): if self.tp_rank == 0: recv_reqs = [] - + while True: try: recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK) except zmq.ZMQError: break recv_reqs.append(recv_req) + if self.server_args.split_prefill_batch: + break else: recv_reqs = None @@ -715,6 +733,7 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: self.token_to_kv_pool, self.tree_cache, self.model_config, + self.server_args.speculative_algorithm ) new_batch.prepare_for_extend() @@ -740,11 +759,14 @@ def update_running_batch(self): return # Check if decode out of memory - if not batch.check_decode_mem() or (test_retract and batch.batch_size() > 10): + buf_multiplier = 1 if self.server_args.speculative_algorithm is None else self.server_args.num_draft_tokens + if not batch.check_decode_mem(buf_multiplier) or (test_retract and batch.batch_size() > 10): old_ratio = self.new_token_ratio retracted_reqs, new_token_ratio = batch.retract_decode() self.new_token_ratio = new_token_ratio + if self.draft_worker is not None: + self.draft_worker.finish_request(retracted_reqs) logger.info( "Decode out of memory happened. " @@ -772,13 +794,18 @@ def update_running_batch(self): def run_batch(self, batch: ScheduleBatch): """Run a batch.""" self.forward_ct += 1 - if self.is_generation: if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0: - model_worker_batch = batch.get_model_worker_batch() - logits_output, next_token_ids = self.tp_worker.forward_batch_generation( - model_worker_batch - ) + if self.server_args.speculative_algorithm: + logits_output, next_token_ids, model_worker_batch = self.draft_worker.forward_batch_speculative_generate( + batch + ) + else: + model_worker_batch = batch.get_model_worker_batch() + logits_output, next_token_ids = self.tp_worker.forward_batch_generation( + model_worker_batch + ) + else: logits_output = None if self.skip_tokenizer_init: @@ -898,7 +925,8 @@ def process_batch_result_decode(self, batch: ScheduleBatch, result): continue req.completion_tokens_wo_jump_forward += 1 - req.output_ids.append(next_token_id) + if batch.spec_algorithm is None: # speculative worker will solve the output_ids in speculative decoding + req.output_ids.append(next_token_id) req.check_finished() if req.grammar is not None: @@ -1018,6 +1046,8 @@ def stream_output(self, reqs: List[Req]): if req.finished() or ( req.stream and (is_stream_iter or len(req.output_ids) == 1) ): + if req.finished() and self.draft_worker is not None: + self.draft_worker.finish_request(req) output_rids.append(req.rid) output_finished_reason.append(req.finished_reason) if self.is_generation: diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 561bfd77c5a..e1572e5fd00 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -1,3 +1,5 @@ +from __future__ import annotations + """ Copyright 2023-2024 SGLang Team Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,6 +17,8 @@ """A tensor parallel worker.""" +from typing import TYPE_CHECKING + import json import logging from typing import Optional @@ -28,12 +32,14 @@ from sglang.srt.server_args import ServerArgs from sglang.srt.utils import broadcast_pyobj, is_multimodal_model, set_random_seed +if TYPE_CHECKING: + from sglang.srt.speculative.speculative_worker import SpeculativeWorker + logger = logging.getLogger(__name__) class TpModelWorker: """A tensor parallel model worker.""" - def __init__( self, server_args: ServerArgs, @@ -44,14 +50,17 @@ def __init__( ): # Parse args self.tp_rank = tp_rank + self.server_args = server_args + is_draft_worker = getattr(self, 'is_draft_worker', False) # Init model and tokenizer self.model_config = ModelConfig( - server_args.model_path, + server_args.model_path if not is_draft_worker else server_args.draft_model_path, server_args.trust_remote_code, context_length=server_args.context_length, model_override_args=json.loads(server_args.json_model_override_args), ) + self.model_runner = ModelRunner( model_config=self.model_config, mem_fraction_static=server_args.mem_fraction_static, @@ -60,6 +69,7 @@ def __init__( tp_size=server_args.tp_size, nccl_port=nccl_port, server_args=server_args, + is_draft_runner=is_draft_worker ) if server_args.skip_tokenizer_init: self.tokenizer = self.processor = None @@ -134,10 +144,14 @@ def get_memory_pool(self): self.model_runner.token_to_kv_pool, ) - def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): + def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch, need_token_id=True): forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) logits_output = self.model_runner.forward(forward_batch) - next_token_ids = self.model_runner.sample(logits_output, model_worker_batch) + if need_token_id: + next_token_ids = self.model_runner.sample(logits_output, model_worker_batch) + else: + next_token_ids = None + model_worker_batch.spec_info = forward_batch.spec_info return logits_output, next_token_ids def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch): diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index e91fbac6523..3685c07a905 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -33,6 +33,7 @@ ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.utils import monkey_patch_vllm_all_gather +from sglang.srt.speculative.speculative_utils import DraftInfoFactory if TYPE_CHECKING: from sglang.srt.model_executor.model_runner import ModelRunner @@ -123,6 +124,17 @@ def __init__(self, model_runner: "ModelRunner"): if bs <= model_runner.req_to_token_pool.size and bs <= model_runner.server_args.cuda_graph_max_bs ] + self.capture_forward_mode = ForwardMode.DECODE + if model_runner.server_args.speculative_algorithm == 'EAGLE': + if self.model_runner.is_draft_runner: + expand_num = self.model_runner.server_args.eagle_topk + else: + self.capture_forward_mode = ForwardMode.SPECVERIFY + expand_num = self.model_runner.server_args.num_draft_tokens + self.num_tokens = [bs * expand_num for bs in self.capture_bs] + else: + self.num_tokens = [bs for bs in self.capture_bs] + self.compile_bs = ( [ bs @@ -135,8 +147,8 @@ def __init__(self, model_runner: "ModelRunner"): # Attention backend self.max_bs = max(self.capture_bs) - self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs) - + self.max_num_token = max(self.num_tokens) + self.model_runner.attn_backend.init_cuda_graph_state(self.max_num_token) self.seq_len_fill_value = ( self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value() ) @@ -149,13 +161,18 @@ def __init__(self, model_runner: "ModelRunner"): # Common inputs with torch.device("cuda"): - self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32) + self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int32) self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32) self.seq_lens = torch.full( (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 ) - self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32) + self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int32) + self.positions = torch.zeros((self.max_num_token, ), dtype=torch.int64) self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int32) + + # speculative_inference + if self.model_runner.server_args.speculative_algorithm == 'EAGLE': + self.hidden_states = torch.zeros((self.max_num_token, self.model_runner.model_config.hidden_size), dtype=self.model_runner.dtype) if self.is_encoder_decoder: # NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch @@ -209,7 +226,7 @@ def can_run(self, forward_batch: ForwardBatch): def capture(self): with graph_capture() as graph_capture_context: self.stream = graph_capture_context.stream - for bs in self.capture_bs: + for bs, num_token in zip(self.capture_bs, self.num_tokens): with patch_model( self.model_runner.model, bs in self.compile_bs, @@ -218,19 +235,35 @@ def capture(self): ( graph, output_buffers, - ) = self.capture_one_batch_size(bs, forward) + ) = self.capture_one_batch_size(bs, num_token, forward) self.graphs[bs] = graph self.output_buffers[bs] = output_buffers - def capture_one_batch_size(self, bs: int, forward: Callable): + def capture_one_batch_size(self, bs: int, num_token: int, forward: Callable): graph = torch.cuda.CUDAGraph() stream = self.stream # Common inputs - input_ids = self.input_ids[:bs] + input_ids = self.input_ids[:num_token] req_pool_indices = self.req_pool_indices[:bs] seq_lens = self.seq_lens[:bs] - out_cache_loc = self.out_cache_loc[:bs] + out_cache_loc = self.out_cache_loc[:num_token] + positions = self.positions[:num_token] + + spec_info = None + if self.model_runner.server_args.speculative_algorithm == 'EAGLE': + if self.model_runner.is_draft_runner: + spec_info = DraftInfoFactory.get(self.model_runner.server_args.speculative_algorithm, 'DraftInput')() + spec_info.hidden_states = self.hidden_states[:num_token] + spec_info.positions = positions + spec_info.init(self.model_runner.server_args) + else: + spec_info = DraftInfoFactory.get(self.model_runner.server_args.speculative_algorithm, 'VerifyInput')( + None, None, None, None, None, None, self.model_runner.server_args.num_draft_tokens) + spec_info.custom_mask = torch.zeros((num_token*self.model_runner.model_config.context_len), dtype=torch.bool, + device="cuda",) + + if self.is_encoder_decoder: encoder_lens = self.encoder_lens[:bs] else: @@ -239,42 +272,49 @@ def capture_one_batch_size(self, bs: int, forward: Callable): seq_lens_sum = seq_lens.sum().item() mrope_positions = self.mrope_positions[:, :bs] + forward_batch = ForwardBatch( + forward_mode=self.capture_forward_mode, + batch_size=bs, + input_ids=input_ids, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + req_to_token_pool=self.model_runner.req_to_token_pool, + token_to_kv_pool=self.model_runner.token_to_kv_pool, + attn_backend=self.model_runner.attn_backend, + out_cache_loc=out_cache_loc, + seq_lens_sum=seq_lens_sum, + encoder_lens=encoder_lens, + return_logprob=False, + top_logprobs_nums=[0] * num_token, + positions=positions, + #positions=clamp_position(seq_lens), + spec_info=spec_info, + spec_algorithm=self.model_runner.server_args.speculative_algorithm, + is_cuda_graph=True, + mrope_positions=mrope_positions, + ) + # Attention backend self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph( bs, + num_token, req_pool_indices, seq_lens, encoder_lens, + spec_info, + self.model_runner.is_draft_runner, + forward_batch=forward_batch, ) # Run and capture def run_once(): - forward_batch = ForwardBatch( - forward_mode=ForwardMode.DECODE, - batch_size=bs, - input_ids=input_ids, - req_pool_indices=req_pool_indices, - seq_lens=seq_lens, - req_to_token_pool=self.model_runner.req_to_token_pool, - token_to_kv_pool=self.model_runner.token_to_kv_pool, - attn_backend=self.model_runner.attn_backend, - out_cache_loc=out_cache_loc, - seq_lens_sum=seq_lens_sum, - encoder_lens=encoder_lens, - return_logprob=False, - top_logprobs_nums=[0] * bs, - positions=clamp_position(seq_lens), - mrope_positions=mrope_positions, - ) logits_output = forward(input_ids, forward_batch.positions, forward_batch) - return logits_output.next_token_logits + return logits_output.next_token_logits, logits_output.hidden_states for _ in range(2): torch.cuda.synchronize() self.model_runner.tp_group.barrier() - run_once() - torch.cuda.synchronize() self.model_runner.tp_group.barrier() @@ -292,37 +332,57 @@ def run_once(): def replay(self, forward_batch: ForwardBatch): assert forward_batch.out_cache_loc is not None + forward_batch.is_cuda_graph = True raw_bs = forward_batch.batch_size + # In most case, raw_bs == num_token in decode stage. + # But for speculative, the token num maybe large than raw_bs + raw_num_token = forward_batch.input_ids.numel() # Pad index = bisect.bisect_left(self.capture_bs, raw_bs) bs = self.capture_bs[index] + num_token = self.num_tokens[index] if bs != raw_bs: self.seq_lens.fill_(1) self.out_cache_loc.zero_() # Common inputs - self.input_ids[:raw_bs].copy_(forward_batch.input_ids) + self.input_ids[:raw_num_token].copy_(forward_batch.input_ids) self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens) - self.out_cache_loc[:raw_bs].copy_(forward_batch.out_cache_loc) + self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc) + positions = forward_batch.positions + if positions is None: + positions = clamp_position(forward_batch.seq_lens) + self.positions[:raw_num_token].copy_(positions) + if self.is_encoder_decoder: self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens) if forward_batch.mrope_positions is not None: self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions) + + # EAGLE speculative decoding + if isinstance(forward_batch.spec_info, DraftInfoFactory.get('EAGLE', 'DraftInput')): + self.hidden_states[:raw_num_token] = forward_batch.spec_info.hidden_states + # Attention backend self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph( bs, + num_token, self.req_pool_indices, self.seq_lens, forward_batch.seq_lens_sum + (bs - raw_bs), self.encoder_lens, + forward_batch=forward_batch ) # Replay self.graphs[bs].replay() - next_token_logits = self.output_buffers[bs][:raw_bs] + next_token_logits, hidden_states = self.output_buffers[bs] + next_token_logits = next_token_logits[:raw_num_token] + if hidden_states is not None: + hidden_states = hidden_states[:raw_num_token] # Extract logprobs if forward_batch.return_logprob: @@ -332,6 +392,8 @@ def replay(self, forward_batch: ForwardBatch): logits_output = LogitsProcessorOutput( next_token_logits=next_token_logits, next_token_logprobs=next_token_logprobs, + next_token_logits_bak=next_token_logits, + hidden_states=hidden_states, ) return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums) if return_top_logprob: @@ -345,6 +407,8 @@ def replay(self, forward_batch: ForwardBatch): else: logits_output = LogitsProcessorOutput( next_token_logits=next_token_logits, + next_token_logits_bak=next_token_logits, + hidden_states=hidden_states, ) return logits_output diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index d314af944ef..38089286e7b 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -45,7 +45,7 @@ from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo - + from sglang.srt.speculative.speculative_utils import SpecInput class ForwardMode(IntEnum): # Prefill a new sequence. This is deprecated now. "EXTEND" covers this case. @@ -56,18 +56,31 @@ class ForwardMode(IntEnum): DECODE = auto() # Contains both EXTEND and DECODE. MIXED = auto() + # Speculative Verify stage + SPECVERIFY = auto() + # Speculative draft Extend stage which after verify stage + SPECEXTEND = auto() def is_prefill(self): return self == ForwardMode.PREFILL def is_extend(self): - return self == ForwardMode.EXTEND or self == ForwardMode.MIXED + return self in (ForwardMode.EXTEND, self == ForwardMode.MIXED, ForwardMode.SPECEXTEND) def is_decode(self): - return self == ForwardMode.DECODE + return self in (ForwardMode.DECODE, ForwardMode.SPECVERIFY) def is_mixed(self): return self == ForwardMode.MIXED + + def is_verify(self): + return self == ForwardMode.SPECVERIFY + + def is_spec_extend(self): + return self == ForwardMode.SPECEXTEND + + def is_cuda_graph(self): + return self in (ForwardMode.DECODE, ForwardMode.SPECVERIFY) @dataclass @@ -124,6 +137,12 @@ class ForwardBatch: req_to_token_pool: ReqToTokenPool = None token_to_kv_pool: BaseTokenToKVPool = None attn_backend: AttentionBackend = None + + # Speculative decoding + spec_info: SpecInput = None + spec_algorithm: str = None + is_draft_batch: bool = False + is_cuda_graph: bool = False # For Qwen2-VL mrope_positions: torch.Tensor = None @@ -207,8 +226,12 @@ def init_new( top_logprobs_nums=batch.top_logprobs_nums, lora_paths=batch.lora_paths, sampling_info=batch.sampling_info, + spec_algorithm=batch.spec_algorithm, + spec_info=batch.spec_info, ) + if ret.spec_info is not None and getattr(ret.spec_info, 'positions', None) is not None: + ret.positions = ret.spec_info.positions # Init position information if not ret.forward_mode.is_decode(): ret.positions = torch.concat( diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 2bc048197e3..920bb2b3b4e 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -83,6 +83,7 @@ def __init__( tp_size: int, nccl_port: int, server_args: ServerArgs, + is_draft_runner: bool ): # Parse args self.model_config = model_config @@ -96,6 +97,7 @@ def __init__( self.is_multimodal_model = is_multimodal_model( self.model_config.hf_config.architectures ) + self.is_draft_runner = is_draft_runner # Model-specific adjustment if ( @@ -183,9 +185,9 @@ def init_torch_distributed(self): if not self.server_args.enable_p2p_check: monkey_patch_vllm_p2p_access_check(self.gpu_id) if self.server_args.dist_init_addr: - dist_init_method = f"tcp://{self.server_args.dist_init_addr}" + dist_init_method = f"tcp://{self.server_args.dist_init_addr[1 if self.is_draft_runner else 0]}" else: - dist_init_method = f"tcp://127.0.0.1:{self.dist_port}" + dist_init_method = f"tcp://127.0.0.1:{self.dist_port[1 if self.is_draft_runner else 0]}" set_custom_all_reduce(not self.server_args.disable_custom_all_reduce) init_distributed_environment( backend=backend, @@ -194,7 +196,9 @@ def init_torch_distributed(self): local_rank=self.gpu_id, distributed_init_method=dist_init_method, ) - initialize_model_parallel(tensor_model_parallel_size=self.tp_size) + # draft model is not support parallel currently + if not self.is_draft_runner: + initialize_model_parallel(tensor_model_parallel_size=self.tp_size) min_per_gpu_memory = get_available_gpu_memory( self.device, self.gpu_id, distributed=self.tp_size > 1 ) @@ -240,7 +244,7 @@ def load_model(self): monkey_patch_vllm_dummy_weight_loader() self.load_config = LoadConfig(load_format=self.server_args.load_format) self.vllm_model_config = VllmModelConfig( - model=self.server_args.model_path, + model=self.server_args.model_path if not self.is_draft_runner else self.server_args.draft_model_path, quantization=self.server_args.quantization, tokenizer=None, tokenizer_mode=None, @@ -424,6 +428,26 @@ def init_memory_pool( ) self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory) + + if max_num_reqs is None: + max_num_reqs = min( + max( + int( + self.max_total_num_tokens / self.model_config.context_len * 512 + ), + 2048, + ), + 4096, + ) + + if self.server_args.speculative_algorithm is not None: + if self.is_draft_runner: + self.max_total_num_tokens = self.server_args.draft_runner_cache_size + else: + self.server_args.draft_runner_cache_size = self.max_total_num_tokens + \ + max_num_reqs * self.server_args.num_speculative_steps + 100 + + if max_total_tokens is not None: if max_total_tokens > self.max_total_num_tokens: logging.warning( @@ -438,17 +462,6 @@ def init_memory_pool( "Not enough memory. Please try to increase --mem-fraction-static." ) - if max_num_reqs is None: - max_num_reqs = min( - max( - int( - self.max_total_num_tokens / self.model_config.context_len * 512 - ), - 2048, - ), - 4096, - ) - self.req_to_token_pool = ReqToTokenPool( size=max_num_reqs + 1, max_context_len=self.model_config.context_len + 4, @@ -557,10 +570,15 @@ def init_cuda_graphs(self): self.cuda_graph_runner = CudaGraphRunner(self) def forward_decode(self, forward_batch: ForwardBatch): - if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch): + + if self.cuda_graph_runner and self.cuda_graph_runner.can_run( + forward_batch + ) and forward_batch.forward_mode.is_cuda_graph(): return self.cuda_graph_runner.replay(forward_batch) - - forward_batch.positions = (forward_batch.seq_lens - 1).to(torch.int64) + if hasattr(forward_batch.spec_info, 'positions'): + forward_batch.positions = forward_batch.spec_info.positions + else: + forward_batch.positions = (forward_batch.seq_lens - 1).to(torch.int64) self.attn_backend.init_forward_metadata(forward_batch) return self.model.forward( forward_batch.input_ids, forward_batch.positions, forward_batch @@ -569,6 +587,8 @@ def forward_decode(self, forward_batch: ForwardBatch): def forward_extend(self, forward_batch: ForwardBatch): self.attn_backend.init_forward_metadata(forward_batch) if self.is_generation: + if getattr(forward_batch.spec_info, 'positions', None) is not None: + forward_batch.positions = forward_batch.spec_info.positions return self.model.forward( forward_batch.input_ids, forward_batch.positions, forward_batch ) diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index 543703c230b..bdd001896d9 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -313,9 +313,12 @@ def forward( input_embeds: torch.Tensor = None, ) -> LogitsProcessorOutput: hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) - return self.logits_processor( + res = self.logits_processor( input_ids, hidden_states, self.lm_head.weight, forward_batch ) + if forward_batch.spec_algorithm == 'EAGLE': + forward_batch.spec_info.hidden_states = hidden_states + return res def get_hidden_dim(self, module_name): # return input_dim, output_dim @@ -416,6 +419,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, self.model.embed_tokens.weight) apply_torchao_config_(self, params_dict, set(["proj.weight"])) + def get_embed_and_head(self): + return self.model.embed_tokens.weight, self.lm_head.weight class Phi3ForCausalLM(LlamaForCausalLM): pass diff --git a/python/sglang/srt/models/llama_eagle.py b/python/sglang/srt/models/llama_eagle.py new file mode 100644 index 00000000000..87082f8561d --- /dev/null +++ b/python/sglang/srt/models/llama_eagle.py @@ -0,0 +1,429 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +# Adapted from +# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1 +# and +# https://github.com/SafeAILab/EAGLE/blob/main/eagle/model/modeling_llama_kv.py +"""Inference-only LLaMA-EAGLE model compatible with HuggingFace weights.""" + +from typing import Any, Dict, Iterable, Optional, Tuple + +import torch +from torch import nn +from transformers import LlamaConfig +from vllm.config import CacheConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.torchao_utils import apply_torchao_config_ +from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.model_executor.forward_batch_info import ForwardBatch + + +class LlamaMLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + if hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now." + ) + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class LlamaAttention(nn.Module): + def __init__( + self, + config: LlamaConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + layer_id: int = 0, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + rope_is_neox_style: bool = True, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + # MistralConfig has an optional head_dim introduced by Mistral-Nemo + self.head_dim = getattr( + config, "head_dim", self.hidden_size // self.total_num_heads + ) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=rope_is_neox_style, + ) + self.attn = RadixAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + layer_id=layer_id, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, forward_batch) + output, _ = self.o_proj(attn_output) + return output + + +class LlamaDecoderLayer(nn.Module): + def __init__( + self, + config: LlamaConfig, + layer_id: int = 0, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.layer_id = layer_id + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None + ): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings + ) + rope_is_neox_style = getattr(config, "rope_is_neox_style", True) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + self.self_attn = LlamaAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + layer_id=layer_id, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + rope_is_neox_style=rope_is_neox_style, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + self.mlp = LlamaMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + residual = hidden_states + + if self.layer_id != 0: + hidden_states = self.input_layernorm(hidden_states) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + ) + + hidden_states = residual + hidden_states + residual = hidden_states + # Fully Connected + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states, residual + + +class LlamaModel(nn.Module): + def __init__( + self, + config: LlamaConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.layers = nn.ModuleList( + [ + LlamaDecoderLayer( + config, i, quant_config=quant_config, prefix=f"model.layers.{i}" + ) + for i in range(config.num_hidden_layers) + ] + ) + # self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.fc = torch.nn.Linear(config.hidden_size * 2, config.hidden_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + ) -> torch.Tensor: + if input_embeds is None: + hidden_states = self.embed_tokens(input_ids) + else: + hidden_states = input_embeds + + hidden_states = self.fc( + torch.cat( + (hidden_states, forward_batch.spec_info.hidden_states), dim=-1 + ) + ) + + residual = None + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + forward_batch, + residual, + ) + return hidden_states + + +class LlamaForCausalLMEagle(nn.Module): + def __init__( + self, + config: LlamaConfig, + quant_config: Optional[QuantizationConfig] = None, + cache_config: Optional[CacheConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.quant_config = quant_config + self.model = LlamaModel(config, quant_config=quant_config) + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.logits_processor = LogitsProcessor(config) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + ) -> LogitsProcessorOutput: + hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) + logits_output = self.logits_processor( + None, hidden_states, self.lm_head.weight, forward_batch + ) + return logits_output + + def get_hidden_dim(self, module_name): + # return input_dim, output_dim + if module_name in ["q_proj", "o_proj", "qkv_proj"]: + return self.config.hidden_size, self.config.hidden_size + elif module_name in ["kv_proj"]: + return self.config.hidden_size, self.config.hidden_size // ( + self.config.num_attention_heads // self.config.num_key_value_heads + ) + elif module_name == "gate_up_proj": + return self.config.hidden_size, self.config.intermediate_size + elif module_name == "down_proj": + return self.config.intermediate_size, self.config.hidden_size + else: + raise NotImplementedError() + + def get_module_name(self, name): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id, num_shard) + ("qkv_proj", "q_proj", "q", 3), + ("qkv_proj", "k_proj", "k", 3), + ("qkv_proj", "v_proj", "v", 3), + ("gate_up_proj", "gate_proj", 0, 2), + ("gate_up_proj", "up_proj", 1, 2), + ] + for param_name, weight_name, shard_id, num_shard in stacked_params_mapping: + if weight_name in name: + return ( + name.replace(weight_name, param_name)[: -len(".weight")], + num_shard, + ) + return name[: -len(".weight")], 1 + + def get_num_params(self): + params_dict = dict(self.named_parameters()) + return len(params_dict) + + def load_weights( + self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None + ): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + + def load_weights_per_param(name, loaded_weight): + if "rotary_emb.inv_freq" in name or "projector" in name: + return + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + return + if name.startswith("model.vision_tower") and name not in params_dict: + return + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + return + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + if name is None or loaded_weight is None: + for name, loaded_weight in weights: + if "lm_head" not in name: + name = "model." + name + load_weights_per_param(name, loaded_weight) + else: + load_weights_per_param(name, loaded_weight) + + def set_embed_and_head(self, embed, head): + del self.model.embed_tokens.weight + del self.lm_head.weight + self.model.embed_tokens.weight = embed + self.lm_head.weight = head + torch.cuda.empty_cache() + torch.cuda.synchronize() + + +EntryClass = [LlamaForCausalLMEagle] diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 7d23cb8bd58..debd1b407d5 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -57,6 +57,7 @@ class ServerArgs: max_prefill_tokens: int = 16384 schedule_policy: str = "lpm" schedule_conservativeness: float = 1.0 + split_prefill_batch: bool = False # Other runtime options tp_size: int = 1 @@ -81,7 +82,7 @@ class ServerArgs: load_balance_method: str = "round_robin" # Distributed args - dist_init_addr: Optional[str] = None + dist_init_addr: Optional[List[str]] = None nnodes: int = 1 node_rank: int = 0 @@ -127,6 +128,19 @@ class ServerArgs: triton_attention_reduce_in_fp32: bool = False num_continuous_decode_steps: int = 1 + #speculative decoding + draft_model_path: str = None + speculative_algorithm: str = None + num_speculative_steps: int = None + # should been set as 2^n + num_draft_tokens: int = None + # should been set as [1, 2, 4, 8] + eagle_topk: int = None + # should not been set by cli, it is only a placeholder + # which would be set and used in model_runner + draft_runner_cache_size: int = None + + def __post_init__(self): # Set missing default values if self.tokenizer_path is None: @@ -200,6 +214,12 @@ def __post_init__(self): if "gemma-2" in self.model_path.lower(): logger.info("When using sliding window in gemma-2, turn on flashinfer.") self.attention_backend = "flashinfer" + + # Speculative Decoding + if self.speculative_algorithm=='EAGLE': + self.split_prefill_batch = True + # EAGLE don't support it currently. + self.disable_cuda_graph_padding = True @staticmethod def add_cli_args(parser: argparse.ArgumentParser): @@ -460,8 +480,9 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--dist-init-addr", "--nccl-init-addr", # For backward compatbility. This will be removed in the future. - type=str, - help="The host address for initializing distributed backend (e.g., `192.168.0.2:25000`).", + type=List[str], + help="""The host address for initializing distributed backend (e.g., `192.168.0.2:25000`). Shoule provide + two host address if use speculative decoding""", ) parser.add_argument( "--nnodes", type=int, default=ServerArgs.nnodes, help="The number of nodes." @@ -663,6 +684,56 @@ def add_cli_args(parser: argparse.ArgumentParser): "This can potentially increase throughput but may also increase time-to-first-token latency. " "The default value is 1, meaning only run one decoding step at a time.", ) + parser.add_argument( + "--draft-model-path", + type=str, + help="The path of the draft model weights. This can be a local folder or a Hugging Face repo ID.", + required=False, + ) + parser.add_argument( + "--speculative-algorithm", + type=str, + choices=["EAGLE"], + help="Speculative algorithm.", + required=False, + ) + parser.add_argument( + "--num-speculative-steps", + type=int, + help="The number of steps sampled from draft model in Speculative Decoding.", + required=False, + default=5, + ) + parser.add_argument( + "--num-draft-tokens", + type=int, + help="The number of token sampled from draft model in Speculative Decoding.", + required=False, + default=64, + ) + parser.add_argument( + "--eagle-topk", + type=int, + help="The number of token sampled from draft model in eagle2 each step.", + required=False, + choices=[1, 2, 4, 8], + default=8, + ) + parser.add_argument( + "--split-prefill-batch", + type=bool, + help="Whether to inference prefill sample one by one.", + required=False, + default=False, + ) + parser.add_argument( + "--draft-runner-cache-size", + type=int, + help="""It should not been set by cli, it is only a placeholder which + would be set and used in model_runner when using speculative inference.""", + required=False, + default=-1, + ) @classmethod def from_cli_args(cls, args: argparse.Namespace): @@ -730,13 +801,17 @@ class PortArgs: detokenizer_ipc_name: str # The port for nccl initialization (torch.dist) - nccl_port: int + # [port] if don't use speculative decoding else [tp worker port, draft worker, port] + nccl_port: List[int] @staticmethod def init_new(server_args) -> "PortArgs": + all_port = [] port = server_args.port + 42 while True: if is_port_available(port): + all_port.append(port) + if len(all_port) == 2 if server_args.speculative_algorithm is not None else 1: break port += 42 @@ -744,7 +819,7 @@ def init_new(server_args) -> "PortArgs": tokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name, scheduler_input_ipc_name=tempfile.NamedTemporaryFile(delete=False).name, detokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name, - nccl_port=port, + nccl_port=all_port, ) diff --git a/python/sglang/srt/speculative/__init__.py b/python/sglang/srt/speculative/__init__.py new file mode 100644 index 00000000000..f56d3b89ef1 --- /dev/null +++ b/python/sglang/srt/speculative/__init__.py @@ -0,0 +1,2 @@ +from .eagle_worker import EAGLEWorker +from . import eagle_utils \ No newline at end of file diff --git a/python/sglang/srt/speculative/build_eagle_tree.py b/python/sglang/srt/speculative/build_eagle_tree.py new file mode 100644 index 00000000000..a7341a51850 --- /dev/null +++ b/python/sglang/srt/speculative/build_eagle_tree.py @@ -0,0 +1,354 @@ +# import triton +# import triton.language as tl +# import torch + + +import time + +import cutex +import torch + +# parent_table [bs,topk*depth+)] +# selected_index [bs,draft_token_num-1)] +# verified_seq_len [bs] +# tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] = [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] +# positions [bs*draft_token] +# retrive_index [b, draft_token, depth+2] +kernels = cutex.SourceModule( + """ +//cuda +__global__ void build_tree(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, + Tensor tree_mask, Tensor positions, Tensor retrive_index, int topk, int depth, int draft_token_num) { + int bid = blockIdx.x; + int tid = threadIdx.x; + if (tid >= draft_token_num){ + return; + } + int seq_tree_idx = draft_token_num * draft_token_num * bid; + for(int i=0; i1: + index = tl.arange(0, draft_token_num) + mask_left = index != max_index + remained_index = tl.where(mask_max and mask_left, index, 0) + max_index = tl.max(remained_index) + + tl.store(accept_length + pid, accept_len) + retrive_index_ptr = retrive_index + (retrive_start + max_index) * max_len + retrive_offset = tl.arange(0, max_len_upper) + retrive_load_mask = retrive_offset < accept_len + 1 + data = tl.load(retrive_index_ptr + retrive_offset, mask=retrive_load_mask) + + tl.store( + accept_index + pid * max_len + retrive_offset, data, mask=retrive_load_mask + ) + + extract_load_ptr = accept_index + pid * max_len + accept_len + if accept_len == max_len - 1: + extract_data = tl.load(extract_load_ptr - 1) + tl.store(extract_index + pid * 2, extract_data) + extract_data = tl.load(extract_load_ptr) + tl.store(extract_index + pid * 2 + 1, extract_data) + + else: + extract_data = tl.load(extract_load_ptr) + tl.store(extract_index + pid * 2, extract_data) + +@triton.jit +def create_extend_spec_info(verified_id, seq_len, accept_len, accept_len_cum, positions, new_verified_id, accept_len_upper: tl.constexpr): + pid = tl.program_id(axis=0) + offset = 0 if pid ==0 else tl.load(accept_len_cum+pid-1) + seq_length = tl.load(seq_len+pid) + accept_length = tl.load(accept_len+pid) + positions_ptr = positions+offset + data = tl.arange(0, accept_len_upper) + mask = data < accept_length + tl.store(positions_ptr+data, seq_length-accept_length+data, mask) + + offset = tl.load(accept_len_cum+pid)-1 + verified_id_data = tl.load(verified_id+offset) + tl.store(new_verified_id+pid, verified_id_data) + + +@triton.jit +def assign_req_to_token_pool(req_pool_indices, req_to_token, start_offset, end_offset, out_cache_loc, pool_len: tl.constexpr, bs_upper: tl.constexpr): + BLOCK_SIZE: tl.constexpr = 128 + pid = tl.program_id(axis=0) + kv_start = tl.load(start_offset+pid) + kv_end = tl.load(end_offset+pid) + token_pool = req_to_token + tl.load(req_pool_indices+pid) * pool_len + + length_offset = tl.arange(0, bs_upper) + start = tl.load(start_offset+length_offset, mask=length_offset 0: + batch.req_to_token_pool.req_to_token[req.req_pool_idx][ + :pre_len + ] = req.prefix_indices + + batch.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = ( + out_cache_loc[pt : pt + req.extend_input_len] + ) + + pt += req.extend_input_len + + seq_lens = [0] + batch.extend_lens + input_ids = batch.input_ids.tolist() + verified_id = batch.spec_info.verified_id.tolist() + model_input_ids = [] + for i in range(len(seq_lens) - 1): + model_input_ids.extend( + input_ids[seq_lens[i] + 1 : seq_lens[i + 1]] + [verified_id[i]] + ) + batch.input_ids = torch.tensor( + model_input_ids, dtype=torch.int32, device="cuda" + ) + + def capture_for_decode(self, sample_output: SampleOutput, hidden_states: torch.Tensor, + prev_mode: ForwardMode): + self.sample_output = sample_output + self.prev_mode = prev_mode + self.hidden_states = hidden_states + + def prepare_for_decode(self, batch: ScheduleBatch): + prob = self.sample_output # b * (1/topk), vocab + top = torch.topk(prob, self.topk, dim=-1) + topk_index, topk_p = top.indices, top.values # b * (1/topk), topk + if self.prev_mode == ForwardMode.DECODE: + scores = torch.mul( + self.scores.unsqueeze(2), topk_p.reshape(-1, self.topk, self.topk) + ) # (b, topk) mul (b * topk ,topk) -> b, topk, topk + topk_cs = torch.topk( + scores.flatten(start_dim=1), self.topk, dim=-1 + ) # (b, topk) + topk_cs_index, topk_cs_p = topk_cs.indices, topk_cs.values + self.scores = topk_cs_p + + selected_input_index = topk_cs_index.flatten() // self.topk # b* topk + + batch.spec_info.hidden_states = batch.spec_info.hidden_states[ + selected_input_index, : + ] + topk_index = topk_index.reshape(-1, self.topk**2) + batch.input_ids = torch.gather( + topk_index, index=topk_cs_index, dim=1 + ).flatten() + batch.out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel()) + self.score_list.append(scores) # b, topk, topk + self.token_list.append(topk_index) # b, topk*topk + self.origin_score_list.append(topk_p.reshape(topk_index.shape)) + self.parents_list.append( + topk_cs_index + (self.topk**2 * (self.iter - 1) + self.topk) + ) # b, topk + + elif self.prev_mode in (ForwardMode.EXTEND, ForwardMode.SPECEXTEND) : + self.scores = topk_p # b, top_k + self.score_list.append(topk_p.unsqueeze(1)) + self.token_list.append(topk_index) + self.origin_score_list.append(topk_p) + batch.spec_info.hidden_states = ( + batch.spec_info.hidden_states.repeat_interleave(self.topk, 0) + ) + batch.input_ids = topk_index.flatten() + batch.out_cache_loc = batch.alloc_token_slots(topk_index.numel()) + self.parents_list.append( + torch.arange(-1, self.topk, dtype=torch.long, device="cuda").unsqueeze(0).repeat(self.scores.shape[0], 1) + ) # b, topk+1 + self.cache_list.append(batch.out_cache_loc) + self.positions = ( + batch.seq_lens[:, None] + + torch.ones([1, self.topk], device="cuda", dtype=torch.long) * self.iter + ).flatten() + + bs = batch.seq_lens.numel() + assign_req_to_token_pool[(bs, )](batch.req_pool_indices, + batch.req_to_token_pool.req_to_token, + batch.seq_lens+self.topk*self.iter, + batch.seq_lens+self.topk*(self.iter+1), + batch.out_cache_loc, + batch.req_to_token_pool.req_to_token.shape[1], + triton.next_power_of_2(bs) + ) + self.iter += 1 + + def prepare_extend_after_decode(self, batch: ScheduleBatch): + batch.out_cache_loc = batch.alloc_token_slots(self.verified_id.numel()) + batch.extend_lens = (self.accept_length+1).tolist() + + pt=0 + seq_lens = batch.seq_lens.tolist() + + i = 0 + # TODO: Chage it to triton kernel @kavioyu + for req in batch.reqs: + if req.finished(): + continue + #assert seq_len - pre_len == req.extend_input_len + input_len = self.accept_length[i] + 1 + seq_len = seq_lens[i] + batch.req_to_token_pool.req_to_token[req.req_pool_idx][seq_len-input_len:seq_len] = ( + batch.out_cache_loc[pt : pt + input_len] + ) + pt += input_len + i += 1 + + self.positions = torch.empty_like(self.verified_id) + new_verified_id = torch.empty_like(self.accept_length, dtype=torch.long) + self.accept_length.add_(1) + + create_extend_spec_info[(self.accept_length.numel(),)](self.verified_id, batch.seq_lens, + self.accept_length, torch.cumsum(self.accept_length, axis=0, dtype=torch.int), + self.positions, new_verified_id, triton.next_power_of_2(self.spec_steps+1)) + + batch.input_ids = self.verified_id + self.verified_id = new_verified_id + + + def prepare_for_verify(self, batch: ScheduleBatch): + score_list = torch.cat(self.score_list, dim=1).flatten(1) # b, n, topk; n= 1+(self.iter-1)*self.topk + ss_token_list = torch.cat(self.token_list, dim=1) # b, (self.topk+(self.iter-1)*self.topk) + origin_token_list = torch.cat(self.origin_score_list, dim=1) + top_scores = torch.topk(score_list, self.num_verify_token - 1, dim=-1) + top_scores_index = top_scores.indices + top_scores_index = torch.sort(top_scores_index).values + + draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1) + scores = torch.gather(origin_token_list, index=top_scores_index, dim=1) + draft_tokens = torch.cat((self.verified_id.unsqueeze(1), draft_tokens), dim=1) + parent_list = torch.cat(self.parents_list[:-1], dim=1) + + tree_mask, position, retrive_index, retrive_cum_len = build_tree_kernel( + parent_list, + top_scores_index, + batch.seq_lens, + self.topk, + self.iter - 1, + self.num_verify_token, + ) + + return EagleVerifyInput( + draft_tokens.flatten(), + scores.flatten(), + tree_mask, + position, + retrive_index, + retrive_cum_len, + self.num_verify_token, + ) + + def generate_attn_arg( + self, + req_pool_indices: torch.Tensor, + paged_kernel_lens: torch.Tensor, + req_to_token: torch.Tensor, + ): + seq_num = req_pool_indices.numel() + bs = self.topk * req_pool_indices.numel() + seq_len = self.positions.reshape(-1).contiguous() + + cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda") + cum_kv_seq_len[1:] = torch.cumsum(seq_len + 1, dim=0) + kv_last_page_len = torch.ones((bs,), dtype=torch.int32, device="cuda") + total_len = torch.sum(paged_kernel_lens).item() + + kv_indices = torch.empty((total_len * self.topk + seq_num*self.iter*self.topk, ), + dtype=torch.int32, device='cuda') + + generate_draft_decode_kv_indices[(req_pool_indices.numel(), self.topk)](req_pool_indices, req_to_token, + paged_kernel_lens, kv_indices, self.iter, self.topk, + req_to_token.shape[1], triton.next_power_of_2(seq_num), + triton.next_power_of_2(self.spec_steps)) + return kv_indices, cum_kv_seq_len, kv_last_page_len, None + + def clear(self): + self.iter = 0 + self.score_list.clear() + self.positions = None + + def clear_draft_cache(self, batch): + draft_cache = torch.cat(self.cache_list, dim=0) + batch.token_to_kv_pool.free(draft_cache) + + def generate_attn_arg_spec_extend( + self, + req_pool_indices: torch.Tensor, + paged_kernel_lens: torch.Tensor, + req_to_token: torch.Tensor, + ): + bs = self.accept_length.numel() + qo_indptr = torch.zeros( + (bs + 1,), dtype=torch.int32, device="cuda" + ) + qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0) + + cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda") + cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0) + + kv_last_page_len = torch.ones((bs,), dtype=torch.int32, device="cuda") + + kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda") + + create_flashinfer_kv_indices_triton[(bs,)]( + req_to_token, + req_pool_indices, + paged_kernel_lens, + cum_kv_seq_len, + None, + kv_indices, + req_to_token.size(1), + ) + + return kv_indices, cum_kv_seq_len, kv_last_page_len, qo_indptr + + def merge_batch(self, spec_info: EAGLEDraftInput): + + self.hidden_states = torch.cat([self.hidden_states, spec_info.hidden_states], axis=0) + self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0) + #self.positions = torch.cat([self.positions, spec_info.positions], axis=0) + self.sample_output = torch.cat([self.sample_output, spec_info.sample_output]) + +@DraftInfoFactory.register("EAGLE", "VerifyInput") +class EagleVerifyInput(SpecVerifyInput): + def __init__( + self, + draft_token: torch.Tensor, + draft_score: torch.Tensor, + tree_mask: torch.Tensor, + positions: torch.Tensor, + retrive_index: torch.Tensor, + retrive_cum_len: torch.Tensor, + draft_token_num: int, + ): + self.draft_token = draft_token + self.draft_score = draft_score + self.custom_mask = tree_mask + self.positions = positions + self.retrive_index = retrive_index + self.retrive_cum_len = retrive_cum_len + self.draft_token_num = draft_token_num + + def prepare_for_verify(self, batch: ScheduleBatch): + batch.input_ids = self.draft_token + batch.out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel()) + bs = batch.seq_lens.numel() + assign_req_to_token_pool[(bs, )](batch.req_pool_indices, + batch.req_to_token_pool.req_to_token, + batch.seq_lens, + batch.seq_lens+self.draft_token_num, + batch.out_cache_loc, + batch.req_to_token_pool.req_to_token.shape[1], + triton.next_power_of_2(bs) + ) + + def generate_attn_arg( + self, + req_pool_indices: torch.Tensor, + paged_kernel_lens: torch.Tensor, + req_to_token: torch.Tensor, + ): + batch_size = len(req_pool_indices) + qo_indptr = torch.arange( + 0, + (1 + batch_size) * self.draft_token_num, + step=self.draft_token_num, + dtype=torch.int32, + device="cuda", + ) + + cum_kv_seq_len = torch.zeros( + (batch_size + 1,), dtype=torch.int32, device="cuda" + ) + + paged_kernel_lens = paged_kernel_lens + self.draft_token_num + cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0) + + kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda") + + kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda") + + create_flashinfer_kv_indices_triton[(batch_size,)]( + req_to_token, + req_pool_indices, + paged_kernel_lens, + cum_kv_seq_len, + None, + kv_indices, + req_to_token.size(1), + ) + return kv_indices, cum_kv_seq_len, kv_last_page_len, qo_indptr + + def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Tensor: + predict = torch.argmax(logits_output.next_token_logits_bak, dim=-1) + predict = torch.cat([predict, torch.full([1], -1, dtype=torch.long, device='cuda')], dim=-1) + draft_token = torch.cat([self.draft_token, torch.full([1], -1, dtype=torch.long, device='cuda')], dim=-1) + target_predict = predict[self.retrive_index] + candidates = draft_token[self.retrive_index] + # logits = logits_output.next_token_logits[self.retrive_index] + # target_predict = torch.argmax(logits[:, :-1], dim=-1) + accept_mask = candidates[:, 1:] == target_predict[:, :-1] + accept_mask = (torch.cumprod(accept_mask, dim=1)).sum(dim=1) + bs = self.retrive_cum_len.numel() - 1 + + max_draft_len = self.retrive_index.shape[-1] + accept_index = torch.full( + (bs, max_draft_len), -1, dtype=torch.long, device="cuda" + ) + accept_length = torch.empty((bs,), dtype=torch.int, device="cuda") + extract_index = torch.full((bs * 2,), 0, dtype=torch.int, device="cuda") + eagle_verify_retrive[(bs,)]( + self.retrive_index.contiguous(), + accept_mask.contiguous(), + self.retrive_cum_len, + accept_index, + accept_length, + extract_index, + max_draft_len, + self.draft_token_num, + triton.next_power_of_2(max_draft_len), + ) + + accept_index = accept_index[accept_index != -1] + #extract_index = extract_index[extract_index != 0] + + draft_input = EAGLEDraftInput() + + accept_length_cpu = accept_length.tolist() + verified_id = predict[accept_index] + verified_id_cpu = verified_id.tolist() + + evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool) + evict_mask[accept_index] = False + mem_need_free_idx = batch.out_cache_loc[evict_mask] + batch.token_to_kv_pool.free(mem_need_free_idx) + batch.seq_lens.add_(accept_length+1) + + assign_req_to_token_pool[(bs, )](batch.req_pool_indices, + batch.req_to_token_pool.req_to_token, + batch.seq_lens, + batch.seq_lens+accept_length+1, + batch.out_cache_loc[accept_index], + batch.req_to_token_pool.req_to_token.shape[1], + triton.next_power_of_2(bs) + ) + + new_accept_index = [] + unfinished_index = [] + finished_extend_len = {} # {rid:accept_length + 1} + #retracted_reqs, new_token_ratio = batch.retract_decode() + + low = 0 + for i, (req, verified_len) in enumerate(zip(batch.reqs, accept_length_cpu)): + req.output_ids.extend(verified_id_cpu[low : low + verified_len + 1]) + req.check_finished() + if req.finished(): + draft_input.has_finished = True + finished_extend_len[req.rid] = verified_len+1 + else: + new_accept_index.append(accept_index[low: low+verified_len+1]) + unfinished_index.append(i) + low += verified_len + 1 + + if len(new_accept_index)>0: + new_accept_index = torch.cat(new_accept_index, dim=0) + draft_input.verified_id = predict[new_accept_index] + draft_input.hidden_states = batch.spec_info.hidden_states[ + new_accept_index + ] + draft_input.accept_length = accept_length[unfinished_index] + draft_input.unfinished_index = unfinished_index + + logits_output.next_token_logits = logits_output.next_token_logits_bak[accept_index] + return draft_input, logits_output, verified_id, finished_extend_len + diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py new file mode 100644 index 00000000000..d8f0aa440d7 --- /dev/null +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -0,0 +1,132 @@ +import torch +from typing import Union, List, Optional +from sglang.srt.server_args import ServerArgs +from sglang.srt.managers.tp_worker import TpModelWorker +from sglang.srt.speculative.speculative_worker import SpeculativeWorker, spec_worker_factory +from sglang.srt.managers.schedule_batch import ModelWorkerBatch, ScheduleBatch, Req +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.speculative.eagle_utils import EAGLEDraftInput, EagleVerifyInput +from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.model_executor.model_runner import ModelRunner + +@spec_worker_factory.register('EAGLE') +class EAGLEWorker(SpeculativeWorker): + def __init__( + self, + server_args: ServerArgs, + gpu_id: int, + tp_rank: int, + dp_rank: Optional[int], + nccl_port: int, + target_worker: TpModelWorker + ): + disable_cuda_graph = server_args.disable_cuda_graph + server_args.disable_cuda_graph = True + super().__init__(gpu_id=gpu_id, tp_rank=tp_rank, server_args=server_args, nccl_port=nccl_port, target_worker=target_worker, dp_rank=dp_rank) + embed, head = self.target_worker.model_runner.model.get_embed_and_head() + self.model_runner.model.set_embed_and_head(embed, head) + self.model_runner.server_args.disable_cuda_graph = disable_cuda_graph + self.model_runner.init_cuda_graphs() + self.finish_extend_len = None + + def forward_draft_decode(self, batch: ScheduleBatch): + batch.spec_info.prepare_for_decode(batch) + model_worker_batch = batch.get_model_worker_batch() + forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) + forward_batch.is_draft_batch = True + logits_output = self.model_runner.forward(forward_batch) + self.capture_for_decode(logits_output, forward_batch) + + def forward_draft_extend(self, batch: ScheduleBatch): + self._swap_mem_pool(batch, self.model_runner) + batch.spec_info.prepare_for_extend(batch) + model_worker_batch = batch.get_model_worker_batch() + forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) + logits_output = self.model_runner.forward(forward_batch) + self.capture_for_decode(logits_output, forward_batch) + self._swap_mem_pool(batch, self.target_worker.model_runner) + + def forward_batch_speculative_generate(self, batch: ScheduleBatch): + if batch.forward_mode.is_decode(): + self._swap_mem_pool(batch, self.model_runner) + for i in range(self.server_args.num_speculative_steps): + self.forward_draft_decode(batch) + batch.spec_info.clear_draft_cache(batch) + self._swap_mem_pool(batch, self.target_worker.model_runner) + next_draft_input, logits_output, verified_id, self.finish_extend_len, model_worker_batch = self.verify(batch) + next_draft_input.init(self.server_args) + batch.spec_info = next_draft_input + # if it is None, means all requsets are finished + if batch.spec_info.verified_id is not None: + self.forward_extend_after_decode(batch) + torch.cuda.synchronize() + + return logits_output, verified_id, model_worker_batch + + else: + batch.spec_info = EAGLEDraftInput() + batch.spec_info.init(self.server_args) + model_worker_batch = batch.get_model_worker_batch() + logits_output, next_token_ids = self.target_worker.forward_batch_generation(model_worker_batch) + model_worker_batch.spec_info.verified_id = next_token_ids + self.forward_draft_extend(batch) + return logits_output, next_token_ids, model_worker_batch + + def verify(self, batch: ScheduleBatch): + verify_input = batch.spec_info.prepare_for_verify(batch) + batch.forward_mode = ForwardMode.SPECVERIFY + verify_input.prepare_for_verify(batch) + batch.spec_info = verify_input + model_worker_batch = batch.get_model_worker_batch() + logits_output, _ = self.target_worker.forward_batch_generation(model_worker_batch, need_token_id=False) + verify_input.hidden_states = logits_output.hidden_states + res = verify_input.verify(batch, logits_output) + batch.forward_mode = ForwardMode.DECODE + return res + (model_worker_batch,) + + def _swap_mem_pool(self, batch: ScheduleBatch, runner: ModelRunner): + batch.token_to_kv_pool = runner.token_to_kv_pool + batch.req_to_token_pool = runner.req_to_token_pool + + def forward_extend_after_decode(self, batch: ScheduleBatch): + self._swap_mem_pool(batch, self.model_runner) + batch.forward_mode = ForwardMode.SPECEXTEND + if batch.spec_info.has_finished: + index = batch.spec_info.unfinished_index + seq_lens = batch.seq_lens + batch.seq_lens = batch.seq_lens[index] + batch.spec_info.prepare_extend_after_decode(batch) + model_worker_batch = batch.get_model_worker_batch() + forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) + + forward_batch.is_draft_batch = True + logits_output = self.model_runner.forward(forward_batch) + torch.cuda.synchronize() + self.capture_for_decode(logits_output, forward_batch) + batch.forward_mode = ForwardMode.DECODE + if batch.spec_info.has_finished: + batch.seq_lens = seq_lens + self._swap_mem_pool(batch, self.target_worker.model_runner) + + def capture_for_decode(self, logits_output, forward_batch): + if isinstance(logits_output, LogitsProcessorOutput): + logits = logits_output.next_token_logits + sample_output = torch.softmax( + logits, dim=-1 + ) # TODO: Support more sampling method @kavioyu + forward_batch.spec_info.capture_for_decode( + sample_output, logits_output.hidden_states, forward_batch.forward_mode + ) + + # Don't support prefix share now. + def finish_request(self, reqs: Union[Req, List[Req]]): + if not isinstance(reqs, List): + reqs = [reqs] + for req in reqs: + req_len = len(req.origin_input_ids) + len(req.output_ids) - self.finish_extend_len[req.rid] - 1 + kv_indices = self.model_runner.req_to_token_pool.req_to_token[ + req.req_pool_idx + ][:req_len] + self.model_runner.token_to_kv_pool.free(kv_indices) + self.model_runner.req_to_token_pool.free(req.req_pool_idx) + \ No newline at end of file diff --git a/python/sglang/srt/speculative/speculative_utils.py b/python/sglang/srt/speculative/speculative_utils.py new file mode 100644 index 00000000000..7de6fc1e5a1 --- /dev/null +++ b/python/sglang/srt/speculative/speculative_utils.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, List, Type + +import torch +import triton +import triton.language as tl + +from .build_eagle_tree import build_tree_kernel +from sglang.srt.model_executor.forward_batch_info import ForwardMode, ForwardBatch + +if TYPE_CHECKING: + from python.sglang.srt.layers.sampler import SampleOutput + from python.sglang.srt.managers.schedule_batch import ScheduleBatch + from sglang.srt.mem_cache.memory_pool import ReqToTokenPool + from sglang.srt.server_args import ServerArgs + + +class SpecInput: + pass + +class SpecVerifyInput(SpecInput): + pass + +class SpecDraftInput(SpecInput): + def prepare_for_extend(self, batch): + raise NotImplementedError() + + def prepare_for_decode(self, batch): + raise NotImplementedError() + + def generate_attn_arg( + self, + req_pool_indices: List, + paged_kernel_lens: List, + req_to_token_pool: ReqToTokenPool, + ): + raise NotImplementedError() + + def clear(): + pass + + def merge_batch(self, batch: SpecDraftInput): + raise NotImplementedError() + + +class SpecInfoFactory: + def __init__(self): + self.factory = {} + + def register(self, alg_name: str, type_name: str) -> SpecInput: + def wrapper(info: Type[SpecInput]) -> Type[SpecInput]: + assert type_name in ['DraftInput', 'VerifyInput'] + if alg_name not in self.factory: + self.factory[alg_name] = {} + self.factory[alg_name].update({type_name: info}) + return info + + return wrapper + + def get(self, alg_name, type_name: str): + if alg_name is None: + return None + return self.factory[alg_name][type_name] + + +DraftInfoFactory = SpecInfoFactory() + + diff --git a/python/sglang/srt/speculative/speculative_worker.py b/python/sglang/srt/speculative/speculative_worker.py new file mode 100644 index 00000000000..48dba1d50c6 --- /dev/null +++ b/python/sglang/srt/speculative/speculative_worker.py @@ -0,0 +1,41 @@ +from typing import Type, Union, List, Optional +from sglang.srt.server_args import ServerArgs +from sglang.srt.managers.tp_worker import TpModelWorker +from sglang.srt.managers.schedule_batch import ScheduleBatch, Req + + +class SpeculativeWorker(TpModelWorker): + is_draft_worker = True + def __init__( + self, + server_args: ServerArgs, + gpu_id: int, + tp_rank: int, + dp_rank: Optional[int], + nccl_port: int, + target_worker: TpModelWorker + ): + super().__init__(gpu_id=gpu_id, tp_rank=tp_rank, server_args=server_args, nccl_port=nccl_port, dp_rank=dp_rank) + self.target_worker = target_worker + + def forward_batch_speculative_generate(self, batch: ScheduleBatch): + raise NotImplementedError() + + def finish_request(self, reqs: Union[Req, List[Req]]): + raise NotImplementedError() + +class SpecWorkerFactory: + def __init__(self): + self.factory = {} + + def register(self, name: str) -> SpeculativeWorker: + def wrapper(info: Type[SpeculativeWorker]) -> Type[SpeculativeWorker]: + self.factory[name] = info + return info + + return wrapper + + def get(self, name): + return self.factory[name] + +spec_worker_factory = SpecWorkerFactory() \ No newline at end of file