Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(srt): support prefill and generate with input_embeds #2082

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 90 additions & 51 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,36 @@
processes (TokenizerManager, DetokenizerManager, Controller).
"""

import dataclasses
import sys
import uuid
from dataclasses import dataclass
from collections.abc import Sequence
from enum import Enum
from typing import Dict, List, Optional, Union

from sglang.srt.managers.schedule_batch import BaseFinishReason
from sglang.srt.sampling.sampling_params import SamplingParams

# Use sequence instead of Tensor here because Pydantic serializes Python objects
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sequence or list?

# based on type annotations.
TokenEmbedding = List[float] # 1D tensor
SingleSequenceEmbedding = List[TokenEmbedding] # 2D tensor
BatchSequenceEmbedding = List[SingleSequenceEmbedding] # 3D tensor

@dataclass

@dataclasses.dataclass
class GenerateReqInput:
if sys.version_info >= (3, 10):
_: dataclasses.KW_ONLY

# The input prompt. It can be a single prompt or a batch of prompts.
text: Optional[Union[List[str], str]] = None
# The token ids for text; one can either specify text or input_ids.
# The token ids for text; one can either specify text, input_ids, or input_embeds.
input_ids: Optional[Union[List[List[int]], List[int]]] = None
# Precalculated embeddings for the input text; one can either specify text, input_ids, or input_embeds.
input_embeds: Optional[Union[BatchSequenceEmbedding, SingleSequenceEmbedding]] = (
None
)
# The image input. It can be a file name, a url, or base64 encoded string.
# See also python/sglang/srt/utils.py:load_image.
image_data: Optional[Union[List[str], str]] = None
Expand Down Expand Up @@ -59,26 +74,27 @@ class GenerateReqInput:
session_rid: Optional[Union[List[str], str]] = None

def normalize_batch_and_arguments(self):
if (self.text is None and self.input_ids is None) or (
self.text is not None and self.input_ids is not None
):
raise ValueError("Either text or input_ids should be provided.")
if (self.text, self.input_ids, self.input_embeds).count(None) != 2:
raise ValueError(
"Only one of text, input_ids, and input_embeds should be provided."
)

# Derive the batch size
self.is_single = True
self.batch_size = 1
if self.text is not None:
if isinstance(self.text, str):
self.is_single = True
self.batch_size = 1
else:
if isinstance(self.text, list):
self.is_single = False
self.batch_size = len(self.text)
else:
if isinstance(self.input_ids[0], int):
self.is_single = True
self.batch_size = 1
else:
elif self.input_ids is not None:
if isinstance(self.input_ids[0], list):
self.is_single = False
self.batch_size = len(self.input_ids)
else:
assert self.input_embeds is not None
if isinstance(self.input_embeds[0][0], Sequence):
self.is_single = False
self.batch_size = len(self.input_embeds)

# Handle parallel sampling
# When parallel sampling is used, we always treat the input as a batch.
Expand Down Expand Up @@ -123,8 +139,6 @@ def normalize_batch_and_arguments(self):
self.image_data = [None] * num
elif not isinstance(self.image_data, list):
self.image_data = [self.image_data] * num
elif isinstance(self.image_data, list):
pass

if self.sampling_params is None:
self.sampling_params = [{}] * num
Expand Down Expand Up @@ -165,6 +179,9 @@ def __getitem__(self, i):
return GenerateReqInput(
text=self.text[i] if self.text is not None else None,
input_ids=self.input_ids[i] if self.input_ids is not None else None,
input_embeds=(
self.input_embeds[i] if self.input_embeds is not None else None
),
image_data=self.image_data[i],
sampling_params=self.sampling_params[i],
rid=self.rid[i],
Expand All @@ -178,14 +195,14 @@ def __getitem__(self, i):
)


@dataclass
@dataclasses.dataclass
class TokenizedGenerateReqInput:
# The request id
rid: str
# The input text
input_text: str
# The input token ids
input_ids: List[int]
input_ids: Optional[List[int]]
# The image inputs
image_inputs: dict
# The sampling parameters
Expand All @@ -198,47 +215,60 @@ class TokenizedGenerateReqInput:
top_logprobs_num: int
# Whether to stream output
stream: bool

# LoRA related
lora_path: Optional[str] = None # None means just use the base model

# Session id info for continual prompting
session_id: Optional[int] = None
session_rid: Optional[str] = None

if sys.version_info >= (3, 10):
_: dataclasses.KW_ONLY
Comment on lines +225 to +226
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this used for?


# The precalculated embeddings for the input text
input_embeds: Optional[SingleSequenceEmbedding] = None


@dataclass
@dataclasses.dataclass
class EmbeddingReqInput:
if sys.version_info >= (3, 10):
_: dataclasses.KW_ONLY

# The input prompt. It can be a single prompt or a batch of prompts.
text: Optional[Union[List[str], str]] = None
# The token ids for text; one can either specify text or input_ids.
# The token ids for text; one can either specify text, input_ids, or input_embeds.
input_ids: Optional[Union[List[List[int]], List[int]]] = None
# The embeddings for text; one can either specify text, input_ids, or input_embeds.
input_embeds: Optional[Union[BatchSequenceEmbedding, SingleSequenceEmbedding]] = (
None
)
# The request id.
rid: Optional[Union[List[str], str]] = None
# Dummy sampling params for compatibility
sampling_params: Union[List[Dict], Dict] = None

def normalize_batch_and_arguments(self):
if (self.text is None and self.input_ids is None) or (
self.text is not None and self.input_ids is not None
):
raise ValueError("Either text or input_ids should be provided.")
if (self.text, self.input_ids, self.input_embeds).count(None) != 2:
raise ValueError(
"Only one of text, input_ids, and input_embeds should be provided."
)

# Derive the batch size
self.is_single = True
self.batch_size = 1
if self.text is not None:
if isinstance(self.text, str):
self.is_single = True
self.batch_size = 1
else:
if isinstance(self.text, list):
self.is_single = False
self.batch_size = len(self.text)
else:
if isinstance(self.input_ids[0], int):
self.is_single = True
self.batch_size = 1
else:
elif self.input_ids is not None:
if isinstance(self.input_ids[0], list):
self.is_single = False
self.batch_size = len(self.input_ids)
else:
assert self.input_embeds is not None
if isinstance(self.input_embeds[0][0], Sequence):
self.is_single = False
self.batch_size = len(self.input_embeds)

# Fill in default arguments
if self.is_single:
Expand Down Expand Up @@ -266,24 +296,33 @@ def __getitem__(self, i):
return EmbeddingReqInput(
text=self.text[i] if self.text is not None else None,
input_ids=self.input_ids[i] if self.input_ids is not None else None,
input_embeds=(
self.input_embeds[i] if self.input_embeds is not None else None
),
sampling_params=self.sampling_params[i],
rid=self.rid[i],
)


@dataclass
@dataclasses.dataclass
class TokenizedEmbeddingReqInput:
# The request id
rid: str
# The input text
input_text: str
input_text: Optional[str]
# The input token ids
input_ids: List[int]
input_ids: Optional[List[int]]
# Dummy sampling params for compatibility
sampling_params: SamplingParams

if sys.version_info >= (3, 10):
_: dataclasses.KW_ONLY

# The precalculated embeddings for the input text
input_embeds: Optional[SingleSequenceEmbedding] = None


@dataclass
@dataclasses.dataclass
class BatchTokenIDOut:
# The request id
rids: List[str]
Expand All @@ -303,7 +342,7 @@ class BatchTokenIDOut:
session_ids: List[str]


@dataclass
@dataclasses.dataclass
class BatchStrOut:
# The request id
rids: List[str]
Expand All @@ -317,7 +356,7 @@ class BatchStrOut:
session_ids: List[str]


@dataclass
@dataclasses.dataclass
class BatchEmbeddingOut:
# The request id
rids: List[str]
Expand All @@ -329,26 +368,26 @@ class BatchEmbeddingOut:
finished_reason: List[BaseFinishReason]


@dataclass
@dataclasses.dataclass
class FlushCacheReq:
pass


@dataclass
@dataclasses.dataclass
class UpdateWeightReqInput:
# The model path with the new weights
model_path: str
# The format to load the weights
load_format: Optional[str] = None


@dataclass
@dataclasses.dataclass
class UpdateWeightReqOutput:
success: bool
message: str


@dataclass
@dataclasses.dataclass
class AbortReq:
# The request id
rid: str
Expand All @@ -359,26 +398,26 @@ class ProfileReq(Enum):
STOP_PROFILE = 2


@dataclass
@dataclasses.dataclass
class GetMemPoolSizeReq:
pass


@dataclass
@dataclasses.dataclass
class GetMemPoolSizeReqOutput:
size: int


@dataclass
@dataclasses.dataclass
class OpenSessionReqInput:
capacity_of_str_len: int


@dataclass
@dataclasses.dataclass
class CloseSessionReqInput:
session_id: str


@dataclass
@dataclasses.dataclass
class OpenSessionReqOutput:
session_id: str
Loading
Loading