From 70135d6440773f4b82f978eeaac4fe182a721468 Mon Sep 17 00:00:00 2001 From: kavioyu Date: Sun, 13 Oct 2024 15:49:56 +0800 Subject: [PATCH 01/26] temp --- .../runtime/engine/offline_batch_inference.py | 10 +- .../sglang/srt/layers/attention/__init__.py | 4 +- .../layers/attention/flashinfer_backend.py | 1 + .../srt/layers/attention/flashinfer_utils.py | 45 +- python/sglang/srt/layers/logits_processor.py | 5 + python/sglang/srt/managers/schedule_batch.py | 20 +- python/sglang/srt/managers/scheduler.py | 27 +- python/sglang/srt/managers/tp_worker.py | 15 +- .../srt/model_executor/forward_batch_info.py | 20 +- .../sglang/srt/model_executor/model_runner.py | 18 +- python/sglang/srt/models/llama.py | 7 +- python/sglang/srt/models/llama_eagle.py | 438 +++++++++++++++++ python/sglang/srt/server_args.py | 49 +- python/sglang/srt/speculative/__init__.py | 1 + .../srt/speculative/build_egale_tree.py | 352 ++++++++++++++ python/sglang/srt/speculative/eagle_worker.py | 55 +++ .../srt/speculative/speculative_utils.py | 442 ++++++++++++++++++ .../srt/speculative/speculative_worker.py | 48 ++ 18 files changed, 1510 insertions(+), 47 deletions(-) create mode 100644 python/sglang/srt/models/llama_eagle.py create mode 100644 python/sglang/srt/speculative/__init__.py create mode 100644 python/sglang/srt/speculative/build_egale_tree.py create mode 100644 python/sglang/srt/speculative/eagle_worker.py create mode 100644 python/sglang/srt/speculative/speculative_utils.py create mode 100644 python/sglang/srt/speculative/speculative_worker.py diff --git a/examples/runtime/engine/offline_batch_inference.py b/examples/runtime/engine/offline_batch_inference.py index 7404c7e4e7f..e644f32b057 100644 --- a/examples/runtime/engine/offline_batch_inference.py +++ b/examples/runtime/engine/offline_batch_inference.py @@ -4,17 +4,13 @@ def main(): # Sample prompts. prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", + "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: Where is the capital city of France? ASSISTANT:" ] # Create a sampling params object. - sampling_params = {"temperature": 0.8, "top_p": 0.95} + sampling_params = {"temperature": 0, "max_new_tokens": 8} # Create an LLM. - llm = sgl.Engine(model_path="meta-llama/Meta-Llama-3.1-8B-Instruct") - + llm = sgl.Engine(model_path="Llama-2-7b-chat-hf", draft_model_path='EAGLE-llama2-chat-7B', disable_cuda_graph=True, num_speculative_steps=5, num_draft_tokens=64, speculative_algorithm='EAGLE', mem_fraction_static=0.60) outputs = llm.generate(prompts, sampling_params) # Print the outputs. for prompt, output in zip(prompts, outputs): diff --git a/python/sglang/srt/layers/attention/__init__.py b/python/sglang/srt/layers/attention/__init__.py index 4cad3d8aa38..0ba039c320b 100644 --- a/python/sglang/srt/layers/attention/__init__.py +++ b/python/sglang/srt/layers/attention/__init__.py @@ -35,7 +35,9 @@ def get_cuda_graph_seq_len_fill_value(self): def forward(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): """Run forward on an attention layer.""" - if forward_batch.forward_mode.is_decode(): + if forward_batch.forward_mode.is_verify(): + return self.forward_extend(q, k, v, layer, forward_batch) + elif forward_batch.forward_mode.is_decode(): return self.forward_decode(q, k, v, layer, forward_batch) else: return self.forward_extend(q, k, v, layer, forward_batch) diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 0c9ca8f9d40..e04a2caddbe 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -133,6 +133,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): forward_batch.seq_lens, prefix_lens, use_ragged=use_ragged, + spec_info=forward_batch.spec_info, ) self.forward_metadata = ( diff --git a/python/sglang/srt/layers/attention/flashinfer_utils.py b/python/sglang/srt/layers/attention/flashinfer_utils.py index 796203c933c..6a5838a05c7 100644 --- a/python/sglang/srt/layers/attention/flashinfer_utils.py +++ b/python/sglang/srt/layers/attention/flashinfer_utils.py @@ -56,6 +56,7 @@ def __init__( prefix_lens, decode_wrappers=None, use_ragged=False, + spec_info=None, ): self.forward_mode = forward_mode self.model_runner = model_runner @@ -63,6 +64,7 @@ def __init__( self.seq_lens = seq_lens self.prefix_lens = prefix_lens self.use_ragged = use_ragged + self.spec_info = spec_info self.num_qo_heads = ( model_runner.model_config.num_attention_heads // model_runner.tp_size @@ -162,23 +164,32 @@ def _get_indices(self, dispatch_reason: WrapperDispatch = None, wrapper_id=0): paged_kernel_lens = self.seq_lens self.kv_start_idx = self.seq_lens - paged_kernel_lens - self.kv_indptr = torch.zeros( - (self.batch_size + 1,), dtype=torch.int32, device="cuda" - ) - self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) - self.kv_indices = torch.empty( - self.kv_indptr[-1], dtype=torch.int32, device="cuda" - ) + if self.spec_info is not None and self.forward_mode.is_decode(): + self.kv_indices, self.kv_indptr, self.kv_last_page_len, self.qo_indptr = ( + self.spec_info.generate_attn_arg( + self.req_pool_indices, + paged_kernel_lens, + self.model_runner.req_to_token_pool, + ) + ) + else: + self.kv_indptr = torch.zeros( + (self.batch_size + 1,), dtype=torch.int32, device="cuda" + ) + self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) + self.kv_indices = torch.empty( + self.kv_indptr[-1], dtype=torch.int32, device="cuda" + ) - create_flashinfer_kv_indices_triton[(self.batch_size,)]( - self.model_runner.req_to_token_pool.req_to_token, - self.req_pool_indices, - paged_kernel_lens, - self.kv_indptr, - self.kv_start_idx, - self.kv_indices, - self.model_runner.req_to_token_pool.req_to_token.size(1), - ) + create_flashinfer_kv_indices_triton[(self.batch_size,)]( + self.model_runner.req_to_token_pool.req_to_token, + self.req_pool_indices, + paged_kernel_lens, + self.kv_indptr, + self.kv_start_idx, + self.kv_indices, + self.model_runner.req_to_token_pool.req_to_token.size(1), + ) def _update_indicess_single_wrapper(self): self._get_indices() @@ -215,6 +226,7 @@ def update_flashinfer_indices( prefix_lens, decode_wrappers=None, use_ragged=False, + spec_info=None ): updater = FlashinferUpdater( forward_mode, @@ -224,6 +236,7 @@ def update_flashinfer_indices( prefix_lens, decode_wrappers, use_ragged, + spec_info, ) dispatch_reason = model_runner.attn_backend.dispatch_reason diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index f0c55af6255..00b113d52ed 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -168,7 +168,9 @@ def forward( weight, logits_metadata: Union[LogitsMetadata, ForwardBatch], ): + spec_info = None if isinstance(logits_metadata, ForwardBatch): + spec_info = getattr(logits_metadata, 'spec_info', None) logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata) assert isinstance(logits_metadata, LogitsMetadata) @@ -180,6 +182,9 @@ def forward( last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1 last_hidden = hidden_states[last_index] + if spec_info: + spec_info.hidden_states = last_hidden + last_logits = torch.matmul(last_hidden, weight.T) if self.do_tensor_parallel_all_gather: last_logits = tensor_model_parallel_all_gather(last_logits) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 156e830d185..58baec970a3 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1,3 +1,5 @@ +from __future__ import annotations + """ Copyright 2023-2024 SGLang Team Licensed under the Apache License, Version 2.0 (the "License"); @@ -29,7 +31,7 @@ import logging from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import TYPE_CHECKING, List, Optional, Tuple, Union import torch @@ -44,6 +46,9 @@ from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import ServerArgs +if TYPE_CHECKING: + from sglang.srt.speculative.speculative_utils import SpecInput + INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 # Put some global args for easy access @@ -428,9 +433,13 @@ class ScheduleBatch: # Has regex has_regex: bool = False + + # speculative decoding + spec_info: SpecInput = None + spec_algorithm: str = None @classmethod - def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache): + def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache, speculative_algorithm): return_logprob = any(req.return_logprob for req in reqs) has_stream = any(req.stream for req in reqs) has_regex = any(req.regex_fsm for req in reqs) @@ -444,6 +453,7 @@ def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache): has_stream=has_stream, device=req_to_token_pool.device, has_regex=has_regex, + spec_algorithm=speculative_algorithm ) def batch_size(self): @@ -827,6 +837,8 @@ def get_model_worker_batch(self): image_inputs=image_inputs, lora_paths=lora_paths, sampling_info=self.sampling_info, + spec_algorithm=self.spec_algorithm, + spec_info=self.spec_info ) @@ -860,3 +872,7 @@ class ModelWorkerBatch: # Sampling info sampling_info: SamplingBatchInfo + + # Speclulative decoding + spec_algorithm: str = None + spec_info: SpecInput = None diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index c6df4a2e81f..a995eee1add 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -56,6 +56,7 @@ PrefillAdder, SchedulePolicy, ) +from sglang.srt.speculative.speculative_worker import spec_worker_factory from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.radix_cache import RadixCache @@ -145,7 +146,17 @@ def __init__( nccl_port=port_args.nccl_port, ) self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group - + + # Launch Speculative worker if need + if self.server_args.speculative_algorithm is not None: + self.draft_worker = spec_worker_factory.get(self.server_args.speculative_algorithm)( + gpu_id=gpu_id, + tp_rank=tp_rank, + server_args=server_args, + nccl_port=port_args.nccl_port, + target_worker=self.tp_worker + ) + # Get token and memory info from the model worker ( self.max_total_num_tokens, @@ -594,6 +605,7 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: self.req_to_token_pool, self.token_to_kv_pool, self.tree_cache, + self.server_args.speculative_algorithm ) new_batch.prepare_for_extend(self.model_config.vocab_size) @@ -644,10 +656,15 @@ def get_new_batch_decode(self) -> Optional[ScheduleBatch]: def run_batch(self, batch: ScheduleBatch): if self.is_generation: if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0: - model_worker_batch = batch.get_model_worker_batch() - logits_output, next_token_ids = self.tp_worker.forward_batch_generation( - model_worker_batch - ) + if self.server_args.speculative_algorithm: + logits_output, next_token_ids, spec_info = self.draft_worker.forward_batch_speculative_generate( + batch + ) + else: + model_worker_batch = batch.get_model_worker_batch() + logits_output, next_token_ids = self.tp_worker.forward_batch_generation( + model_worker_batch + ) else: logits_output = None if self.tokenizer is not None: diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 73c4abe0864..661001130d1 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -1,3 +1,5 @@ +from __future__ import annotations + """ Copyright 2023-2024 SGLang Team Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,6 +17,8 @@ """A tensor parallel worker.""" +from typing import TYPE_CHECKING + import json import logging @@ -27,12 +31,14 @@ from sglang.srt.server_args import ServerArgs from sglang.srt.utils import broadcast_pyobj, is_multimodal_model, set_random_seed +if TYPE_CHECKING: + from sglang.srt.speculative.speculative_worker import SpeculativeWorker + logger = logging.getLogger(__name__) class TpModelWorker: """A tensor parallel model worker.""" - def __init__( self, gpu_id: int, @@ -42,14 +48,17 @@ def __init__( ): # Parse args self.tp_rank = tp_rank + self.server_args = server_args + is_draft_worker = getattr(self, 'is_draft_worker', False) # Init model and tokenizer self.model_config = ModelConfig( - server_args.model_path, + server_args.model_path if not is_draft_worker else server_args.draft_model_path, server_args.trust_remote_code, context_length=server_args.context_length, model_override_args=json.loads(server_args.json_model_override_args), ) + self.model_runner = ModelRunner( model_config=self.model_config, mem_fraction_static=server_args.mem_fraction_static, @@ -58,6 +67,7 @@ def __init__( tp_size=server_args.tp_size, nccl_port=nccl_port, server_args=server_args, + is_draft_runner=is_draft_worker ) if server_args.skip_tokenizer_init: self.tokenizer = self.processor = None @@ -113,6 +123,7 @@ def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) logits_output = self.model_runner.forward(forward_batch) next_token_ids = self.model_runner.sample(logits_output, model_worker_batch) + model_worker_batch.spec_info = forward_batch.spec_info return logits_output, next_token_ids def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch): diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index d76b981d546..b7cc1a6799b 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -42,7 +42,7 @@ from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo - + from sglang.srt.speculative.speculative_utils import SpecInput class ForwardMode(IntEnum): # Prefill a new sequence. This is deprecated now. "EXTEND" covers this case. @@ -53,6 +53,8 @@ class ForwardMode(IntEnum): DECODE = auto() # Contains both EXTEND and DECODE. MIXED = auto() + # Speculative Verify stage + SPECVERIFY = auto() def is_prefill(self): return self == ForwardMode.PREFILL @@ -61,10 +63,13 @@ def is_extend(self): return self == ForwardMode.EXTEND or self == ForwardMode.MIXED def is_decode(self): - return self == ForwardMode.DECODE + return self in (ForwardMode.DECODE, ForwardMode.SPECVERIFY) def is_mixed(self): return self == ForwardMode.MIXED + + def is_verify(self): + return self == ForwardMode.SPECVERIFY @dataclass @@ -111,6 +116,10 @@ class ForwardBatch: req_to_token_pool: ReqToTokenPool = None token_to_kv_pool: BaseTokenToKVPool = None attn_backend: AttentionBackend = None + + # Speculative decoding + spec_info: SpecInput = None + spec_algorithm: str = None @classmethod def init_new( @@ -119,7 +128,6 @@ def init_new( model_runner: ModelRunner, ): device = "cuda" - ret = cls( forward_mode=batch.forward_mode, batch_size=len(batch.seq_lens), @@ -131,10 +139,14 @@ def init_new( top_logprobs_nums=batch.top_logprobs_nums, lora_paths=batch.lora_paths, sampling_info=batch.sampling_info, + spec_algorithm=batch.spec_algorithm, + spec_info=batch.spec_info, ) # Init position information - if ret.forward_mode.is_decode(): + if ret.spec_info is not None and getattr(ret.spec_info, 'positions', None) is not None: + ret.positions = ret.spec_info.positions + elif ret.forward_mode.is_decode(): ret.positions = (ret.seq_lens - 1).to(torch.int64) else: ret.positions = torch.tensor( diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 5f0675de51a..a5946f693fb 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -77,6 +77,7 @@ def __init__( tp_size: int, nccl_port: int, server_args: ServerArgs, + is_draft_runner: bool ): # Parse args self.model_config = model_config @@ -90,6 +91,7 @@ def __init__( self.is_multimodal_model = is_multimodal_model( self.model_config.hf_config.architectures ) + self.is_draft_runner = is_draft_runner # Model-specific adjustment if ( @@ -150,9 +152,9 @@ def init_torch_distributed(self): if not self.server_args.enable_p2p_check: monkey_patch_vllm_p2p_access_check(self.gpu_id) if self.server_args.dist_init_addr: - dist_init_method = f"tcp://{self.server_args.dist_init_addr}" + dist_init_method = f"tcp://{self.server_args.dist_init_addr[1 if self.is_draft_runner else 0]}" else: - dist_init_method = f"tcp://127.0.0.1:{self.dist_port}" + dist_init_method = f"tcp://127.0.0.1:{self.dist_port[1 if self.is_draft_runner else 0]}" set_custom_all_reduce(not self.server_args.disable_custom_all_reduce) init_distributed_environment( backend=backend, @@ -161,7 +163,9 @@ def init_torch_distributed(self): local_rank=self.gpu_id, distributed_init_method=dist_init_method, ) - initialize_model_parallel(tensor_model_parallel_size=self.tp_size) + # draft model is not support parallel currently + if not self.is_draft_runner: + initialize_model_parallel(tensor_model_parallel_size=self.tp_size) min_per_gpu_memory = get_available_gpu_memory( self.device, self.gpu_id, distributed=self.tp_size > 1 ) @@ -207,7 +211,7 @@ def load_model(self): monkey_patch_vllm_dummy_weight_loader() self.load_config = LoadConfig(load_format=self.server_args.load_format) self.vllm_model_config = VllmModelConfig( - model=self.server_args.model_path, + model=self.server_args.model_path if not self.is_draft_runner else self.server_args.draft_model_path, quantization=self.server_args.quantization, tokenizer=None, tokenizer_mode=None, @@ -390,6 +394,12 @@ def init_memory_pool( ) self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory) + if self.is_draft_runner: + self.max_total_num_tokens = self.server_args.draft_runner_cache_size + else: + self.server_args.draft_runner_cache_size = self.max_total_num_tokens + + if max_total_tokens is not None: if max_total_tokens > self.max_total_num_tokens: logging.warning( diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index 930c6838d85..7ebd163e8e0 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -314,9 +314,12 @@ def forward( input_embeds: torch.Tensor = None, ) -> LogitsProcessorOutput: hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) - return self.logits_processor( + res = self.logits_processor( input_ids, hidden_states, self.lm_head.weight, forward_batch ) + if forward_batch.spec_algorithm == 'EAGLE': + forward_batch.spec_info.hidden_states = hidden_states + return res def get_hidden_dim(self, module_name): # return input_dim, output_dim @@ -417,6 +420,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, self.model.embed_tokens.weight) apply_torchao_config_(self, params_dict, set(["proj.weight"])) + def get_embed_and_head(self): + return self.model.embed_tokens.weight, self.lm_head.weight class Phi3ForCausalLM(LlamaForCausalLM): pass diff --git a/python/sglang/srt/models/llama_eagle.py b/python/sglang/srt/models/llama_eagle.py new file mode 100644 index 00000000000..14c2848819c --- /dev/null +++ b/python/sglang/srt/models/llama_eagle.py @@ -0,0 +1,438 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +# Adapted from +# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1 +# and +# https://github.com/SafeAILab/EAGLE/blob/main/eagle/model/modeling_llama_kv.py +"""Inference-only LLaMA-EAGLE model compatible with HuggingFace weights.""" + +from typing import Any, Dict, Iterable, Optional, Tuple + +import torch +from torch import nn +from transformers import LlamaConfig +from vllm.config import CacheConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.torchao_utils import apply_torchao_config_ +from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.model_executor.forward_batch_info import ForwardBatch + + +class LlamaMLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + if hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now." + ) + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class LlamaAttention(nn.Module): + def __init__( + self, + config: LlamaConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + layer_id: int = 0, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + rope_is_neox_style: bool = True, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + # MistralConfig has an optional head_dim introduced by Mistral-Nemo + self.head_dim = getattr( + config, "head_dim", self.hidden_size // self.total_num_heads + ) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=rope_is_neox_style, + ) + self.attn = RadixAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + layer_id=layer_id, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, forward_batch) + output, _ = self.o_proj(attn_output) + return output + + +class LlamaDecoderLayer(nn.Module): + def __init__( + self, + config: LlamaConfig, + layer_id: int = 0, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.layer_id = layer_id + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None + ): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings + ) + rope_is_neox_style = getattr(config, "rope_is_neox_style", True) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + self.self_attn = LlamaAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + layer_id=layer_id, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + rope_is_neox_style=rope_is_neox_style, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + self.mlp = LlamaMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + residual = hidden_states + + if self.layer_id != 0: + hidden_states = self.input_layernorm(hidden_states) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + ) + + hidden_states = residual + hidden_states + residual = hidden_states + # Fully Connected + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states, residual + + +class LlamaModel(nn.Module): + def __init__( + self, + config: LlamaConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.layers = nn.ModuleList( + [ + LlamaDecoderLayer( + config, i, quant_config=quant_config, prefix=f"model.layers.{i}" + ) + for i in range(config.num_hidden_layers) + ] + ) + # self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.fc = torch.nn.Linear(config.hidden_size * 2, config.hidden_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + ) -> torch.Tensor: + if input_embeds is None: + hidden_states = self.embed_tokens(input_ids) + else: + hidden_states = input_embeds + + hidden_states = self.fc( + torch.cat( + (hidden_states, forward_batch.spec_info.hidden_states), dim=-1 + ) + ) + + residual = None + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + forward_batch, + residual, + ) + # hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class LlamaForCausalLMEagle(nn.Module): + def __init__( + self, + config: LlamaConfig, + quant_config: Optional[QuantizationConfig] = None, + cache_config: Optional[CacheConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.quant_config = quant_config + self.model = LlamaModel(config, quant_config=quant_config) + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.logits_processor = LogitsProcessor(config) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + ) -> LogitsProcessorOutput: + hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) + logits_output = self.logits_processor( + input_ids, hidden_states, self.lm_head.weight, forward_batch + ) + if isinstance(logits_output, LogitsProcessorOutput): + logits = logits_output.next_token_logits + sample_output = torch.softmax( + logits, dim=-1 + ) # TODO: Support more sampling method @kavioyu + forward_batch.spec_info.capture_for_decode( + sample_output, forward_batch.forward_mode + ) + return logits_output + + def get_hidden_dim(self, module_name): + # return input_dim, output_dim + if module_name in ["q_proj", "o_proj", "qkv_proj"]: + return self.config.hidden_size, self.config.hidden_size + elif module_name in ["kv_proj"]: + return self.config.hidden_size, self.config.hidden_size // ( + self.config.num_attention_heads // self.config.num_key_value_heads + ) + elif module_name == "gate_up_proj": + return self.config.hidden_size, self.config.intermediate_size + elif module_name == "down_proj": + return self.config.intermediate_size, self.config.hidden_size + else: + raise NotImplementedError() + + def get_module_name(self, name): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id, num_shard) + ("qkv_proj", "q_proj", "q", 3), + ("qkv_proj", "k_proj", "k", 3), + ("qkv_proj", "v_proj", "v", 3), + ("gate_up_proj", "gate_proj", 0, 2), + ("gate_up_proj", "up_proj", 1, 2), + ] + for param_name, weight_name, shard_id, num_shard in stacked_params_mapping: + if weight_name in name: + return ( + name.replace(weight_name, param_name)[: -len(".weight")], + num_shard, + ) + return name[: -len(".weight")], 1 + + def get_num_params(self): + params_dict = dict(self.named_parameters()) + return len(params_dict) + + def load_weights( + self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None + ): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + + def load_weights_per_param(name, loaded_weight): + if "rotary_emb.inv_freq" in name or "projector" in name: + return + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + return + if name.startswith("model.vision_tower") and name not in params_dict: + return + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + return + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + if name is None or loaded_weight is None: + for name, loaded_weight in weights: + if "lm_head" not in name: + name = "model." + name + load_weights_per_param(name, loaded_weight) + else: + load_weights_per_param(name, loaded_weight) + + def set_embed_and_head(self, embed, head): + del self.model.embed_tokens.weight + del self.lm_head.weight + self.model.embed_tokens.weight = embed + self.lm_head.weight = head + torch.cuda.empty_cache() + torch.cuda.synchronize() + + +EntryClass = [LlamaForCausalLMEagle] diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 4b70b393ec9..c89d7731ddd 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -79,7 +79,7 @@ class ServerArgs: load_balance_method: str = "round_robin" # Distributed args - dist_init_addr: Optional[str] = None + dist_init_addr: Optional[List[str]] = None nnodes: int = 1 node_rank: int = 0 @@ -110,6 +110,13 @@ class ServerArgs: lora_paths: Optional[List[str]] = None max_loras_per_batch: int = 8 + #speculative decoding + draft_model_path: str = None + speculative_algorithm: str = None + num_speculative_steps: int = None + num_draft_tokens: int = None + draft_runner_cache_size: int = None + def __post_init__(self): # Set missing default values if self.tokenizer_path is None: @@ -422,8 +429,9 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--dist-init-addr", "--nccl-init-addr", # For backward compatbility. This will be removed in the future. - type=str, - help="The host address for initializing distributed backend (e.g., `192.168.0.2:25000`).", + type=List[str], + help="""The host address for initializing distributed backend (e.g., `192.168.0.2:25000`). Shoule provide + two host address if use speculative decoding""", ) parser.add_argument( "--nnodes", type=int, default=ServerArgs.nnodes, help="The number of nodes." @@ -555,6 +563,33 @@ def add_cli_args(parser: argparse.ArgumentParser): default=8, help="Maximum number of adapters for a running batch, include base-only request", ) + parser.add_argument( + "--draft-model-path", + type=str, + help="The path of the draft model weights. This can be a local folder or a Hugging Face repo ID.", + required=False, + ) + parser.add_argument( + "--speculative-algorithm", + type=str, + choices=["EAGLE"], + help="Speculative algorithm.", + required=False, + ) + parser.add_argument( + "--num-speculative-steps", + type=int, + help="The number of steps sampled from draft model in Speculative Decoding.", + required=False, + default=5, + ) + parser.add_argument( + "--num-draft-tokens", + type=int, + help="The number of token sampled from draft model in Speculative Decoding.", + required=False, + default=5, + ) @classmethod def from_cli_args(cls, args: argparse.Namespace): @@ -622,13 +657,17 @@ class PortArgs: detokenizer_ipc_name: str # The port for nccl initialization (torch.dist) - nccl_port: int + # [port] if don't use speculative decoding else [tp worker port, draft worker, port] + nccl_port: List[int] @staticmethod def init_new(server_args) -> "PortArgs": + all_port = [] port = server_args.port + 1 while True: if is_port_available(port): + all_port.append(port) + if len(all_port) == 2 if server_args.speculative_algorithm is not None else 1: break port += 1 @@ -636,7 +675,7 @@ def init_new(server_args) -> "PortArgs": tokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name, scheduler_input_ipc_name=tempfile.NamedTemporaryFile(delete=False).name, detokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name, - nccl_port=port, + nccl_port=all_port, ) diff --git a/python/sglang/srt/speculative/__init__.py b/python/sglang/srt/speculative/__init__.py new file mode 100644 index 00000000000..0bd3481e6e2 --- /dev/null +++ b/python/sglang/srt/speculative/__init__.py @@ -0,0 +1 @@ +from .eagle_worker import EAGLEWorker \ No newline at end of file diff --git a/python/sglang/srt/speculative/build_egale_tree.py b/python/sglang/srt/speculative/build_egale_tree.py new file mode 100644 index 00000000000..b28f54d3275 --- /dev/null +++ b/python/sglang/srt/speculative/build_egale_tree.py @@ -0,0 +1,352 @@ +# import triton +# import triton.language as tl +# import torch + + +import time + +import cutex +import torch + +# parent_table [bs*(topk*depth+1)] +# selected_index [bs*(draft_token_num-1)] +# verified_seq_len [bs] +# tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] = [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] +# positions [bs*draft_token] +# retrive_index [b, draft_token, depth+2] +kernels = cutex.SourceModule( + """ +//cuda +__global__ void build_tree(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, + Tensor tree_mask, Tensor positions, Tensor retrive_index, int topk, int depth, int draft_token_num) { + int bid = blockIdx.x; + int tid = threadIdx.x; + if (tid >= draft_token_num){ + return; + } + int seq_tree_idx = draft_token_num * draft_token_num * bid; + for(int i=0; i SpecDraftInput: + def wrapper(info: Type[SpecDraftInput]) -> Type[SpecDraftInput]: + self.factory[name] = info + return info + + return wrapper + + def get(self, name): + if name is None: + return None + return self.factory[name] + + +DraftInfoFactory = SpecDraftInfoFactory() + + +@DraftInfoFactory.register("EAGLE") +class EAGLEDraftInput(SpecDraftInput): + hidden_states: torch.Tensor = None + verified_id: torch.Tensor = None + positions: torch.Tensor = None + evict_mask: torch.Tensor = None + + def init(self, server_args: ServerArgs): + self.prev_mode = ForwardMode.DECODE + self.sample_output = None + self.topk: int = 10 + self.num_verify_token: int = server_args.num_draft_tokens + + self.scores: torch.Tensor = None + self.score_list: List[torch.Tensor] = [] + self.token_list: List[torch.Tensor] = [] + self.parents_list: List[torch.Tensor] = [] + self.cache_list: List[torch.Tenor] = [] + self.iter = 0 + self.root_token: int = None + assert self.topk <= 10, "topk should <= 10" + + def prepare_for_extend(self, forward_batch: ForwardBatch): + seq_lens = [0] + forward_batch.extend_seq_lens_cpu + input_ids = forward_batch.input_ids.tolist() + verified_id = forward_batch.spec_info.verified_id.tolist() + model_input_ids = [] + for i in range(len(seq_lens) - 1): + model_input_ids.extend( + input_ids[seq_lens[i] + 1 : seq_lens[i + 1]] + [verified_id[i]] + ) + forward_batch.input_ids = torch.tensor( + model_input_ids, dtype=torch.int32, device="cuda" + ) + + def capture_for_decode(self, sample_output: SampleOutput, prev_mode: ForwardMode): + self.sample_output = sample_output + self.prev_mode = prev_mode + + def prepare_for_decode(self, batch: ScheduleBatch): + prob = self.sample_output # b * (1/topk), vocab + top = torch.topk(prob, self.topk, dim=-1) + topk_index, topk_p = top.indices, top.values # b * (1/topk), topk + if self.prev_mode == ForwardMode.DECODE: + scores = torch.mul( + self.scores.unsqueeze(2), topk_p.reshape(-1, self.topk, self.topk) + ) # (b, topk) mul (b * topk ,topk) -> b, topk, topk + topk_cs = torch.topk( + scores.flatten(start_dim=1), self.topk, dim=-1 + ) # (b, topk) + topk_cs_index, topk_cs_p = topk_cs.indices, topk_cs.values + self.scores = topk_cs_p + + selected_input_index = topk_cs_index.flatten() // self.topk # b* topk + + batch.spec_info.hidden_states = batch.spec_info.hidden_states[ + selected_input_index, : + ] + batch.input_ids = torch.gather( + topk_index.reshape(-1, self.topk**2), index=topk_cs_index, dim=1 + ).flatten() + batch.out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel()) + self.score_list.append(scores) + self.token_list.append(topk_index) + self.parents_list.append( + topk_cs_index.flatten() + (self.topk**2 * (self.iter - 1) + self.topk) + ) + + elif self.prev_mode == ForwardMode.EXTEND: + self.scores = topk_p # b, top_k + self.score_list.append(topk_p.unsqueeze(1)) + self.token_list.append(topk_index) + batch.spec_info.hidden_states = ( + batch.spec_info.hidden_states.repeat_interleave(self.topk, 0) + ) + batch.input_ids = topk_index.flatten() + batch.out_cache_loc = batch.alloc_token_slots(topk_index.numel()) + self.parents_list.append( + torch.arange(-1, self.topk, dtype=torch.int, device="cuda") + ) + self.cache_list.append(batch.out_cache_loc) + self.positions = ( + batch.seq_lens[:, None] + + torch.ones([1, self.topk], device="cuda") * self.iter + ).flatten() + + batch.req_to_token_pool.req_to_token[ + batch.req_pool_indices, + batch.seq_lens + + self.topk * self.iter : batch.seq_lens + + self.topk * (self.iter + 1), + ] = batch.out_cache_loc + self.iter += 1 + + def prepare_for_verify(self, batch: ScheduleBatch): + score_list = torch.cat(self.score_list, dim=1).view(-1) # b, 1/topk, topk + ss_token_list = torch.cat(self.token_list, dim=0).view( + -1 + ) # b * (self.topk+depth*self.topk) + top_scores = torch.topk(score_list, self.num_verify_token - 1, dim=-1) + top_scores_index = top_scores.indices + top_scores_index = torch.sort(top_scores_index).values + + draft_tokens = ss_token_list[top_scores_index] + draft_tokens = torch.cat((self.verified_id, draft_tokens), dim=0) + + parent_list = torch.cat(self.parents_list[:-1], dim=0) + + tree_mask, position, retrive_index, retrive_cum_len = build_tree_kernel( + parent_list, + top_scores_index, + batch.seq_lens, + self.topk, + self.iter - 1, + self.num_verify_token, + ) + + out_cache = torch.cat(self.cache_list, dim=0) + mem_need_free_idx = out_cache[top_scores_index] + + batch.token_to_kv_pool.free(mem_need_free_idx) + return EagleVerifyInput( + draft_tokens, + tree_mask, + position, + retrive_index, + retrive_cum_len, + self.num_verify_token, + ) + + def prepare_new_draft_stage(self, batch: ScheduleBatch): + batch.input_ids = self.verified_id + + def generate_attn_arg( + self, + req_pool_indices: torch.Tensor, + paged_kernel_lens: torch.Tensor, + req_to_token_pool: ReqToTokenPool, + ): + req_pool_indices = req_pool_indices.tolist() + paged_kernel_lens = paged_kernel_lens.tolist() + bs = self.topk * len(req_pool_indices) + seq_len = self.positions.reshape(-1).contiguous() + cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda") + cum_kv_seq_len[1:] = torch.cumsum(seq_len + 1, dim=0) + kv_last_page_len = torch.ones((bs,), dtype=torch.int32, device="cuda") + kv_indices_list = [] + # TODO: reimplement it by triton @kavioyu + for i in range(len(req_pool_indices)): + for k in range(self.topk): + index = torch.arange(self.iter) * self.topk + k + kv_indices_list.append( + req_to_token_pool.req_to_token[ + req_pool_indices[i], : paged_kernel_lens[i] + ] + ) + kv_indices_list.append( + req_to_token_pool.req_to_token[ + req_pool_indices[i], paged_kernel_lens[i] + index + ] + ) + kv_indices = torch.cat(kv_indices_list, dim=0).contiguous() + return kv_indices, cum_kv_seq_len, kv_last_page_len, None + + def clear(self): + self.iter = 0 + self.score_list.clear() + self.positions = None + + +class EagleVerifyInput(SpecVerifyInput): + def __init__( + self, + draft_token: torch.Tensor, + tree_mask: torch.Tensor, + positions: torch.Tensor, + retrive_index: torch.Tensor, + retrive_cum_len: torch.Tensor, + draft_token_num: int, + ): + self.draft_token = draft_token + self.custom_mask = tree_mask + self.positions = positions + self.retrive_index = retrive_index + self.retrive_cum_len = retrive_cum_len + self.draft_token_num = draft_token_num + + def prepare_for_verify(self, batch: ScheduleBatch): + batch_size = self.retrive_cum_len.numel() + + batch.input_ids = self.draft_token + batch.out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel()) + batch.req_to_token_pool.req_to_token[ + batch.req_pool_indices, + batch.seq_lens : batch.seq_lens + self.draft_token_num, + ] = batch.out_cache_loc + + def generate_attn_arg( + self, + req_pool_indices: torch.Tensor, + paged_kernel_lens: torch.Tensor, + req_to_token_pool: ReqToTokenPool, + ): + batch_size = len(req_pool_indices) + qo_indptr = torch.arange( + 0, + (1 + batch_size) * self.draft_token_num, + step=self.draft_token_num, + dtype=torch.int32, + device="cuda", + ) + + cum_kv_seq_len = torch.zeros( + (batch_size + 1,), dtype=torch.int32, device="cuda" + ) + paged_kernel_lens = paged_kernel_lens.add_(self.draft_token_num) + cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0) + + kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda") + + kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda") + create_flashinfer_kv_indices_triton[(batch_size,)]( + req_to_token_pool.req_to_token, + req_pool_indices, + paged_kernel_lens, + cum_kv_seq_len, + None, + kv_indices, + req_to_token_pool.req_to_token.size(1), + ) + + return kv_indices, cum_kv_seq_len, kv_last_page_len, qo_indptr + + def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Tensor: + predict = torch.argmax(logits_output.next_token_logits, dim=-1) + target_predict = predict[self.retrive_index] + candidates = self.draft_token[self.retrive_index] + # logits = logits_output.next_token_logits[self.retrive_index] + # target_predict = torch.argmax(logits[:, :-1], dim=-1) + accept_mask = candidates[:, 1:] == target_predict[:, :-1] + accept_mask = (torch.cumprod(accept_mask, dim=1)).sum(dim=1) + bs = self.retrive_cum_len.numel() - 1 + + max_draft_len = self.retrive_index.shape[-1] + accept_index = torch.full( + (bs, max_draft_len), -1, dtype=torch.long, device="cuda" + ) + accept_length = torch.empty((bs,), dtype=torch.int, device="cuda") + extract_index = torch.full((bs * 2,), 0, dtype=torch.int, device="cuda") + + eagle_verify_retrive[(bs,)]( + self.retrive_index.contiguous(), + accept_mask.contiguous(), + self.retrive_cum_len, + accept_index, + accept_length, + extract_index, + max_draft_len, + self.draft_token_num, + triton.next_power_of_2(max_draft_len), + ) + accept_index = accept_index[accept_index != -1] + extract_index = extract_index[extract_index != 0] + + batch.spec_info.verified_id = predict[extract_index] + batch.spec_info.hidden_states = batch.spec_info.hidden_states[ + extract_index + ] + + accept_length_cpu = accept_length.tolist() + verified_id_cpu = predict[accept_index].tolist() + + low = 0 + for req, verified_len in zip(batch.reqs, accept_length_cpu): + req.output_ids.extend(verified_id_cpu[low : low + verified_len + 1]) + low += verified_len + + # TODO: have memory leak, fix it @kavioyu + evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool) + evict_mask[accept_index] = False + mem_need_free_idx = batch.out_cache_loc[evict_mask] + + batch.req_to_token_pool.req_to_token[ + batch.req_pool_indices, + batch.seq_lens : batch.seq_lens + self.draft_token_num, + ] = batch.out_cache_loc + + batch.token_to_kv_pool.free(mem_need_free_idx) + batch.spec_info.evict_mask = evict_mask + + return batch.spec_info.verified_id + diff --git a/python/sglang/srt/speculative/speculative_worker.py b/python/sglang/srt/speculative/speculative_worker.py new file mode 100644 index 00000000000..f57a1d30120 --- /dev/null +++ b/python/sglang/srt/speculative/speculative_worker.py @@ -0,0 +1,48 @@ +from typing import Type +from sglang.srt.server_args import ServerArgs +from sglang.srt.managers.tp_worker import TpModelWorker +from sglang.srt.managers.schedule_batch import ModelWorkerBatch + + +class SpeculativeWorker(TpModelWorker): + is_draft_worker = True + + def __init__( + self, + gpu_id: int, + tp_rank: int, + server_args: ServerArgs, + nccl_port: int, + target_worker: TpModelWorker + ): + super().__init__(gpu_id=gpu_id, tp_rank=tp_rank, server_args=server_args, nccl_port=nccl_port) + self.target_worker = target_worker + + def forward_draft_decode(self, model_worker_batch: ModelWorkerBatch): + raise NotImplementedError() + + def forward_draft_extend(self, model_worker_batch: ModelWorkerBatch): + raise NotImplementedError() + + def verify(self, model_worker_batch: ModelWorkerBatch): + raise NotImplementedError() + + def forward_batch_speculative_generate(model_worker_batch: ModelWorkerBatch): + raise NotImplementedError() + + +class SpecWorkerFactory: + def __init__(self): + self.factory = {} + + def register(self, name: str) -> SpeculativeWorker: + def wrapper(info: Type[SpeculativeWorker]) -> Type[SpeculativeWorker]: + self.factory[name] = info + return info + + return wrapper + + def get(self, name): + return self.factory[name] + +spec_worker_factory = SpecWorkerFactory() \ No newline at end of file From 65fae7bf85b5b635f18b7c3021462eb5a6d312f3 Mon Sep 17 00:00:00 2001 From: kavioyu Date: Mon, 14 Oct 2024 12:17:04 +0800 Subject: [PATCH 02/26] migrated to new upstream, need implement evict memory --- .../layers/attention/flashinfer_backend.py | 8 ++- .../srt/layers/attention/flashinfer_utils.py | 27 ++++++++- python/sglang/srt/managers/schedule_batch.py | 2 + python/sglang/srt/managers/scheduler.py | 2 +- python/sglang/srt/models/llama.py | 1 + python/sglang/srt/models/llama_eagle.py | 1 + python/sglang/srt/speculative/eagle_worker.py | 60 ++++++++++++------- .../srt/speculative/speculative_utils.py | 48 ++++++++++----- .../srt/speculative/speculative_worker.py | 13 +--- 9 files changed, 111 insertions(+), 51 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index e04a2caddbe..084d741a3ac 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -222,10 +222,16 @@ def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch) forward_batch.token_to_kv_pool.set_kv_buffer( layer.layer_id, forward_batch.out_cache_loc, k, v ) + causal = True + if ( + forward_batch.spec_algorithm == "EAGLE" + and forward_batch.forward_mode == ForwardMode.SPECVERIFY + ): + causal = False o = prefill_wrapper_paged.forward( q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), - causal=True, + causal=causal, sm_scale=layer.scaling, window_left=layer.sliding_window_size, logits_soft_cap=layer.logit_cap, diff --git a/python/sglang/srt/layers/attention/flashinfer_utils.py b/python/sglang/srt/layers/attention/flashinfer_utils.py index 6a5838a05c7..9e9d1792d46 100644 --- a/python/sglang/srt/layers/attention/flashinfer_utils.py +++ b/python/sglang/srt/layers/attention/flashinfer_utils.py @@ -91,6 +91,8 @@ def __init__( def _update_decode_indices(self, decode_wrapper): assert not isinstance(decode_wrapper, list) + print('decode update') + print(self.kv_indices) decode_wrapper.end_forward() decode_wrapper.begin_forward( self.kv_indptr, @@ -114,6 +116,9 @@ def _update_extend_indices(self, ragged_wrapper, paged_wrapper): ) qo_indptr[1:] = torch.cumsum(self.seq_lens - self.prefix_lens, dim=0) + print('extend update') + print(self.kv_indices) + if self.use_ragged: ragged_wrapper.end_forward() ragged_wrapper.begin_forward( @@ -136,6 +141,23 @@ def _update_extend_indices(self, ragged_wrapper, paged_wrapper): self.head_dim, 1, ) + + def _update_verify_indices(self, paged_wrapper): + custom_mask = getattr(self.spec_info, "custom_mask", None) + paged_wrapper.end_forward() + print('verify update') + print(self.kv_indices) + paged_wrapper.begin_forward( + self.qo_indptr, + self.kv_indptr, + self.kv_indices, + self.kv_last_page_len, + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + 1, + custom_mask=custom_mask, + ) def _get_indices(self, dispatch_reason: WrapperDispatch = None, wrapper_id=0): if dispatch_reason is None: @@ -193,8 +215,9 @@ def _get_indices(self, dispatch_reason: WrapperDispatch = None, wrapper_id=0): def _update_indicess_single_wrapper(self): self._get_indices() - - if self.forward_mode.is_decode(): + if self.forward_mode.is_verify(): + self._update_verify_indices(self.prefill_wrappers_paged[0]) + elif self.forward_mode.is_decode(): self._update_decode_indices(self.decode_wrappers[0]) else: self._update_extend_indices( diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 58baec970a3..51a5d406f49 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -730,6 +730,8 @@ def check_for_jump_forward(self, pad_input_ids_func): def prepare_for_decode(self, input_ids=None): self.forward_mode = ForwardMode.DECODE + if self.spec_algorithm == 'EAGLE': + return if input_ids is None: input_ids = [ diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index a995eee1add..ec9ebeeb62e 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -657,7 +657,7 @@ def run_batch(self, batch: ScheduleBatch): if self.is_generation: if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0: if self.server_args.speculative_algorithm: - logits_output, next_token_ids, spec_info = self.draft_worker.forward_batch_speculative_generate( + logits_output, next_token_ids = self.draft_worker.forward_batch_speculative_generate( batch ) else: diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index 7ebd163e8e0..714871306d7 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -313,6 +313,7 @@ def forward( forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> LogitsProcessorOutput: + print(forward_batch.out_cache_loc) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) res = self.logits_processor( input_ids, hidden_states, self.lm_head.weight, forward_batch diff --git a/python/sglang/srt/models/llama_eagle.py b/python/sglang/srt/models/llama_eagle.py index 14c2848819c..3c571d01777 100644 --- a/python/sglang/srt/models/llama_eagle.py +++ b/python/sglang/srt/models/llama_eagle.py @@ -326,6 +326,7 @@ def forward( forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> LogitsProcessorOutput: + print(forward_batch.out_cache_loc) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) logits_output = self.logits_processor( input_ids, hidden_states, self.lm_head.weight, forward_batch diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 4ce5e9fda86..f4c2dbc89d9 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -1,10 +1,11 @@ import torch -from sglang.srt.speculative.speculative_worker import SpeculativeWorker, spec_worker_factory -from sglang.srt.managers.schedule_batch import ModelWorkerBatch, ScheduleBatch -from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.server_args import ServerArgs from sglang.srt.managers.tp_worker import TpModelWorker -from sglang.srt.speculative.speculative_utils import DraftInfoFactory +from sglang.srt.speculative.speculative_worker import SpeculativeWorker, spec_worker_factory +from sglang.srt.managers.schedule_batch import ModelWorkerBatch, ScheduleBatch +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.speculative.speculative_utils import EAGLEDraftInput, EagleVerifyInput +from sglang.srt.model_executor.model_runner import ModelRunner @spec_worker_factory.register('EAGLE') class EAGLEWorker(SpeculativeWorker): @@ -19,37 +20,52 @@ def __init__( super().__init__(gpu_id=gpu_id, tp_rank=tp_rank, server_args=server_args, nccl_port=nccl_port, target_worker=target_worker) embed, head = self.target_worker.model_runner.model.get_embed_and_head() self.model_runner.model.set_embed_and_head(embed, head) - - def forward_draft_decode(self, batch: ScheduleBatch): - print('** start decode **') + batch.spec_info.prepare_for_decode(batch) model_worker_batch = batch.get_model_worker_batch() - forward_batch = ForwardBatch.init_new(model_worker_batch, self.target_worker.model_runner) + forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) self.model_runner.forward(forward_batch) - def forward_draft_extend(self, model_worker_batch: ModelWorkerBatch): - forward_batch = ForwardBatch.init_new(model_worker_batch, self.target_worker.model_runner) - forward_batch.spec_info.prepare_for_extend(forward_batch) + def forward_draft_extend(self, batch: ScheduleBatch): + self._swap_mem_pool(batch, self.model_runner) + batch.spec_info.prepare_for_extend(batch) + model_worker_batch = batch.get_model_worker_batch() + forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) + #forward_batch.spec_info.prepare_for_extend(forward_batch) logits_output = self.model_runner.forward(forward_batch) - next_token_ids = self.model_runner.sample(logits_output, model_worker_batch) - model_worker_batch.spec_info.verified_id = next_token_ids + #next_token_ids = self.model_runner.sample(logits_output, model_worker_batch) + self._swap_mem_pool(batch, self.target_worker.model_runner) def forward_batch_speculative_generate(self, batch: ScheduleBatch): if batch.forward_mode.is_decode(): + self._swap_mem_pool(batch, self.model_runner) for i in range(self.server_args.num_speculative_steps): self.forward_draft_decode(batch) - - model_worker_batch = batch.get_model_worker_batch() - self.forward_batch_generation(model_worker_batch) - return self.draft_worker.verify(model_worker_batch) + self._swap_mem_pool(batch, self.target_worker.model_runner) + self.verify(batch) else: + batch.spec_info = EAGLEDraftInput() + batch.spec_info.init(self.server_args) model_worker_batch = batch.get_model_worker_batch() - model_worker_batch.spec_info = DraftInfoFactory.get(model_worker_batch.spec_algorithm)() - model_worker_batch.spec_info.init(self.server_args) logits_output, next_token_ids = self.target_worker.forward_batch_generation(model_worker_batch) model_worker_batch.spec_info.verified_id = next_token_ids - self.forward_draft_extend(model_worker_batch) - batch.spec_info = model_worker_batch.spec_info - return logits_output, next_token_ids, model_worker_batch.spec_info \ No newline at end of file + self.forward_draft_extend(batch) + return logits_output, next_token_ids + + def verify(self, batch: ScheduleBatch): + print('*'*100) + verify_input = batch.spec_info.prepare_for_verify(batch) + batch.forward_mode = ForwardMode.SPECVERIFY + verify_input.prepare_for_verify(batch) + batch.spec_info = verify_input + model_worker_batch = batch.get_model_worker_batch() + logits_output, next_token_ids = self.target_worker.forward_batch_generation(model_worker_batch) + verify_input.verify(batch, logits_output) + + def _swap_mem_pool(self, batch: ScheduleBatch, runner: ModelRunner): + batch.token_to_kv_pool = runner.token_to_kv_pool + batch.req_to_token_pool = runner.req_to_token_pool + + \ No newline at end of file diff --git a/python/sglang/srt/speculative/speculative_utils.py b/python/sglang/srt/speculative/speculative_utils.py index 6447ecac46a..0bcc466dfa6 100644 --- a/python/sglang/srt/speculative/speculative_utils.py +++ b/python/sglang/srt/speculative/speculative_utils.py @@ -170,16 +170,36 @@ def init(self, server_args: ServerArgs): self.root_token: int = None assert self.topk <= 10, "topk should <= 10" - def prepare_for_extend(self, forward_batch: ForwardBatch): - seq_lens = [0] + forward_batch.extend_seq_lens_cpu - input_ids = forward_batch.input_ids.tolist() - verified_id = forward_batch.spec_info.verified_id.tolist() + def prepare_for_extend(self, batch: ForwardBatch): + req_pool_indices = batch.alloc_req_slots(len(batch.reqs)) + out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel()) + + pt=0 + for i, req in enumerate(batch.reqs): + req.req_pool_idx = req_pool_indices[i] + pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids) + assert seq_len - pre_len == req.extend_input_len + + if pre_len > 0: + batch.req_to_token_pool.req_to_token[req.req_pool_idx][ + :pre_len + ] = req.prefix_indices + + batch.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = ( + out_cache_loc[pt : pt + req.extend_input_len] + ) + + pt += req.extend_input_len + + seq_lens = [0] + batch.extend_lens + input_ids = batch.input_ids.tolist() + verified_id = batch.spec_info.verified_id.tolist() model_input_ids = [] for i in range(len(seq_lens) - 1): model_input_ids.extend( input_ids[seq_lens[i] + 1 : seq_lens[i + 1]] + [verified_id[i]] ) - forward_batch.input_ids = torch.tensor( + batch.input_ids = torch.tensor( model_input_ids, dtype=torch.int32, device="cuda" ) @@ -231,7 +251,7 @@ def prepare_for_decode(self, batch: ScheduleBatch): self.cache_list.append(batch.out_cache_loc) self.positions = ( batch.seq_lens[:, None] - + torch.ones([1, self.topk], device="cuda") * self.iter + + torch.ones([1, self.topk], device="cuda", dtype=torch.long) * self.iter - 1 ).flatten() batch.req_to_token_pool.req_to_token[ @@ -265,14 +285,14 @@ def prepare_for_verify(self, batch: ScheduleBatch): self.num_verify_token, ) - out_cache = torch.cat(self.cache_list, dim=0) - mem_need_free_idx = out_cache[top_scores_index] - - batch.token_to_kv_pool.free(mem_need_free_idx) + # out_cache = torch.cat(self.cache_list, dim=0) + # mem_need_free_idx = out_cache[top_scores_index] + # batch.token_to_kv_pool.free(mem_need_free_idx) + return EagleVerifyInput( draft_tokens, tree_mask, - position, + position-1, retrive_index, retrive_cum_len, self.num_verify_token, @@ -336,8 +356,6 @@ def __init__( self.draft_token_num = draft_token_num def prepare_for_verify(self, batch: ScheduleBatch): - batch_size = self.retrive_cum_len.numel() - batch.input_ids = self.draft_token batch.out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel()) batch.req_to_token_pool.req_to_token[ @@ -363,12 +381,14 @@ def generate_attn_arg( cum_kv_seq_len = torch.zeros( (batch_size + 1,), dtype=torch.int32, device="cuda" ) + paged_kernel_lens = paged_kernel_lens.add_(self.draft_token_num) cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0) kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda") kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda") + create_flashinfer_kv_indices_triton[(batch_size,)]( req_to_token_pool.req_to_token, req_pool_indices, @@ -419,13 +439,13 @@ def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Ten accept_length_cpu = accept_length.tolist() verified_id_cpu = predict[accept_index].tolist() + print(verified_id_cpu) low = 0 for req, verified_len in zip(batch.reqs, accept_length_cpu): req.output_ids.extend(verified_id_cpu[low : low + verified_len + 1]) low += verified_len - # TODO: have memory leak, fix it @kavioyu evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool) evict_mask[accept_index] = False mem_need_free_idx = batch.out_cache_loc[evict_mask] diff --git a/python/sglang/srt/speculative/speculative_worker.py b/python/sglang/srt/speculative/speculative_worker.py index f57a1d30120..a2022839c87 100644 --- a/python/sglang/srt/speculative/speculative_worker.py +++ b/python/sglang/srt/speculative/speculative_worker.py @@ -1,7 +1,7 @@ from typing import Type from sglang.srt.server_args import ServerArgs from sglang.srt.managers.tp_worker import TpModelWorker -from sglang.srt.managers.schedule_batch import ModelWorkerBatch +from sglang.srt.managers.schedule_batch import ScheduleBatch class SpeculativeWorker(TpModelWorker): @@ -17,17 +17,8 @@ def __init__( ): super().__init__(gpu_id=gpu_id, tp_rank=tp_rank, server_args=server_args, nccl_port=nccl_port) self.target_worker = target_worker - - def forward_draft_decode(self, model_worker_batch: ModelWorkerBatch): - raise NotImplementedError() - - def forward_draft_extend(self, model_worker_batch: ModelWorkerBatch): - raise NotImplementedError() - - def verify(self, model_worker_batch: ModelWorkerBatch): - raise NotImplementedError() - def forward_batch_speculative_generate(model_worker_batch: ModelWorkerBatch): + def forward_batch_speculative_generate(self, batch: ScheduleBatch): raise NotImplementedError() From 064cca6a434f1e5adc0974ef6261a67ff20c784b Mon Sep 17 00:00:00 2001 From: kavioyu Date: Tue, 15 Oct 2024 19:53:21 +0800 Subject: [PATCH 03/26] prove single req --- .../layers/attention/flashinfer_backend.py | 5 + .../srt/layers/attention/flashinfer_utils.py | 33 +++-- python/sglang/srt/layers/logits_processor.py | 8 +- python/sglang/srt/managers/scheduler.py | 5 +- .../srt/model_executor/forward_batch_info.py | 9 +- python/sglang/srt/models/llama.py | 1 - python/sglang/srt/models/llama_eagle.py | 1 - python/sglang/srt/speculative/eagle_worker.py | 32 ++++- .../srt/speculative/speculative_utils.py | 133 +++++++++++++++--- .../srt/speculative/speculative_worker.py | 3 + 10 files changed, 191 insertions(+), 39 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 084d741a3ac..70e285dffaa 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -112,6 +112,11 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): use_ragged = False extend_no_prefix = False total_num_tokens = None + elif forward_batch.forward_mode.is_spec_extend(): + use_ragged = False + total_num_tokens = forward_batch.spec_info.verified_id.numel() + extend_no_prefix = True + prefix_lens = None else: prefix_lens = forward_batch.extend_prefix_lens diff --git a/python/sglang/srt/layers/attention/flashinfer_utils.py b/python/sglang/srt/layers/attention/flashinfer_utils.py index 9e9d1792d46..67c1b3dab70 100644 --- a/python/sglang/srt/layers/attention/flashinfer_utils.py +++ b/python/sglang/srt/layers/attention/flashinfer_utils.py @@ -91,8 +91,6 @@ def __init__( def _update_decode_indices(self, decode_wrapper): assert not isinstance(decode_wrapper, list) - print('decode update') - print(self.kv_indices) decode_wrapper.end_forward() decode_wrapper.begin_forward( self.kv_indptr, @@ -116,9 +114,6 @@ def _update_extend_indices(self, ragged_wrapper, paged_wrapper): ) qo_indptr[1:] = torch.cumsum(self.seq_lens - self.prefix_lens, dim=0) - print('extend update') - print(self.kv_indices) - if self.use_ragged: ragged_wrapper.end_forward() ragged_wrapper.begin_forward( @@ -145,8 +140,6 @@ def _update_extend_indices(self, ragged_wrapper, paged_wrapper): def _update_verify_indices(self, paged_wrapper): custom_mask = getattr(self.spec_info, "custom_mask", None) paged_wrapper.end_forward() - print('verify update') - print(self.kv_indices) paged_wrapper.begin_forward( self.qo_indptr, self.kv_indptr, @@ -158,6 +151,20 @@ def _update_verify_indices(self, paged_wrapper): 1, custom_mask=custom_mask, ) + + def _update_spec_extend(self, paged_wrapper): + assert not isinstance(paged_wrapper, list) + paged_wrapper.end_forward() + paged_wrapper.begin_forward( + self.qo_indptr, + self.kv_indptr, + self.kv_indices, + self.kv_last_page_len, + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + 1, + ) def _get_indices(self, dispatch_reason: WrapperDispatch = None, wrapper_id=0): if dispatch_reason is None: @@ -186,7 +193,15 @@ def _get_indices(self, dispatch_reason: WrapperDispatch = None, wrapper_id=0): paged_kernel_lens = self.seq_lens self.kv_start_idx = self.seq_lens - paged_kernel_lens - if self.spec_info is not None and self.forward_mode.is_decode(): + if self.spec_info is not None and self.forward_mode.is_spec_extend(): + self.kv_indices, self.kv_indptr, self.kv_last_page_len, self.qo_indptr = ( + self.spec_info.generate_attn_arg_spec_extend( + self.req_pool_indices, + paged_kernel_lens, + self.model_runner.req_to_token_pool, + ) + ) + elif self.spec_info is not None and self.forward_mode.is_decode(): self.kv_indices, self.kv_indptr, self.kv_last_page_len, self.qo_indptr = ( self.spec_info.generate_attn_arg( self.req_pool_indices, @@ -217,6 +232,8 @@ def _update_indicess_single_wrapper(self): self._get_indices() if self.forward_mode.is_verify(): self._update_verify_indices(self.prefill_wrappers_paged[0]) + elif self.forward_mode.is_spec_extend(): + self._update_spec_extend(self.prefill_wrappers_paged[0]) elif self.forward_mode.is_decode(): self._update_decode_indices(self.decode_wrappers[0]) else: diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 00b113d52ed..8fe88bd4fe7 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -59,6 +59,8 @@ class LogitsMetadata: extend_logprob_start_lens_cpu: Optional[List[int]] = None extend_logprob_pruned_lens_cpu: Optional[List[int]] = None + + is_draft_batch: bool = False @classmethod def from_forward_batch(cls, forward_batch: ForwardBatch): @@ -67,7 +69,7 @@ def from_forward_batch(cls, forward_batch: ForwardBatch): else: return_top_logprob = False - if forward_batch.forward_mode.is_extend(): + if forward_batch.forward_mode.is_extend() and not forward_batch.is_draft_batch: extend_logprob_pruned_lens_cpu = [ extend_len - start_len for extend_len, start_len in zip( @@ -75,6 +77,7 @@ def from_forward_batch(cls, forward_batch: ForwardBatch): forward_batch.extend_logprob_start_lens_cpu, ) ] + else: extend_logprob_pruned_lens_cpu = None return cls( @@ -86,6 +89,7 @@ def from_forward_batch(cls, forward_batch: ForwardBatch): extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu, extend_logprob_start_lens_cpu=forward_batch.extend_logprob_start_lens_cpu, extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu, + is_draft_batch=forward_batch.is_draft_batch ) @@ -196,7 +200,7 @@ def forward( last_logits.mul_(self.config.final_logit_softcapping) # Return only last_logits if logprob is not requested - if not logits_metadata.return_logprob: + if not logits_metadata.return_logprob or logits_metadata.is_draft_batch: return LogitsProcessorOutput( next_token_logits=last_logits, next_token_logprobs=None, diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index ec9ebeeb62e..6d8e3c7f576 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -790,8 +790,9 @@ def process_batch_result_decode(self, batch: ScheduleBatch, result): # Check finish condition for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)): req.completion_tokens_wo_jump_forward += 1 - req.output_ids.append(next_token_id) - req.check_finished() + if batch.spec_algorithm is None: # speculative worker will solve the output_ids in speculative decoding + req.output_ids.append(next_token_id) + req.check_finished() # TODO: SUPPORT IT @kavioyu if req.regex_fsm is not None: req.regex_fsm_state = req.regex_fsm.get_next_state( diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index b7cc1a6799b..15e5d65b1b4 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -55,12 +55,14 @@ class ForwardMode(IntEnum): MIXED = auto() # Speculative Verify stage SPECVERIFY = auto() + # Speculative draft Extend stage which after verify stage + SPECEXTEND = auto() def is_prefill(self): return self == ForwardMode.PREFILL def is_extend(self): - return self == ForwardMode.EXTEND or self == ForwardMode.MIXED + return self in (ForwardMode.EXTEND, self == ForwardMode.MIXED, ForwardMode.SPECEXTEND) def is_decode(self): return self in (ForwardMode.DECODE, ForwardMode.SPECVERIFY) @@ -70,6 +72,9 @@ def is_mixed(self): def is_verify(self): return self == ForwardMode.SPECVERIFY + + def is_spec_extend(self): + return self == ForwardMode.SPECEXTEND @dataclass @@ -120,6 +125,7 @@ class ForwardBatch: # Speculative decoding spec_info: SpecInput = None spec_algorithm: str = None + is_draft_batch: bool = False @classmethod def init_new( @@ -162,6 +168,7 @@ def init_new( device=device, ).to(torch.int64) + if not ret.forward_mode.is_decode(): ret.image_inputs = batch.image_inputs ret.extend_seq_lens = torch.tensor(batch.extend_seq_lens, device=device) ret.extend_prefix_lens = torch.tensor( diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index 714871306d7..7ebd163e8e0 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -313,7 +313,6 @@ def forward( forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> LogitsProcessorOutput: - print(forward_batch.out_cache_loc) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) res = self.logits_processor( input_ids, hidden_states, self.lm_head.weight, forward_batch diff --git a/python/sglang/srt/models/llama_eagle.py b/python/sglang/srt/models/llama_eagle.py index 3c571d01777..14c2848819c 100644 --- a/python/sglang/srt/models/llama_eagle.py +++ b/python/sglang/srt/models/llama_eagle.py @@ -326,7 +326,6 @@ def forward( forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> LogitsProcessorOutput: - print(forward_batch.out_cache_loc) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) logits_output = self.logits_processor( input_ids, hidden_states, self.lm_head.weight, forward_batch diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index f4c2dbc89d9..8b951f4188a 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -22,10 +22,10 @@ def __init__( self.model_runner.model.set_embed_and_head(embed, head) def forward_draft_decode(self, batch: ScheduleBatch): - batch.spec_info.prepare_for_decode(batch) model_worker_batch = batch.get_model_worker_batch() forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) + forward_batch.is_draft_batch = True self.model_runner.forward(forward_batch) def forward_draft_extend(self, batch: ScheduleBatch): @@ -33,9 +33,7 @@ def forward_draft_extend(self, batch: ScheduleBatch): batch.spec_info.prepare_for_extend(batch) model_worker_batch = batch.get_model_worker_batch() forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) - #forward_batch.spec_info.prepare_for_extend(forward_batch) logits_output = self.model_runner.forward(forward_batch) - #next_token_ids = self.model_runner.sample(logits_output, model_worker_batch) self._swap_mem_pool(batch, self.target_worker.model_runner) def forward_batch_speculative_generate(self, batch: ScheduleBatch): @@ -43,8 +41,15 @@ def forward_batch_speculative_generate(self, batch: ScheduleBatch): self._swap_mem_pool(batch, self.model_runner) for i in range(self.server_args.num_speculative_steps): self.forward_draft_decode(batch) + batch.spec_info.clear_draft_cache(batch) self._swap_mem_pool(batch, self.target_worker.model_runner) - self.verify(batch) + next_draft_input, logits_output = self.verify(batch) + verified_id = next_draft_input.verified_id + next_draft_input.init(self.server_args) + batch.spec_info = next_draft_input + self.forward_extend_after_decode(batch) + return logits_output, verified_id + else: batch.spec_info = EAGLEDraftInput() batch.spec_info.init(self.server_args) @@ -55,17 +60,32 @@ def forward_batch_speculative_generate(self, batch: ScheduleBatch): return logits_output, next_token_ids def verify(self, batch: ScheduleBatch): - print('*'*100) verify_input = batch.spec_info.prepare_for_verify(batch) batch.forward_mode = ForwardMode.SPECVERIFY verify_input.prepare_for_verify(batch) batch.spec_info = verify_input model_worker_batch = batch.get_model_worker_batch() logits_output, next_token_ids = self.target_worker.forward_batch_generation(model_worker_batch) - verify_input.verify(batch, logits_output) + res = verify_input.verify(batch, logits_output) + batch.forward_mode = ForwardMode.DECODE + return res def _swap_mem_pool(self, batch: ScheduleBatch, runner: ModelRunner): batch.token_to_kv_pool = runner.token_to_kv_pool batch.req_to_token_pool = runner.req_to_token_pool + def forward_extend_after_decode(self, batch: ScheduleBatch): + self._swap_mem_pool(batch, self.model_runner) + batch.forward_mode = ForwardMode.SPECEXTEND + batch.spec_info.prepare_extend_after_decode(batch) + model_worker_batch = batch.get_model_worker_batch() + forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) + forward_batch.is_draft_batch = True + logits_output = self.model_runner.forward(forward_batch) + batch.forward_mode = ForwardMode.DECODE + self._swap_mem_pool(batch, self.model_runner) + + def post_decode_process(self, batch): + return self.forward_extend_after_decode(batch) + \ No newline at end of file diff --git a/python/sglang/srt/speculative/speculative_utils.py b/python/sglang/srt/speculative/speculative_utils.py index 0bcc466dfa6..387803113be 100644 --- a/python/sglang/srt/speculative/speculative_utils.py +++ b/python/sglang/srt/speculative/speculative_utils.py @@ -100,6 +100,21 @@ def eagle_verify_retrive( extract_data = tl.load(extract_load_ptr) tl.store(extract_index + pid * 2, extract_data) +@triton.jit +def create_extend_spec_info(verified_id, seq_len, accept_len, accept_len_cum, positions, new_verified_id, accept_len_upper: tl.constexpr): + pid = tl.program_id(axis=0) + offset = 0 if pid ==0 else tl.load(accept_len_cum+pid-1) + seq_length = tl.load(seq_len+pid) + accept_length = tl.load(accept_len+pid) + positions_ptr = positions+offset + data = tl.arange(0, accept_len_upper) + mask = data < accept_length + tl.store(positions_ptr+data, seq_length-accept_length+data, mask) + + offset = tl.load(accept_len_cum+pid)-1 + verified_id_data = tl.load(verified_id+offset) + tl.store(new_verified_id+pid, verified_id_data) + class SpecInput: pass @@ -153,13 +168,14 @@ class EAGLEDraftInput(SpecDraftInput): hidden_states: torch.Tensor = None verified_id: torch.Tensor = None positions: torch.Tensor = None - evict_mask: torch.Tensor = None + accept_length: torch.Tensor = None def init(self, server_args: ServerArgs): self.prev_mode = ForwardMode.DECODE self.sample_output = None self.topk: int = 10 self.num_verify_token: int = server_args.num_draft_tokens + self.spec_steps = server_args.num_speculative_steps self.scores: torch.Tensor = None self.score_list: List[torch.Tensor] = [] @@ -173,6 +189,7 @@ def init(self, server_args: ServerArgs): def prepare_for_extend(self, batch: ForwardBatch): req_pool_indices = batch.alloc_req_slots(len(batch.reqs)) out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel()) + batch.out_cache_loc = out_cache_loc pt=0 for i, req in enumerate(batch.reqs): @@ -236,7 +253,7 @@ def prepare_for_decode(self, batch: ScheduleBatch): topk_cs_index.flatten() + (self.topk**2 * (self.iter - 1) + self.topk) ) - elif self.prev_mode == ForwardMode.EXTEND: + elif self.prev_mode in (ForwardMode.EXTEND, ForwardMode.SPECEXTEND) : self.scores = topk_p # b, top_k self.score_list.append(topk_p.unsqueeze(1)) self.token_list.append(topk_index) @@ -251,7 +268,7 @@ def prepare_for_decode(self, batch: ScheduleBatch): self.cache_list.append(batch.out_cache_loc) self.positions = ( batch.seq_lens[:, None] - + torch.ones([1, self.topk], device="cuda", dtype=torch.long) * self.iter - 1 + + torch.ones([1, self.topk], device="cuda", dtype=torch.long) * self.iter ).flatten() batch.req_to_token_pool.req_to_token[ @@ -261,6 +278,39 @@ def prepare_for_decode(self, batch: ScheduleBatch): + self.topk * (self.iter + 1), ] = batch.out_cache_loc self.iter += 1 + + def prepare_extend_after_decode(self, batch: ScheduleBatch): + #req_pool_indices = batch.alloc_req_slots(len(batch.reqs)) + batch.out_cache_loc = batch.alloc_token_slots(self.verified_id.numel()) + batch.extend_lens = (self.accept_length+1).tolist() + + pt=0 + positions = [] + seq_lens = batch.seq_lens.tolist() + for i, req in enumerate(batch.reqs): + #assert seq_len - pre_len == req.extend_input_len + input_len = self.accept_length[i] + 1 + seq_len = seq_lens[i] + batch.req_to_token_pool.req_to_token[req.req_pool_idx][seq_len-input_len:seq_len] = ( + batch.out_cache_loc[pt : pt + input_len] + ) + pt += input_len + + + self.positions = torch.empty_like(self.verified_id) + new_verified_id = torch.empty_like(self.accept_length, dtype=torch.long) + self.accept_length.add_(1) + + create_extend_spec_info[(self.accept_length.numel(),)](self.verified_id, batch.seq_lens, + self.accept_length, torch.cumsum(self.accept_length, axis=0, dtype=torch.int), + self.positions, new_verified_id, triton.next_power_of_2(self.spec_steps+1)) + + torch.save((self.verified_id, batch.seq_lens, self.accept_length, + torch.cumsum(self.accept_length, axis=0, dtype=torch.int)), 'test.pth') + + batch.input_ids = self.verified_id + self.verified_id = new_verified_id + def prepare_for_verify(self, batch: ScheduleBatch): score_list = torch.cat(self.score_list, dim=1).view(-1) # b, 1/topk, topk @@ -292,7 +342,7 @@ def prepare_for_verify(self, batch: ScheduleBatch): return EagleVerifyInput( draft_tokens, tree_mask, - position-1, + position, retrive_index, retrive_cum_len, self.num_verify_token, @@ -315,7 +365,7 @@ def generate_attn_arg( cum_kv_seq_len[1:] = torch.cumsum(seq_len + 1, dim=0) kv_last_page_len = torch.ones((bs,), dtype=torch.int32, device="cuda") kv_indices_list = [] - # TODO: reimplement it by triton @kavioyu + # TODO: reimplement it by triton if it is slow @kavioyu for i in range(len(req_pool_indices)): for k in range(self.topk): index = torch.arange(self.iter) * self.topk + k @@ -336,6 +386,41 @@ def clear(self): self.iter = 0 self.score_list.clear() self.positions = None + + def clear_draft_cache(self, batch): + draft_cache = torch.cat(self.cache_list, dim=0) + batch.token_to_kv_pool.free(draft_cache) + + def generate_attn_arg_spec_extend( + self, + req_pool_indices: torch.Tensor, + paged_kernel_lens: torch.Tensor, + req_to_token_pool: ReqToTokenPool, + ): + bs = self.accept_length.numel() + qo_indptr = torch.zeros( + (bs + 1,), dtype=torch.int32, device="cuda" + ) + qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0) + + cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda") + cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0) + + kv_last_page_len = torch.ones((bs,), dtype=torch.int32, device="cuda") + + kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda") + + create_flashinfer_kv_indices_triton[(bs,)]( + req_to_token_pool.req_to_token, + req_pool_indices, + paged_kernel_lens, + cum_kv_seq_len, + None, + kv_indices, + req_to_token_pool.req_to_token.size(1), + ) + + return kv_indices, cum_kv_seq_len, kv_last_page_len, qo_indptr class EagleVerifyInput(SpecVerifyInput): @@ -398,7 +483,7 @@ def generate_attn_arg( kv_indices, req_to_token_pool.req_to_token.size(1), ) - + paged_kernel_lens = paged_kernel_lens.sub_(self.draft_token_num) return kv_indices, cum_kv_seq_len, kv_last_page_len, qo_indptr def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Tensor: @@ -430,16 +515,22 @@ def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Ten triton.next_power_of_2(max_draft_len), ) accept_index = accept_index[accept_index != -1] - extract_index = extract_index[extract_index != 0] + #extract_index = extract_index[extract_index != 0] + + draft_input = EAGLEDraftInput() - batch.spec_info.verified_id = predict[extract_index] - batch.spec_info.hidden_states = batch.spec_info.hidden_states[ - extract_index + accept_length_cpu = accept_length.tolist() + + draft_input.verified_id = predict[accept_index] + + draft_input.hidden_states = batch.spec_info.hidden_states[ + accept_index ] + draft_input.accept_length = accept_length + - accept_length_cpu = accept_length.tolist() - verified_id_cpu = predict[accept_index].tolist() - print(verified_id_cpu) + + verified_id_cpu = draft_input.verified_id.tolist() low = 0 for req, verified_len in zip(batch.reqs, accept_length_cpu): @@ -449,14 +540,20 @@ def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Ten evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool) evict_mask[accept_index] = False mem_need_free_idx = batch.out_cache_loc[evict_mask] - + + # TODO: support batch inference @kavioyu batch.req_to_token_pool.req_to_token[ batch.req_pool_indices, - batch.seq_lens : batch.seq_lens + self.draft_token_num, - ] = batch.out_cache_loc + batch.seq_lens : batch.seq_lens + accept_length+1, + ] = batch.out_cache_loc[accept_index] + batch.token_to_kv_pool.free(mem_need_free_idx) - batch.spec_info.evict_mask = evict_mask + #batch.spec_info.evict_mask = evict_mask + batch.seq_lens.add_(accept_length+1) + + #print(batch.req_to_token_pool.req_to_token[0][:60]) - return batch.spec_info.verified_id + logits_output.next_token_logits = logits_output.next_token_logits[accept_index] + return draft_input, logits_output diff --git a/python/sglang/srt/speculative/speculative_worker.py b/python/sglang/srt/speculative/speculative_worker.py index a2022839c87..228e4aa4885 100644 --- a/python/sglang/srt/speculative/speculative_worker.py +++ b/python/sglang/srt/speculative/speculative_worker.py @@ -21,6 +21,9 @@ def __init__( def forward_batch_speculative_generate(self, batch: ScheduleBatch): raise NotImplementedError() + def post_decode_process(self, batch: ScheduleBatch): + # do nothing by default + pass class SpecWorkerFactory: def __init__(self): From cb01c64282846d1c8cc67bd975b4daebe44e58ca Mon Sep 17 00:00:00 2001 From: kavioyu Date: Wed, 16 Oct 2024 16:26:29 +0800 Subject: [PATCH 04/26] fix bug for long generate due to eagle_verify_retrive kernel --- python/sglang/srt/layers/logits_processor.py | 5 ++++ .../srt/speculative/speculative_utils.py | 28 +++++++++++++++++-- 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 8fe88bd4fe7..14b481d8c3a 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -184,6 +184,11 @@ def forward( last_hidden = hidden_states else: last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1 + # print('&&&') + # print(logits_metadata.extend_seq_lens) + # print(last_index) + # print(hidden_states.shape) + # print('&&&') last_hidden = hidden_states[last_index] if spec_info: diff --git a/python/sglang/srt/speculative/speculative_utils.py b/python/sglang/srt/speculative/speculative_utils.py index 387803113be..5eb534ee60c 100644 --- a/python/sglang/srt/speculative/speculative_utils.py +++ b/python/sglang/srt/speculative/speculative_utils.py @@ -76,15 +76,27 @@ def eagle_verify_retrive( accept_len_list = tl.load( accept_ptr + accept_offset, mask=accept_load_mask, other=-1 ) - max_index = tl.argmax(accept_len_list, axis=0) + accept_len = tl.max(accept_len_list) + max_index = tl.argmax(accept_len_list, axis=0, tie_break_left=True) + # triton is not support argmax with tie_break_right, so I need implement it by some way + mask_max = accept_len_list == accept_len + + count_mask = tl.full(shape=[draft_token_num], value=0, dtype=tl.int32) + + count = tl.sum(tl.where(mask_max, 1, count_mask)) + if count>1: + index = tl.arange(0, draft_token_num) + mask_left = index != max_index + remained_index = tl.where(mask_max and mask_left, index, 0) + max_index = tl.max(remained_index) tl.store(accept_length + pid, accept_len) retrive_index_ptr = retrive_index + (retrive_start + max_index) * max_len retrive_offset = tl.arange(0, max_len_upper) retrive_load_mask = retrive_offset < accept_len + 1 data = tl.load(retrive_index_ptr + retrive_offset, mask=retrive_load_mask) - + tl.store( accept_index + pid * max_len + retrive_offset, data, mask=retrive_load_mask ) @@ -488,6 +500,7 @@ def generate_attn_arg( def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Tensor: predict = torch.argmax(logits_output.next_token_logits, dim=-1) + predict = torch.cat([predict, torch.full([1], -1, dtype=torch.long, device='cuda')], dim=-1) target_predict = predict[self.retrive_index] candidates = self.draft_token[self.retrive_index] # logits = logits_output.next_token_logits[self.retrive_index] @@ -514,6 +527,7 @@ def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Ten self.draft_token_num, triton.next_power_of_2(max_draft_len), ) + old_accept_index = accept_index accept_index = accept_index[accept_index != -1] #extract_index = extract_index[extract_index != 0] @@ -528,6 +542,16 @@ def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Ten ] draft_input.accept_length = accept_length + + if accept_length.item() != accept_index.numel()-1: + print(target_predict) + print(candidates) + print(accept_index) + print(old_accept_index) + print(accept_length) + print(self.retrive_index) + print(accept_mask) + verified_id_cpu = draft_input.verified_id.tolist() From df3de9d63ea6bb4b578745fa8668fd39844207f7 Mon Sep 17 00:00:00 2001 From: kavioyu Date: Wed, 16 Oct 2024 19:21:20 +0800 Subject: [PATCH 05/26] fix bug of eagle spec verify --- .../srt/speculative/speculative_utils.py | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/python/sglang/srt/speculative/speculative_utils.py b/python/sglang/srt/speculative/speculative_utils.py index 5eb534ee60c..a7059397a22 100644 --- a/python/sglang/srt/speculative/speculative_utils.py +++ b/python/sglang/srt/speculative/speculative_utils.py @@ -83,7 +83,6 @@ def eagle_verify_retrive( mask_max = accept_len_list == accept_len count_mask = tl.full(shape=[draft_token_num], value=0, dtype=tl.int32) - count = tl.sum(tl.where(mask_max, 1, count_mask)) if count>1: index = tl.arange(0, draft_token_num) @@ -185,7 +184,7 @@ class EAGLEDraftInput(SpecDraftInput): def init(self, server_args: ServerArgs): self.prev_mode = ForwardMode.DECODE self.sample_output = None - self.topk: int = 10 + self.topk: int = 8 self.num_verify_token: int = server_args.num_draft_tokens self.spec_steps = server_args.num_speculative_steps @@ -501,8 +500,9 @@ def generate_attn_arg( def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Tensor: predict = torch.argmax(logits_output.next_token_logits, dim=-1) predict = torch.cat([predict, torch.full([1], -1, dtype=torch.long, device='cuda')], dim=-1) + draft_token = torch.cat([self.draft_token, torch.full([1], -1, dtype=torch.long, device='cuda')], dim=-1) target_predict = predict[self.retrive_index] - candidates = self.draft_token[self.retrive_index] + candidates = draft_token[self.retrive_index] # logits = logits_output.next_token_logits[self.retrive_index] # target_predict = torch.argmax(logits[:, :-1], dim=-1) accept_mask = candidates[:, 1:] == target_predict[:, :-1] @@ -515,7 +515,6 @@ def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Ten ) accept_length = torch.empty((bs,), dtype=torch.int, device="cuda") extract_index = torch.full((bs * 2,), 0, dtype=torch.int, device="cuda") - eagle_verify_retrive[(bs,)]( self.retrive_index.contiguous(), accept_mask.contiguous(), @@ -542,18 +541,6 @@ def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Ten ] draft_input.accept_length = accept_length - - if accept_length.item() != accept_index.numel()-1: - print(target_predict) - print(candidates) - print(accept_index) - print(old_accept_index) - print(accept_length) - print(self.retrive_index) - print(accept_mask) - - - verified_id_cpu = draft_input.verified_id.tolist() low = 0 From b7628f2c2ed96f06727c1d792b12f3c889fd6414 Mon Sep 17 00:00:00 2001 From: kavioyu Date: Sat, 19 Oct 2024 14:49:58 +0800 Subject: [PATCH 06/26] support cuda graph --- .../sglang/srt/layers/attention/__init__.py | 4 +- .../layers/attention/flashinfer_backend.py | 15 +- .../srt/layers/attention/triton_backend.py | 4 +- python/sglang/srt/managers/scheduler.py | 1 + .../srt/model_executor/cuda_graph_runner.py | 73 ++- .../srt/model_executor/forward_batch_info.py | 7 +- .../sglang/srt/model_executor/model_runner.py | 6 +- python/sglang/srt/models/llama_eagle.py | 15 +- python/sglang/srt/speculative/__init__.py | 3 +- python/sglang/srt/speculative/eagle_utils.py | 525 ++++++++++++++++++ python/sglang/srt/speculative/eagle_worker.py | 26 +- .../srt/speculative/speculative_utils.py | 511 +---------------- 12 files changed, 632 insertions(+), 558 deletions(-) create mode 100644 python/sglang/srt/speculative/eagle_utils.py diff --git a/python/sglang/srt/layers/attention/__init__.py b/python/sglang/srt/layers/attention/__init__.py index 0ba039c320b..26ce3431bf6 100644 --- a/python/sglang/srt/layers/attention/__init__.py +++ b/python/sglang/srt/layers/attention/__init__.py @@ -18,13 +18,13 @@ def init_cuda_graph_state(self, max_bs: int): raise NotImplementedError() def init_forward_metadata_capture_cuda_graph( - self, bs: int, req_pool_indices, seq_lens + self, num_token: int, req_pool_indices, seq_lens ): """Init the metadata for a forward pass for capturing a cuda graph.""" raise NotImplementedError() def init_forward_metadata_replay_cuda_graph( - self, bs: int, req_pool_indices, seq_lens + self, bs: int, num_token: int, req_pool_indices, seq_lens, spec_info ): """Init the metadata for a forward pass for replying a cuda graph.""" raise NotImplementedError() diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 70e285dffaa..c98884eebc5 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -170,7 +170,7 @@ def init_cuda_graph_state(self, max_bs: int): ] def init_forward_metadata_capture_cuda_graph( - self, bs: int, req_pool_indices, seq_lens + self, num_token: int, req_pool_indices, seq_lens, spec_info ): decode_wrappers = [] for i in range(self.num_wrappers): @@ -180,9 +180,9 @@ def init_forward_metadata_capture_cuda_graph( "NHD", use_cuda_graph=True, use_tensor_cores=self.decode_use_tensor_cores, - paged_kv_indptr_buffer=self.cuda_graph_kv_indptr[i][: bs + 1], + paged_kv_indptr_buffer=self.cuda_graph_kv_indptr[i][: num_token + 1], paged_kv_indices_buffer=self.cuda_graph_kv_indices[i], - paged_kv_last_page_len_buffer=self.cuda_graph_kv_last_page_len[:bs], + paged_kv_last_page_len_buffer=self.cuda_graph_kv_last_page_len[:num_token], ) ) @@ -193,22 +193,25 @@ def init_forward_metadata_capture_cuda_graph( seq_lens, None, decode_wrappers, + spec_info=spec_info ) - self.cuda_graph_metadata[bs] = decode_wrappers + self.cuda_graph_metadata[num_token] = decode_wrappers self.forward_metadata = (False, False, None, decode_wrappers) def init_forward_metadata_replay_cuda_graph( - self, bs: int, req_pool_indices, seq_lens + self, bs: int, num_token: int, req_pool_indices, seq_lens, spec_info ): + # num_token == bs if not use speculative decoding with eagle2 update_flashinfer_indices( ForwardMode.DECODE, self.model_runner, req_pool_indices[:bs], seq_lens[:bs], None, - self.cuda_graph_metadata[bs], + self.cuda_graph_metadata[num_token], + spec_info=spec_info ) def get_cuda_graph_seq_len_fill_value(self): diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 82b9596bf16..b9e1a6c4196 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -79,7 +79,7 @@ def init_cuda_graph_state(self, max_bs: int): ) def init_forward_metadata_capture_cuda_graph( - self, bs: int, req_pool_indices, seq_lens + self, num_token: int, req_pool_indices, seq_lens ): self.forward_metadata = ( self.cuda_graph_start_loc, @@ -89,7 +89,7 @@ def init_forward_metadata_capture_cuda_graph( ) def init_forward_metadata_replay_cuda_graph( - self, bs: int, req_pool_indices, seq_lens + self, bs: int, num_token: int, req_pool_indices, seq_lens, spec_info ): self.cuda_graph_start_loc.zero_() self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 6d8e3c7f576..85b8e162443 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -138,6 +138,7 @@ def __init__( self.model_config.hf_config.architectures, self.server_args.is_embedding ) + # Launch a tensor parallel worker self.tp_worker = TpModelWorker( gpu_id=gpu_id, diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index cdf3a77c99f..2c72e197346 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -33,6 +33,7 @@ ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.utils import monkey_patch_vllm_all_gather +from sglang.srt.speculative.speculative_utils import DraftInfoFactory if TYPE_CHECKING: from sglang.srt.model_executor.model_runner import ModelRunner @@ -107,6 +108,7 @@ def __init__(self, model_runner: "ModelRunner"): self.disable_padding = model_runner.server_args.disable_cuda_graph_padding # Batch sizes to capture + # For speculative decoding, it means number of input token if self.model_runner.server_args.disable_cuda_graph_padding: self.capture_bs = list(range(1, 32)) + [64, 128] else: @@ -115,10 +117,17 @@ def __init__(self, model_runner: "ModelRunner"): self.capture_bs = [ bs for bs in self.capture_bs if bs <= model_runner.req_to_token_pool.size ] + + if model_runner.server_args.speculative_algorithm == 'EAGLE' and model_runner.is_draft_runner: + # TODO: Support edit top_k in config @kavioyu + self.num_tokens = [bs * 8 for bs in self.capture_bs] + else: + self.num_tokens = [bs for bs in self.capture_bs] + self.compile_bs = ( [ bs - for bs in self.capture_bs + for bs in self.num_tokens if bs <= self.model_runner.server_args.max_torch_compile_bs ] if self.use_torch_compile @@ -127,7 +136,8 @@ def __init__(self, model_runner: "ModelRunner"): # Attention backend self.max_bs = max(self.capture_bs) - self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs) + self.max_num_token = max(self.num_tokens) + self.model_runner.attn_backend.init_cuda_graph_state(self.max_num_token) self.seq_len_fill_value = ( self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value() ) @@ -137,12 +147,17 @@ def __init__(self, model_runner: "ModelRunner"): # Common inputs with torch.device("cuda"): - self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32) + self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int32) self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32) self.seq_lens = torch.full( (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 ) - self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32) + self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int32) + self.positions = torch.zeros((self.max_num_token, ), dtype=torch.int64) + + # speculative_inference + if self.model_runner.server_args.speculative_algorithm == 'EAGLE': + self.hidden_states = torch.zeros((self.max_num_token, self.model_runner.model_config.hidden_size), dtype=self.model_runner.dtype) # Capture try: @@ -166,7 +181,7 @@ def can_run(self, batch_size: int): def capture(self): with graph_capture() as graph_capture_context: self.stream = graph_capture_context.stream - for bs in self.capture_bs: + for bs, num_token in zip(self.capture_bs, self.num_tokens): with patch_model( self.model_runner.model, bs in self.compile_bs, @@ -175,23 +190,31 @@ def capture(self): ( graph, output_buffers, - ) = self.capture_one_batch_size(bs, forward) + ) = self.capture_one_batch_size(bs, num_token, forward) self.graphs[bs] = graph self.output_buffers[bs] = output_buffers - def capture_one_batch_size(self, bs: int, forward: Callable): + def capture_one_batch_size(self, bs: int, num_token: int, forward: Callable): graph = torch.cuda.CUDAGraph() stream = self.stream # Common inputs - input_ids = self.input_ids[:bs] + input_ids = self.input_ids[:num_token] req_pool_indices = self.req_pool_indices[:bs] seq_lens = self.seq_lens[:bs] - out_cache_loc = self.out_cache_loc[:bs] - + out_cache_loc = self.out_cache_loc[:num_token] + positions = self.positions[:num_token] + + spec_info = None + if self.model_runner.server_args.speculative_algorithm == 'EAGLE' and self.model_runner.is_draft_runner: + spec_info = DraftInfoFactory.get(self.model_runner.server_args.speculative_algorithm)() + spec_info.hidden_states = self.hidden_states[:num_token] + spec_info.positions = positions + spec_info.init(self.model_runner.server_args) + # Attention backend self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph( - bs, req_pool_indices, seq_lens + num_token, req_pool_indices, seq_lens, spec_info ) # Run and capture @@ -207,8 +230,10 @@ def run_once(): attn_backend=self.model_runner.attn_backend, out_cache_loc=out_cache_loc, return_logprob=False, - top_logprobs_nums=[0] * bs, - positions=torch.clamp((seq_lens - 1), min=0).to(torch.int64), + top_logprobs_nums=[0] * num_token, + positions=positions, + #positions=torch.clamp((seq_lens - 1), min=0).to(torch.int64), + spec_info=spec_info, ) return forward(input_ids, forward_batch.positions, forward_batch) @@ -236,23 +261,33 @@ def run_once(): def replay(self, forward_batch: ForwardBatch): assert forward_batch.out_cache_loc is not None raw_bs = forward_batch.batch_size + # In most case, raw_bs == num_token in decode stage. + # But for speculative, the token num maybe large than raw_bs + raw_num_token = forward_batch.input_ids.numel() # Pad index = bisect.bisect_left(self.capture_bs, raw_bs) bs = self.capture_bs[index] - if bs != raw_bs: + index = bisect.bisect_left(self.num_tokens, raw_bs) + num_token = self.num_tokens[index] + if bs != raw_num_token: self.seq_lens.fill_(self.seq_len_fill_value) self.out_cache_loc.zero_() # Common inputs - self.input_ids[:raw_bs] = forward_batch.input_ids + self.input_ids[:num_token] = forward_batch.input_ids self.req_pool_indices[:raw_bs] = forward_batch.req_pool_indices self.seq_lens[:raw_bs] = forward_batch.seq_lens - self.out_cache_loc[:raw_bs] = forward_batch.out_cache_loc + self.out_cache_loc[:num_token] = forward_batch.out_cache_loc + self.positions[:num_token] = forward_batch.positions + + # EAGLE speculative decoding + if isinstance(forward_batch.spec_info, DraftInfoFactory.get('EAGLE')): + self.hidden_states[:num_token] = forward_batch.spec_info.hidden_states # Attention backend self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph( - bs, self.req_pool_indices, self.seq_lens + bs, num_token, self.req_pool_indices, self.seq_lens, forward_batch.spec_info ) # Replay @@ -260,9 +295,9 @@ def replay(self, forward_batch: ForwardBatch): logits_output = self.output_buffers[bs] # Unpad - if bs != raw_bs: + if raw_num_token != num_token: logits_output = LogitsProcessorOutput( - next_token_logits=logits_output.next_token_logits[:raw_bs], + next_token_logits=logits_output.next_token_logits[:num_token], next_token_logprobs=None, normalized_prompt_logprobs=None, input_token_logprobs=None, diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 15e5d65b1b4..e3f887c0c96 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -75,6 +75,9 @@ def is_verify(self): def is_spec_extend(self): return self == ForwardMode.SPECEXTEND + + def is_cuda_graph(self): + return self == ForwardMode.DECODE @dataclass @@ -183,7 +186,9 @@ def init_new( ret.req_to_token_pool = model_runner.req_to_token_pool ret.token_to_kv_pool = model_runner.token_to_kv_pool ret.attn_backend = model_runner.attn_backend - model_runner.attn_backend.init_forward_metadata(ret) + + if not batch.forward_mode.is_decode(): + model_runner.attn_backend.init_forward_metadata(ret) # Init lora information if model_runner.server_args.lora_paths is not None: diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index a5946f693fb..7af76581f70 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -502,10 +502,10 @@ def init_cuda_graphs(self): def forward_decode(self, forward_batch: ForwardBatch): if self.cuda_graph_runner and self.cuda_graph_runner.can_run( - forward_batch.batch_size - ): + forward_batch.input_ids.numel() + ) and forward_batch.forward_mode.is_cuda_graph(): return self.cuda_graph_runner.replay(forward_batch) - + self.attn_backend.init_forward_metadata(forward_batch) return self.model.forward( forward_batch.input_ids, forward_batch.positions, forward_batch ) diff --git a/python/sglang/srt/models/llama_eagle.py b/python/sglang/srt/models/llama_eagle.py index 14c2848819c..5c3dd378296 100644 --- a/python/sglang/srt/models/llama_eagle.py +++ b/python/sglang/srt/models/llama_eagle.py @@ -290,6 +290,7 @@ def forward( (hidden_states, forward_batch.spec_info.hidden_states), dim=-1 ) ) + #hidden_states = forward_batch.spec_info.hidden_states residual = None for i in range(len(self.layers)): @@ -300,7 +301,6 @@ def forward( forward_batch, residual, ) - # hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -327,18 +327,7 @@ def forward( input_embeds: torch.Tensor = None, ) -> LogitsProcessorOutput: hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) - logits_output = self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, forward_batch - ) - if isinstance(logits_output, LogitsProcessorOutput): - logits = logits_output.next_token_logits - sample_output = torch.softmax( - logits, dim=-1 - ) # TODO: Support more sampling method @kavioyu - forward_batch.spec_info.capture_for_decode( - sample_output, forward_batch.forward_mode - ) - return logits_output + return hidden_states def get_hidden_dim(self, module_name): # return input_dim, output_dim diff --git a/python/sglang/srt/speculative/__init__.py b/python/sglang/srt/speculative/__init__.py index 0bd3481e6e2..f56d3b89ef1 100644 --- a/python/sglang/srt/speculative/__init__.py +++ b/python/sglang/srt/speculative/__init__.py @@ -1 +1,2 @@ -from .eagle_worker import EAGLEWorker \ No newline at end of file +from .eagle_worker import EAGLEWorker +from . import eagle_utils \ No newline at end of file diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py new file mode 100644 index 00000000000..3aed3faf176 --- /dev/null +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -0,0 +1,525 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, List, Type + +import torch +import triton +import triton.language as tl + +from .build_egale_tree import build_tree_kernel +from sglang.srt.model_executor.forward_batch_info import ForwardMode, ForwardBatch +from sglang.srt.speculative.speculative_utils import SpecDraftInput, SpecVerifyInput, DraftInfoFactory + +if TYPE_CHECKING: + from python.sglang.srt.layers.sampler import SampleOutput + from python.sglang.srt.managers.schedule_batch import ScheduleBatch + from sglang.srt.mem_cache.memory_pool import ReqToTokenPool + from sglang.srt.server_args import ServerArgs + + +# Copy from sglang.srt.layers.flashinfer_utils.create_flashinfer_kv_indices_triton due to import error +@triton.jit +def create_flashinfer_kv_indices_triton( + req_to_token_ptr, # [max_batch, max_context_len] + req_pool_indices_ptr, + page_kernel_lens_ptr, + kv_indptr, + kv_start_idx, + kv_indices_ptr, + max_context_len: tl.constexpr, +): + BLOCK_SIZE: tl.constexpr = 512 + pid = tl.program_id(axis=0) + req_pool_index = tl.load(req_pool_indices_ptr + pid) + kv_indices_offset = tl.load(kv_indptr + pid) + + kv_start = 0 + kv_end = 0 + if kv_start_idx: + kv_start = tl.load(kv_start_idx + pid).to(tl.int32) + kv_end = kv_start + kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32) + + req_to_token_ptr += req_pool_index * max_context_len + kv_indices_ptr += kv_indices_offset + + ld_offset = kv_start + tl.arange(0, BLOCK_SIZE) + st_offset = tl.arange(0, BLOCK_SIZE) + num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE) + for _ in range(num_loop): + mask = ld_offset < kv_end + data = tl.load(req_to_token_ptr + ld_offset, mask=mask) + tl.store(kv_indices_ptr + st_offset, data, mask=mask) + ld_offset += BLOCK_SIZE + st_offset += BLOCK_SIZE + + +@triton.jit +def eagle_verify_retrive( + retrive_index, + accept_mask, + retrive_cum_len, + accept_index, + accept_length, + extract_index, + max_len: tl.constexpr, + draft_token_num: tl.constexpr, + max_len_upper: tl.constexpr, +): + pid = tl.program_id(axis=0) + + retrive_end = tl.load(retrive_cum_len + pid + 1) + retrive_start = tl.load(retrive_cum_len + pid) + retrive_len = retrive_end - retrive_start + accept_ptr = accept_mask + retrive_start + accept_offset = tl.arange(0, draft_token_num) + accept_load_mask = accept_offset < retrive_len + accept_len_list = tl.load( + accept_ptr + accept_offset, mask=accept_load_mask, other=-1 + ) + + accept_len = tl.max(accept_len_list) + max_index = tl.argmax(accept_len_list, axis=0, tie_break_left=True) + # triton is not support argmax with tie_break_right, so I need implement it by some way + mask_max = accept_len_list == accept_len + + count_mask = tl.full(shape=[draft_token_num], value=0, dtype=tl.int32) + count = tl.sum(tl.where(mask_max, 1, count_mask)) + if count>1: + index = tl.arange(0, draft_token_num) + mask_left = index != max_index + remained_index = tl.where(mask_max and mask_left, index, 0) + max_index = tl.max(remained_index) + + tl.store(accept_length + pid, accept_len) + retrive_index_ptr = retrive_index + (retrive_start + max_index) * max_len + retrive_offset = tl.arange(0, max_len_upper) + retrive_load_mask = retrive_offset < accept_len + 1 + data = tl.load(retrive_index_ptr + retrive_offset, mask=retrive_load_mask) + + tl.store( + accept_index + pid * max_len + retrive_offset, data, mask=retrive_load_mask + ) + + extract_load_ptr = accept_index + pid * max_len + accept_len + if accept_len == max_len - 1: + extract_data = tl.load(extract_load_ptr - 1) + tl.store(extract_index + pid * 2, extract_data) + extract_data = tl.load(extract_load_ptr) + tl.store(extract_index + pid * 2 + 1, extract_data) + + else: + extract_data = tl.load(extract_load_ptr) + tl.store(extract_index + pid * 2, extract_data) + +@triton.jit +def create_extend_spec_info(verified_id, seq_len, accept_len, accept_len_cum, positions, new_verified_id, accept_len_upper: tl.constexpr): + pid = tl.program_id(axis=0) + offset = 0 if pid ==0 else tl.load(accept_len_cum+pid-1) + seq_length = tl.load(seq_len+pid) + accept_length = tl.load(accept_len+pid) + positions_ptr = positions+offset + data = tl.arange(0, accept_len_upper) + mask = data < accept_length + tl.store(positions_ptr+data, seq_length-accept_length+data, mask) + + offset = tl.load(accept_len_cum+pid)-1 + verified_id_data = tl.load(verified_id+offset) + tl.store(new_verified_id+pid, verified_id_data) + + +@DraftInfoFactory.register("EAGLE") +class EAGLEDraftInput(SpecDraftInput): + hidden_states: torch.Tensor = None + verified_id: torch.Tensor = None + positions: torch.Tensor = None + accept_length: torch.Tensor = None + + def init(self, server_args: ServerArgs): + self.prev_mode = ForwardMode.DECODE + self.sample_output = None + self.topk: int = 8 + self.num_verify_token: int = server_args.num_draft_tokens + self.spec_steps = server_args.num_speculative_steps + + self.scores: torch.Tensor = None + self.score_list: List[torch.Tensor] = [] + self.token_list: List[torch.Tensor] = [] + self.parents_list: List[torch.Tensor] = [] + self.cache_list: List[torch.Tenor] = [] + self.iter = 0 + self.root_token: int = None + assert self.topk <= 10, "topk should <= 10" + + def prepare_for_extend(self, batch: ForwardBatch): + req_pool_indices = batch.alloc_req_slots(len(batch.reqs)) + out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel()) + batch.out_cache_loc = out_cache_loc + + pt=0 + for i, req in enumerate(batch.reqs): + req.req_pool_idx = req_pool_indices[i] + pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids) + assert seq_len - pre_len == req.extend_input_len + + if pre_len > 0: + batch.req_to_token_pool.req_to_token[req.req_pool_idx][ + :pre_len + ] = req.prefix_indices + + batch.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = ( + out_cache_loc[pt : pt + req.extend_input_len] + ) + + pt += req.extend_input_len + + seq_lens = [0] + batch.extend_lens + input_ids = batch.input_ids.tolist() + verified_id = batch.spec_info.verified_id.tolist() + model_input_ids = [] + for i in range(len(seq_lens) - 1): + model_input_ids.extend( + input_ids[seq_lens[i] + 1 : seq_lens[i + 1]] + [verified_id[i]] + ) + batch.input_ids = torch.tensor( + model_input_ids, dtype=torch.int32, device="cuda" + ) + + def capture_for_decode(self, sample_output: SampleOutput, prev_mode: ForwardMode): + self.sample_output = sample_output + self.prev_mode = prev_mode + + def prepare_for_decode(self, batch: ScheduleBatch): + prob = self.sample_output # b * (1/topk), vocab + top = torch.topk(prob, self.topk, dim=-1) + topk_index, topk_p = top.indices, top.values # b * (1/topk), topk + if self.prev_mode == ForwardMode.DECODE: + scores = torch.mul( + self.scores.unsqueeze(2), topk_p.reshape(-1, self.topk, self.topk) + ) # (b, topk) mul (b * topk ,topk) -> b, topk, topk + topk_cs = torch.topk( + scores.flatten(start_dim=1), self.topk, dim=-1 + ) # (b, topk) + topk_cs_index, topk_cs_p = topk_cs.indices, topk_cs.values + self.scores = topk_cs_p + + selected_input_index = topk_cs_index.flatten() // self.topk # b* topk + + batch.spec_info.hidden_states = batch.spec_info.hidden_states[ + selected_input_index, : + ] + batch.input_ids = torch.gather( + topk_index.reshape(-1, self.topk**2), index=topk_cs_index, dim=1 + ).flatten() + batch.out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel()) + self.score_list.append(scores) + self.token_list.append(topk_index) + self.parents_list.append( + topk_cs_index.flatten() + (self.topk**2 * (self.iter - 1) + self.topk) + ) + + elif self.prev_mode in (ForwardMode.EXTEND, ForwardMode.SPECEXTEND) : + self.scores = topk_p # b, top_k + self.score_list.append(topk_p.unsqueeze(1)) + self.token_list.append(topk_index) + batch.spec_info.hidden_states = ( + batch.spec_info.hidden_states.repeat_interleave(self.topk, 0) + ) + batch.input_ids = topk_index.flatten() + batch.out_cache_loc = batch.alloc_token_slots(topk_index.numel()) + self.parents_list.append( + torch.arange(-1, self.topk, dtype=torch.int, device="cuda") + ) + self.cache_list.append(batch.out_cache_loc) + self.positions = ( + batch.seq_lens[:, None] + + torch.ones([1, self.topk], device="cuda", dtype=torch.long) * self.iter + ).flatten() + + batch.req_to_token_pool.req_to_token[ + batch.req_pool_indices, + batch.seq_lens + + self.topk * self.iter : batch.seq_lens + + self.topk * (self.iter + 1), + ] = batch.out_cache_loc + self.iter += 1 + + def prepare_extend_after_decode(self, batch: ScheduleBatch): + #req_pool_indices = batch.alloc_req_slots(len(batch.reqs)) + batch.out_cache_loc = batch.alloc_token_slots(self.verified_id.numel()) + batch.extend_lens = (self.accept_length+1).tolist() + + pt=0 + positions = [] + seq_lens = batch.seq_lens.tolist() + for i, req in enumerate(batch.reqs): + #assert seq_len - pre_len == req.extend_input_len + input_len = self.accept_length[i] + 1 + seq_len = seq_lens[i] + batch.req_to_token_pool.req_to_token[req.req_pool_idx][seq_len-input_len:seq_len] = ( + batch.out_cache_loc[pt : pt + input_len] + ) + pt += input_len + + + self.positions = torch.empty_like(self.verified_id) + new_verified_id = torch.empty_like(self.accept_length, dtype=torch.long) + self.accept_length.add_(1) + + create_extend_spec_info[(self.accept_length.numel(),)](self.verified_id, batch.seq_lens, + self.accept_length, torch.cumsum(self.accept_length, axis=0, dtype=torch.int), + self.positions, new_verified_id, triton.next_power_of_2(self.spec_steps+1)) + + torch.save((self.verified_id, batch.seq_lens, self.accept_length, + torch.cumsum(self.accept_length, axis=0, dtype=torch.int)), 'test.pth') + + batch.input_ids = self.verified_id + self.verified_id = new_verified_id + + + def prepare_for_verify(self, batch: ScheduleBatch): + score_list = torch.cat(self.score_list, dim=1).view(-1) # b, 1/topk, topk + ss_token_list = torch.cat(self.token_list, dim=0).view( + -1 + ) # b * (self.topk+depth*self.topk) + top_scores = torch.topk(score_list, self.num_verify_token - 1, dim=-1) + top_scores_index = top_scores.indices + top_scores_index = torch.sort(top_scores_index).values + + draft_tokens = ss_token_list[top_scores_index] + draft_tokens = torch.cat((self.verified_id, draft_tokens), dim=0) + + parent_list = torch.cat(self.parents_list[:-1], dim=0) + + tree_mask, position, retrive_index, retrive_cum_len = build_tree_kernel( + parent_list, + top_scores_index, + batch.seq_lens, + self.topk, + self.iter - 1, + self.num_verify_token, + ) + + # out_cache = torch.cat(self.cache_list, dim=0) + # mem_need_free_idx = out_cache[top_scores_index] + # batch.token_to_kv_pool.free(mem_need_free_idx) + + return EagleVerifyInput( + draft_tokens, + tree_mask, + position, + retrive_index, + retrive_cum_len, + self.num_verify_token, + ) + + def prepare_new_draft_stage(self, batch: ScheduleBatch): + batch.input_ids = self.verified_id + + def generate_attn_arg( + self, + req_pool_indices: torch.Tensor, + paged_kernel_lens: torch.Tensor, + req_to_token_pool: ReqToTokenPool, + ): + req_pool_indices = req_pool_indices.tolist() + paged_kernel_lens = paged_kernel_lens.tolist() + bs = self.topk * len(req_pool_indices) + seq_len = self.positions.reshape(-1).contiguous() + + cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda") + cum_kv_seq_len[1:] = torch.cumsum(seq_len + 1, dim=0) + kv_last_page_len = torch.ones((bs,), dtype=torch.int32, device="cuda") + kv_indices_list = [] + # TODO: reimplement it by triton if it is slow @kavioyu + for i in range(len(req_pool_indices)): + for k in range(self.topk): + index = torch.arange(self.iter) * self.topk + k + kv_indices_list.append( + req_to_token_pool.req_to_token[ + req_pool_indices[i], : paged_kernel_lens[i] + ] + ) + kv_indices_list.append( + req_to_token_pool.req_to_token[ + req_pool_indices[i], paged_kernel_lens[i] + index + ] + ) + kv_indices = torch.cat(kv_indices_list, dim=0).contiguous() + return kv_indices, cum_kv_seq_len, kv_last_page_len, None + + def clear(self): + self.iter = 0 + self.score_list.clear() + self.positions = None + + def clear_draft_cache(self, batch): + draft_cache = torch.cat(self.cache_list, dim=0) + batch.token_to_kv_pool.free(draft_cache) + + def generate_attn_arg_spec_extend( + self, + req_pool_indices: torch.Tensor, + paged_kernel_lens: torch.Tensor, + req_to_token_pool: ReqToTokenPool, + ): + bs = self.accept_length.numel() + qo_indptr = torch.zeros( + (bs + 1,), dtype=torch.int32, device="cuda" + ) + qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0) + + cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda") + cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0) + + kv_last_page_len = torch.ones((bs,), dtype=torch.int32, device="cuda") + + kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda") + + create_flashinfer_kv_indices_triton[(bs,)]( + req_to_token_pool.req_to_token, + req_pool_indices, + paged_kernel_lens, + cum_kv_seq_len, + None, + kv_indices, + req_to_token_pool.req_to_token.size(1), + ) + + return kv_indices, cum_kv_seq_len, kv_last_page_len, qo_indptr + + +class EagleVerifyInput(SpecVerifyInput): + def __init__( + self, + draft_token: torch.Tensor, + tree_mask: torch.Tensor, + positions: torch.Tensor, + retrive_index: torch.Tensor, + retrive_cum_len: torch.Tensor, + draft_token_num: int, + ): + self.draft_token = draft_token + self.custom_mask = tree_mask + self.positions = positions + self.retrive_index = retrive_index + self.retrive_cum_len = retrive_cum_len + self.draft_token_num = draft_token_num + + def prepare_for_verify(self, batch: ScheduleBatch): + batch.input_ids = self.draft_token + batch.out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel()) + batch.req_to_token_pool.req_to_token[ + batch.req_pool_indices, + batch.seq_lens : batch.seq_lens + self.draft_token_num, + ] = batch.out_cache_loc + + def generate_attn_arg( + self, + req_pool_indices: torch.Tensor, + paged_kernel_lens: torch.Tensor, + req_to_token_pool: ReqToTokenPool, + ): + batch_size = len(req_pool_indices) + qo_indptr = torch.arange( + 0, + (1 + batch_size) * self.draft_token_num, + step=self.draft_token_num, + dtype=torch.int32, + device="cuda", + ) + + cum_kv_seq_len = torch.zeros( + (batch_size + 1,), dtype=torch.int32, device="cuda" + ) + + paged_kernel_lens = paged_kernel_lens.add_(self.draft_token_num) + cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0) + + kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda") + + kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda") + + create_flashinfer_kv_indices_triton[(batch_size,)]( + req_to_token_pool.req_to_token, + req_pool_indices, + paged_kernel_lens, + cum_kv_seq_len, + None, + kv_indices, + req_to_token_pool.req_to_token.size(1), + ) + paged_kernel_lens = paged_kernel_lens.sub_(self.draft_token_num) + return kv_indices, cum_kv_seq_len, kv_last_page_len, qo_indptr + + def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Tensor: + predict = torch.argmax(logits_output.next_token_logits, dim=-1) + predict = torch.cat([predict, torch.full([1], -1, dtype=torch.long, device='cuda')], dim=-1) + draft_token = torch.cat([self.draft_token, torch.full([1], -1, dtype=torch.long, device='cuda')], dim=-1) + target_predict = predict[self.retrive_index] + candidates = draft_token[self.retrive_index] + # logits = logits_output.next_token_logits[self.retrive_index] + # target_predict = torch.argmax(logits[:, :-1], dim=-1) + accept_mask = candidates[:, 1:] == target_predict[:, :-1] + accept_mask = (torch.cumprod(accept_mask, dim=1)).sum(dim=1) + bs = self.retrive_cum_len.numel() - 1 + + max_draft_len = self.retrive_index.shape[-1] + accept_index = torch.full( + (bs, max_draft_len), -1, dtype=torch.long, device="cuda" + ) + accept_length = torch.empty((bs,), dtype=torch.int, device="cuda") + extract_index = torch.full((bs * 2,), 0, dtype=torch.int, device="cuda") + eagle_verify_retrive[(bs,)]( + self.retrive_index.contiguous(), + accept_mask.contiguous(), + self.retrive_cum_len, + accept_index, + accept_length, + extract_index, + max_draft_len, + self.draft_token_num, + triton.next_power_of_2(max_draft_len), + ) + old_accept_index = accept_index + accept_index = accept_index[accept_index != -1] + #extract_index = extract_index[extract_index != 0] + + draft_input = EAGLEDraftInput() + + accept_length_cpu = accept_length.tolist() + + draft_input.verified_id = predict[accept_index] + + draft_input.hidden_states = batch.spec_info.hidden_states[ + accept_index + ] + draft_input.accept_length = accept_length + + verified_id_cpu = draft_input.verified_id.tolist() + + low = 0 + for req, verified_len in zip(batch.reqs, accept_length_cpu): + req.output_ids.extend(verified_id_cpu[low : low + verified_len + 1]) + low += verified_len + + evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool) + evict_mask[accept_index] = False + mem_need_free_idx = batch.out_cache_loc[evict_mask] + + # TODO: support batch inference @kavioyu + batch.req_to_token_pool.req_to_token[ + batch.req_pool_indices, + batch.seq_lens : batch.seq_lens + accept_length+1, + ] = batch.out_cache_loc[accept_index] + + + batch.token_to_kv_pool.free(mem_need_free_idx) + #batch.spec_info.evict_mask = evict_mask + batch.seq_lens.add_(accept_length+1) + + #print(batch.req_to_token_pool.req_to_token[0][:60]) + + logits_output.next_token_logits = logits_output.next_token_logits[accept_index] + return draft_input, logits_output + diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 8b951f4188a..7336f0ee383 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -4,7 +4,8 @@ from sglang.srt.speculative.speculative_worker import SpeculativeWorker, spec_worker_factory from sglang.srt.managers.schedule_batch import ModelWorkerBatch, ScheduleBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode -from sglang.srt.speculative.speculative_utils import EAGLEDraftInput, EagleVerifyInput +from sglang.srt.speculative.eagle_utils import EAGLEDraftInput, EagleVerifyInput +from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.model_executor.model_runner import ModelRunner @spec_worker_factory.register('EAGLE') @@ -17,16 +18,21 @@ def __init__( nccl_port: int, target_worker: TpModelWorker ): + disable_cuda_graph = server_args.disable_cuda_graph + server_args.disable_cuda_graph = True super().__init__(gpu_id=gpu_id, tp_rank=tp_rank, server_args=server_args, nccl_port=nccl_port, target_worker=target_worker) embed, head = self.target_worker.model_runner.model.get_embed_and_head() self.model_runner.model.set_embed_and_head(embed, head) + self.model_runner.server_args.disable_cuda_graph = disable_cuda_graph + self.model_runner.init_cuda_graphs() def forward_draft_decode(self, batch: ScheduleBatch): batch.spec_info.prepare_for_decode(batch) model_worker_batch = batch.get_model_worker_batch() forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) forward_batch.is_draft_batch = True - self.model_runner.forward(forward_batch) + logits_output = self.model_runner.forward(forward_batch) + self.capture_for_decode(logits_output, forward_batch) def forward_draft_extend(self, batch: ScheduleBatch): self._swap_mem_pool(batch, self.model_runner) @@ -34,6 +40,7 @@ def forward_draft_extend(self, batch: ScheduleBatch): model_worker_batch = batch.get_model_worker_batch() forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) logits_output = self.model_runner.forward(forward_batch) + self.capture_for_decode(logits_output, forward_batch) self._swap_mem_pool(batch, self.target_worker.model_runner) def forward_batch_speculative_generate(self, batch: ScheduleBatch): @@ -82,10 +89,25 @@ def forward_extend_after_decode(self, batch: ScheduleBatch): forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) forward_batch.is_draft_batch = True logits_output = self.model_runner.forward(forward_batch) + self.capture_for_decode(logits_output, forward_batch) batch.forward_mode = ForwardMode.DECODE self._swap_mem_pool(batch, self.model_runner) def post_decode_process(self, batch): return self.forward_extend_after_decode(batch) + def capture_for_decode(self, hidden_states, forward_batch): + # lm head is not support cuda graph currently. But it could be support theoretically. + # TODO: Support it. @kavioyu + logits_output = self.model_runner.model.logits_processor( + None, hidden_states, self.model_runner.model.lm_head.weight, forward_batch + ) + if isinstance(logits_output, LogitsProcessorOutput): + logits = logits_output.next_token_logits + sample_output = torch.softmax( + logits, dim=-1 + ) # TODO: Support more sampling method @kavioyu + forward_batch.spec_info.capture_for_decode( + sample_output, forward_batch.forward_mode + ) \ No newline at end of file diff --git a/python/sglang/srt/speculative/speculative_utils.py b/python/sglang/srt/speculative/speculative_utils.py index a7059397a22..005282bd3c6 100644 --- a/python/sglang/srt/speculative/speculative_utils.py +++ b/python/sglang/srt/speculative/speculative_utils.py @@ -16,120 +16,11 @@ from sglang.srt.server_args import ServerArgs -# Copy from sglang.srt.layers.flashinfer_utils.create_flashinfer_kv_indices_triton due to import error -@triton.jit -def create_flashinfer_kv_indices_triton( - req_to_token_ptr, # [max_batch, max_context_len] - req_pool_indices_ptr, - page_kernel_lens_ptr, - kv_indptr, - kv_start_idx, - kv_indices_ptr, - max_context_len: tl.constexpr, -): - BLOCK_SIZE: tl.constexpr = 512 - pid = tl.program_id(axis=0) - req_pool_index = tl.load(req_pool_indices_ptr + pid) - kv_indices_offset = tl.load(kv_indptr + pid) - - kv_start = 0 - kv_end = 0 - if kv_start_idx: - kv_start = tl.load(kv_start_idx + pid).to(tl.int32) - kv_end = kv_start - kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32) - - req_to_token_ptr += req_pool_index * max_context_len - kv_indices_ptr += kv_indices_offset - - ld_offset = kv_start + tl.arange(0, BLOCK_SIZE) - st_offset = tl.arange(0, BLOCK_SIZE) - num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE) - for _ in range(num_loop): - mask = ld_offset < kv_end - data = tl.load(req_to_token_ptr + ld_offset, mask=mask) - tl.store(kv_indices_ptr + st_offset, data, mask=mask) - ld_offset += BLOCK_SIZE - st_offset += BLOCK_SIZE - - -@triton.jit -def eagle_verify_retrive( - retrive_index, - accept_mask, - retrive_cum_len, - accept_index, - accept_length, - extract_index, - max_len: tl.constexpr, - draft_token_num: tl.constexpr, - max_len_upper: tl.constexpr, -): - pid = tl.program_id(axis=0) - - retrive_end = tl.load(retrive_cum_len + pid + 1) - retrive_start = tl.load(retrive_cum_len + pid) - retrive_len = retrive_end - retrive_start - accept_ptr = accept_mask + retrive_start - accept_offset = tl.arange(0, draft_token_num) - accept_load_mask = accept_offset < retrive_len - accept_len_list = tl.load( - accept_ptr + accept_offset, mask=accept_load_mask, other=-1 - ) - - accept_len = tl.max(accept_len_list) - max_index = tl.argmax(accept_len_list, axis=0, tie_break_left=True) - # triton is not support argmax with tie_break_right, so I need implement it by some way - mask_max = accept_len_list == accept_len - - count_mask = tl.full(shape=[draft_token_num], value=0, dtype=tl.int32) - count = tl.sum(tl.where(mask_max, 1, count_mask)) - if count>1: - index = tl.arange(0, draft_token_num) - mask_left = index != max_index - remained_index = tl.where(mask_max and mask_left, index, 0) - max_index = tl.max(remained_index) - - tl.store(accept_length + pid, accept_len) - retrive_index_ptr = retrive_index + (retrive_start + max_index) * max_len - retrive_offset = tl.arange(0, max_len_upper) - retrive_load_mask = retrive_offset < accept_len + 1 - data = tl.load(retrive_index_ptr + retrive_offset, mask=retrive_load_mask) - - tl.store( - accept_index + pid * max_len + retrive_offset, data, mask=retrive_load_mask - ) - - extract_load_ptr = accept_index + pid * max_len + accept_len - if accept_len == max_len - 1: - extract_data = tl.load(extract_load_ptr - 1) - tl.store(extract_index + pid * 2, extract_data) - extract_data = tl.load(extract_load_ptr) - tl.store(extract_index + pid * 2 + 1, extract_data) - - else: - extract_data = tl.load(extract_load_ptr) - tl.store(extract_index + pid * 2, extract_data) - -@triton.jit -def create_extend_spec_info(verified_id, seq_len, accept_len, accept_len_cum, positions, new_verified_id, accept_len_upper: tl.constexpr): - pid = tl.program_id(axis=0) - offset = 0 if pid ==0 else tl.load(accept_len_cum+pid-1) - seq_length = tl.load(seq_len+pid) - accept_length = tl.load(accept_len+pid) - positions_ptr = positions+offset - data = tl.arange(0, accept_len_upper) - mask = data < accept_length - tl.store(positions_ptr+data, seq_length-accept_length+data, mask) - - offset = tl.load(accept_len_cum+pid)-1 - verified_id_data = tl.load(verified_id+offset) - tl.store(new_verified_id+pid, verified_id_data) - - class SpecInput: pass +class SpecVerifyInput(SpecInput): + pass class SpecDraftInput(SpecInput): def prepare_for_extend(self, batch): @@ -150,10 +41,6 @@ def clear(): pass -class SpecVerifyInput(SpecInput): - pass - - class SpecDraftInfoFactory: def __init__(self): self.factory = {} @@ -174,397 +61,3 @@ def get(self, name): DraftInfoFactory = SpecDraftInfoFactory() -@DraftInfoFactory.register("EAGLE") -class EAGLEDraftInput(SpecDraftInput): - hidden_states: torch.Tensor = None - verified_id: torch.Tensor = None - positions: torch.Tensor = None - accept_length: torch.Tensor = None - - def init(self, server_args: ServerArgs): - self.prev_mode = ForwardMode.DECODE - self.sample_output = None - self.topk: int = 8 - self.num_verify_token: int = server_args.num_draft_tokens - self.spec_steps = server_args.num_speculative_steps - - self.scores: torch.Tensor = None - self.score_list: List[torch.Tensor] = [] - self.token_list: List[torch.Tensor] = [] - self.parents_list: List[torch.Tensor] = [] - self.cache_list: List[torch.Tenor] = [] - self.iter = 0 - self.root_token: int = None - assert self.topk <= 10, "topk should <= 10" - - def prepare_for_extend(self, batch: ForwardBatch): - req_pool_indices = batch.alloc_req_slots(len(batch.reqs)) - out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel()) - batch.out_cache_loc = out_cache_loc - - pt=0 - for i, req in enumerate(batch.reqs): - req.req_pool_idx = req_pool_indices[i] - pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids) - assert seq_len - pre_len == req.extend_input_len - - if pre_len > 0: - batch.req_to_token_pool.req_to_token[req.req_pool_idx][ - :pre_len - ] = req.prefix_indices - - batch.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = ( - out_cache_loc[pt : pt + req.extend_input_len] - ) - - pt += req.extend_input_len - - seq_lens = [0] + batch.extend_lens - input_ids = batch.input_ids.tolist() - verified_id = batch.spec_info.verified_id.tolist() - model_input_ids = [] - for i in range(len(seq_lens) - 1): - model_input_ids.extend( - input_ids[seq_lens[i] + 1 : seq_lens[i + 1]] + [verified_id[i]] - ) - batch.input_ids = torch.tensor( - model_input_ids, dtype=torch.int32, device="cuda" - ) - - def capture_for_decode(self, sample_output: SampleOutput, prev_mode: ForwardMode): - self.sample_output = sample_output - self.prev_mode = prev_mode - - def prepare_for_decode(self, batch: ScheduleBatch): - prob = self.sample_output # b * (1/topk), vocab - top = torch.topk(prob, self.topk, dim=-1) - topk_index, topk_p = top.indices, top.values # b * (1/topk), topk - if self.prev_mode == ForwardMode.DECODE: - scores = torch.mul( - self.scores.unsqueeze(2), topk_p.reshape(-1, self.topk, self.topk) - ) # (b, topk) mul (b * topk ,topk) -> b, topk, topk - topk_cs = torch.topk( - scores.flatten(start_dim=1), self.topk, dim=-1 - ) # (b, topk) - topk_cs_index, topk_cs_p = topk_cs.indices, topk_cs.values - self.scores = topk_cs_p - - selected_input_index = topk_cs_index.flatten() // self.topk # b* topk - - batch.spec_info.hidden_states = batch.spec_info.hidden_states[ - selected_input_index, : - ] - batch.input_ids = torch.gather( - topk_index.reshape(-1, self.topk**2), index=topk_cs_index, dim=1 - ).flatten() - batch.out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel()) - self.score_list.append(scores) - self.token_list.append(topk_index) - self.parents_list.append( - topk_cs_index.flatten() + (self.topk**2 * (self.iter - 1) + self.topk) - ) - - elif self.prev_mode in (ForwardMode.EXTEND, ForwardMode.SPECEXTEND) : - self.scores = topk_p # b, top_k - self.score_list.append(topk_p.unsqueeze(1)) - self.token_list.append(topk_index) - batch.spec_info.hidden_states = ( - batch.spec_info.hidden_states.repeat_interleave(self.topk, 0) - ) - batch.input_ids = topk_index.flatten() - batch.out_cache_loc = batch.alloc_token_slots(topk_index.numel()) - self.parents_list.append( - torch.arange(-1, self.topk, dtype=torch.int, device="cuda") - ) - self.cache_list.append(batch.out_cache_loc) - self.positions = ( - batch.seq_lens[:, None] - + torch.ones([1, self.topk], device="cuda", dtype=torch.long) * self.iter - ).flatten() - - batch.req_to_token_pool.req_to_token[ - batch.req_pool_indices, - batch.seq_lens - + self.topk * self.iter : batch.seq_lens - + self.topk * (self.iter + 1), - ] = batch.out_cache_loc - self.iter += 1 - - def prepare_extend_after_decode(self, batch: ScheduleBatch): - #req_pool_indices = batch.alloc_req_slots(len(batch.reqs)) - batch.out_cache_loc = batch.alloc_token_slots(self.verified_id.numel()) - batch.extend_lens = (self.accept_length+1).tolist() - - pt=0 - positions = [] - seq_lens = batch.seq_lens.tolist() - for i, req in enumerate(batch.reqs): - #assert seq_len - pre_len == req.extend_input_len - input_len = self.accept_length[i] + 1 - seq_len = seq_lens[i] - batch.req_to_token_pool.req_to_token[req.req_pool_idx][seq_len-input_len:seq_len] = ( - batch.out_cache_loc[pt : pt + input_len] - ) - pt += input_len - - - self.positions = torch.empty_like(self.verified_id) - new_verified_id = torch.empty_like(self.accept_length, dtype=torch.long) - self.accept_length.add_(1) - - create_extend_spec_info[(self.accept_length.numel(),)](self.verified_id, batch.seq_lens, - self.accept_length, torch.cumsum(self.accept_length, axis=0, dtype=torch.int), - self.positions, new_verified_id, triton.next_power_of_2(self.spec_steps+1)) - - torch.save((self.verified_id, batch.seq_lens, self.accept_length, - torch.cumsum(self.accept_length, axis=0, dtype=torch.int)), 'test.pth') - - batch.input_ids = self.verified_id - self.verified_id = new_verified_id - - - def prepare_for_verify(self, batch: ScheduleBatch): - score_list = torch.cat(self.score_list, dim=1).view(-1) # b, 1/topk, topk - ss_token_list = torch.cat(self.token_list, dim=0).view( - -1 - ) # b * (self.topk+depth*self.topk) - top_scores = torch.topk(score_list, self.num_verify_token - 1, dim=-1) - top_scores_index = top_scores.indices - top_scores_index = torch.sort(top_scores_index).values - - draft_tokens = ss_token_list[top_scores_index] - draft_tokens = torch.cat((self.verified_id, draft_tokens), dim=0) - - parent_list = torch.cat(self.parents_list[:-1], dim=0) - - tree_mask, position, retrive_index, retrive_cum_len = build_tree_kernel( - parent_list, - top_scores_index, - batch.seq_lens, - self.topk, - self.iter - 1, - self.num_verify_token, - ) - - # out_cache = torch.cat(self.cache_list, dim=0) - # mem_need_free_idx = out_cache[top_scores_index] - # batch.token_to_kv_pool.free(mem_need_free_idx) - - return EagleVerifyInput( - draft_tokens, - tree_mask, - position, - retrive_index, - retrive_cum_len, - self.num_verify_token, - ) - - def prepare_new_draft_stage(self, batch: ScheduleBatch): - batch.input_ids = self.verified_id - - def generate_attn_arg( - self, - req_pool_indices: torch.Tensor, - paged_kernel_lens: torch.Tensor, - req_to_token_pool: ReqToTokenPool, - ): - req_pool_indices = req_pool_indices.tolist() - paged_kernel_lens = paged_kernel_lens.tolist() - bs = self.topk * len(req_pool_indices) - seq_len = self.positions.reshape(-1).contiguous() - cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda") - cum_kv_seq_len[1:] = torch.cumsum(seq_len + 1, dim=0) - kv_last_page_len = torch.ones((bs,), dtype=torch.int32, device="cuda") - kv_indices_list = [] - # TODO: reimplement it by triton if it is slow @kavioyu - for i in range(len(req_pool_indices)): - for k in range(self.topk): - index = torch.arange(self.iter) * self.topk + k - kv_indices_list.append( - req_to_token_pool.req_to_token[ - req_pool_indices[i], : paged_kernel_lens[i] - ] - ) - kv_indices_list.append( - req_to_token_pool.req_to_token[ - req_pool_indices[i], paged_kernel_lens[i] + index - ] - ) - kv_indices = torch.cat(kv_indices_list, dim=0).contiguous() - return kv_indices, cum_kv_seq_len, kv_last_page_len, None - - def clear(self): - self.iter = 0 - self.score_list.clear() - self.positions = None - - def clear_draft_cache(self, batch): - draft_cache = torch.cat(self.cache_list, dim=0) - batch.token_to_kv_pool.free(draft_cache) - - def generate_attn_arg_spec_extend( - self, - req_pool_indices: torch.Tensor, - paged_kernel_lens: torch.Tensor, - req_to_token_pool: ReqToTokenPool, - ): - bs = self.accept_length.numel() - qo_indptr = torch.zeros( - (bs + 1,), dtype=torch.int32, device="cuda" - ) - qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0) - - cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda") - cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0) - - kv_last_page_len = torch.ones((bs,), dtype=torch.int32, device="cuda") - - kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda") - - create_flashinfer_kv_indices_triton[(bs,)]( - req_to_token_pool.req_to_token, - req_pool_indices, - paged_kernel_lens, - cum_kv_seq_len, - None, - kv_indices, - req_to_token_pool.req_to_token.size(1), - ) - - return kv_indices, cum_kv_seq_len, kv_last_page_len, qo_indptr - - -class EagleVerifyInput(SpecVerifyInput): - def __init__( - self, - draft_token: torch.Tensor, - tree_mask: torch.Tensor, - positions: torch.Tensor, - retrive_index: torch.Tensor, - retrive_cum_len: torch.Tensor, - draft_token_num: int, - ): - self.draft_token = draft_token - self.custom_mask = tree_mask - self.positions = positions - self.retrive_index = retrive_index - self.retrive_cum_len = retrive_cum_len - self.draft_token_num = draft_token_num - - def prepare_for_verify(self, batch: ScheduleBatch): - batch.input_ids = self.draft_token - batch.out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel()) - batch.req_to_token_pool.req_to_token[ - batch.req_pool_indices, - batch.seq_lens : batch.seq_lens + self.draft_token_num, - ] = batch.out_cache_loc - - def generate_attn_arg( - self, - req_pool_indices: torch.Tensor, - paged_kernel_lens: torch.Tensor, - req_to_token_pool: ReqToTokenPool, - ): - batch_size = len(req_pool_indices) - qo_indptr = torch.arange( - 0, - (1 + batch_size) * self.draft_token_num, - step=self.draft_token_num, - dtype=torch.int32, - device="cuda", - ) - - cum_kv_seq_len = torch.zeros( - (batch_size + 1,), dtype=torch.int32, device="cuda" - ) - - paged_kernel_lens = paged_kernel_lens.add_(self.draft_token_num) - cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0) - - kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda") - - kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda") - - create_flashinfer_kv_indices_triton[(batch_size,)]( - req_to_token_pool.req_to_token, - req_pool_indices, - paged_kernel_lens, - cum_kv_seq_len, - None, - kv_indices, - req_to_token_pool.req_to_token.size(1), - ) - paged_kernel_lens = paged_kernel_lens.sub_(self.draft_token_num) - return kv_indices, cum_kv_seq_len, kv_last_page_len, qo_indptr - - def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Tensor: - predict = torch.argmax(logits_output.next_token_logits, dim=-1) - predict = torch.cat([predict, torch.full([1], -1, dtype=torch.long, device='cuda')], dim=-1) - draft_token = torch.cat([self.draft_token, torch.full([1], -1, dtype=torch.long, device='cuda')], dim=-1) - target_predict = predict[self.retrive_index] - candidates = draft_token[self.retrive_index] - # logits = logits_output.next_token_logits[self.retrive_index] - # target_predict = torch.argmax(logits[:, :-1], dim=-1) - accept_mask = candidates[:, 1:] == target_predict[:, :-1] - accept_mask = (torch.cumprod(accept_mask, dim=1)).sum(dim=1) - bs = self.retrive_cum_len.numel() - 1 - - max_draft_len = self.retrive_index.shape[-1] - accept_index = torch.full( - (bs, max_draft_len), -1, dtype=torch.long, device="cuda" - ) - accept_length = torch.empty((bs,), dtype=torch.int, device="cuda") - extract_index = torch.full((bs * 2,), 0, dtype=torch.int, device="cuda") - eagle_verify_retrive[(bs,)]( - self.retrive_index.contiguous(), - accept_mask.contiguous(), - self.retrive_cum_len, - accept_index, - accept_length, - extract_index, - max_draft_len, - self.draft_token_num, - triton.next_power_of_2(max_draft_len), - ) - old_accept_index = accept_index - accept_index = accept_index[accept_index != -1] - #extract_index = extract_index[extract_index != 0] - - draft_input = EAGLEDraftInput() - - accept_length_cpu = accept_length.tolist() - - draft_input.verified_id = predict[accept_index] - - draft_input.hidden_states = batch.spec_info.hidden_states[ - accept_index - ] - draft_input.accept_length = accept_length - - verified_id_cpu = draft_input.verified_id.tolist() - - low = 0 - for req, verified_len in zip(batch.reqs, accept_length_cpu): - req.output_ids.extend(verified_id_cpu[low : low + verified_len + 1]) - low += verified_len - - evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool) - evict_mask[accept_index] = False - mem_need_free_idx = batch.out_cache_loc[evict_mask] - - # TODO: support batch inference @kavioyu - batch.req_to_token_pool.req_to_token[ - batch.req_pool_indices, - batch.seq_lens : batch.seq_lens + accept_length+1, - ] = batch.out_cache_loc[accept_index] - - - batch.token_to_kv_pool.free(mem_need_free_idx) - #batch.spec_info.evict_mask = evict_mask - batch.seq_lens.add_(accept_length+1) - - #print(batch.req_to_token_pool.req_to_token[0][:60]) - - logits_output.next_token_logits = logits_output.next_token_logits[accept_index] - return draft_input, logits_output - From e2634e962471d6625e69db6d1875e93dfebc7d91 Mon Sep 17 00:00:00 2001 From: kavioyu Date: Tue, 22 Oct 2024 19:28:52 +0800 Subject: [PATCH 07/26] support batch inference --- python/sglang/srt/layers/logits_processor.py | 5 - python/sglang/srt/managers/schedule_batch.py | 3 + python/sglang/srt/managers/scheduler.py | 2 +- python/sglang/srt/managers/tp_worker.py | 7 +- .../srt/model_executor/cuda_graph_runner.py | 6 +- .../srt/model_executor/forward_batch_info.py | 2 +- python/sglang/srt/server_args.py | 16 +- ...uild_egale_tree.py => build_eagle_tree.py} | 16 +- python/sglang/srt/speculative/eagle_utils.py | 179 ++++++++++++------ python/sglang/srt/speculative/eagle_worker.py | 16 +- .../srt/speculative/speculative_utils.py | 5 +- 11 files changed, 177 insertions(+), 80 deletions(-) rename python/sglang/srt/speculative/{build_egale_tree.py => build_eagle_tree.py} (94%) diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 14b481d8c3a..8fe88bd4fe7 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -184,11 +184,6 @@ def forward( last_hidden = hidden_states else: last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1 - # print('&&&') - # print(logits_metadata.extend_seq_lens) - # print(last_index) - # print(hidden_states.shape) - # print('&&&') last_hidden = hidden_states[last_index] if spec_info: diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 51a5d406f49..37c959ee14f 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -804,6 +804,9 @@ def merge_batch(self, other: "ScheduleBatch"): self.return_logprob = self.return_logprob or other.return_logprob self.has_stream = self.has_stream or other.has_stream self.has_regex = self.has_regex or other.has_regex + + if self.spec_info is not None: + self.spec_info.merge_batch(other.spec_info) def get_model_worker_batch(self): if self.forward_mode.is_decode(): diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 85b8e162443..2a506a775b2 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -793,7 +793,7 @@ def process_batch_result_decode(self, batch: ScheduleBatch, result): req.completion_tokens_wo_jump_forward += 1 if batch.spec_algorithm is None: # speculative worker will solve the output_ids in speculative decoding req.output_ids.append(next_token_id) - req.check_finished() # TODO: SUPPORT IT @kavioyu + req.check_finished() if req.regex_fsm is not None: req.regex_fsm_state = req.regex_fsm.get_next_state( diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 661001130d1..47ceebb37dd 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -119,10 +119,13 @@ def get_token_and_memory_info(self): self.random_seed, ) - def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): + def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch, need_token_id=True): forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) logits_output = self.model_runner.forward(forward_batch) - next_token_ids = self.model_runner.sample(logits_output, model_worker_batch) + if need_token_id: + next_token_ids = self.model_runner.sample(logits_output, model_worker_batch) + else: + next_token_ids = None model_worker_batch.spec_info = forward_batch.spec_info return logits_output, next_token_ids diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 2c72e197346..267febfb088 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -118,9 +118,10 @@ def __init__(self, model_runner: "ModelRunner"): bs for bs in self.capture_bs if bs <= model_runner.req_to_token_pool.size ] - if model_runner.server_args.speculative_algorithm == 'EAGLE' and model_runner.is_draft_runner: + if model_runner.server_args.speculative_algorithm == 'EAGLE' and self.model_runner.is_draft_runner: # TODO: Support edit top_k in config @kavioyu - self.num_tokens = [bs * 8 for bs in self.capture_bs] + expand_num = self.model_runner.server_args.eagle_topk + self.num_tokens = [bs * expand_num for bs in self.capture_bs] else: self.num_tokens = [bs for bs in self.capture_bs] @@ -268,7 +269,6 @@ def replay(self, forward_batch: ForwardBatch): # Pad index = bisect.bisect_left(self.capture_bs, raw_bs) bs = self.capture_bs[index] - index = bisect.bisect_left(self.num_tokens, raw_bs) num_token = self.num_tokens[index] if bs != raw_num_token: self.seq_lens.fill_(self.seq_len_fill_value) diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index e3f887c0c96..815a45391eb 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -77,7 +77,7 @@ def is_spec_extend(self): return self == ForwardMode.SPECEXTEND def is_cuda_graph(self): - return self == ForwardMode.DECODE + return self in (ForwardMode.DECODE, ) @dataclass diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index c89d7731ddd..f84ebd3c586 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -114,8 +114,14 @@ class ServerArgs: draft_model_path: str = None speculative_algorithm: str = None num_speculative_steps: int = None + # should been set as 2^n num_draft_tokens: int = None + # should been set as [1, 2, 4, 8] + eagle_topk: int = None + # should not been set by cli, it is only a placeholder + # which would be set and used in model_runner draft_runner_cache_size: int = None + def __post_init__(self): # Set missing default values @@ -588,7 +594,15 @@ def add_cli_args(parser: argparse.ArgumentParser): type=int, help="The number of token sampled from draft model in Speculative Decoding.", required=False, - default=5, + default=64, + ) + parser.add_argument( + "--eagle-topk", + type=int, + help="The number of token sampled from draft model in eagle2 each step.", + required=False, + choices=[1, 2, 4, 8], + default=8, ) @classmethod diff --git a/python/sglang/srt/speculative/build_egale_tree.py b/python/sglang/srt/speculative/build_eagle_tree.py similarity index 94% rename from python/sglang/srt/speculative/build_egale_tree.py rename to python/sglang/srt/speculative/build_eagle_tree.py index b28f54d3275..40538c3cda3 100644 --- a/python/sglang/srt/speculative/build_egale_tree.py +++ b/python/sglang/srt/speculative/build_eagle_tree.py @@ -17,7 +17,7 @@ kernels = cutex.SourceModule( """ //cuda -__global__ void build_tree(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, +__global__ void build_tree(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, Tensor tree_mask, Tensor positions, Tensor retrive_index, int topk, int depth, int draft_token_num) { int bid = blockIdx.x; int tid = threadIdx.x; @@ -48,13 +48,14 @@ depends_order[position] = cur_position+1; position += 1; tree_mask[token_tree_idx+cur_position] = true; - int parent_tb_idx = selected_index[bid*draft_token_num+cur_position]/topk; + int parent_tb_idx = selected_index[bid][cur_position]/topk; if(parent_tb_idx==0){ break; } - int token_idx = parent_list[parent_tb_idx]; + + int token_idx = parent_list[bid][parent_tb_idx]; for(cur_position=0; cur_position torch.Ten self.draft_token_num, triton.next_power_of_2(max_draft_len), ) - old_accept_index = accept_index + accept_index = accept_index[accept_index != -1] #extract_index = extract_index[extract_index != 0] draft_input = EAGLEDraftInput() accept_length_cpu = accept_length.tolist() + verified_id = predict[accept_index] + verified_id_cpu = verified_id.tolist() - draft_input.verified_id = predict[accept_index] - - draft_input.hidden_states = batch.spec_info.hidden_states[ - accept_index - ] - draft_input.accept_length = accept_length - - verified_id_cpu = draft_input.verified_id.tolist() + new_accept_index = [] + unfinished_index = [] + low = 0 - for req, verified_len in zip(batch.reqs, accept_length_cpu): + for i, (req, verified_len) in enumerate(zip(batch.reqs, accept_length_cpu)): req.output_ids.extend(verified_id_cpu[low : low + verified_len + 1]) - low += verified_len + req.check_finished() + if req.finished(): + draft_input.has_finished = True + else: + new_accept_index.append(accept_index[low: low+verified_len+1]) + unfinished_index.append(i) + low += verified_len + 1 + + if len(new_accept_index)>0: + accept_index = torch.cat(new_accept_index, dim=0) + draft_input.verified_id = predict[accept_index] + draft_input.hidden_states = batch.spec_info.hidden_states[ + accept_index + ] + draft_input.accept_length = accept_length[unfinished_index] + draft_input.unfinished_index = unfinished_index + evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool) evict_mask[accept_index] = False mem_need_free_idx = batch.out_cache_loc[evict_mask] - # TODO: support batch inference @kavioyu - batch.req_to_token_pool.req_to_token[ - batch.req_pool_indices, - batch.seq_lens : batch.seq_lens + accept_length+1, - ] = batch.out_cache_loc[accept_index] + assign_req_to_token_pool[(bs, )](batch.req_pool_indices, + batch.req_to_token_pool.req_to_token, + batch.seq_lens, + batch.seq_lens+accept_length+1, + batch.out_cache_loc[accept_index], + batch.req_to_token_pool.req_to_token.shape[1], + triton.next_power_of_2(bs) + ) + batch.token_to_kv_pool.free(mem_need_free_idx) #batch.spec_info.evict_mask = evict_mask batch.seq_lens.add_(accept_length+1) - - #print(batch.req_to_token_pool.req_to_token[0][:60]) logits_output.next_token_logits = logits_output.next_token_logits[accept_index] - return draft_input, logits_output + return draft_input, logits_output, verified_id diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 7336f0ee383..3ab918cd9b4 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -50,11 +50,12 @@ def forward_batch_speculative_generate(self, batch: ScheduleBatch): self.forward_draft_decode(batch) batch.spec_info.clear_draft_cache(batch) self._swap_mem_pool(batch, self.target_worker.model_runner) - next_draft_input, logits_output = self.verify(batch) - verified_id = next_draft_input.verified_id + next_draft_input, logits_output, verified_id = self.verify(batch) next_draft_input.init(self.server_args) batch.spec_info = next_draft_input - self.forward_extend_after_decode(batch) + # if it is None, means all requsets are finished + if batch.spec_info.verified_id is not None: + self.forward_extend_after_decode(batch) return logits_output, verified_id else: @@ -72,7 +73,7 @@ def verify(self, batch: ScheduleBatch): verify_input.prepare_for_verify(batch) batch.spec_info = verify_input model_worker_batch = batch.get_model_worker_batch() - logits_output, next_token_ids = self.target_worker.forward_batch_generation(model_worker_batch) + logits_output, _ = self.target_worker.forward_batch_generation(model_worker_batch, need_token_id=False) res = verify_input.verify(batch, logits_output) batch.forward_mode = ForwardMode.DECODE return res @@ -84,13 +85,20 @@ def _swap_mem_pool(self, batch: ScheduleBatch, runner: ModelRunner): def forward_extend_after_decode(self, batch: ScheduleBatch): self._swap_mem_pool(batch, self.model_runner) batch.forward_mode = ForwardMode.SPECEXTEND + if batch.spec_info.has_finished: + index = batch.spec_info.unfinished_index + seq_lens = batch.seq_lens + batch.seq_lens = batch.seq_lens[index] batch.spec_info.prepare_extend_after_decode(batch) model_worker_batch = batch.get_model_worker_batch() forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) + forward_batch.is_draft_batch = True logits_output = self.model_runner.forward(forward_batch) self.capture_for_decode(logits_output, forward_batch) batch.forward_mode = ForwardMode.DECODE + if batch.spec_info.has_finished: + batch.seq_lens = seq_lens self._swap_mem_pool(batch, self.model_runner) def post_decode_process(self, batch): diff --git a/python/sglang/srt/speculative/speculative_utils.py b/python/sglang/srt/speculative/speculative_utils.py index 005282bd3c6..1aa56139ddb 100644 --- a/python/sglang/srt/speculative/speculative_utils.py +++ b/python/sglang/srt/speculative/speculative_utils.py @@ -6,7 +6,7 @@ import triton import triton.language as tl -from .build_egale_tree import build_tree_kernel +from .build_eagle_tree import build_tree_kernel from sglang.srt.model_executor.forward_batch_info import ForwardMode, ForwardBatch if TYPE_CHECKING: @@ -39,6 +39,9 @@ def generate_attn_arg( def clear(): pass + + def merge_batch(self, batch: SpecDraftInput): + raise NotImplementedError() class SpecDraftInfoFactory: From f557a06ab5cf909e5f22c4db368e9024fb9f93fe Mon Sep 17 00:00:00 2001 From: kavioyu Date: Wed, 23 Oct 2024 10:19:52 +0800 Subject: [PATCH 08/26] temp --- examples/runtime/engine/offline_batch_inference.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/examples/runtime/engine/offline_batch_inference.py b/examples/runtime/engine/offline_batch_inference.py index e644f32b057..8918c6ef8b6 100644 --- a/examples/runtime/engine/offline_batch_inference.py +++ b/examples/runtime/engine/offline_batch_inference.py @@ -1,17 +1,23 @@ import sglang as sgl - +import time def main(): # Sample prompts. prompts = [ - "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: Where is the capital city of France? ASSISTANT:" + "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: Where is the capital city of France? ASSISTANT:", + "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: 北京今天天气怎么样? ASSISTANT:" ] # Create a sampling params object. - sampling_params = {"temperature": 0, "max_new_tokens": 8} + sampling_params = {"temperature": 0, "max_new_tokens": 30} + # Create an LLM. - llm = sgl.Engine(model_path="Llama-2-7b-chat-hf", draft_model_path='EAGLE-llama2-chat-7B', disable_cuda_graph=True, num_speculative_steps=5, num_draft_tokens=64, speculative_algorithm='EAGLE', mem_fraction_static=0.60) + llm = sgl.Engine(model_path="Llama-2-7b-chat-hf", draft_model_path='EAGLE-llama2-chat-7B', disable_cuda_graph=True, num_speculative_steps=5, eagle_topk=8, num_draft_tokens=64, speculative_algorithm='EAGLE', mem_fraction_static=0.60) + #llm = sgl.Engine(model_path="Llama-2-7b-chat-hf", disable_cuda_graph=False) + #outputs = llm.generate(prompts, sampling_params) + start = time.time() outputs = llm.generate(prompts, sampling_params) + print(time.time()-start) # Print the outputs. for prompt, output in zip(prompts, outputs): print("===============================") From 9987741b17f7381c68de39fac4f3c8b5ce7e2f6a Mon Sep 17 00:00:00 2001 From: kavioyu Date: Thu, 24 Oct 2024 11:04:01 +0800 Subject: [PATCH 09/26] fix memeory leak --- python/sglang/srt/speculative/eagle_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py index af4c4e64012..e127eded9df 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -561,10 +561,10 @@ def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Ten low += verified_len + 1 if len(new_accept_index)>0: - accept_index = torch.cat(new_accept_index, dim=0) - draft_input.verified_id = predict[accept_index] + new_accept_index = torch.cat(new_accept_index, dim=0) + draft_input.verified_id = predict[new_accept_index] draft_input.hidden_states = batch.spec_info.hidden_states[ - accept_index + new_accept_index ] draft_input.accept_length = accept_length[unfinished_index] draft_input.unfinished_index = unfinished_index From dcbc11c56c1d212d3dca0d71de0fa3f01bab52fc Mon Sep 17 00:00:00 2001 From: kavioyu Date: Thu, 24 Oct 2024 14:29:04 +0800 Subject: [PATCH 10/26] add sampling score --- python/sglang/srt/speculative/eagle_utils.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py index e127eded9df..29965348802 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -174,6 +174,7 @@ def init(self, server_args: ServerArgs): self.scores: torch.Tensor = None self.score_list: List[torch.Tensor] = [] self.token_list: List[torch.Tensor] = [] + self.origin_score_list: List[torch.Tensor] = [] #used for sampling self.parents_list: List[torch.Tensor] = [] self.cache_list: List[torch.Tenor] = [] self.iter = 0 @@ -245,6 +246,7 @@ def prepare_for_decode(self, batch: ScheduleBatch): batch.out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel()) self.score_list.append(scores) # b, topk, topk self.token_list.append(topk_index) # b, topk*topk + self.origin_score_list.append(topk_p.reshape(topk_index.shape)) self.parents_list.append( topk_cs_index + (self.topk**2 * (self.iter - 1) + self.topk) ) # b, topk @@ -253,6 +255,7 @@ def prepare_for_decode(self, batch: ScheduleBatch): self.scores = topk_p # b, top_k self.score_list.append(topk_p.unsqueeze(1)) self.token_list.append(topk_index) + self.origin_score_list.append(topk_p) batch.spec_info.hidden_states = ( batch.spec_info.hidden_states.repeat_interleave(self.topk, 0) ) @@ -322,13 +325,16 @@ def prepare_extend_after_decode(self, batch: ScheduleBatch): def prepare_for_verify(self, batch: ScheduleBatch): score_list = torch.cat(self.score_list, dim=1).flatten(1) # b, n, topk; n= 1+(self.iter-1)*self.topk ss_token_list = torch.cat(self.token_list, dim=1) # b, (self.topk+(self.iter-1)*self.topk) + origin_token_list = torch.cat(self.origin_score_list, dim=1) top_scores = torch.topk(score_list, self.num_verify_token - 1, dim=-1) top_scores_index = top_scores.indices top_scores_index = torch.sort(top_scores_index).values draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1) + scores = torch.gather(origin_token_list, index=top_scores_index, dim=1) draft_tokens = torch.cat((self.verified_id.unsqueeze(1), draft_tokens), dim=1) parent_list = torch.cat(self.parents_list[:-1], dim=1) + tree_mask, position, retrive_index, retrive_cum_len = build_tree_kernel( parent_list, @@ -339,12 +345,9 @@ def prepare_for_verify(self, batch: ScheduleBatch): self.num_verify_token, ) - # out_cache = torch.cat(self.cache_list, dim=0) - # mem_need_free_idx = out_cache[top_scores_index] - # batch.token_to_kv_pool.free(mem_need_free_idx) - return EagleVerifyInput( draft_tokens.flatten(), + scores.flatten(), tree_mask, position, retrive_index, @@ -436,6 +439,7 @@ class EagleVerifyInput(SpecVerifyInput): def __init__( self, draft_token: torch.Tensor, + draft_score: torch.Tensor, tree_mask: torch.Tensor, positions: torch.Tensor, retrive_index: torch.Tensor, @@ -443,6 +447,7 @@ def __init__( draft_token_num: int, ): self.draft_token = draft_token + self.draft_score = draft_score self.custom_mask = tree_mask self.positions = positions self.retrive_index = retrive_index From 5578b1898e98f4695bcebfc831967e5c75bb9861 Mon Sep 17 00:00:00 2001 From: kavioyu Date: Thu, 24 Oct 2024 18:46:57 +0800 Subject: [PATCH 11/26] support target model cuda graph --- .../sglang/srt/layers/attention/__init__.py | 2 +- .../layers/attention/flashinfer_backend.py | 41 +++++++++++++------ .../srt/layers/attention/triton_backend.py | 2 +- .../srt/model_executor/cuda_graph_runner.py | 40 ++++++++++-------- .../srt/model_executor/forward_batch_info.py | 2 +- python/sglang/srt/speculative/eagle_utils.py | 4 +- .../srt/speculative/speculative_utils.py | 19 +++++---- 7 files changed, 68 insertions(+), 42 deletions(-) diff --git a/python/sglang/srt/layers/attention/__init__.py b/python/sglang/srt/layers/attention/__init__.py index 26ce3431bf6..2dde3355edc 100644 --- a/python/sglang/srt/layers/attention/__init__.py +++ b/python/sglang/srt/layers/attention/__init__.py @@ -18,7 +18,7 @@ def init_cuda_graph_state(self, max_bs: int): raise NotImplementedError() def init_forward_metadata_capture_cuda_graph( - self, num_token: int, req_pool_indices, seq_lens + self, num_token: int, req_pool_indices, seq_lens, spec_info, is_draft_runner ): """Init the metadata for a forward pass for capturing a cuda graph.""" raise NotImplementedError() diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index c98884eebc5..5b85205ec1c 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -20,6 +20,7 @@ ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.utils import is_flashinfer_available +from sglang.srt.speculative.speculative_utils import SpecInput if TYPE_CHECKING: from sglang.srt.model_executor.model_runner import ModelRunner @@ -152,6 +153,7 @@ def init_cuda_graph_state(self, max_bs: int): self.cuda_graph_kv_indptr = torch.zeros( (max_bs + 1,), dtype=torch.int32, device="cuda" ) + self.cuda_graph_q_indptr = self.cuda_graph_kv_indptr.clone() self.cuda_graph_kv_indices = torch.zeros( (max_bs * self.model_runner.model_config.context_len,), dtype=torch.int32, @@ -170,24 +172,39 @@ def init_cuda_graph_state(self, max_bs: int): ] def init_forward_metadata_capture_cuda_graph( - self, num_token: int, req_pool_indices, seq_lens, spec_info + self, num_token: int, req_pool_indices, seq_lens, + spec_info:SpecInput, is_draft_runner: bool=False ): decode_wrappers = [] for i in range(self.num_wrappers): - decode_wrappers.append( - BatchDecodeWithPagedKVCacheWrapper( - self.workspace_buffer, - "NHD", - use_cuda_graph=True, - use_tensor_cores=self.decode_use_tensor_cores, - paged_kv_indptr_buffer=self.cuda_graph_kv_indptr[i][: num_token + 1], - paged_kv_indices_buffer=self.cuda_graph_kv_indices[i], - paged_kv_last_page_len_buffer=self.cuda_graph_kv_last_page_len[:num_token], + if spec_info is not None and not is_draft_runner: + decode_wrappers.append( + BatchPrefillWithPagedKVCacheWrapper( + self.workspace_buffer, + "NHD", + use_cuda_graph=True, + qo_indptr_buf=self.cuda_graph_q_indptr[:num_token+1], + paged_kv_indptr_buf=self.cuda_graph_kv_indptr[i][: num_token + 1], + paged_kv_indices_buf=self.cuda_graph_kv_indices[i], + paged_kv_last_page_len_buf=self.cuda_graph_kv_last_page_len[:num_token], + ) ) - ) + else: + decode_wrappers.append( + BatchDecodeWithPagedKVCacheWrapper( + self.workspace_buffer, + "NHD", + use_cuda_graph=True, + use_tensor_cores=self.decode_use_tensor_cores, + paged_kv_indptr_buffer=self.cuda_graph_kv_indptr[i][: num_token + 1], + paged_kv_indices_buffer=self.cuda_graph_kv_indices[i], + paged_kv_last_page_len_buffer=self.cuda_graph_kv_last_page_len[:num_token], + ) + ) + mode = ForwardMode.SPECVERIFY if spec_info is not None and not is_draft_runner else ForwardMode.DECODE update_flashinfer_indices( - ForwardMode.DECODE, + mode, self.model_runner, req_pool_indices, seq_lens, diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index b9e1a6c4196..3099b6c4177 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -79,7 +79,7 @@ def init_cuda_graph_state(self, max_bs: int): ) def init_forward_metadata_capture_cuda_graph( - self, num_token: int, req_pool_indices, seq_lens + self, num_token: int, req_pool_indices, seq_lens, spec_info, is_draft_runner ): self.forward_metadata = ( self.cuda_graph_start_loc, diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 267febfb088..04e21046127 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -108,7 +108,6 @@ def __init__(self, model_runner: "ModelRunner"): self.disable_padding = model_runner.server_args.disable_cuda_graph_padding # Batch sizes to capture - # For speculative decoding, it means number of input token if self.model_runner.server_args.disable_cuda_graph_padding: self.capture_bs = list(range(1, 32)) + [64, 128] else: @@ -117,10 +116,13 @@ def __init__(self, model_runner: "ModelRunner"): self.capture_bs = [ bs for bs in self.capture_bs if bs <= model_runner.req_to_token_pool.size ] - - if model_runner.server_args.speculative_algorithm == 'EAGLE' and self.model_runner.is_draft_runner: - # TODO: Support edit top_k in config @kavioyu - expand_num = self.model_runner.server_args.eagle_topk + self.capture_forward_mode = ForwardMode.DECODE + if model_runner.server_args.speculative_algorithm == 'EAGLE': + if self.model_runner.is_draft_runner: + expand_num = self.model_runner.server_args.eagle_topk + else: + self.capture_forward_mode = ForwardMode.SPECVERIFY + expand_num = self.model_runner.server_args.num_draft_tokens self.num_tokens = [bs * expand_num for bs in self.capture_bs] else: self.num_tokens = [bs for bs in self.capture_bs] @@ -207,21 +209,25 @@ def capture_one_batch_size(self, bs: int, num_token: int, forward: Callable): positions = self.positions[:num_token] spec_info = None - if self.model_runner.server_args.speculative_algorithm == 'EAGLE' and self.model_runner.is_draft_runner: - spec_info = DraftInfoFactory.get(self.model_runner.server_args.speculative_algorithm)() - spec_info.hidden_states = self.hidden_states[:num_token] - spec_info.positions = positions - spec_info.init(self.model_runner.server_args) - + if self.model_runner.server_args.speculative_algorithm == 'EAGLE': + if self.model_runner.is_draft_runner: + spec_info = DraftInfoFactory.get(self.model_runner.server_args.speculative_algorithm, 'DraftInput')() + spec_info.hidden_states = self.hidden_states[:num_token] + spec_info.positions = positions + spec_info.init(self.model_runner.server_args) + else: + spec_info = DraftInfoFactory.get(self.model_runner.server_args.speculative_algorithm, 'VerifyInput')( + None, None, None, None, None, None, self.model_runner.server_args.num_draft_tokens) + # Attention backend self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph( - num_token, req_pool_indices, seq_lens, spec_info + num_token, req_pool_indices, seq_lens, spec_info, self.model_runner.is_draft_runner ) # Run and capture - def run_once(): + def run_once(mode): forward_batch = ForwardBatch( - forward_mode=ForwardMode.DECODE, + forward_mode=mode, batch_size=bs, input_ids=input_ids, req_pool_indices=req_pool_indices, @@ -242,7 +248,7 @@ def run_once(): torch.cuda.synchronize() self.model_runner.tp_group.barrier() - run_once() + run_once(self.capture_forward_mode) torch.cuda.synchronize() self.model_runner.tp_group.barrier() @@ -251,7 +257,7 @@ def run_once(): self.model_runner.tp_group.barrier() with torch.cuda.graph(graph, pool=self.graph_memory_pool, stream=stream): - out = run_once() + out = run_once(self.capture_forward_mode) torch.cuda.synchronize() self.model_runner.tp_group.barrier() @@ -282,7 +288,7 @@ def replay(self, forward_batch: ForwardBatch): self.positions[:num_token] = forward_batch.positions # EAGLE speculative decoding - if isinstance(forward_batch.spec_info, DraftInfoFactory.get('EAGLE')): + if isinstance(forward_batch.spec_info, DraftInfoFactory.get('EAGLE', 'DraftInput')): self.hidden_states[:num_token] = forward_batch.spec_info.hidden_states # Attention backend diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 815a45391eb..3309df2e720 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -77,7 +77,7 @@ def is_spec_extend(self): return self == ForwardMode.SPECEXTEND def is_cuda_graph(self): - return self in (ForwardMode.DECODE, ) + return self in (ForwardMode.DECODE, ForwardMode) @dataclass diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py index 29965348802..270a4d706d9 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -155,7 +155,7 @@ def assign_req_to_token_pool(req_pool_indices, req_to_token, start_offset, end_o load_offset += BLOCK_SIZE -@DraftInfoFactory.register("EAGLE") +@DraftInfoFactory.register("EAGLE", "DraftInput") class EAGLEDraftInput(SpecDraftInput): hidden_states: torch.Tensor = None verified_id: torch.Tensor = None @@ -434,7 +434,7 @@ def merge_batch(self, spec_info: EAGLEDraftInput): #self.positions = torch.cat([self.positions, spec_info.positions], axis=0) self.sample_output = torch.cat([self.sample_output, spec_info.sample_output]) - +@DraftInfoFactory.register("EAGLE", "VerifyInput") class EagleVerifyInput(SpecVerifyInput): def __init__( self, diff --git a/python/sglang/srt/speculative/speculative_utils.py b/python/sglang/srt/speculative/speculative_utils.py index 1aa56139ddb..7de6fc1e5a1 100644 --- a/python/sglang/srt/speculative/speculative_utils.py +++ b/python/sglang/srt/speculative/speculative_utils.py @@ -44,23 +44,26 @@ def merge_batch(self, batch: SpecDraftInput): raise NotImplementedError() -class SpecDraftInfoFactory: +class SpecInfoFactory: def __init__(self): self.factory = {} - def register(self, name: str) -> SpecDraftInput: - def wrapper(info: Type[SpecDraftInput]) -> Type[SpecDraftInput]: - self.factory[name] = info + def register(self, alg_name: str, type_name: str) -> SpecInput: + def wrapper(info: Type[SpecInput]) -> Type[SpecInput]: + assert type_name in ['DraftInput', 'VerifyInput'] + if alg_name not in self.factory: + self.factory[alg_name] = {} + self.factory[alg_name].update({type_name: info}) return info return wrapper - def get(self, name): - if name is None: + def get(self, alg_name, type_name: str): + if alg_name is None: return None - return self.factory[name] + return self.factory[alg_name][type_name] -DraftInfoFactory = SpecDraftInfoFactory() +DraftInfoFactory = SpecInfoFactory() From af2e79a0a949ecc35f8f4f2c055ff12740f28fa1 Mon Sep 17 00:00:00 2001 From: kavioyu Date: Thu, 24 Oct 2024 19:21:24 +0800 Subject: [PATCH 12/26] disable target model cuda graph --- python/sglang/srt/model_executor/model_runner.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 7af76581f70..abb9495ac1a 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -496,6 +496,11 @@ def init_cuda_graphs(self): if self.server_args.disable_cuda_graph: return + + # Target model don't need cuda graph due to it have big batch size in verify stage. + # Disable it to save gpu memory. + if self.server_args.speculative_algorithm is not None and not self.is_draft_runner: + return logger.info("Capture cuda graph begin. This can take up to several minutes.") self.cuda_graph_runner = CudaGraphRunner(self) From 0fdd0b1725153b31468034c33039c0ecbfc7a74a Mon Sep 17 00:00:00 2001 From: kavioyu Date: Thu, 24 Oct 2024 20:09:53 +0800 Subject: [PATCH 13/26] fix batch bug --- python/sglang/srt/managers/scheduler.py | 19 ++++++++++++++----- python/sglang/srt/server_args.py | 5 +++++ 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 2a506a775b2..4c0563604b6 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -262,11 +262,20 @@ def __init__( def event_loop(self): while True: recv_reqs = self.recv_requests() - self.process_input_requests(recv_reqs) - - self.run_step() - - self.send_results() + if self.server_args.split_prefill_batch: + if len(recv_reqs)==0: + self.process_input_requests(recv_reqs) + self.run_step() + self.send_results() + for recv_req in recv_reqs: + self.process_input_requests([recv_req]) + self.run_step() + self.send_results() + else: + for recv_req in recv_reqs: + self.process_input_requests(recv_req) + self.run_step() + self.send_results() def recv_requests(self): if self.tp_rank == 0: diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index f84ebd3c586..e46b389496c 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -57,6 +57,7 @@ class ServerArgs: max_prefill_tokens: int = 16384 schedule_policy: str = "lpm" schedule_conservativeness: float = 1.0 + split_prefill_batch: bool = False # Other runtime options tp_size: int = 1 @@ -186,6 +187,10 @@ def __post_init__(self): if "gemma-2" in self.model_path.lower(): logger.info("When using sliding window in gemma-2, turn on flashinfer.") self.attention_backend = "flashinfer" + + # Speculative Decoding + if self.speculative_algorithm=='EAGLE': + self.split_prefill_batch = True @staticmethod def add_cli_args(parser: argparse.ArgumentParser): From 923523b3d00fc753c2eae988b9e0600963574e24 Mon Sep 17 00:00:00 2001 From: kavioyu Date: Fri, 25 Oct 2024 10:07:23 +0800 Subject: [PATCH 14/26] disable cuda graph pad in eagle --- python/sglang/srt/model_executor/cuda_graph_runner.py | 8 ++++---- python/sglang/srt/models/llama_eagle.py | 1 - python/sglang/srt/server_args.py | 2 ++ python/sglang/srt/speculative/build_eagle_tree.py | 2 +- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 04e21046127..237b10ca253 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -281,15 +281,15 @@ def replay(self, forward_batch: ForwardBatch): self.out_cache_loc.zero_() # Common inputs - self.input_ids[:num_token] = forward_batch.input_ids + self.input_ids[:raw_num_token] = forward_batch.input_ids self.req_pool_indices[:raw_bs] = forward_batch.req_pool_indices self.seq_lens[:raw_bs] = forward_batch.seq_lens - self.out_cache_loc[:num_token] = forward_batch.out_cache_loc - self.positions[:num_token] = forward_batch.positions + self.out_cache_loc[:raw_num_token] = forward_batch.out_cache_loc + self.positions[:raw_num_token] = forward_batch.positions # EAGLE speculative decoding if isinstance(forward_batch.spec_info, DraftInfoFactory.get('EAGLE', 'DraftInput')): - self.hidden_states[:num_token] = forward_batch.spec_info.hidden_states + self.hidden_states[:raw_num_token] = forward_batch.spec_info.hidden_states # Attention backend self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph( diff --git a/python/sglang/srt/models/llama_eagle.py b/python/sglang/srt/models/llama_eagle.py index 5c3dd378296..dfe087badc0 100644 --- a/python/sglang/srt/models/llama_eagle.py +++ b/python/sglang/srt/models/llama_eagle.py @@ -290,7 +290,6 @@ def forward( (hidden_states, forward_batch.spec_info.hidden_states), dim=-1 ) ) - #hidden_states = forward_batch.spec_info.hidden_states residual = None for i in range(len(self.layers)): diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index e46b389496c..efc310e2d58 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -191,6 +191,8 @@ def __post_init__(self): # Speculative Decoding if self.speculative_algorithm=='EAGLE': self.split_prefill_batch = True + # EAGLE don't support it currently. + self.disable_cuda_graph_padding = True @staticmethod def add_cli_args(parser: argparse.ArgumentParser): diff --git a/python/sglang/srt/speculative/build_eagle_tree.py b/python/sglang/srt/speculative/build_eagle_tree.py index 40538c3cda3..3ff19b6a567 100644 --- a/python/sglang/srt/speculative/build_eagle_tree.py +++ b/python/sglang/srt/speculative/build_eagle_tree.py @@ -82,7 +82,7 @@ //!cuda """, float_bits=16, # change to 16 to use half precision as `float` type in the above source code. - boundscheck=True, # turning on for debug and off for performance (to use full threads of a block), default is on. + boundscheck=False, # turning on for debug and off for performance (to use full threads of a block), default is on. ) From 0e3fea28b467bdf760b16603bee851a9d5fee1de Mon Sep 17 00:00:00 2001 From: kavioyu Date: Fri, 25 Oct 2024 10:43:05 +0800 Subject: [PATCH 15/26] fix server args --- python/sglang/srt/server_args.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index efc310e2d58..bc16a4f75a8 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -611,6 +611,21 @@ def add_cli_args(parser: argparse.ArgumentParser): choices=[1, 2, 4, 8], default=8, ) + parser.add_argument( + "--split-prefill-batch", + type=bool, + help="Whether to inference prefill sample one by one.", + required=False, + default=False, + ) + parser.add_argument( + "--draft-runner-cache-size", + type=int, + help="""It should not been set by cli, it is only a placeholder which + would be set and used in model_runner when using speculative inference.""", + required=False, + default=-1, + ) @classmethod def from_cli_args(cls, args: argparse.Namespace): From 4faaa31fbc2d0437a680708a05fd6dab3fc68c96 Mon Sep 17 00:00:00 2001 From: kavioyu Date: Fri, 25 Oct 2024 11:15:07 +0800 Subject: [PATCH 16/26] fix cuda graph and split prefill --- python/sglang/srt/managers/scheduler.py | 7 +++---- python/sglang/srt/model_executor/cuda_graph_runner.py | 2 +- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 4c0563604b6..92945e7a7c9 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -272,10 +272,9 @@ def event_loop(self): self.run_step() self.send_results() else: - for recv_req in recv_reqs: - self.process_input_requests(recv_req) - self.run_step() - self.send_results() + self.process_input_requests(recv_reqs) + self.run_step() + self.send_results() def recv_requests(self): if self.tp_rank == 0: diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 237b10ca253..040c9031537 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -303,7 +303,7 @@ def replay(self, forward_batch: ForwardBatch): # Unpad if raw_num_token != num_token: logits_output = LogitsProcessorOutput( - next_token_logits=logits_output.next_token_logits[:num_token], + next_token_logits=logits_output.next_token_logits[:raw_num_token], next_token_logprobs=None, normalized_prompt_logprobs=None, input_token_logprobs=None, From 33d8aef55b21384fc36f8393b19a5c5e8c76b355 Mon Sep 17 00:00:00 2001 From: kavioyu Date: Fri, 25 Oct 2024 15:32:38 +0800 Subject: [PATCH 17/26] optimize generate attn arg --- python/sglang/srt/speculative/eagle_utils.py | 86 ++++++++++++++------ 1 file changed, 59 insertions(+), 27 deletions(-) diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py index 270a4d706d9..32da12892fb 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -153,6 +153,37 @@ def assign_req_to_token_pool(req_pool_indices, req_to_token, start_offset, end_o tl.store(token_pool+save_offset, data, mask=mask) save_offset += BLOCK_SIZE load_offset += BLOCK_SIZE + + +@triton.jit +def generate_draft_decode_kv_indices(req_pool_indices, req_to_token, paged_kernel_lens, kv_indices, + iters: tl.constexpr, topk: tl.constexpr, pool_len: tl.constexpr, + bs_upper: tl.constexpr, iter_upper: tl.constexpr): + BLOCK_SIZE: tl.constexpr = 128 + bid = tl.program_id(axis=0) + topk_id = tl.program_id(axis=1) + + load_offset = tl.arange(0, bs_upper) + seq_lens = tl.load(paged_kernel_lens+load_offset, mask=load_offset Date: Fri, 25 Oct 2024 22:16:13 +0800 Subject: [PATCH 18/26] fix parent list dtype --- python/sglang/srt/speculative/eagle_utils.py | 20 +------------------- 1 file changed, 1 insertion(+), 19 deletions(-) diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py index 32da12892fb..07ffba110a1 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -293,7 +293,7 @@ def prepare_for_decode(self, batch: ScheduleBatch): batch.input_ids = topk_index.flatten() batch.out_cache_loc = batch.alloc_token_slots(topk_index.numel()) self.parents_list.append( - torch.arange(-1, self.topk, dtype=torch.int, device="cuda").unsqueeze(0).repeat(self.scores.shape[0], 1) + torch.arange(-1, self.topk, dtype=torch.long, device="cuda").unsqueeze(0).repeat(self.scores.shape[0], 1) ) # b, topk+1 self.cache_list.append(batch.out_cache_loc) self.positions = ( @@ -357,7 +357,6 @@ def prepare_for_verify(self, batch: ScheduleBatch): scores = torch.gather(origin_token_list, index=top_scores_index, dim=1) draft_tokens = torch.cat((self.verified_id.unsqueeze(1), draft_tokens), dim=1) parent_list = torch.cat(self.parents_list[:-1], dim=1) - tree_mask, position, retrive_index, retrive_cum_len = build_tree_kernel( parent_list, @@ -395,23 +394,6 @@ def generate_attn_arg( kv_indices = torch.empty((total_len * self.topk + seq_num*self.iter*self.topk, ), dtype=torch.int32, device='cuda') - - # kv_indices_list = [] - # req_pool_indice = req_pool_indices.tolist() - # paged_kernel_len = paged_kernel_lens.tolist() - # for i in range(len(req_pool_indice)): - # for k in range(self.topk): - # index = torch.arange(self.iter) * self.topk + k - # kv_indices_list.append( - # req_to_token_pool.req_to_token[ - # req_pool_indice[i], : paged_kernel_len[i] - # ] - # ) - # kv_indices_list.append( - # req_to_token_pool.req_to_token[ - # req_pool_indice[i], paged_kernel_len[i] + index - # ] - # ) generate_draft_decode_kv_indices[(req_pool_indices.numel(), self.topk)](req_pool_indices, req_to_token_pool.req_to_token, paged_kernel_lens, kv_indices, self.iter, self.topk, From 2b3cb228f82130917f062137c9e8d0a46b162a5b Mon Sep 17 00:00:00 2001 From: kavioyu Date: Sat, 26 Oct 2024 23:47:41 +0800 Subject: [PATCH 19/26] fix draft worker memory problem --- python/sglang/srt/managers/scheduler.py | 4 +++ .../sglang/srt/model_executor/model_runner.py | 26 ++++++++++--------- python/sglang/srt/speculative/eagle_utils.py | 4 ++- python/sglang/srt/speculative/eagle_worker.py | 18 +++++++++---- .../srt/speculative/speculative_worker.py | 7 +++-- 5 files changed, 37 insertions(+), 22 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 92945e7a7c9..ce439a6e64f 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -157,6 +157,8 @@ def __init__( nccl_port=port_args.nccl_port, target_worker=self.tp_worker ) + else: + self.draft_worker = None # Get token and memory info from the model worker ( @@ -925,6 +927,8 @@ def handle_finished_requests(self, batch: ScheduleBatch): or len(req.output_ids) == 1 ) ): + if req.finished() and self.draft_worker is not None: + self.draft_worker.finish_request(req) output_rids.append(req.rid) output_finished_reason.append(req.finished_reason) if self.is_generation: diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index abb9495ac1a..59192fdcd48 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -394,10 +394,23 @@ def init_memory_pool( ) self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory) + + if max_num_reqs is None: + max_num_reqs = min( + max( + int( + self.max_total_num_tokens / self.model_config.context_len * 512 + ), + 2048, + ), + 4096, + ) + if self.is_draft_runner: self.max_total_num_tokens = self.server_args.draft_runner_cache_size else: - self.server_args.draft_runner_cache_size = self.max_total_num_tokens + self.server_args.draft_runner_cache_size = self.max_total_num_tokens + \ + max_num_reqs * self.server_args.num_speculative_steps + 100 if max_total_tokens is not None: @@ -414,17 +427,6 @@ def init_memory_pool( "Not enough memory. Please try to increase --mem-fraction-static." ) - if max_num_reqs is None: - max_num_reqs = min( - max( - int( - self.max_total_num_tokens / self.model_config.context_len * 512 - ), - 2048, - ), - 4096, - ) - self.req_to_token_pool = ReqToTokenPool( size=max_num_reqs + 1, max_context_len=self.model_config.context_len + 4, diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py index 07ffba110a1..c7b0cd91d4d 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -566,6 +566,7 @@ def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Ten new_accept_index = [] unfinished_index = [] + finished_extend_len = {} # {rid:accept_length + 1} low = 0 @@ -574,6 +575,7 @@ def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Ten req.check_finished() if req.finished(): draft_input.has_finished = True + finished_extend_len[req.rid] = verified_len+1 else: new_accept_index.append(accept_index[low: low+verified_len+1]) unfinished_index.append(i) @@ -609,5 +611,5 @@ def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Ten batch.seq_lens.add_(accept_length+1) logits_output.next_token_logits = logits_output.next_token_logits[accept_index] - return draft_input, logits_output, verified_id + return draft_input, logits_output, verified_id, finished_extend_len diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 3ab918cd9b4..7588aa63e63 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -2,7 +2,7 @@ from sglang.srt.server_args import ServerArgs from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.speculative.speculative_worker import SpeculativeWorker, spec_worker_factory -from sglang.srt.managers.schedule_batch import ModelWorkerBatch, ScheduleBatch +from sglang.srt.managers.schedule_batch import ModelWorkerBatch, ScheduleBatch, Req from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.speculative.eagle_utils import EAGLEDraftInput, EagleVerifyInput from sglang.srt.layers.logits_processor import LogitsProcessorOutput @@ -25,6 +25,7 @@ def __init__( self.model_runner.model.set_embed_and_head(embed, head) self.model_runner.server_args.disable_cuda_graph = disable_cuda_graph self.model_runner.init_cuda_graphs() + self.finish_extend_len = None def forward_draft_decode(self, batch: ScheduleBatch): batch.spec_info.prepare_for_decode(batch) @@ -50,7 +51,8 @@ def forward_batch_speculative_generate(self, batch: ScheduleBatch): self.forward_draft_decode(batch) batch.spec_info.clear_draft_cache(batch) self._swap_mem_pool(batch, self.target_worker.model_runner) - next_draft_input, logits_output, verified_id = self.verify(batch) + next_draft_input, logits_output, verified_id, self.finish_extend_len = self.verify(batch) + print('aval', self.model_runner.token_to_kv_pool.available_size()) next_draft_input.init(self.server_args) batch.spec_info = next_draft_input # if it is None, means all requsets are finished @@ -100,9 +102,6 @@ def forward_extend_after_decode(self, batch: ScheduleBatch): if batch.spec_info.has_finished: batch.seq_lens = seq_lens self._swap_mem_pool(batch, self.model_runner) - - def post_decode_process(self, batch): - return self.forward_extend_after_decode(batch) def capture_for_decode(self, hidden_states, forward_batch): # lm head is not support cuda graph currently. But it could be support theoretically. @@ -118,4 +117,13 @@ def capture_for_decode(self, hidden_states, forward_batch): forward_batch.spec_info.capture_for_decode( sample_output, forward_batch.forward_mode ) + + # Don't support prefix share now. + def finish_request(self, req): + req_len = len(req.origin_input_ids) + len(req.output_ids) - self.finish_extend_len[req.rid] - 1 + kv_indices = self.model_runner.req_to_token_pool.req_to_token[ + req.req_pool_idx + ][:req_len] + self.model_runner.token_to_kv_pool.free(kv_indices) + self.model_runner.req_to_token_pool.free(req.req_pool_idx) \ No newline at end of file diff --git a/python/sglang/srt/speculative/speculative_worker.py b/python/sglang/srt/speculative/speculative_worker.py index 228e4aa4885..2d6fc31da8b 100644 --- a/python/sglang/srt/speculative/speculative_worker.py +++ b/python/sglang/srt/speculative/speculative_worker.py @@ -1,7 +1,7 @@ from typing import Type from sglang.srt.server_args import ServerArgs from sglang.srt.managers.tp_worker import TpModelWorker -from sglang.srt.managers.schedule_batch import ScheduleBatch +from sglang.srt.managers.schedule_batch import ScheduleBatch, Req class SpeculativeWorker(TpModelWorker): @@ -21,9 +21,8 @@ def __init__( def forward_batch_speculative_generate(self, batch: ScheduleBatch): raise NotImplementedError() - def post_decode_process(self, batch: ScheduleBatch): - # do nothing by default - pass + def finish_request(self, req: Req): + raise NotImplementedError() class SpecWorkerFactory: def __init__(self): From 7aa0affad49bea149fa5b85c99ba94d16e9dcb91 Mon Sep 17 00:00:00 2001 From: kavioyu Date: Sun, 27 Oct 2024 02:11:50 +0800 Subject: [PATCH 20/26] need to fix decode error when request retract happend --- python/sglang/srt/managers/schedule_batch.py | 5 ++- python/sglang/srt/managers/scheduler.py | 5 ++- python/sglang/srt/speculative/eagle_utils.py | 37 ++++++++----------- python/sglang/srt/speculative/eagle_worker.py | 21 ++++++----- .../srt/speculative/speculative_worker.py | 4 +- 5 files changed, 37 insertions(+), 35 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 37c959ee14f..718cf066999 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -570,8 +570,9 @@ def mix_with_running(self, running_batch: "ScheduleBatch"): self.extend_lens.extend([1] * running_bs) self.extend_logprob_start_lens.extend([0] * running_bs) - def check_decode_mem(self): - bs = len(self.reqs) + def check_decode_mem(self, buf_multiplier=1): + bs = len(self.reqs)*buf_multiplier + print('aval', self.token_to_kv_pool.available_size(), buf_multiplier, bs) if self.token_to_kv_pool.available_size() >= bs: return True diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index ce439a6e64f..4a85dcb1567 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -635,11 +635,14 @@ def get_new_batch_decode(self) -> Optional[ScheduleBatch]: batch = self.running_batch # Check if decode out of memory - if not batch.check_decode_mem(): + buf_multiplier = 1 if self.server_args.speculative_algorithm is None else self.server_args.num_draft_tokens + if not batch.check_decode_mem(buf_multiplier): old_ratio = self.new_token_ratio retracted_reqs, new_token_ratio = batch.retract_decode() self.new_token_ratio = new_token_ratio + if self.draft_worker is not None: + self.draft_worker.finish_request(retracted_reqs) logger.info( "Decode out of memory happened. " diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py index c7b0cd91d4d..f4a40eb18be 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -564,10 +564,25 @@ def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Ten verified_id = predict[accept_index] verified_id_cpu = verified_id.tolist() + evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool) + evict_mask[accept_index] = False + mem_need_free_idx = batch.out_cache_loc[evict_mask] + batch.token_to_kv_pool.free(mem_need_free_idx) + batch.seq_lens.add_(accept_length+1) + + assign_req_to_token_pool[(bs, )](batch.req_pool_indices, + batch.req_to_token_pool.req_to_token, + batch.seq_lens, + batch.seq_lens+accept_length+1, + batch.out_cache_loc[accept_index], + batch.req_to_token_pool.req_to_token.shape[1], + triton.next_power_of_2(bs) + ) + new_accept_index = [] unfinished_index = [] finished_extend_len = {} # {rid:accept_length + 1} - + retracted_reqs, new_token_ratio = batch.retract_decode() low = 0 for i, (req, verified_len) in enumerate(zip(batch.reqs, accept_length_cpu)): @@ -589,26 +604,6 @@ def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Ten ] draft_input.accept_length = accept_length[unfinished_index] draft_input.unfinished_index = unfinished_index - - - evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool) - evict_mask[accept_index] = False - mem_need_free_idx = batch.out_cache_loc[evict_mask] - - assign_req_to_token_pool[(bs, )](batch.req_pool_indices, - batch.req_to_token_pool.req_to_token, - batch.seq_lens, - batch.seq_lens+accept_length+1, - batch.out_cache_loc[accept_index], - batch.req_to_token_pool.req_to_token.shape[1], - triton.next_power_of_2(bs) - ) - - - - batch.token_to_kv_pool.free(mem_need_free_idx) - #batch.spec_info.evict_mask = evict_mask - batch.seq_lens.add_(accept_length+1) logits_output.next_token_logits = logits_output.next_token_logits[accept_index] return draft_input, logits_output, verified_id, finished_extend_len diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 7588aa63e63..b754bbc20ca 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -1,4 +1,5 @@ import torch +from typing import Union, List from sglang.srt.server_args import ServerArgs from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.speculative.speculative_worker import SpeculativeWorker, spec_worker_factory @@ -52,7 +53,6 @@ def forward_batch_speculative_generate(self, batch: ScheduleBatch): batch.spec_info.clear_draft_cache(batch) self._swap_mem_pool(batch, self.target_worker.model_runner) next_draft_input, logits_output, verified_id, self.finish_extend_len = self.verify(batch) - print('aval', self.model_runner.token_to_kv_pool.available_size()) next_draft_input.init(self.server_args) batch.spec_info = next_draft_input # if it is None, means all requsets are finished @@ -101,7 +101,7 @@ def forward_extend_after_decode(self, batch: ScheduleBatch): batch.forward_mode = ForwardMode.DECODE if batch.spec_info.has_finished: batch.seq_lens = seq_lens - self._swap_mem_pool(batch, self.model_runner) + self._swap_mem_pool(batch, self.target_worker.model_runner) def capture_for_decode(self, hidden_states, forward_batch): # lm head is not support cuda graph currently. But it could be support theoretically. @@ -119,11 +119,14 @@ def capture_for_decode(self, hidden_states, forward_batch): ) # Don't support prefix share now. - def finish_request(self, req): - req_len = len(req.origin_input_ids) + len(req.output_ids) - self.finish_extend_len[req.rid] - 1 - kv_indices = self.model_runner.req_to_token_pool.req_to_token[ - req.req_pool_idx - ][:req_len] - self.model_runner.token_to_kv_pool.free(kv_indices) - self.model_runner.req_to_token_pool.free(req.req_pool_idx) + def finish_request(self, reqs: Union[Req, List[Req]]): + if not isinstance(reqs, List): + reqs = [reqs] + for req in reqs: + req_len = len(req.origin_input_ids) + len(req.output_ids) - self.finish_extend_len[req.rid] - 1 + kv_indices = self.model_runner.req_to_token_pool.req_to_token[ + req.req_pool_idx + ][:req_len] + self.model_runner.token_to_kv_pool.free(kv_indices) + self.model_runner.req_to_token_pool.free(req.req_pool_idx) \ No newline at end of file diff --git a/python/sglang/srt/speculative/speculative_worker.py b/python/sglang/srt/speculative/speculative_worker.py index 2d6fc31da8b..66f89268247 100644 --- a/python/sglang/srt/speculative/speculative_worker.py +++ b/python/sglang/srt/speculative/speculative_worker.py @@ -1,4 +1,4 @@ -from typing import Type +from typing import Type, Union, List from sglang.srt.server_args import ServerArgs from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.schedule_batch import ScheduleBatch, Req @@ -21,7 +21,7 @@ def __init__( def forward_batch_speculative_generate(self, batch: ScheduleBatch): raise NotImplementedError() - def finish_request(self, req: Req): + def finish_request(self, reqs: Union[Req, List[Req]]): raise NotImplementedError() class SpecWorkerFactory: From 404c5ab4a061df0356d29d994d25e1aa95f195bf Mon Sep 17 00:00:00 2001 From: kavioyu Date: Sun, 27 Oct 2024 02:17:51 +0800 Subject: [PATCH 21/26] remove debug info --- python/sglang/srt/managers/schedule_batch.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 718cf066999..25aeb050d17 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -572,7 +572,6 @@ def mix_with_running(self, running_batch: "ScheduleBatch"): def check_decode_mem(self, buf_multiplier=1): bs = len(self.reqs)*buf_multiplier - print('aval', self.token_to_kv_pool.available_size(), buf_multiplier, bs) if self.token_to_kv_pool.available_size() >= bs: return True From e095ec07c4b669de45898b1f7fb1fd11f9673d28 Mon Sep 17 00:00:00 2001 From: kavioyu Date: Sun, 27 Oct 2024 14:24:54 +0800 Subject: [PATCH 22/26] fix bug --- python/sglang/srt/model_executor/model_runner.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 59192fdcd48..c6beedbbf14 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -406,11 +406,12 @@ def init_memory_pool( 4096, ) - if self.is_draft_runner: - self.max_total_num_tokens = self.server_args.draft_runner_cache_size - else: - self.server_args.draft_runner_cache_size = self.max_total_num_tokens + \ - max_num_reqs * self.server_args.num_speculative_steps + 100 + if self.server_args.speculative_algorithm is not None: + if self.is_draft_runner: + self.max_total_num_tokens = self.server_args.draft_runner_cache_size + else: + self.server_args.draft_runner_cache_size = self.max_total_num_tokens + \ + max_num_reqs * self.server_args.num_speculative_steps + 100 if max_total_tokens is not None: From 9f0a0c2d657c263b66a1e5bc1da9f1649ed11977 Mon Sep 17 00:00:00 2001 From: kavioyu Date: Tue, 29 Oct 2024 14:00:13 +0800 Subject: [PATCH 23/26] fix some bug and support target model use cuda graph --- .../sglang/srt/layers/attention/__init__.py | 2 +- .../layers/attention/flashinfer_backend.py | 65 ++++++++++++------- .../srt/layers/attention/flashinfer_utils.py | 11 +++- .../srt/layers/attention/triton_backend.py | 2 +- python/sglang/srt/layers/logits_processor.py | 16 +++-- .../srt/model_executor/cuda_graph_runner.py | 9 ++- .../srt/model_executor/forward_batch_info.py | 3 +- .../sglang/srt/model_executor/model_runner.py | 7 +- python/sglang/srt/models/llama_eagle.py | 5 +- python/sglang/srt/speculative/eagle_utils.py | 19 ++---- python/sglang/srt/speculative/eagle_worker.py | 10 +-- 11 files changed, 87 insertions(+), 62 deletions(-) diff --git a/python/sglang/srt/layers/attention/__init__.py b/python/sglang/srt/layers/attention/__init__.py index 2dde3355edc..11c15e19f7c 100644 --- a/python/sglang/srt/layers/attention/__init__.py +++ b/python/sglang/srt/layers/attention/__init__.py @@ -18,7 +18,7 @@ def init_cuda_graph_state(self, max_bs: int): raise NotImplementedError() def init_forward_metadata_capture_cuda_graph( - self, num_token: int, req_pool_indices, seq_lens, spec_info, is_draft_runner + self, num_token: int, bs: int, req_pool_indices, seq_lens, spec_info, is_draft_runner ): """Init the metadata for a forward pass for capturing a cuda graph.""" raise NotImplementedError() diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 5b85205ec1c..d9cbbfdb144 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -125,7 +125,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): use_ragged = False if ( torch.sum(forward_batch.seq_lens).item() >= 4096 - and self.num_wrappers == 1 + and self.num_wrappers == 1 and not forward_batch.forward_mode.is_verify() ): use_ragged = True @@ -162,7 +162,14 @@ def init_cuda_graph_state(self, max_bs: int): self.cuda_graph_kv_last_page_len = torch.ones( (max_bs,), dtype=torch.int32, device="cuda" ) - + + self.cuda_graph_custom_mask = torch.zeros( + (max_bs * (self.model_runner.model_config.context_len+7)//8), + dtype=torch.uint8, + device="cuda", + ) + self.cuda_graph_qk_indptr = self.cuda_graph_kv_indptr.clone() + # NOTE: the buffers are always in the form of list self.cuda_graph_kv_indptr = [self.cuda_graph_kv_indptr] + [ self.cuda_graph_kv_indptr.clone() for _ in range(self.num_wrappers - 1) @@ -172,7 +179,7 @@ def init_cuda_graph_state(self, max_bs: int): ] def init_forward_metadata_capture_cuda_graph( - self, num_token: int, req_pool_indices, seq_lens, + self, num_token: int, bs: int, req_pool_indices, seq_lens, spec_info:SpecInput, is_draft_runner: bool=False ): decode_wrappers = [] @@ -183,10 +190,12 @@ def init_forward_metadata_capture_cuda_graph( self.workspace_buffer, "NHD", use_cuda_graph=True, - qo_indptr_buf=self.cuda_graph_q_indptr[:num_token+1], - paged_kv_indptr_buf=self.cuda_graph_kv_indptr[i][: num_token + 1], + qo_indptr_buf=self.cuda_graph_q_indptr[:bs+1], + paged_kv_indptr_buf=self.cuda_graph_kv_indptr[i][: bs + 1], paged_kv_indices_buf=self.cuda_graph_kv_indices[i], - paged_kv_last_page_len_buf=self.cuda_graph_kv_last_page_len[:num_token], + paged_kv_last_page_len_buf=self.cuda_graph_kv_last_page_len[:bs], + custom_mask_buf=self.cuda_graph_custom_mask, + qk_indptr_buf=self.cuda_graph_qk_indptr[:bs+1] ) ) else: @@ -210,7 +219,8 @@ def init_forward_metadata_capture_cuda_graph( seq_lens, None, decode_wrappers, - spec_info=spec_info + spec_info=spec_info, + use_cuda_graph=True, ) self.cuda_graph_metadata[num_token] = decode_wrappers @@ -218,17 +228,18 @@ def init_forward_metadata_capture_cuda_graph( self.forward_metadata = (False, False, None, decode_wrappers) def init_forward_metadata_replay_cuda_graph( - self, bs: int, num_token: int, req_pool_indices, seq_lens, spec_info + self, bs: int, num_token: int, req_pool_indices, seq_lens, spec_info, forward_mode ): # num_token == bs if not use speculative decoding with eagle2 update_flashinfer_indices( - ForwardMode.DECODE, + forward_mode, self.model_runner, req_pool_indices[:bs], seq_lens[:bs], None, self.cuda_graph_metadata[num_token], - spec_info=spec_info + spec_info=spec_info, + use_cuda_graph=True, ) def get_cuda_graph_seq_len_fill_value(self): @@ -247,20 +258,26 @@ def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch) forward_batch.token_to_kv_pool.set_kv_buffer( layer.layer_id, forward_batch.out_cache_loc, k, v ) - causal = True - if ( - forward_batch.spec_algorithm == "EAGLE" - and forward_batch.forward_mode == ForwardMode.SPECVERIFY - ): - causal = False - o = prefill_wrapper_paged.forward( - q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), - forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), - causal=causal, - sm_scale=layer.scaling, - window_left=layer.sliding_window_size, - logits_soft_cap=layer.logit_cap, - ) + + if forward_batch.forward_mode == ForwardMode.SPECVERIFY and forward_batch.is_cuda_graph: + decode_wrapper = self.forward_metadata[-1][self._get_wrapper_idx(layer)] + o = decode_wrapper.forward( + q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), + causal=False, + sm_scale=layer.scaling, + window_left=layer.sliding_window_size, + logits_soft_cap=layer.logit_cap, + ) + else: + o = prefill_wrapper_paged.forward( + q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), + causal=True, + sm_scale=layer.scaling, + window_left=layer.sliding_window_size, + logits_soft_cap=layer.logit_cap, + ) else: o1, s1 = self.prefill_wrapper_ragged.forward_return_lse( q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), diff --git a/python/sglang/srt/layers/attention/flashinfer_utils.py b/python/sglang/srt/layers/attention/flashinfer_utils.py index 67c1b3dab70..58964aaf65e 100644 --- a/python/sglang/srt/layers/attention/flashinfer_utils.py +++ b/python/sglang/srt/layers/attention/flashinfer_utils.py @@ -57,6 +57,7 @@ def __init__( decode_wrappers=None, use_ragged=False, spec_info=None, + use_cuda_graph=False, ): self.forward_mode = forward_mode self.model_runner = model_runner @@ -65,6 +66,7 @@ def __init__( self.prefix_lens = prefix_lens self.use_ragged = use_ragged self.spec_info = spec_info + self.use_cuda_graph = use_cuda_graph self.num_qo_heads = ( model_runner.model_config.num_attention_heads // model_runner.tp_size @@ -231,7 +233,10 @@ def _get_indices(self, dispatch_reason: WrapperDispatch = None, wrapper_id=0): def _update_indicess_single_wrapper(self): self._get_indices() if self.forward_mode.is_verify(): - self._update_verify_indices(self.prefill_wrappers_paged[0]) + if self.use_cuda_graph: + self._update_verify_indices(self.decode_wrappers[0]) + else: + self._update_verify_indices(self.prefill_wrappers_paged[0]) elif self.forward_mode.is_spec_extend(): self._update_spec_extend(self.prefill_wrappers_paged[0]) elif self.forward_mode.is_decode(): @@ -266,7 +271,8 @@ def update_flashinfer_indices( prefix_lens, decode_wrappers=None, use_ragged=False, - spec_info=None + spec_info=None, + use_cuda_graph=False, ): updater = FlashinferUpdater( forward_mode, @@ -277,6 +283,7 @@ def update_flashinfer_indices( decode_wrappers, use_ragged, spec_info, + use_cuda_graph, ) dispatch_reason = model_runner.attn_backend.dispatch_reason diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 3099b6c4177..995a5a0fbe7 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -79,7 +79,7 @@ def init_cuda_graph_state(self, max_bs: int): ) def init_forward_metadata_capture_cuda_graph( - self, num_token: int, req_pool_indices, seq_lens, spec_info, is_draft_runner + self, num_token: int, bs: int, req_pool_indices, seq_lens, spec_info, is_draft_runner ): self.forward_metadata = ( self.cuda_graph_start_loc, diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 8fe88bd4fe7..39296052bb6 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -34,6 +34,13 @@ class LogitsProcessorOutput: next_token_logits: torch.Tensor # The logprobs of the next tokens. shape: [#seq, vocab_size] next_token_logprobs: torch.Tensor + + # Used by speculative inference + # The output of transformer layers + hidden_states: Optional[torch.Tensor] + # backup of next_token_logits when use cuda graph + # id(next_token_logits_bak) == id(next_token_logits) + next_token_logits_bak: Optional[torch.Tensor] # The normlaized logprobs of prompts. shape: [#seq] normalized_prompt_logprobs: torch.Tensor @@ -172,9 +179,9 @@ def forward( weight, logits_metadata: Union[LogitsMetadata, ForwardBatch], ): - spec_info = None + need_hidden_states = False if isinstance(logits_metadata, ForwardBatch): - spec_info = getattr(logits_metadata, 'spec_info', None) + need_hidden_states = logits_metadata.spec_algorithm == 'EAGLE' logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata) assert isinstance(logits_metadata, LogitsMetadata) @@ -186,9 +193,6 @@ def forward( last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1 last_hidden = hidden_states[last_index] - if spec_info: - spec_info.hidden_states = last_hidden - last_logits = torch.matmul(last_hidden, weight.T) if self.do_tensor_parallel_all_gather: last_logits = tensor_model_parallel_all_gather(last_logits) @@ -208,6 +212,8 @@ def forward( input_token_logprobs=None, input_top_logprobs=None, output_top_logprobs=None, + hidden_states=last_hidden if need_hidden_states else None, + next_token_logits_bak=last_logits, ) else: last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 040c9031537..2344a3b4628 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -218,10 +218,12 @@ def capture_one_batch_size(self, bs: int, num_token: int, forward: Callable): else: spec_info = DraftInfoFactory.get(self.model_runner.server_args.speculative_algorithm, 'VerifyInput')( None, None, None, None, None, None, self.model_runner.server_args.num_draft_tokens) + spec_info.custom_mask = torch.zeros((num_token*self.model_runner.model_config.context_len), dtype=torch.bool, + device="cuda",) # Attention backend self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph( - num_token, req_pool_indices, seq_lens, spec_info, self.model_runner.is_draft_runner + num_token, bs, req_pool_indices, seq_lens, spec_info, self.model_runner.is_draft_runner ) # Run and capture @@ -241,6 +243,8 @@ def run_once(mode): positions=positions, #positions=torch.clamp((seq_lens - 1), min=0).to(torch.int64), spec_info=spec_info, + spec_algorithm=self.model_runner.server_args.speculative_algorithm, + is_cuda_graph=True, ) return forward(input_ids, forward_batch.positions, forward_batch) @@ -267,6 +271,7 @@ def run_once(mode): def replay(self, forward_batch: ForwardBatch): assert forward_batch.out_cache_loc is not None + forward_batch.is_cuda_graph = True raw_bs = forward_batch.batch_size # In most case, raw_bs == num_token in decode stage. # But for speculative, the token num maybe large than raw_bs @@ -293,7 +298,7 @@ def replay(self, forward_batch: ForwardBatch): # Attention backend self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph( - bs, num_token, self.req_pool_indices, self.seq_lens, forward_batch.spec_info + bs, num_token, self.req_pool_indices, self.seq_lens, forward_batch.spec_info, self.capture_forward_mode ) # Replay diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 3309df2e720..f346907afe4 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -77,7 +77,7 @@ def is_spec_extend(self): return self == ForwardMode.SPECEXTEND def is_cuda_graph(self): - return self in (ForwardMode.DECODE, ForwardMode) + return self in (ForwardMode.DECODE, ForwardMode.SPECVERIFY) @dataclass @@ -129,6 +129,7 @@ class ForwardBatch: spec_info: SpecInput = None spec_algorithm: str = None is_draft_batch: bool = False + is_cuda_graph: bool = False @classmethod def init_new( diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index c6beedbbf14..d24d03097dd 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -499,18 +499,13 @@ def init_cuda_graphs(self): if self.server_args.disable_cuda_graph: return - - # Target model don't need cuda graph due to it have big batch size in verify stage. - # Disable it to save gpu memory. - if self.server_args.speculative_algorithm is not None and not self.is_draft_runner: - return logger.info("Capture cuda graph begin. This can take up to several minutes.") self.cuda_graph_runner = CudaGraphRunner(self) def forward_decode(self, forward_batch: ForwardBatch): if self.cuda_graph_runner and self.cuda_graph_runner.can_run( - forward_batch.input_ids.numel() + forward_batch.batch_size ) and forward_batch.forward_mode.is_cuda_graph(): return self.cuda_graph_runner.replay(forward_batch) self.attn_backend.init_forward_metadata(forward_batch) diff --git a/python/sglang/srt/models/llama_eagle.py b/python/sglang/srt/models/llama_eagle.py index dfe087badc0..87082f8561d 100644 --- a/python/sglang/srt/models/llama_eagle.py +++ b/python/sglang/srt/models/llama_eagle.py @@ -326,7 +326,10 @@ def forward( input_embeds: torch.Tensor = None, ) -> LogitsProcessorOutput: hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) - return hidden_states + logits_output = self.logits_processor( + None, hidden_states, self.lm_head.weight, forward_batch + ) + return logits_output def get_hidden_dim(self, module_name): # return input_dim, output_dim diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py index f4a40eb18be..db828221db9 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -198,7 +198,7 @@ class EAGLEDraftInput(SpecDraftInput): def init(self, server_args: ServerArgs): self.prev_mode = ForwardMode.DECODE self.sample_output = None - self.topk: int = 8 + self.topk: int = server_args.eagle_topk self.num_verify_token: int = server_args.num_draft_tokens self.spec_steps = server_args.num_speculative_steps @@ -247,9 +247,11 @@ def prepare_for_extend(self, batch: ForwardBatch): model_input_ids, dtype=torch.int32, device="cuda" ) - def capture_for_decode(self, sample_output: SampleOutput, prev_mode: ForwardMode): + def capture_for_decode(self, sample_output: SampleOutput, hidden_states: torch.Tensor, + prev_mode: ForwardMode): self.sample_output = sample_output self.prev_mode = prev_mode + self.hidden_states = hidden_states def prepare_for_decode(self, batch: ScheduleBatch): prob = self.sample_output # b * (1/topk), vocab @@ -471,12 +473,6 @@ def __init__( def prepare_for_verify(self, batch: ScheduleBatch): batch.input_ids = self.draft_token batch.out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel()) - # batch.req_to_token_pool.req_to_token[ - # batch.req_pool_indices, - # batch.seq_lens : batch.seq_lens + self.draft_token_num, - # ] = batch.out_cache_loc - - #print(batch.req_to_token_pool.req_to_token[0][:100]) bs = batch.seq_lens.numel() assign_req_to_token_pool[(bs, )](batch.req_pool_indices, batch.req_to_token_pool.req_to_token, @@ -506,7 +502,7 @@ def generate_attn_arg( (batch_size + 1,), dtype=torch.int32, device="cuda" ) - paged_kernel_lens = paged_kernel_lens.add_(self.draft_token_num) + paged_kernel_lens = paged_kernel_lens + self.draft_token_num cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0) kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda") @@ -522,11 +518,10 @@ def generate_attn_arg( kv_indices, req_to_token_pool.req_to_token.size(1), ) - paged_kernel_lens = paged_kernel_lens.sub_(self.draft_token_num) return kv_indices, cum_kv_seq_len, kv_last_page_len, qo_indptr def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Tensor: - predict = torch.argmax(logits_output.next_token_logits, dim=-1) + predict = torch.argmax(logits_output.next_token_logits_bak, dim=-1) predict = torch.cat([predict, torch.full([1], -1, dtype=torch.long, device='cuda')], dim=-1) draft_token = torch.cat([self.draft_token, torch.full([1], -1, dtype=torch.long, device='cuda')], dim=-1) target_predict = predict[self.retrive_index] @@ -605,6 +600,6 @@ def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Ten draft_input.accept_length = accept_length[unfinished_index] draft_input.unfinished_index = unfinished_index - logits_output.next_token_logits = logits_output.next_token_logits[accept_index] + logits_output.next_token_logits = logits_output.next_token_logits_bak[accept_index] return draft_input, logits_output, verified_id, finished_extend_len diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index b754bbc20ca..98a85180717 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -76,6 +76,7 @@ def verify(self, batch: ScheduleBatch): batch.spec_info = verify_input model_worker_batch = batch.get_model_worker_batch() logits_output, _ = self.target_worker.forward_batch_generation(model_worker_batch, need_token_id=False) + verify_input.hidden_states = logits_output.hidden_states res = verify_input.verify(batch, logits_output) batch.forward_mode = ForwardMode.DECODE return res @@ -103,19 +104,14 @@ def forward_extend_after_decode(self, batch: ScheduleBatch): batch.seq_lens = seq_lens self._swap_mem_pool(batch, self.target_worker.model_runner) - def capture_for_decode(self, hidden_states, forward_batch): - # lm head is not support cuda graph currently. But it could be support theoretically. - # TODO: Support it. @kavioyu - logits_output = self.model_runner.model.logits_processor( - None, hidden_states, self.model_runner.model.lm_head.weight, forward_batch - ) + def capture_for_decode(self, logits_output, forward_batch): if isinstance(logits_output, LogitsProcessorOutput): logits = logits_output.next_token_logits sample_output = torch.softmax( logits, dim=-1 ) # TODO: Support more sampling method @kavioyu forward_batch.spec_info.capture_for_decode( - sample_output, forward_batch.forward_mode + sample_output, logits_output.hidden_states, forward_batch.forward_mode ) # Don't support prefix share now. From b647a707491b211c1731a8a264483a00126b1fdf Mon Sep 17 00:00:00 2001 From: kavioyu Date: Fri, 1 Nov 2024 20:22:37 +0800 Subject: [PATCH 24/26] fix naive cuda graph --- .../layers/attention/flashinfer_backend.py | 18 +++--- python/sglang/srt/layers/logits_processor.py | 6 +- .../srt/model_executor/cuda_graph_runner.py | 61 ++++++++++--------- .../sglang/srt/model_executor/model_runner.py | 3 +- 4 files changed, 47 insertions(+), 41 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 2ddb079ac4c..94c47cc07fb 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -235,15 +235,15 @@ def init_forward_metadata_capture_cuda_graph( ) self.forward_metadata = (decode_wrappers,) - # seq_lens_sum = seq_lens.sum().item() - # self.indices_updater_decode.update( - # req_pool_indices, - # seq_lens, - # seq_lens_sum, - # decode_wrappers=decode_wrappers, - # encoder_lens=encoder_lens, - # forward_batch=forward_batch, - # ) + seq_lens_sum = seq_lens.sum().item() + self.indices_updater_decode.update( + req_pool_indices, + seq_lens, + seq_lens_sum, + decode_wrappers=decode_wrappers, + encoder_lens=encoder_lens, + forward_batch=forward_batch, + ) self.cuda_graph_metadata[num_token] = decode_wrappers diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index fdf5012410c..d95cf4da097 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -33,14 +33,14 @@ class LogitsProcessorOutput: # The logits of the next tokens. shape: [#seq, vocab_size] next_token_logits: torch.Tensor # The logprobs of the next tokens. shape: [#seq, vocab_size] - next_token_logprobs: torch.Tensor + next_token_logprobs: torch.Tensor = None # Used by speculative inference # The output of transformer layers - hidden_states: Optional[torch.Tensor] + hidden_states: Optional[torch.Tensor] = None # backup of next_token_logits when use cuda graph # id(next_token_logits_bak) == id(next_token_logits) - next_token_logits_bak: Optional[torch.Tensor] + next_token_logits_bak: Optional[torch.Tensor] = None # The normlaized logprobs of prompts. shape: [#seq] normalized_prompt_logprobs: torch.Tensor = None diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 7ab9de9a419..14058981c24 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -272,6 +272,28 @@ def capture_one_batch_size(self, bs: int, num_token: int, forward: Callable): seq_lens_sum = seq_lens.sum().item() mrope_positions = self.mrope_positions[:, :bs] + forward_batch = ForwardBatch( + forward_mode=self.capture_forward_mode, + batch_size=bs, + input_ids=input_ids, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + req_to_token_pool=self.model_runner.req_to_token_pool, + token_to_kv_pool=self.model_runner.token_to_kv_pool, + attn_backend=self.model_runner.attn_backend, + out_cache_loc=out_cache_loc, + seq_lens_sum=seq_lens_sum, + encoder_lens=encoder_lens, + return_logprob=False, + top_logprobs_nums=[0] * num_token, + positions=positions, + #positions=clamp_position(seq_lens), + spec_info=spec_info, + spec_algorithm=self.model_runner.server_args.speculative_algorithm, + is_cuda_graph=True, + mrope_positions=mrope_positions, + ) + # Attention backend self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph( bs, @@ -281,41 +303,18 @@ def capture_one_batch_size(self, bs: int, num_token: int, forward: Callable): encoder_lens, spec_info, self.model_runner.is_draft_runner, - forward_batch=None, + forward_batch=forward_batch, ) # Run and capture - def run_once(mode): - forward_batch = ForwardBatch( - forward_mode=mode, - batch_size=bs, - input_ids=input_ids, - req_pool_indices=req_pool_indices, - seq_lens=seq_lens, - req_to_token_pool=self.model_runner.req_to_token_pool, - token_to_kv_pool=self.model_runner.token_to_kv_pool, - attn_backend=self.model_runner.attn_backend, - out_cache_loc=out_cache_loc, - seq_lens_sum=seq_lens_sum, - encoder_lens=encoder_lens, - return_logprob=False, - top_logprobs_nums=[0] * num_token, - positions=positions, - #positions=clamp_position(seq_lens), - spec_info=spec_info, - spec_algorithm=self.model_runner.server_args.speculative_algorithm, - is_cuda_graph=True, - mrope_positions=mrope_positions, - ) + def run_once(): logits_output = forward(input_ids, forward_batch.positions, forward_batch) return logits_output.next_token_logits for _ in range(2): torch.cuda.synchronize() self.model_runner.tp_group.barrier() - - run_once(self.capture_forward_mode) - + run_once() torch.cuda.synchronize() self.model_runner.tp_group.barrier() @@ -323,7 +322,7 @@ def run_once(mode): self.model_runner.tp_group.barrier() with torch.cuda.graph(graph, pool=self.graph_memory_pool, stream=stream): - out = run_once(self.capture_forward_mode) + out = run_once() torch.cuda.synchronize() self.model_runner.tp_group.barrier() @@ -352,7 +351,10 @@ def replay(self, forward_batch: ForwardBatch): self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens) self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc) - self.positions[:raw_num_token].copy_(forward_batch.positions) + positions = forward_batch.positions + if positions is None: + positions = clamp_position(forward_batch.seq_lens) + self.positions[:raw_num_token].copy_(positions) if self.is_encoder_decoder: self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens) @@ -367,6 +369,7 @@ def replay(self, forward_batch: ForwardBatch): # Attention backend self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph( bs, + num_token, self.req_pool_indices, self.seq_lens, forward_batch.seq_lens_sum + (bs - raw_bs), @@ -386,6 +389,7 @@ def replay(self, forward_batch: ForwardBatch): logits_output = LogitsProcessorOutput( next_token_logits=next_token_logits, next_token_logprobs=next_token_logprobs, + next_token_logits_bak=next_token_logits ) return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums) if return_top_logprob: @@ -399,6 +403,7 @@ def replay(self, forward_batch: ForwardBatch): else: logits_output = LogitsProcessorOutput( next_token_logits=next_token_logits, + next_token_logits_bak=next_token_logits, ) return logits_output diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 1eeedf14132..920bb2b3b4e 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -570,8 +570,9 @@ def init_cuda_graphs(self): self.cuda_graph_runner = CudaGraphRunner(self) def forward_decode(self, forward_batch: ForwardBatch): + if self.cuda_graph_runner and self.cuda_graph_runner.can_run( - forward_batch.batch_size + forward_batch ) and forward_batch.forward_mode.is_cuda_graph(): return self.cuda_graph_runner.replay(forward_batch) if hasattr(forward_batch.spec_info, 'positions'): From 722698702742328545003b2815632b976821ea1d Mon Sep 17 00:00:00 2001 From: kavioyu Date: Sat, 2 Nov 2024 15:48:32 +0800 Subject: [PATCH 25/26] fix cuda graph --- .../layers/attention/flashinfer_backend.py | 23 ++++++++++--------- .../srt/model_executor/cuda_graph_runner.py | 11 ++++++--- python/sglang/srt/models/llama_eagle.py | 3 --- 3 files changed, 20 insertions(+), 17 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 94c47cc07fb..a2c7ded54e0 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -190,7 +190,7 @@ def init_cuda_graph_state(self, max_bs: int): device="cuda", ) self.cuda_graph_qk_indptr = [x.clone() for x in self.kv_indptr] - self.cuda_graph_q_indptr = [x.clone() for x in self.kv_indptr] + self.cuda_graph_qo_indptr = [x.clone() for x in self.kv_indptr] def init_forward_metadata_capture_cuda_graph( self, @@ -212,7 +212,7 @@ def init_forward_metadata_capture_cuda_graph( self.workspace_buffer, "NHD", use_cuda_graph=True, - qo_indptr_buf=self.cuda_graph_qk_indptr[i][:bs+1], + qo_indptr_buf=self.cuda_graph_qo_indptr[i][:bs+1], paged_kv_indptr_buf=self.kv_indptr[i][: bs + 1], paged_kv_indices_buf=self.cuda_graph_kv_indices[i], paged_kv_last_page_len_buf=self.kv_last_page_len[:bs], @@ -221,6 +221,7 @@ def init_forward_metadata_capture_cuda_graph( ) ) self.forward_metadata = (False, False, decode_wrappers) + else: decode_wrappers.append( BatchDecodeWithPagedKVCacheWrapper( @@ -235,15 +236,15 @@ def init_forward_metadata_capture_cuda_graph( ) self.forward_metadata = (decode_wrappers,) - seq_lens_sum = seq_lens.sum().item() - self.indices_updater_decode.update( - req_pool_indices, - seq_lens, - seq_lens_sum, - decode_wrappers=decode_wrappers, - encoder_lens=encoder_lens, - forward_batch=forward_batch, - ) + seq_lens_sum = seq_lens.sum().item() + self.indices_updater_decode.update( + req_pool_indices, + seq_lens, + seq_lens_sum, + decode_wrappers=decode_wrappers, + encoder_lens=encoder_lens, + forward_batch=forward_batch, + ) self.cuda_graph_metadata[num_token] = decode_wrappers diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 14058981c24..3685c07a905 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -309,7 +309,7 @@ def capture_one_batch_size(self, bs: int, num_token: int, forward: Callable): # Run and capture def run_once(): logits_output = forward(input_ids, forward_batch.positions, forward_batch) - return logits_output.next_token_logits + return logits_output.next_token_logits, logits_output.hidden_states for _ in range(2): torch.cuda.synchronize() @@ -379,7 +379,10 @@ def replay(self, forward_batch: ForwardBatch): # Replay self.graphs[bs].replay() - next_token_logits = self.output_buffers[bs][:raw_bs] + next_token_logits, hidden_states = self.output_buffers[bs] + next_token_logits = next_token_logits[:raw_num_token] + if hidden_states is not None: + hidden_states = hidden_states[:raw_num_token] # Extract logprobs if forward_batch.return_logprob: @@ -389,7 +392,8 @@ def replay(self, forward_batch: ForwardBatch): logits_output = LogitsProcessorOutput( next_token_logits=next_token_logits, next_token_logprobs=next_token_logprobs, - next_token_logits_bak=next_token_logits + next_token_logits_bak=next_token_logits, + hidden_states=hidden_states, ) return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums) if return_top_logprob: @@ -404,6 +408,7 @@ def replay(self, forward_batch: ForwardBatch): logits_output = LogitsProcessorOutput( next_token_logits=next_token_logits, next_token_logits_bak=next_token_logits, + hidden_states=hidden_states, ) return logits_output diff --git a/python/sglang/srt/models/llama_eagle.py b/python/sglang/srt/models/llama_eagle.py index fd897fd59c2..87082f8561d 100644 --- a/python/sglang/srt/models/llama_eagle.py +++ b/python/sglang/srt/models/llama_eagle.py @@ -325,9 +325,6 @@ def forward( forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> LogitsProcessorOutput: - # if forward_batch.forward_mode.is_spec_extend(): - # print('input_ids', input_ids) - # print('positions', positions) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) logits_output = self.logits_processor( None, hidden_states, self.lm_head.weight, forward_batch From aaf1cae4a06072020eda32408d56ae9d716df239 Mon Sep 17 00:00:00 2001 From: kavioyu Date: Sat, 2 Nov 2024 16:26:41 +0800 Subject: [PATCH 26/26] support split prefill batch --- python/sglang/srt/managers/scheduler.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 3ca987c0f91..b1e80c7c93e 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -398,6 +398,8 @@ def recv_requests(self): except zmq.ZMQError: break recv_reqs.append(recv_req) + if self.server_args.split_prefill_batch: + break else: recv_reqs = None