From f53408f4f93edd362843f3fa9a286a70db0fab8b Mon Sep 17 00:00:00 2001 From: Tianmu Li Date: Fri, 19 Dec 2025 01:02:13 +0200 Subject: [PATCH] cherry-pick #740 Signed-off-by: Tianmu Li --- vllm_gaudi/v1/worker/hpu_model_runner.py | 53 +++++++++++------------- 1 file changed, 24 insertions(+), 29 deletions(-) diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 95b55d034..a5dae093b 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -862,6 +862,7 @@ def __init__( ) self.use_async_scheduling = self.scheduler_config.async_scheduling + self.use_structured_output: bool = False # Default to false. Set to true when needed during a run # Cache token ids on device to avoid h2d copies self.input_ids_hpu = torch.zeros( self.max_num_tokens, dtype=torch.int32, device=self.device, @@ -1788,7 +1789,19 @@ def _extract_prefill_batch_contents(self, num_prefills, num_decodes, num_schedul # self.input_batch.num_prompt_tokens[batch_idx] == self.input_batch.num_tokens[batch_idx]. # In preemption scenario num_tokens will also include the tokens emitted before preemption num_prompt_tokens = self.input_batch.num_prompt_tokens[batch_idx] - num_output_logits = max(0, seq_num_computed_tokens + seq_num_scheduled_tokens - num_prompt_tokens + 1) + if self.use_async_scheduling or self.use_structured_output: + # NOTE(tianmu-li): align behavior of incomplete prompt with gpu_model_runner + # Always have at least 1 logit when using async scheduling + # or structured output + if seq_num_computed_tokens + seq_num_scheduled_tokens - num_prompt_tokens + 1 < 1: + num_output_logits = 1 + if self.use_async_scheduling: + # Discard partial prefill logit for async scheduling + self.invalid_req_indices.append(batch_idx) + else: + num_output_logits = seq_num_computed_tokens + seq_num_scheduled_tokens - num_prompt_tokens + 1 + else: + num_output_logits = max(0, seq_num_computed_tokens + seq_num_scheduled_tokens - num_prompt_tokens + 1) logits_positions = list(range(seq_num_scheduled_tokens - num_output_logits, seq_num_scheduled_tokens)) new_batch_contents = BatchContents( @@ -3201,6 +3214,8 @@ def sample_tokens(self, grammar_output: "GrammarOutput | None") -> ModelRunnerOu num_decodes = len(pd_info.decode_req_ids) num_prefills = len(pd_info.prompt_req_ids) num_reqs = num_decodes + num_prefills + if self.use_async_scheduling: + self.invalid_req_indices: list[int] = [] with self.profiler.record_event('internal', 'prepare_input_tensors'): prefill_input_data, decode_input_data = self._prepare_inputs(scheduler_output, num_prefills, num_decodes, warmup_mode) @@ -3232,14 +3247,13 @@ def sample_tokens(self, grammar_output: "GrammarOutput | None") -> ModelRunnerOu decode_sampled_token_ids_device = None # NOTE(tianmu-li): For structured output, combine logits before # postprocessing. Should it be done for all requests? - structured_output = False + self.use_structured_output = False spec_decode_num_tokens = None if grammar_output is not None: logits_prompt = [] logits_decode = [] - structured_output = True - if self.use_async_scheduling: - invalid_req_indices = [] + self.use_structured_output = True + ######################### PREFILLS ######################### if num_prefills > 0: htorch.core.mark_step() @@ -3258,25 +3272,6 @@ def sample_tokens(self, grammar_output: "GrammarOutput | None") -> ModelRunnerOu self.event_start = self.profiler.get_timestamp_us() self.profiler.start("internal", "prefill") - # NOTE(tianmu-li): Align behavior of incomplete prompt with gpu_model_runner - # If logits_indices is smaller than req_id, the last request is a chunked prompt request that - # hasn't finished in this step. We add the last token position to logits_indices to ensure - # the last token of the chunk is sampled. This sampled token will be discarded later - if logits_indices.shape[0] < len(req_id): - if structured_output or self.use_async_scheduling: - # When there are multiple requests in the batch (e.g. self.use_merged_prefill=True), - # the last token position is the sum of all prompt lengths - 1 - # This logic also holds when there is only one request in the batch - logits_indices_append = torch.full((1, ), - torch.sum(prompt_len) - 1, - device=logits_indices.device, - dtype=logits_indices.dtype) - logits_indices = torch.cat([logits_indices, logits_indices_append]) - if self.use_async_scheduling: - # Discard partial prefill logit for async scheduling - # Depends on 1 decode token/batch - prefill_start_idx = num_decodes - invalid_req_indices.append(prefill_start_idx + idx) htorch.core.mark_step() non_flattened_hidden_states, aux_hidden_states, \ @@ -3295,7 +3290,7 @@ def sample_tokens(self, grammar_output: "GrammarOutput | None") -> ModelRunnerOu aux_hidden_states_prefills.append(aux_hidden_states) sample_hidden_states_prefills.append(sample_hidden_states) # Skip separate sampling for structured output - if structured_output: + if self.use_structured_output: logits_prompt.append(logits_device) prefill_sampled_requests.extend(logits_requests) else: @@ -3365,7 +3360,7 @@ def sample_tokens(self, grammar_output: "GrammarOutput | None") -> ModelRunnerOu warmup_mode=warmup_mode) htorch.core.mark_step() - if structured_output: + if self.use_structured_output: logits_decode.append(logits_device[:num_decodes]) decode_sampled_requests.extend(self.input_batch.req_ids[:num_decodes]) else: @@ -3434,7 +3429,7 @@ def sample_tokens(self, grammar_output: "GrammarOutput | None") -> ModelRunnerOu warmup_mode=warmup_mode) htorch.core.mark_step() - if structured_output: + if self.use_structured_output: # Scheduler places cached before prompt logits_combined = logits_decode + logits_prompt logits = torch.cat(logits_combined, dim=0) @@ -3469,7 +3464,7 @@ def sample_tokens(self, grammar_output: "GrammarOutput | None") -> ModelRunnerOu if self.use_async_scheduling: self.input_batch.prev_sampled_token_ids = sampled_token_ids.flatten() # self.input_batch.prev_sampled_token_ids_invalid_indices - invalid_req_indices_set = set(invalid_req_indices) + invalid_req_indices_set = set(self.invalid_req_indices) self.input_batch.prev_sampled_token_ids_invalid_indices = \ invalid_req_indices_set self.input_batch.prev_req_id_to_index = { @@ -3584,7 +3579,7 @@ def sample_tokens(self, grammar_output: "GrammarOutput | None") -> ModelRunnerOu return AsyncHPUModelRunnerOutput( model_runner_output=model_runner_output, sampled_token_ids=sampled_token_ids, - invalid_req_indices=invalid_req_indices, + invalid_req_indices=self.invalid_req_indices, async_output_copy_stream=self.async_output_copy_stream, ) model_runner_output = ModelRunnerOutput(