Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions fastdeploy/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,10 +671,7 @@ def to_dict_for_infer(self, request_id=None):
if request_id is not None:
req_dict["request_id"] = request_id

if "prompt_token_ids" in req_dict:
if "messages" in req_dict:
del req_dict["messages"]
else:
if "prompt_token_ids" not in req_dict:
# If disable_chat_template is set, then the first message in messages will be used as the prompt.
assert (
len(req_dict["messages"]) > 0
Expand Down
12 changes: 10 additions & 2 deletions fastdeploy/input/ernie4_5_vl_processor/ernie4_5_vl_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,13 @@ def process_request_dict(self, request, max_model_len=None):
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
request["bad_words_token_ids"] = bad_words_token_ids

if request.get("prompt"):
if request.get("prompt_token_ids"):
messages = request.get("messages")
if messages:
self._check_mm_limits(messages)
request.setdefault("enable_thinking", True)
outputs = self.ernie4_5_processor.prompt_token_ids2outputs(request)
elif request.get("prompt"):
multimodal_data = request.get("multimodal_data")
if multimodal_data is None:
multimodal_data = {}
Expand Down Expand Up @@ -256,7 +262,9 @@ def process_request_dict(self, request, max_model_len=None):
self.append_completion_tokens(outputs, request["completion_token_ids"])

outputs = self.pack_outputs(outputs)
request["prompt_token_ids"] = outputs["input_ids"].tolist()
request["prompt_token_ids"] = (
outputs["input_ids"].tolist() if "prompt_token_ids" not in request else request["prompt_token_ids"]
)
request["prompt_token_ids_len"] = len(request["prompt_token_ids"])
request["multimodal_inputs"] = outputs

Expand Down
151 changes: 139 additions & 12 deletions fastdeploy/input/ernie4_5_vl_processor/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,9 @@ def __init__(
self.video_end = self.VID_END
self.image_patch_id = self.tokenizer.convert_tokens_to_ids("<|IMAGE_PLACEHOLDER|>")
self.image_start_id = self.tokenizer.convert_tokens_to_ids(self.image_start)
self.image_end_id = self.tokenizer.convert_tokens_to_ids(self.image_end)
self.video_start_id = self.tokenizer.convert_tokens_to_ids(self.video_start)
self.video_end_id = self.tokenizer.convert_tokens_to_ids(self.video_end)
self.sep_token_id = self.tokenizer.convert_tokens_to_ids(self.sep_token)
self.eos_token_id = self.tokenizer.convert_tokens_to_ids(self.eos_token)

Expand Down Expand Up @@ -243,14 +245,7 @@ def text2ids(self, text, images=None, videos=None, image_uuid=None, video_uuid=N

return outputs

def request2ids(
self, request: Dict[str, Any], tgts: List[str] = None
) -> Dict[str, Union[np.ndarray, List[np.ndarray], None]]:
"""
Convert chat messages into model inputs.
Returns a dict with input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels.
"""

def extract_mm_items(self, request: Dict[str, Any]):
messages = parse_chat_messages(request.get("messages"))
mm_items = []
for msg in messages:
Expand All @@ -273,6 +268,7 @@ def request2ids(
if len(missing_hashes) > 0 and not self.enable_processor_cache:
raise ValueError("Missing items cannot be retrieved without processor cache.")

dealer = None
if self.enable_processor_cache:
context = zmq.Context()
dealer = context.socket(zmq.DEALER)
Expand All @@ -295,6 +291,16 @@ def request2ids(
video_uuid.append(item["uuid"])
else:
raise ValueError(f"Unsupported multimodal type: {item.get('type')}")
return images, videos, image_uuid, video_uuid, dealer, missing_idx, mm_items

def request2ids(
self, request: Dict[str, Any], tgts: List[str] = None
) -> Dict[str, Union[np.ndarray, List[np.ndarray], None]]:
"""
Convert chat messages into model inputs.
Returns a dict with input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels.
"""
images, videos, image_uuid, video_uuid, dealer, missing_idx, mm_items = self.extract_mm_items(request)

if self.tokenizer.chat_template is None:
raise ValueError("This model does not support chat template.")
Expand Down Expand Up @@ -329,6 +335,115 @@ def request2ids(

return outputs

def prompt_token_ids2outputs(
self, request: Dict[str, Any], tgts: List[str] = None
) -> Dict[str, Union[np.ndarray, List[np.ndarray], None]]:
outputs = {
"input_ids": [],
"token_type_ids": [],
"position_ids": [],
"images": [],
"grid_thw": [],
"image_type_ids": [],
"labels": [],
"cur_position": 0,
"video_cnt": 0,
"num_input_image_tokens": 0,
"num_input_video_tokens": 0,
"mm_positions": [],
"mm_hashes": [],
}
prompt_token_ids = request.get("prompt_token_ids", [])
prompt_token_ids_len = len(prompt_token_ids)
if not request.get("messages"):
outputs["input_ids"].append(prompt_token_ids)
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]] * prompt_token_ids_len)
for i in range(prompt_token_ids_len):
outputs["position_ids"].append([i] * 3)
outputs["cur_position"] += prompt_token_ids_len
return outputs
images, videos, image_uuid, video_uuid, dealer, missing_idx, mm_items = self.extract_mm_items(request)
st, image_idx, video_idx = 0, 0, 0
while st < prompt_token_ids_len:
cur_token_id = prompt_token_ids[st]
if cur_token_id == self.image_start_id:
if image_idx >= len(images):
raise ValueError("prompt token ids has more image placeholder than in messages")
# append image_start_id
outputs["input_ids"].extend([cur_token_id])
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]])
outputs["position_ids"].append([outputs["cur_position"]] * 3)
outputs["cur_position"] += 1
st += 1
# process placeholder token ids
cur_idx = st
while cur_idx < prompt_token_ids_len and prompt_token_ids[cur_idx] != self.image_end_id:
cur_idx += 1
if cur_idx >= prompt_token_ids_len:
raise ValueError("image token ids not complete")
image = images[image_idx]
uuid = image_uuid[image_idx] if image_uuid else None
token_len = cur_idx - st
if not isinstance(image, tuple):
self._add_image(image, outputs, uuid, token_len)
else:
self._add_processed_image(image, outputs, uuid, token_len)
image_idx += 1
st = cur_idx
elif cur_token_id == self.video_start_id:
if video_idx >= len(videos):
raise ValueError("prompt token ids has more video placeholder than in messages")
# append video_start_id
outputs["input_ids"].extend([cur_token_id])
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]])
outputs["position_ids"].append([outputs["cur_position"]] * 3)
outputs["cur_position"] += 1
st += 1
# process placeholder token ids
cur_idx = st
while cur_idx < prompt_token_ids_len and prompt_token_ids[cur_idx] != self.video_end_id:
cur_idx += 1
if cur_idx >= prompt_token_ids_len:
raise ValueError("video token ids not complete")
video = videos[video_idx]
uuid = video_uuid[video_idx] if video_uuid else None
token_len = cur_idx - st
if not isinstance(video, tuple):
if isinstance(video, dict):
frames = self._load_and_process_video(video["video"], video)
else:
frames = self._load_and_process_video(video, {})
self._add_video(frames, outputs, uuid, token_len)
else:
self._add_processed_video(video, outputs, uuid, token_len)
video_idx += 1
st = cur_idx
else:
outputs["input_ids"].extend([cur_token_id])
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]])
outputs["position_ids"].append([outputs["cur_position"]] * 3)
outputs["cur_position"] += 1
st += 1
if image_idx != len(images):
raise ValueError("number of images does not match")
if video_idx != len(videos):
raise ValueError("number of videos does not match")

if self.enable_processor_cache:
missing_idx = set(missing_idx)
hashes_to_cache, items_to_cache = [], []
for idx in range(len(mm_items)):
if idx in missing_idx:
continue
meta = {}
t, h, w = outputs["grid_thw"][idx][0]
meta["thw"] = (t, h, w)
hashes_to_cache.append(outputs["mm_hashes"][idx])
items_to_cache.append((outputs["images"][idx], meta))
self.update_processor_cache(dealer, hashes_to_cache, items_to_cache)

return outputs

def _add_special_token(self, token: Union[str, int], outputs: Dict) -> None:
token_id = token if isinstance(token, int) else self.tokenizer.convert_tokens_to_ids(token)
outputs["input_ids"].append(token_id)
Expand All @@ -348,14 +463,16 @@ def _add_text(self, tokens, outputs: Dict) -> None:
outputs["position_ids"].append([start + i] * 3)
outputs["cur_position"] += len(tokens)

def _add_image(self, img, outputs: Dict, uuid: Optional[str]) -> None:
def _add_image(self, img, outputs: Dict, uuid: Optional[str], token_len=None) -> None:
patches_h, patches_w = self.image_preprocessor.get_smarted_resize(
img.height,
img.width,
min_pixels=self.image_min_pixels,
max_pixels=self.image_max_pixels,
)[1]
num_tokens = (patches_h * patches_w) // (self.spatial_conv_size**2)
if token_len and token_len != num_tokens:
raise ValueError("image tokens num not match the size")

outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
outputs["input_ids"].extend([self.image_patch_id] * num_tokens)
Expand Down Expand Up @@ -383,9 +500,13 @@ def _add_image(self, img, outputs: Dict, uuid: Optional[str]) -> None:
outputs["grid_thw"].append(ret["image_grid_thw"])
outputs["image_type_ids"].append(0)

def _add_processed_image(self, img_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str) -> None:
def _add_processed_image(
self, img_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str, token_len=None
) -> None:
img, meta = img_cache
num_tokens = img.shape[0] // (self.spatial_conv_size**2)
if token_len and num_tokens != token_len:
raise ValueError("image tokens num not match the size")

outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
outputs["input_ids"].extend([self.image_patch_id] * num_tokens)
Expand All @@ -401,7 +522,7 @@ def _add_processed_image(self, img_cache: Tuple[np.ndarray, dict], outputs: Dict
outputs["grid_thw"].append(np.array([[1, h, w]]))
outputs["image_type_ids"].append(0)

def _add_video(self, frames, outputs: Dict, uuid: Optional[str]) -> None:
def _add_video(self, frames, outputs: Dict, uuid: Optional[str], token_len=None) -> None:
patches_h, patches_w = self.image_preprocessor.get_smarted_resize(
frames[0].height,
frames[0].width,
Expand All @@ -410,6 +531,8 @@ def _add_video(self, frames, outputs: Dict, uuid: Optional[str]) -> None:
)[1]
num_frames = len(frames)
num_tokens = (num_frames * patches_h * patches_w) // (self.spatial_conv_size**2 * self.temporal_conv_size)
if token_len and num_tokens != token_len:
raise ValueError("video tokens num not match the size")

pixel_stack = np.stack([np.array(f.convert("RGB")) for f in frames], axis=0)
ret = self.image_preprocessor.preprocess(
Expand Down Expand Up @@ -438,9 +561,13 @@ def _add_video(self, frames, outputs: Dict, uuid: Optional[str]) -> None:
outputs["position_ids"].extend(pos_ids)
outputs["cur_position"] = np.max(pos_ids) + 1

def _add_processed_video(self, frames_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str) -> None:
def _add_processed_video(
self, frames_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str, token_len=None
) -> None:
frames, meta = frames_cache
num_tokens = frames.shape[0] // (self.spatial_conv_size**2 * self.temporal_conv_size)
if token_len and num_tokens != token_len:
raise ValueError("video tokens num not match the size")

t, h, w = meta["thw"]
outputs["images"].append(frames)
Expand Down
27 changes: 24 additions & 3 deletions tests/input/test_ernie_vl_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def test_process_request_dict_with_options(self):
"prompt_token_ids": [1, 1, 1],
}
self.processor.process_request_dict(request_dict, 100)
self.assertEqual(request_dict["enable_thinking"], False)
self.assertEqual(request_dict["enable_thinking"], True)

request_dict = {
"messages": [{"role": "user", "content": "Hello"}],
Expand All @@ -93,15 +93,15 @@ def test_process_request_dict_with_options(self):
"prompt_token_ids": [1, 1, 1],
}
self.processor.process_request_dict(request_dict, 100)
self.assertEqual(request_dict["enable_thinking"], False)
self.assertEqual(request_dict["enable_thinking"], True)

request_dict = {
"messages": [{"role": "user", "content": "Hello"}],
"chat_template_kwargs": {"options": {"thinking_mode": "false"}},
"prompt_token_ids": [1, 1, 1],
}
self.processor.process_request_dict(request_dict, 100)
self.assertEqual(request_dict["enable_thinking"], False)
self.assertEqual(request_dict["enable_thinking"], True)

request_dict = {
"messages": [{"role": "user", "content": "Hello"}],
Expand All @@ -111,6 +111,27 @@ def test_process_request_dict_with_options(self):
self.processor.process_request_dict(request_dict, 100)
self.assertEqual(request_dict["enable_thinking"], True)

request_dict = {
"messages": [{"role": "user", "content": "Hello"}],
"chat_template_kwargs": {"options": {"thinking_mode": "close"}},
}
self.processor.process_request_dict(request_dict, 100)
self.assertEqual(request_dict["enable_thinking"], False)

request_dict = {
"messages": [{"role": "user", "content": "Hello"}],
"chat_template_kwargs": {"options": {"thinking_mode": "false"}},
}
self.processor.process_request_dict(request_dict, 100)
self.assertEqual(request_dict["enable_thinking"], False)

request_dict = {
"messages": [{"role": "user", "content": "Hello"}],
"chat_template_kwargs": {"enable_thinking": False},
}
self.processor.process_request_dict(request_dict, 100)
self.assertEqual(request_dict["enable_thinking"], False)


if __name__ == "__main__":
unittest.main()
Loading