Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 24 additions & 29 deletions vllm_gaudi/v1/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,7 @@ def __init__(
)

self.use_async_scheduling = self.scheduler_config.async_scheduling
self.use_structured_output: bool = False # Default to false. Set to true when needed during a run
Copy link

Copilot AI Dec 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Capitalized 'false' and 'true' to match Python boolean literals 'False' and 'True'.

Suggested change
self.use_structured_output: bool = False # Default to false. Set to true when needed during a run
self.use_structured_output: bool = False # Default to False. Set to True when needed during a run

Copilot uses AI. Check for mistakes.
# Cache token ids on device to avoid h2d copies
self.input_ids_hpu = torch.zeros(
self.max_num_tokens, dtype=torch.int32, device=self.device,
Expand Down Expand Up @@ -1788,7 +1789,19 @@ def _extract_prefill_batch_contents(self, num_prefills, num_decodes, num_schedul
# self.input_batch.num_prompt_tokens[batch_idx] == self.input_batch.num_tokens[batch_idx].
# In preemption scenario num_tokens will also include the tokens emitted before preemption
num_prompt_tokens = self.input_batch.num_prompt_tokens[batch_idx]
num_output_logits = max(0, seq_num_computed_tokens + seq_num_scheduled_tokens - num_prompt_tokens + 1)
if self.use_async_scheduling or self.use_structured_output:
# NOTE(tianmu-li): align behavior of incomplete prompt with gpu_model_runner
# Always have at least 1 logit when using async scheduling
# or structured output
if seq_num_computed_tokens + seq_num_scheduled_tokens - num_prompt_tokens + 1 < 1:
Copy link

Copilot AI Dec 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The condition seq_num_computed_tokens + seq_num_scheduled_tokens - num_prompt_tokens + 1 < 1 is complex and unclear. Consider simplifying to seq_num_computed_tokens + seq_num_scheduled_tokens < num_prompt_tokens for better readability.

Suggested change
if seq_num_computed_tokens + seq_num_scheduled_tokens - num_prompt_tokens + 1 < 1:
if seq_num_computed_tokens + seq_num_scheduled_tokens < num_prompt_tokens:

Copilot uses AI. Check for mistakes.
num_output_logits = 1
if self.use_async_scheduling:
# Discard partial prefill logit for async scheduling
self.invalid_req_indices.append(batch_idx)
else:
num_output_logits = seq_num_computed_tokens + seq_num_scheduled_tokens - num_prompt_tokens + 1
else:
num_output_logits = max(0, seq_num_computed_tokens + seq_num_scheduled_tokens - num_prompt_tokens + 1)
logits_positions = list(range(seq_num_scheduled_tokens - num_output_logits, seq_num_scheduled_tokens))

new_batch_contents = BatchContents(
Expand Down Expand Up @@ -3201,6 +3214,8 @@ def sample_tokens(self, grammar_output: "GrammarOutput | None") -> ModelRunnerOu
num_decodes = len(pd_info.decode_req_ids)
num_prefills = len(pd_info.prompt_req_ids)
num_reqs = num_decodes + num_prefills
if self.use_async_scheduling:
self.invalid_req_indices: list[int] = []
with self.profiler.record_event('internal', 'prepare_input_tensors'):
prefill_input_data, decode_input_data = self._prepare_inputs(scheduler_output, num_prefills, num_decodes,
warmup_mode)
Expand Down Expand Up @@ -3232,14 +3247,13 @@ def sample_tokens(self, grammar_output: "GrammarOutput | None") -> ModelRunnerOu
decode_sampled_token_ids_device = None
# NOTE(tianmu-li): For structured output, combine logits before
# postprocessing. Should it be done for all requests?
structured_output = False
self.use_structured_output = False
spec_decode_num_tokens = None
if grammar_output is not None:
logits_prompt = []
logits_decode = []
structured_output = True
if self.use_async_scheduling:
invalid_req_indices = []
self.use_structured_output = True

######################### PREFILLS #########################
if num_prefills > 0:
htorch.core.mark_step()
Expand All @@ -3258,25 +3272,6 @@ def sample_tokens(self, grammar_output: "GrammarOutput | None") -> ModelRunnerOu

self.event_start = self.profiler.get_timestamp_us()
self.profiler.start("internal", "prefill")
# NOTE(tianmu-li): Align behavior of incomplete prompt with gpu_model_runner
# If logits_indices is smaller than req_id, the last request is a chunked prompt request that
# hasn't finished in this step. We add the last token position to logits_indices to ensure
# the last token of the chunk is sampled. This sampled token will be discarded later
if logits_indices.shape[0] < len(req_id):
if structured_output or self.use_async_scheduling:
# When there are multiple requests in the batch (e.g. self.use_merged_prefill=True),
# the last token position is the sum of all prompt lengths - 1
# This logic also holds when there is only one request in the batch
logits_indices_append = torch.full((1, ),
torch.sum(prompt_len) - 1,
device=logits_indices.device,
dtype=logits_indices.dtype)
logits_indices = torch.cat([logits_indices, logits_indices_append])
if self.use_async_scheduling:
# Discard partial prefill logit for async scheduling
# Depends on 1 decode token/batch
prefill_start_idx = num_decodes
invalid_req_indices.append(prefill_start_idx + idx)

htorch.core.mark_step()
non_flattened_hidden_states, aux_hidden_states, \
Expand All @@ -3295,7 +3290,7 @@ def sample_tokens(self, grammar_output: "GrammarOutput | None") -> ModelRunnerOu
aux_hidden_states_prefills.append(aux_hidden_states)
sample_hidden_states_prefills.append(sample_hidden_states)
# Skip separate sampling for structured output
if structured_output:
if self.use_structured_output:
logits_prompt.append(logits_device)
prefill_sampled_requests.extend(logits_requests)
else:
Expand Down Expand Up @@ -3365,7 +3360,7 @@ def sample_tokens(self, grammar_output: "GrammarOutput | None") -> ModelRunnerOu
warmup_mode=warmup_mode)
htorch.core.mark_step()

if structured_output:
if self.use_structured_output:
logits_decode.append(logits_device[:num_decodes])
decode_sampled_requests.extend(self.input_batch.req_ids[:num_decodes])
else:
Expand Down Expand Up @@ -3434,7 +3429,7 @@ def sample_tokens(self, grammar_output: "GrammarOutput | None") -> ModelRunnerOu
warmup_mode=warmup_mode)
htorch.core.mark_step()

if structured_output:
if self.use_structured_output:
# Scheduler places cached before prompt
logits_combined = logits_decode + logits_prompt
logits = torch.cat(logits_combined, dim=0)
Expand Down Expand Up @@ -3469,7 +3464,7 @@ def sample_tokens(self, grammar_output: "GrammarOutput | None") -> ModelRunnerOu
if self.use_async_scheduling:
self.input_batch.prev_sampled_token_ids = sampled_token_ids.flatten()
# self.input_batch.prev_sampled_token_ids_invalid_indices
invalid_req_indices_set = set(invalid_req_indices)
invalid_req_indices_set = set(self.invalid_req_indices)
self.input_batch.prev_sampled_token_ids_invalid_indices = \
invalid_req_indices_set
self.input_batch.prev_req_id_to_index = {
Expand Down Expand Up @@ -3584,7 +3579,7 @@ def sample_tokens(self, grammar_output: "GrammarOutput | None") -> ModelRunnerOu
return AsyncHPUModelRunnerOutput(
model_runner_output=model_runner_output,
sampled_token_ids=sampled_token_ids,
invalid_req_indices=invalid_req_indices,
invalid_req_indices=self.invalid_req_indices,
async_output_copy_stream=self.async_output_copy_stream,
)
model_runner_output = ModelRunnerOutput(
Expand Down