Skip to content

Commit f25ee3a

Browse files
ST-XXCopilot
andauthored
[Feature] enable guided decoding ENABLE_V1_KVCACHE_SCHEDULER = 1 (#5140)
* enable guided decoding ENABLE_V1_KVCACHE_SCHEDULER = 1 * Apply suggestions from code review Co-authored-by: Copilot <[email protected]> --------- Co-authored-by: Copilot <[email protected]>
1 parent 2d78759 commit f25ee3a

File tree

3 files changed

+38
-5
lines changed

3 files changed

+38
-5
lines changed

fastdeploy/engine/args_utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -535,8 +535,6 @@ def __post_init__(self):
535535

536536
if not (current_platform.is_cuda() or current_platform.is_xpu() or current_platform.is_maca()):
537537
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
538-
if self.guided_decoding_backend != "off":
539-
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
540538

541539
if "PaddleOCR" in get_model_architecture(self.model, self.model_config_name):
542540
envs.FD_ENABLE_MAX_PREFILL = 1

fastdeploy/worker/gpu_model_runner.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,7 +521,28 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int =
521521
if hasattr(request, "pooling_params") and request.pooling_params is not None:
522522
batch_pooling_params.append(request.pooling_params)
523523

524+
logits_info = None
525+
prefill_tokens = []
524526
if request.task_type.value == RequestType.PREFILL.value: # prefill task
527+
# guided decoding
528+
if (
529+
request.guided_json is not None
530+
or request.guided_regex is not None
531+
or request.structural_tag is not None
532+
or request.guided_grammar is not None
533+
):
534+
logits_info, schemata_key = self._init_logits_processor(request)
535+
request.schemata_key = schemata_key
536+
537+
if self.scheduler_config.splitwise_role == "decode":
538+
if (
539+
hasattr(request, "prefill_end_index")
540+
and hasattr(request, "prompt_token_ids")
541+
and request.prefill_end_index > len(request.prompt_token_ids)
542+
):
543+
if hasattr(request, "output_token_ids"):
544+
prefill_tokens.extend(request.output_token_ids)
545+
525546
prefill_start_index = request.prefill_start_index
526547
prefill_end_index = request.prefill_end_index
527548
length = prefill_end_index - prefill_start_index
@@ -657,6 +678,8 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int =
657678
# For logits processors
658679
self.share_inputs["logits_processors_args"][idx] = request.get("logits_processors_args") or {}
659680

681+
self.sampler.apply_logits_processor(idx, logits_info, prefill_tokens)
682+
660683
if len(multi_vision_inputs["images_lst"]) > 0:
661684
self.share_inputs["image_features"] = self.extract_vision_features(multi_vision_inputs)
662685

@@ -2059,6 +2082,21 @@ def _get_p_done_idxs_gd(self, model_forward_batch: Optional[List[Request]], num_
20592082
if self.share_inputs["step_idx"][idx] == 0:
20602083
prefill_done_idxs.append(idx)
20612084

2085+
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
2086+
if model_forward_batch is None:
2087+
return prefill_done_idxs
2088+
2089+
for task in model_forward_batch:
2090+
if task.task_type.value != RequestType.PREFILL.value:
2091+
continue
2092+
# in chunk prefill
2093+
if self.cache_config.enable_chunked_prefill:
2094+
if hasattr(task, "prefill_end_index") and hasattr(task, "prompt_token_ids"):
2095+
if len(task.prompt_token_ids) > task.prefill_end_index and task.idx in prefill_done_idxs:
2096+
prefill_done_idxs.remove(task.idx)
2097+
2098+
return prefill_done_idxs
2099+
20622100
if self.cache_config.enable_chunked_prefill:
20632101
if model_forward_batch is not None:
20642102
for task in model_forward_batch:

fastdeploy/worker/worker_process.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -968,9 +968,6 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
968968
if not (current_platform.is_cuda() or current_platform.is_xpu() or current_platform.is_maca()):
969969
logger.info("Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not supported.")
970970
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
971-
if structured_outputs_config.guided_decoding_backend != "off":
972-
logger.info("Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not supported guided_decoding.")
973-
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
974971

975972
if envs.ENABLE_V1_KVCACHE_SCHEDULER and args.splitwise_role == "prefill":
976973
os.environ["PREFILL_NODE_ONE_STEP_STOP_V1"] = "1"

0 commit comments

Comments
 (0)