@@ -521,7 +521,28 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int =
521521 if hasattr (request , "pooling_params" ) and request .pooling_params is not None :
522522 batch_pooling_params .append (request .pooling_params )
523523
524+ logits_info = None
525+ prefill_tokens = []
524526 if request .task_type .value == RequestType .PREFILL .value : # prefill task
527+ # guided decoding
528+ if (
529+ request .guided_json is not None
530+ or request .guided_regex is not None
531+ or request .structural_tag is not None
532+ or request .guided_grammar is not None
533+ ):
534+ logits_info , schemata_key = self ._init_logits_processor (request )
535+ request .schemata_key = schemata_key
536+
537+ if self .scheduler_config .splitwise_role == "decode" :
538+ if (
539+ hasattr (request , "prefill_end_index" )
540+ and hasattr (request , "prompt_token_ids" )
541+ and request .prefill_end_index > len (request .prompt_token_ids )
542+ ):
543+ if hasattr (request , "output_token_ids" ):
544+ prefill_tokens .extend (request .output_token_ids )
545+
525546 prefill_start_index = request .prefill_start_index
526547 prefill_end_index = request .prefill_end_index
527548 length = prefill_end_index - prefill_start_index
@@ -657,6 +678,8 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int =
657678 # For logits processors
658679 self .share_inputs ["logits_processors_args" ][idx ] = request .get ("logits_processors_args" ) or {}
659680
681+ self .sampler .apply_logits_processor (idx , logits_info , prefill_tokens )
682+
660683 if len (multi_vision_inputs ["images_lst" ]) > 0 :
661684 self .share_inputs ["image_features" ] = self .extract_vision_features (multi_vision_inputs )
662685
@@ -2059,6 +2082,21 @@ def _get_p_done_idxs_gd(self, model_forward_batch: Optional[List[Request]], num_
20592082 if self .share_inputs ["step_idx" ][idx ] == 0 :
20602083 prefill_done_idxs .append (idx )
20612084
2085+ if envs .ENABLE_V1_KVCACHE_SCHEDULER :
2086+ if model_forward_batch is None :
2087+ return prefill_done_idxs
2088+
2089+ for task in model_forward_batch :
2090+ if task .task_type .value != RequestType .PREFILL .value :
2091+ continue
2092+ # in chunk prefill
2093+ if self .cache_config .enable_chunked_prefill :
2094+ if hasattr (task , "prefill_end_index" ) and hasattr (task , "prompt_token_ids" ):
2095+ if len (task .prompt_token_ids ) > task .prefill_end_index and task .idx in prefill_done_idxs :
2096+ prefill_done_idxs .remove (task .idx )
2097+
2098+ return prefill_done_idxs
2099+
20622100 if self .cache_config .enable_chunked_prefill :
20632101 if model_forward_batch is not None :
20642102 for task in model_forward_batch :
0 commit comments