Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Input_embeds support #2052

Merged
merged 15 commits into from
Nov 26, 2024
28 changes: 24 additions & 4 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class GenerateReqInput:
text: Optional[Union[List[str], str]] = None
# The token ids for text; one can either specify text or input_ids.
input_ids: Optional[Union[List[List[int]], List[int]]] = None
# The embeddings for input_ids
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
# The image input. It can be a file name, a url, or base64 encoded string.
# See also python/sglang/srt/utils.py:load_image.
image_data: Optional[Union[List[str], str]] = None
Expand All @@ -57,10 +59,16 @@ class GenerateReqInput:
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None

def normalize_batch_and_arguments(self):
if (self.text is None and self.input_ids is None) or (
self.text is not None and self.input_ids is not None
if (
self.text is None and self.input_ids is None and self.input_embeds is None
) or (
self.text is not None
and self.input_ids is not None
and self.input_embeds is not None
):
raise ValueError("Either text or input_ids should be provided.")
raise ValueError(
"Either text, input_ids or input_embeds should be provided."
)

# Derive the batch size
if self.text is not None:
Expand All @@ -70,13 +78,21 @@ def normalize_batch_and_arguments(self):
else:
self.is_single = False
self.batch_size = len(self.text)
else:
self.input_embeds = None
elif self.input_ids is not None:
if isinstance(self.input_ids[0], int):
self.is_single = True
self.batch_size = 1
else:
self.is_single = False
self.batch_size = len(self.input_ids)
self.input_embeds = None
else:
if isinstance(self.input_embeds[0][0], float):
self.is_single = True
self.batch_size = 1
else:
self.batch_size = len(self.input_embeds)

# Handle parallel sampling
# When parallel sampling is used, we always treat the input as a batch.
Expand Down Expand Up @@ -199,6 +215,8 @@ class TokenizedGenerateReqInput:

# LoRA related
lora_path: Optional[str] = None # None means just use the base model
# The input embeds
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None


@dataclass
Expand All @@ -211,6 +229,8 @@ class EmbeddingReqInput:
rid: Optional[Union[List[str], str]] = None
# Dummy sampling params for compatibility
sampling_params: Union[List[Dict], Dict] = None
# Dummy input embeds for compatibility
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None

def normalize_batch_and_arguments(self):
if (self.text is None and self.input_ids is None) or (
Expand Down
18 changes: 18 additions & 0 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def __init__(
origin_input_ids: Tuple[int],
sampling_params: SamplingParams,
lora_path: Optional[str] = None,
input_embeds: Optional = None,
):
# Input and output info
self.rid = rid
Expand All @@ -188,6 +189,7 @@ def __init__(

self.sampling_params = sampling_params
self.lora_path = lora_path
self.input_embeds = input_embeds

# Memory info
self.req_pool_idx = None
Expand Down Expand Up @@ -442,6 +444,7 @@ class ScheduleBatch:

# Batched arguments to model runner
input_ids: torch.Tensor = None
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
req_pool_indices: torch.Tensor = None
seq_lens: torch.Tensor = None
# The output locations of the KV cache
Expand Down Expand Up @@ -620,6 +623,8 @@ def prepare_for_extend(self):
req_pool_indices = self.alloc_req_slots(bs)
out_cache_loc = self.alloc_token_slots(extend_num_tokens)

input_embeds = []

pt = 0
for i, req in enumerate(reqs):
already_computed = (
Expand All @@ -643,6 +648,11 @@ def prepare_for_extend(self):
out_cache_loc[pt : pt + req.extend_input_len],
)

# If input_embeds are available, store them
if req.input_embeds is not None:
# If req.input_embeds is already a list, append its content directly
input_embeds.extend(req.input_embeds) # Use extend to avoid nesting

# Compute the relative logprob_start_len in an extend batch
if req.logprob_start_len >= pre_len:
extend_logprob_start_len = min(
Expand All @@ -666,6 +676,10 @@ def prepare_for_extend(self):
self.device, non_blocking=True
)

self.input_embeds = torch.tensor(input_embeds).to(
self.device, non_blocking=True
)

self.out_cache_loc = out_cache_loc

self.seq_lens_sum = sum(seq_lens)
Expand Down Expand Up @@ -1022,6 +1036,7 @@ def get_model_worker_batch(self):
encoder_out_cache_loc=self.encoder_out_cache_loc,
lora_paths=[req.lora_path for req in self.reqs],
sampling_info=self.sampling_info,
input_embeds=self.input_embeds,
)

def copy(self):
Expand Down Expand Up @@ -1091,6 +1106,9 @@ class ModelWorkerBatch:
# Sampling info
sampling_info: SamplingBatchInfo

# The input Embeds
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None

def copy(self):
return dataclasses.replace(self, sampling_info=self.sampling_info.copy())

Expand Down
16 changes: 15 additions & 1 deletion python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,12 +490,26 @@ def handle_generate_request(
self,
recv_req: TokenizedGenerateReqInput,
):
# Check if input_embeds is present and create dummy input_ids
if recv_req.input_embeds is not None:
input_embeds_tensor = torch.tensor(recv_req.input_embeds)
# Generate fake input_ids based on the length of input_embeds
seq_length = input_embeds_tensor.shape[
0
] # Assuming embeddings are shaped (batch_size, seq_length, hidden_size)
fake_input_ids = [
1
] * seq_length # Create dummy input_ids as ones (or use zeros if preferred)
recv_req.input_ids = fake_input_ids
# Pass input_embeds to Req if present, otherwise pass as None

req = Req(
recv_req.rid,
recv_req.input_text,
recv_req.input_ids,
recv_req.input_ids, # This is either the real or dummy input_ids
recv_req.sampling_params,
lora_path=recv_req.lora_path,
input_embeds=recv_req.input_embeds,
)
req.tokenizer = self.tokenizer

Expand Down
20 changes: 13 additions & 7 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,13 @@ async def _tokenize_one_request(
obj: Union[GenerateReqInput, EmbeddingReqInput],
):
"""Tokenize one request."""
input_embeds = None
# Tokenize
input_text = obj.text
if obj.input_ids is None:
if obj.input_embeds is not None:
input_embeds = obj.input_embeds
input_ids = obj.input_ids
elif obj.input_ids is None:
input_ids = self.tokenizer.encode(input_text)
else:
input_ids = obj.input_ids
Expand All @@ -212,11 +216,12 @@ async def _tokenize_one_request(
logprob_start_len = obj.logprob_start_len
top_logprobs_num = obj.top_logprobs_num

if len(input_ids) >= self.context_len:
raise ValueError(
f"The input ({len(input_ids)} tokens) is longer than the "
f"model's context length ({self.context_len} tokens)."
)
if obj.input_ids is not None:
if len(input_ids) >= self.context_len:
raise ValueError(
f"The input ({len(input_ids)} tokens) is longer than the "
f"model's context length ({self.context_len} tokens)."
)

# Parse sampling parameters
sampling_params = SamplingParams(**obj.sampling_params)
Expand All @@ -235,7 +240,8 @@ async def _tokenize_one_request(
logprob_start_len,
top_logprobs_num,
obj.stream,
obj.lora_path,
input_embeds=input_embeds,
lora_path=obj.lora_path,
)
elif isinstance(obj, EmbeddingReqInput):
tokenized_obj = TokenizedEmbeddingReqInput(
Expand Down
8 changes: 8 additions & 0 deletions python/sglang/srt/model_executor/forward_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ class ForwardBatch:
# For LoRA
lora_paths: Optional[List[str]] = None

# For input embeddings
input_embeds: Optional[torch.tensor] = None

# Sampling info
sampling_info: SamplingBatchInfo = None

Expand Down Expand Up @@ -221,6 +224,11 @@ def init_new(
global_num_tokens=batch.global_num_tokens,
lora_paths=batch.lora_paths,
sampling_info=batch.sampling_info,
input_embeds=(
batch.input_embeds.clone().detach().to(device)
if batch.input_embeds is not None
else None
),
)

if ret.global_num_tokens is not None:
Expand Down
17 changes: 14 additions & 3 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,9 +581,20 @@ def forward_decode(self, forward_batch: ForwardBatch):
def forward_extend(self, forward_batch: ForwardBatch):
self.attn_backend.init_forward_metadata(forward_batch)
if self.is_generation:
return self.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
)
if (
forward_batch.input_embeds is not None
and forward_batch.input_embeds.numel() != 0
):
return self.model.forward(
forward_batch.input_ids,
forward_batch.positions,
forward_batch,
input_embeds=forward_batch.input_embeds.bfloat16(),
)
else:
return self.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
)
else:
# Only embedding models have get_embedding parameter
return self.model.forward(
Expand Down
Loading