diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index d311f6e2a21..97083a31c6e 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -16,21 +16,36 @@ processes (TokenizerManager, DetokenizerManager, Controller). """ +import dataclasses +import sys import uuid -from dataclasses import dataclass +from collections.abc import Sequence from enum import Enum from typing import Dict, List, Optional, Union from sglang.srt.managers.schedule_batch import BaseFinishReason from sglang.srt.sampling.sampling_params import SamplingParams +# Use sequence instead of Tensor here because Pydantic serializes Python objects +# based on type annotations. +TokenEmbedding = List[float] # 1D tensor +SingleSequenceEmbedding = List[TokenEmbedding] # 2D tensor +BatchSequenceEmbedding = List[SingleSequenceEmbedding] # 3D tensor -@dataclass + +@dataclasses.dataclass class GenerateReqInput: + if sys.version_info >= (3, 10): + _: dataclasses.KW_ONLY + # The input prompt. It can be a single prompt or a batch of prompts. text: Optional[Union[List[str], str]] = None - # The token ids for text; one can either specify text or input_ids. + # The token ids for text; one can either specify text, input_ids, or input_embeds. input_ids: Optional[Union[List[List[int]], List[int]]] = None + # Precalculated embeddings for the input text; one can either specify text, input_ids, or input_embeds. + input_embeds: Optional[Union[BatchSequenceEmbedding, SingleSequenceEmbedding]] = ( + None + ) # The image input. It can be a file name, a url, or base64 encoded string. # See also python/sglang/srt/utils.py:load_image. image_data: Optional[Union[List[str], str]] = None @@ -59,26 +74,27 @@ class GenerateReqInput: session_rid: Optional[Union[List[str], str]] = None def normalize_batch_and_arguments(self): - if (self.text is None and self.input_ids is None) or ( - self.text is not None and self.input_ids is not None - ): - raise ValueError("Either text or input_ids should be provided.") + if (self.text, self.input_ids, self.input_embeds).count(None) != 2: + raise ValueError( + "Only one of text, input_ids, and input_embeds should be provided." + ) # Derive the batch size + self.is_single = True + self.batch_size = 1 if self.text is not None: - if isinstance(self.text, str): - self.is_single = True - self.batch_size = 1 - else: + if isinstance(self.text, list): self.is_single = False self.batch_size = len(self.text) - else: - if isinstance(self.input_ids[0], int): - self.is_single = True - self.batch_size = 1 - else: + elif self.input_ids is not None: + if isinstance(self.input_ids[0], list): self.is_single = False self.batch_size = len(self.input_ids) + else: + assert self.input_embeds is not None + if isinstance(self.input_embeds[0][0], Sequence): + self.is_single = False + self.batch_size = len(self.input_embeds) # Handle parallel sampling # When parallel sampling is used, we always treat the input as a batch. @@ -123,8 +139,6 @@ def normalize_batch_and_arguments(self): self.image_data = [None] * num elif not isinstance(self.image_data, list): self.image_data = [self.image_data] * num - elif isinstance(self.image_data, list): - pass if self.sampling_params is None: self.sampling_params = [{}] * num @@ -165,6 +179,9 @@ def __getitem__(self, i): return GenerateReqInput( text=self.text[i] if self.text is not None else None, input_ids=self.input_ids[i] if self.input_ids is not None else None, + input_embeds=( + self.input_embeds[i] if self.input_embeds is not None else None + ), image_data=self.image_data[i], sampling_params=self.sampling_params[i], rid=self.rid[i], @@ -178,14 +195,14 @@ def __getitem__(self, i): ) -@dataclass +@dataclasses.dataclass class TokenizedGenerateReqInput: # The request id rid: str # The input text input_text: str # The input token ids - input_ids: List[int] + input_ids: Optional[List[int]] # The image inputs image_inputs: dict # The sampling parameters @@ -198,7 +215,6 @@ class TokenizedGenerateReqInput: top_logprobs_num: int # Whether to stream output stream: bool - # LoRA related lora_path: Optional[str] = None # None means just use the base model @@ -206,39 +222,53 @@ class TokenizedGenerateReqInput: session_id: Optional[int] = None session_rid: Optional[str] = None + if sys.version_info >= (3, 10): + _: dataclasses.KW_ONLY + + # The precalculated embeddings for the input text + input_embeds: Optional[SingleSequenceEmbedding] = None + -@dataclass +@dataclasses.dataclass class EmbeddingReqInput: + if sys.version_info >= (3, 10): + _: dataclasses.KW_ONLY + # The input prompt. It can be a single prompt or a batch of prompts. text: Optional[Union[List[str], str]] = None - # The token ids for text; one can either specify text or input_ids. + # The token ids for text; one can either specify text, input_ids, or input_embeds. input_ids: Optional[Union[List[List[int]], List[int]]] = None + # The embeddings for text; one can either specify text, input_ids, or input_embeds. + input_embeds: Optional[Union[BatchSequenceEmbedding, SingleSequenceEmbedding]] = ( + None + ) # The request id. rid: Optional[Union[List[str], str]] = None # Dummy sampling params for compatibility sampling_params: Union[List[Dict], Dict] = None def normalize_batch_and_arguments(self): - if (self.text is None and self.input_ids is None) or ( - self.text is not None and self.input_ids is not None - ): - raise ValueError("Either text or input_ids should be provided.") + if (self.text, self.input_ids, self.input_embeds).count(None) != 2: + raise ValueError( + "Only one of text, input_ids, and input_embeds should be provided." + ) # Derive the batch size + self.is_single = True + self.batch_size = 1 if self.text is not None: - if isinstance(self.text, str): - self.is_single = True - self.batch_size = 1 - else: + if isinstance(self.text, list): self.is_single = False self.batch_size = len(self.text) - else: - if isinstance(self.input_ids[0], int): - self.is_single = True - self.batch_size = 1 - else: + elif self.input_ids is not None: + if isinstance(self.input_ids[0], list): self.is_single = False self.batch_size = len(self.input_ids) + else: + assert self.input_embeds is not None + if isinstance(self.input_embeds[0][0], Sequence): + self.is_single = False + self.batch_size = len(self.input_embeds) # Fill in default arguments if self.is_single: @@ -266,24 +296,33 @@ def __getitem__(self, i): return EmbeddingReqInput( text=self.text[i] if self.text is not None else None, input_ids=self.input_ids[i] if self.input_ids is not None else None, + input_embeds=( + self.input_embeds[i] if self.input_embeds is not None else None + ), sampling_params=self.sampling_params[i], rid=self.rid[i], ) -@dataclass +@dataclasses.dataclass class TokenizedEmbeddingReqInput: # The request id rid: str # The input text - input_text: str + input_text: Optional[str] # The input token ids - input_ids: List[int] + input_ids: Optional[List[int]] # Dummy sampling params for compatibility sampling_params: SamplingParams + if sys.version_info >= (3, 10): + _: dataclasses.KW_ONLY + + # The precalculated embeddings for the input text + input_embeds: Optional[SingleSequenceEmbedding] = None + -@dataclass +@dataclasses.dataclass class BatchTokenIDOut: # The request id rids: List[str] @@ -303,7 +342,7 @@ class BatchTokenIDOut: session_ids: List[str] -@dataclass +@dataclasses.dataclass class BatchStrOut: # The request id rids: List[str] @@ -317,7 +356,7 @@ class BatchStrOut: session_ids: List[str] -@dataclass +@dataclasses.dataclass class BatchEmbeddingOut: # The request id rids: List[str] @@ -329,12 +368,12 @@ class BatchEmbeddingOut: finished_reason: List[BaseFinishReason] -@dataclass +@dataclasses.dataclass class FlushCacheReq: pass -@dataclass +@dataclasses.dataclass class UpdateWeightReqInput: # The model path with the new weights model_path: str @@ -342,13 +381,13 @@ class UpdateWeightReqInput: load_format: Optional[str] = None -@dataclass +@dataclasses.dataclass class UpdateWeightReqOutput: success: bool message: str -@dataclass +@dataclasses.dataclass class AbortReq: # The request id rid: str @@ -359,26 +398,26 @@ class ProfileReq(Enum): STOP_PROFILE = 2 -@dataclass +@dataclasses.dataclass class GetMemPoolSizeReq: pass -@dataclass +@dataclasses.dataclass class GetMemPoolSizeReqOutput: size: int -@dataclass +@dataclasses.dataclass class OpenSessionReqInput: capacity_of_str_len: int -@dataclass +@dataclasses.dataclass class CloseSessionReqInput: session_id: str -@dataclass +@dataclasses.dataclass class OpenSessionReqOutput: session_id: str diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index ad56be197e7..ae833f69bc8 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -29,6 +29,7 @@ import dataclasses import logging +import sys from typing import List, Optional, Tuple, Union import torch @@ -177,6 +178,8 @@ def __init__( origin_input_text: str, origin_input_ids: Tuple[int], sampling_params: SamplingParams, + *, + input_embeds: Optional[List[float]] = None, lora_path: Optional[str] = None, session_id: Optional[str] = None, ): @@ -190,6 +193,7 @@ def __init__( self.session_id = session_id self.sampling_params = sampling_params + self.input_embeds = input_embeds self.lora_path = lora_path # Memory pool info @@ -388,6 +392,7 @@ def jump_forward_and_retokenize(self, jump_forward_str, next_state): return False old_output_ids = self.output_ids + self.input_embeds = None self.output_ids = all_ids[prompt_tokens:] self.decoded_text = self.decoded_text + jump_forward_str self.surr_offset = prompt_tokens @@ -431,6 +436,9 @@ def __repr__(self): class ScheduleBatch: """Store all inforamtion of a batch on the scheduler.""" + if sys.version_info >= (3, 10): + _: dataclasses.KW_ONLY + # Request, memory pool, and cache reqs: List[Req] req_to_token_pool: ReqToTokenPool = None @@ -445,11 +453,13 @@ class ScheduleBatch: # Batched arguments to model runner input_ids: torch.Tensor = None + input_embeds: Optional[torch.Tensor] = None req_pool_indices: torch.Tensor = None seq_lens: torch.Tensor = None # The output locations of the KV cache out_cache_loc: torch.Tensor = None output_ids: torch.Tensor = None + output_embeds: Optional[torch.Tensor] = None # The sum of all sequence lengths seq_lens_sum: int = None @@ -614,9 +624,10 @@ def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]) def prepare_for_extend(self, enable_overlap_schedule: bool = False): self.forward_mode = ForwardMode.EXTEND - bs = len(self.reqs) reqs = self.reqs + bs = len(reqs) input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs] + input_embeds = [r.input_embeds for r in reqs] extend_num_tokens = sum(len(ids) for ids in input_ids) seq_lens = [] pre_lens = [] @@ -659,6 +670,16 @@ def prepare_for_extend(self, enable_overlap_schedule: bool = False): self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to( self.device, non_blocking=True ) + if len(input_embeds) > 0 and input_embeds[0] is not None: + if not all(ie is not None for ie in input_embeds): + raise ValueError("input_embeds contains None") + self.input_embeds = torch.tensor(sum(input_embeds, [])).to( + self.device, non_blocking=True + ) + else: + if not all(ie is None for ie in input_embeds): + raise ValueError("input_embeds contains non-None") + self.input_embeds = None self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int32).to( self.device, non_blocking=True ) @@ -717,10 +738,15 @@ def mix_with_running(self, running_batch: "ScheduleBatch"): req.extend_input_len = 1 input_ids = torch.cat([self.input_ids, running_batch.input_ids]) + if self.input_embeds is not None: + input_embeds = torch.cat([self.input_embeds, running_batch.input_embeds]) + else: + input_embeds = None out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc]) self.merge_batch(running_batch) self.input_ids = input_ids + self.input_embeds = input_embeds self.out_cache_loc = out_cache_loc self.extend_num_tokens += running_bs @@ -831,7 +857,7 @@ def retract_decode(self): def check_for_jump_forward(self, pad_input_ids_func): jump_forward_reqs = [] - keep_indices = set(i for i in range(len(self.reqs))) + keep_indices = set(range(len(self.reqs))) for i, req in enumerate(self.reqs): if req.grammar is not None: @@ -878,7 +904,7 @@ def check_for_jump_forward(self, pad_input_ids_func): jump_forward_reqs.append(req) keep_indices.remove(i) - self.filter_batch(keep_indices=list(keep_indices)) + self.filter_batch(keep_indices=sorted(keep_indices)) return jump_forward_reqs @@ -889,6 +915,7 @@ def prepare_encoder_info_decode(self): def prepare_for_idle(self): self.forward_mode = ForwardMode.IDLE self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device) + self.input_embeds = None self.seq_lens = torch.empty(0, dtype=torch.int32, device=self.device) self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device) self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device) @@ -898,8 +925,8 @@ def prepare_for_idle(self): def prepare_for_decode(self, enable_overlap: bool = False): self.forward_mode = ForwardMode.DECODE - self.input_ids = self.output_ids - self.output_ids = None + self.input_ids, self.output_ids = self.output_ids, None + self.input_embeds, self.output_embeds = self.output_embeds, None self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(self.input_ids) # Alloc mem @@ -1023,6 +1050,7 @@ def get_model_worker_batch(self): bid=bid, forward_mode=self.forward_mode, input_ids=self.input_ids, + input_embeds=self.input_embeds, req_pool_indices=self.req_pool_indices, seq_lens=self.seq_lens, out_cache_loc=self.out_cache_loc, @@ -1065,12 +1093,17 @@ def __str__(self): @dataclasses.dataclass class ModelWorkerBatch: + if sys.version_info >= (3, 10): + _: dataclasses.KW_ONLY + # The batch id bid: int # The forward mode forward_mode: ForwardMode # The input ids input_ids: torch.Tensor + # The input embeddings + input_embeds: Optional[torch.Tensor] # The indices of requests in the req_to_token_pool req_pool_indices: torch.Tensor # The sequence length diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index de3c753ef08..aba575f3d6b 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -528,6 +528,7 @@ def handle_generate_request( recv_req.input_ids, recv_req.sampling_params, lora_path=recv_req.lora_path, + input_embeds=recv_req.input_embeds, ) req.tokenizer = self.tokenizer if recv_req.session_id is not None: @@ -621,6 +622,7 @@ def handle_embedding_request( recv_req.input_text, recv_req.input_ids, recv_req.sampling_params, + input_embeds=recv_req.input_embeds, ) req.tokenizer = self.tokenizer @@ -913,9 +915,11 @@ def run_batch(self, batch: ScheduleBatch): if self.is_generation: model_worker_batch = batch.get_model_worker_batch() if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0: - logits_output, next_token_ids = self.tp_worker.forward_batch_generation( - model_worker_batch - ) + ( + logits_output, + next_token_ids, + next_token_embeds, + ) = self.tp_worker.forward_batch_generation(model_worker_batch) elif batch.forward_mode.is_idle(): model_worker_batch = batch.get_model_worker_batch() self.tp_worker.forward_batch_idle(model_worker_batch) @@ -928,8 +932,15 @@ def run_batch(self, batch: ScheduleBatch): ) else: next_token_ids = torch.full((batch.batch_size(),), 0) + next_token_embeds = None batch.output_ids = next_token_ids - ret = logits_output, next_token_ids, model_worker_batch.bid + batch.output_embeds = next_token_embeds + ret = ( + logits_output, + next_token_ids, + next_token_embeds, + model_worker_batch.bid, + ) else: # embedding or reward model assert batch.extend_num_tokens != 0 model_worker_batch = batch.get_model_worker_batch() @@ -952,10 +963,19 @@ def process_batch_result(self, batch: ScheduleBatch, result): def process_batch_result_prefill(self, batch: ScheduleBatch, result): if self.is_generation: - logits_output, next_token_ids, bid = result + ( + logits_output, + next_token_ids, + next_token_embeds, + bid, + ) = result if self.enable_overlap: - logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid) + ( + logits_output, + next_token_ids, + next_token_embeds, + ) = self.tp_worker.resolve_batch_result(bid) else: # Move next_token_ids and logprobs to cpu if batch.return_logprob: @@ -983,6 +1003,8 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result): # Inflight reqs' prefill is not finished req.completion_tokens_wo_jump_forward += 1 req.output_ids.append(next_token_id) + if next_token_embeds is not None: + req.input_embeds.append(next_token_embeds[i]) req.check_finished() if req.finished(): @@ -1031,11 +1053,20 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result): self.stream_output(batch.reqs) def process_batch_result_decode(self, batch: ScheduleBatch, result): - logits_output, next_token_ids, bid = result + ( + logits_output, + next_token_ids, + next_token_embeds, + bid, + ) = result self.num_generated_tokens += len(batch.reqs) if self.enable_overlap: - logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid) + ( + logits_output, + next_token_ids, + next_token_embeds, + ) = self.tp_worker.resolve_batch_result(bid) next_token_logprobs = logits_output.next_token_logprobs else: # Move next_token_ids and logprobs to cpu @@ -1048,8 +1079,17 @@ def process_batch_result_decode(self, batch: ScheduleBatch, result): self.token_to_kv_pool.free_group_begin() + if next_token_embeds is None: + next_token_embeds = [None] * len(next_token_ids) + # Check finish condition - for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)): + for i, (req, next_token_id, next_token_embed) in enumerate( + zip( + batch.reqs, + next_token_ids, + next_token_embeds, + ) + ): if req.is_retracted: continue @@ -1059,6 +1099,8 @@ def process_batch_result_decode(self, batch: ScheduleBatch, result): req.completion_tokens_wo_jump_forward += 1 req.output_ids.append(next_token_id) + if next_token_embed is not None: + req.input_embeds.append(next_token_embed) req.check_finished() if req.grammar is not None: diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index cb0f8738ed8..770ca4064ed 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -201,8 +201,11 @@ async def _tokenize_one_request( ): """Tokenize one request.""" # Tokenize + input_embeds = obj.input_embeds input_text = obj.text - if obj.input_ids is None: + if obj.input_embeds is not None: + input_ids = None + elif obj.input_ids is None: input_ids = self.tokenizer.encode(input_text) else: input_ids = obj.input_ids @@ -245,6 +248,7 @@ async def _tokenize_one_request( obj.lora_path, session_id=session_id, session_rid=session_rid, + input_embeds=input_embeds, ) elif isinstance(obj, EmbeddingReqInput): tokenized_obj = TokenizedEmbeddingReqInput( @@ -252,6 +256,7 @@ async def _tokenize_one_request( input_text, input_ids, sampling_params, + input_embeds=input_embeds, ) return tokenized_obj diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index a5d694e77bc..0d3ae26e888 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -147,7 +147,13 @@ def forward_batch_generation( if launch_done: launch_done.set() next_token_ids = self.model_runner.sample(logits_output, model_worker_batch) - return logits_output, next_token_ids + if forward_batch.input_embeds is not None: + next_token_embeds = self.model_runner.get_token_embeddings( + next_token_ids, forward_batch + ) + else: + next_token_embeds = None + return logits_output, next_token_ids, next_token_embeds def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch): forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 3b53759a75f..1f6d07f6b75 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -115,7 +115,11 @@ def forward_thread_func_(self): resolve_future_token_ids(input_ids, self.future_token_ids_map) # Run forward - logits_output, next_token_ids = self.worker.forward_batch_generation( + ( + logits_output, + next_token_ids, + next_token_embeds, + ) = self.worker.forward_batch_generation( model_worker_batch, self.launch_done ) @@ -143,10 +147,22 @@ def forward_thread_func_(self): next_token_ids = next_token_ids.to("cpu", non_blocking=True) copy_done.record() - self.output_queue.put((copy_done, logits_output, next_token_ids)) + self.output_queue.put( + ( + copy_done, + logits_output, + next_token_ids, + next_token_embeds, + ) + ) def resolve_batch_result(self, bid: int): - copy_done, logits_output, next_token_ids = self.output_queue.get() + ( + copy_done, + logits_output, + next_token_ids, + next_token_embeds, + ) = self.output_queue.get() copy_done.synchronize() self.launch_done.wait() @@ -162,7 +178,9 @@ def resolve_batch_result(self, bid: int): logits_output.normalized_prompt_logprobs.tolist() ) next_token_ids = next_token_ids.tolist() - return logits_output, next_token_ids + if next_token_embeds is not None: + next_token_embeds = next_token_embeds.tolist() + return logits_output, next_token_ids, next_token_embeds def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): # Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch. @@ -193,7 +211,7 @@ def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): self.future_token_ids_ct = ( self.future_token_ids_ct + bs ) % self.future_token_ids_limit - return None, future_next_token_ids + return None, future_next_token_ids, None def update_weights(self, recv_req: UpdateWeightReqInput): success, message = self.worker.update_weights(recv_req) diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index e1e27752d09..7f3b4db3162 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -29,7 +29,8 @@ from __future__ import annotations -from dataclasses import dataclass +import dataclasses +import sys from enum import IntEnum, auto from typing import TYPE_CHECKING, List, Optional @@ -82,10 +83,13 @@ def is_dummy_first(self): return self == ForwardMode.DUMMY_FIRST -@dataclass +@dataclasses.dataclass class ForwardBatch: """Store all inputs of a forward pass.""" + if sys.version_info >= (3, 10): + _: dataclasses.KW_ONLY + # The forward mode forward_mode: ForwardMode # The batch size @@ -102,6 +106,9 @@ class ForwardBatch: # The sum of all sequence lengths seq_lens_sum: int + # The input embeddings + input_embeds: Optional[torch.tensor] = None + # For logprob return_logprob: bool = False top_logprobs_nums: Optional[List[int]] = None @@ -216,6 +223,11 @@ def init_new( forward_mode=batch.forward_mode, batch_size=len(batch.seq_lens), input_ids=batch.input_ids, + input_embeds=( + batch.input_embeds.clone().detach().to(device) + if batch.input_embeds is not None + else None + ), req_pool_indices=batch.req_pool_indices, seq_lens=batch.seq_lens, out_cache_loc=batch.out_cache_loc, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 3ba311b8c68..eba9ecbff34 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -591,21 +591,47 @@ def apply_torch_tp(self): device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,)) tensor_parallel(self.model, device_mesh) + def get_token_embeddings( + self, input_ids: torch.Tensor, forward_batch: ForwardBatch + ) -> torch.Tensor: + """Get embeddings for the input_ids.""" + return self.model.embed_tokens(input_ids) + def forward_decode(self, forward_batch: ForwardBatch): if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch): return self.cuda_graph_runner.replay(forward_batch) forward_batch.positions = (forward_batch.seq_lens - 1).to(torch.int64) self.attn_backend.init_forward_metadata(forward_batch) + + if forward_batch.input_embeds is not None: + dtype = next(self.model.parameters()).dtype + input_embeds = forward_batch.input_embeds.to(dtype=dtype) + else: + input_embeds = None + return self.model.forward( - forward_batch.input_ids, forward_batch.positions, forward_batch + forward_batch.input_ids, + forward_batch.positions, + forward_batch, + input_embeds=input_embeds, ) def forward_extend(self, forward_batch: ForwardBatch): self.attn_backend.init_forward_metadata(forward_batch) + + if forward_batch.input_embeds is not None: + dtype = next(self.model.parameters()).dtype + input_embeds = forward_batch.input_embeds.to(dtype=dtype) + else: + input_embeds = None + if self.is_generation: return self.model.forward( - forward_batch.input_ids, forward_batch.positions, forward_batch + forward_batch.input_ids, + forward_batch.positions, + forward_batch, + input_embeds=input_embeds, ) else: # Only embedding models have get_embedding parameter @@ -620,8 +646,17 @@ def forward_idle(self, forward_batch: ForwardBatch): if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch): return self.cuda_graph_runner.replay(forward_batch) + if forward_batch.input_embeds is not None: + dtype = next(self.model.parameters()).dtype + input_embeds = forward_batch.input_embeds.to(dtype=dtype) + else: + input_embeds = None + return self.model.forward( - forward_batch.input_ids, forward_batch.positions, forward_batch + forward_batch.input_ids, + forward_batch.positions, + forward_batch, + input_embeds=input_embeds, ) def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput: diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index 780bf36b5d9..a676e038d70 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -128,6 +128,7 @@ def forward( input_ids: torch.LongTensor, positions: torch.Tensor, forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, ) -> torch.Tensor: image_inputs = forward_batch.image_inputs @@ -146,7 +147,8 @@ def forward( max_image_offset.append(-1) # Embed text inputs - input_embeds = self.language_model.model.embed_tokens(input_ids) + if input_embeds is None: + input_embeds = self.language_model.model.embed_tokens(input_ids) start_positions = positions[forward_batch.extend_start_loc].cpu().numpy() need_vision = start_positions <= np.array(max_image_offset) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index fcba31a56bf..9a344d769f9 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -14,6 +14,7 @@ """Common utilities.""" import base64 +import importlib.util import ipaddress import json import logging @@ -72,6 +73,8 @@ def is_flashinfer_available(): """ if os.environ.get("SGLANG_IS_FLASHINFER_AVAILABLE", "true") == "false": return False + if importlib.util.find_spec("flashinfer") is None: + return False return torch.cuda.is_available() and not is_hip()