diff --git a/fastdeploy/entrypoints/openai/protocol.py b/fastdeploy/entrypoints/openai/protocol.py index 4a1e4ef647f..cdcc5cb9d87 100644 --- a/fastdeploy/entrypoints/openai/protocol.py +++ b/fastdeploy/entrypoints/openai/protocol.py @@ -671,10 +671,7 @@ def to_dict_for_infer(self, request_id=None): if request_id is not None: req_dict["request_id"] = request_id - if "prompt_token_ids" in req_dict: - if "messages" in req_dict: - del req_dict["messages"] - else: + if "prompt_token_ids" not in req_dict: # If disable_chat_template is set, then the first message in messages will be used as the prompt. assert ( len(req_dict["messages"]) > 0 diff --git a/fastdeploy/input/ernie4_5_vl_processor/ernie4_5_vl_processor.py b/fastdeploy/input/ernie4_5_vl_processor/ernie4_5_vl_processor.py index 77c62125c7a..d9eec5275c2 100644 --- a/fastdeploy/input/ernie4_5_vl_processor/ernie4_5_vl_processor.py +++ b/fastdeploy/input/ernie4_5_vl_processor/ernie4_5_vl_processor.py @@ -219,7 +219,13 @@ def process_request_dict(self, request, max_model_len=None): bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids) request["bad_words_token_ids"] = bad_words_token_ids - if request.get("prompt"): + if request.get("prompt_token_ids"): + messages = request.get("messages") + if messages: + self._check_mm_limits(messages) + request.setdefault("enable_thinking", True) + outputs = self.ernie4_5_processor.prompt_token_ids2outputs(request) + elif request.get("prompt"): multimodal_data = request.get("multimodal_data") if multimodal_data is None: multimodal_data = {} @@ -256,7 +262,9 @@ def process_request_dict(self, request, max_model_len=None): self.append_completion_tokens(outputs, request["completion_token_ids"]) outputs = self.pack_outputs(outputs) - request["prompt_token_ids"] = outputs["input_ids"].tolist() + request["prompt_token_ids"] = ( + outputs["input_ids"].tolist() if "prompt_token_ids" not in request else request["prompt_token_ids"] + ) request["prompt_token_ids_len"] = len(request["prompt_token_ids"]) request["multimodal_inputs"] = outputs diff --git a/fastdeploy/input/ernie4_5_vl_processor/process.py b/fastdeploy/input/ernie4_5_vl_processor/process.py index 4ccdf287f20..efbb3452607 100644 --- a/fastdeploy/input/ernie4_5_vl_processor/process.py +++ b/fastdeploy/input/ernie4_5_vl_processor/process.py @@ -136,7 +136,9 @@ def __init__( self.video_end = self.VID_END self.image_patch_id = self.tokenizer.convert_tokens_to_ids("<|IMAGE_PLACEHOLDER|>") self.image_start_id = self.tokenizer.convert_tokens_to_ids(self.image_start) + self.image_end_id = self.tokenizer.convert_tokens_to_ids(self.image_end) self.video_start_id = self.tokenizer.convert_tokens_to_ids(self.video_start) + self.video_end_id = self.tokenizer.convert_tokens_to_ids(self.video_end) self.sep_token_id = self.tokenizer.convert_tokens_to_ids(self.sep_token) self.eos_token_id = self.tokenizer.convert_tokens_to_ids(self.eos_token) @@ -243,14 +245,7 @@ def text2ids(self, text, images=None, videos=None, image_uuid=None, video_uuid=N return outputs - def request2ids( - self, request: Dict[str, Any], tgts: List[str] = None - ) -> Dict[str, Union[np.ndarray, List[np.ndarray], None]]: - """ - Convert chat messages into model inputs. - Returns a dict with input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels. - """ - + def extract_mm_items(self, request: Dict[str, Any]): messages = parse_chat_messages(request.get("messages")) mm_items = [] for msg in messages: @@ -273,6 +268,7 @@ def request2ids( if len(missing_hashes) > 0 and not self.enable_processor_cache: raise ValueError("Missing items cannot be retrieved without processor cache.") + dealer = None if self.enable_processor_cache: context = zmq.Context() dealer = context.socket(zmq.DEALER) @@ -295,6 +291,16 @@ def request2ids( video_uuid.append(item["uuid"]) else: raise ValueError(f"Unsupported multimodal type: {item.get('type')}") + return images, videos, image_uuid, video_uuid, dealer, missing_idx, mm_items + + def request2ids( + self, request: Dict[str, Any], tgts: List[str] = None + ) -> Dict[str, Union[np.ndarray, List[np.ndarray], None]]: + """ + Convert chat messages into model inputs. + Returns a dict with input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels. + """ + images, videos, image_uuid, video_uuid, dealer, missing_idx, mm_items = self.extract_mm_items(request) if self.tokenizer.chat_template is None: raise ValueError("This model does not support chat template.") @@ -329,6 +335,115 @@ def request2ids( return outputs + def prompt_token_ids2outputs( + self, request: Dict[str, Any], tgts: List[str] = None + ) -> Dict[str, Union[np.ndarray, List[np.ndarray], None]]: + outputs = { + "input_ids": [], + "token_type_ids": [], + "position_ids": [], + "images": [], + "grid_thw": [], + "image_type_ids": [], + "labels": [], + "cur_position": 0, + "video_cnt": 0, + "num_input_image_tokens": 0, + "num_input_video_tokens": 0, + "mm_positions": [], + "mm_hashes": [], + } + prompt_token_ids = request.get("prompt_token_ids", []) + prompt_token_ids_len = len(prompt_token_ids) + if not request.get("messages"): + outputs["input_ids"].append(prompt_token_ids) + outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]] * prompt_token_ids_len) + for i in range(prompt_token_ids_len): + outputs["position_ids"].append([i] * 3) + outputs["cur_position"] += prompt_token_ids_len + return outputs + images, videos, image_uuid, video_uuid, dealer, missing_idx, mm_items = self.extract_mm_items(request) + st, image_idx, video_idx = 0, 0, 0 + while st < prompt_token_ids_len: + cur_token_id = prompt_token_ids[st] + if cur_token_id == self.image_start_id: + if image_idx >= len(images): + raise ValueError("prompt token ids has more image placeholder than in messages") + # append image_start_id + outputs["input_ids"].extend([cur_token_id]) + outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]]) + outputs["position_ids"].append([outputs["cur_position"]] * 3) + outputs["cur_position"] += 1 + st += 1 + # process placeholder token ids + cur_idx = st + while cur_idx < prompt_token_ids_len and prompt_token_ids[cur_idx] != self.image_end_id: + cur_idx += 1 + if cur_idx >= prompt_token_ids_len: + raise ValueError("image token ids not complete") + image = images[image_idx] + uuid = image_uuid[image_idx] if image_uuid else None + token_len = cur_idx - st + if not isinstance(image, tuple): + self._add_image(image, outputs, uuid, token_len) + else: + self._add_processed_image(image, outputs, uuid, token_len) + image_idx += 1 + st = cur_idx + elif cur_token_id == self.video_start_id: + if video_idx >= len(videos): + raise ValueError("prompt token ids has more video placeholder than in messages") + # append video_start_id + outputs["input_ids"].extend([cur_token_id]) + outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]]) + outputs["position_ids"].append([outputs["cur_position"]] * 3) + outputs["cur_position"] += 1 + st += 1 + # process placeholder token ids + cur_idx = st + while cur_idx < prompt_token_ids_len and prompt_token_ids[cur_idx] != self.video_end_id: + cur_idx += 1 + if cur_idx >= prompt_token_ids_len: + raise ValueError("video token ids not complete") + video = videos[video_idx] + uuid = video_uuid[video_idx] if video_uuid else None + token_len = cur_idx - st + if not isinstance(video, tuple): + if isinstance(video, dict): + frames = self._load_and_process_video(video["video"], video) + else: + frames = self._load_and_process_video(video, {}) + self._add_video(frames, outputs, uuid, token_len) + else: + self._add_processed_video(video, outputs, uuid, token_len) + video_idx += 1 + st = cur_idx + else: + outputs["input_ids"].extend([cur_token_id]) + outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]]) + outputs["position_ids"].append([outputs["cur_position"]] * 3) + outputs["cur_position"] += 1 + st += 1 + if image_idx != len(images): + raise ValueError("number of images does not match") + if video_idx != len(videos): + raise ValueError("number of videos does not match") + + if self.enable_processor_cache: + missing_idx = set(missing_idx) + hashes_to_cache, items_to_cache = [], [] + for idx in range(len(mm_items)): + if idx in missing_idx: + continue + meta = {} + t, h, w = outputs["grid_thw"][idx][0] + meta["thw"] = (t, h, w) + hashes_to_cache.append(outputs["mm_hashes"][idx]) + items_to_cache.append((outputs["images"][idx], meta)) + self.update_processor_cache(dealer, hashes_to_cache, items_to_cache) + + return outputs + def _add_special_token(self, token: Union[str, int], outputs: Dict) -> None: token_id = token if isinstance(token, int) else self.tokenizer.convert_tokens_to_ids(token) outputs["input_ids"].append(token_id) @@ -348,7 +463,7 @@ def _add_text(self, tokens, outputs: Dict) -> None: outputs["position_ids"].append([start + i] * 3) outputs["cur_position"] += len(tokens) - def _add_image(self, img, outputs: Dict, uuid: Optional[str]) -> None: + def _add_image(self, img, outputs: Dict, uuid: Optional[str], token_len=None) -> None: patches_h, patches_w = self.image_preprocessor.get_smarted_resize( img.height, img.width, @@ -356,6 +471,8 @@ def _add_image(self, img, outputs: Dict, uuid: Optional[str]) -> None: max_pixels=self.image_max_pixels, )[1] num_tokens = (patches_h * patches_w) // (self.spatial_conv_size**2) + if token_len and token_len != num_tokens: + raise ValueError("image tokens num not match the size") outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens)) outputs["input_ids"].extend([self.image_patch_id] * num_tokens) @@ -383,9 +500,13 @@ def _add_image(self, img, outputs: Dict, uuid: Optional[str]) -> None: outputs["grid_thw"].append(ret["image_grid_thw"]) outputs["image_type_ids"].append(0) - def _add_processed_image(self, img_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str) -> None: + def _add_processed_image( + self, img_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str, token_len=None + ) -> None: img, meta = img_cache num_tokens = img.shape[0] // (self.spatial_conv_size**2) + if token_len and num_tokens != token_len: + raise ValueError("image tokens num not match the size") outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens)) outputs["input_ids"].extend([self.image_patch_id] * num_tokens) @@ -401,7 +522,7 @@ def _add_processed_image(self, img_cache: Tuple[np.ndarray, dict], outputs: Dict outputs["grid_thw"].append(np.array([[1, h, w]])) outputs["image_type_ids"].append(0) - def _add_video(self, frames, outputs: Dict, uuid: Optional[str]) -> None: + def _add_video(self, frames, outputs: Dict, uuid: Optional[str], token_len=None) -> None: patches_h, patches_w = self.image_preprocessor.get_smarted_resize( frames[0].height, frames[0].width, @@ -410,6 +531,8 @@ def _add_video(self, frames, outputs: Dict, uuid: Optional[str]) -> None: )[1] num_frames = len(frames) num_tokens = (num_frames * patches_h * patches_w) // (self.spatial_conv_size**2 * self.temporal_conv_size) + if token_len and num_tokens != token_len: + raise ValueError("video tokens num not match the size") pixel_stack = np.stack([np.array(f.convert("RGB")) for f in frames], axis=0) ret = self.image_preprocessor.preprocess( @@ -438,9 +561,13 @@ def _add_video(self, frames, outputs: Dict, uuid: Optional[str]) -> None: outputs["position_ids"].extend(pos_ids) outputs["cur_position"] = np.max(pos_ids) + 1 - def _add_processed_video(self, frames_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str) -> None: + def _add_processed_video( + self, frames_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str, token_len=None + ) -> None: frames, meta = frames_cache num_tokens = frames.shape[0] // (self.spatial_conv_size**2 * self.temporal_conv_size) + if token_len and num_tokens != token_len: + raise ValueError("video tokens num not match the size") t, h, w = meta["thw"] outputs["images"].append(frames) diff --git a/tests/input/test_ernie_vl_processor.py b/tests/input/test_ernie_vl_processor.py index 92d24d5b96f..b9bc22d4cc2 100644 --- a/tests/input/test_ernie_vl_processor.py +++ b/tests/input/test_ernie_vl_processor.py @@ -77,7 +77,7 @@ def test_process_request_dict_with_options(self): "prompt_token_ids": [1, 1, 1], } self.processor.process_request_dict(request_dict, 100) - self.assertEqual(request_dict["enable_thinking"], False) + self.assertEqual(request_dict["enable_thinking"], True) request_dict = { "messages": [{"role": "user", "content": "Hello"}], @@ -93,7 +93,7 @@ def test_process_request_dict_with_options(self): "prompt_token_ids": [1, 1, 1], } self.processor.process_request_dict(request_dict, 100) - self.assertEqual(request_dict["enable_thinking"], False) + self.assertEqual(request_dict["enable_thinking"], True) request_dict = { "messages": [{"role": "user", "content": "Hello"}], @@ -101,7 +101,7 @@ def test_process_request_dict_with_options(self): "prompt_token_ids": [1, 1, 1], } self.processor.process_request_dict(request_dict, 100) - self.assertEqual(request_dict["enable_thinking"], False) + self.assertEqual(request_dict["enable_thinking"], True) request_dict = { "messages": [{"role": "user", "content": "Hello"}], @@ -111,6 +111,27 @@ def test_process_request_dict_with_options(self): self.processor.process_request_dict(request_dict, 100) self.assertEqual(request_dict["enable_thinking"], True) + request_dict = { + "messages": [{"role": "user", "content": "Hello"}], + "chat_template_kwargs": {"options": {"thinking_mode": "close"}}, + } + self.processor.process_request_dict(request_dict, 100) + self.assertEqual(request_dict["enable_thinking"], False) + + request_dict = { + "messages": [{"role": "user", "content": "Hello"}], + "chat_template_kwargs": {"options": {"thinking_mode": "false"}}, + } + self.processor.process_request_dict(request_dict, 100) + self.assertEqual(request_dict["enable_thinking"], False) + + request_dict = { + "messages": [{"role": "user", "content": "Hello"}], + "chat_template_kwargs": {"enable_thinking": False}, + } + self.processor.process_request_dict(request_dict, 100) + self.assertEqual(request_dict["enable_thinking"], False) + if __name__ == "__main__": unittest.main()