Skip to content
Closed
12 changes: 7 additions & 5 deletions vllm_gaudi/v1/spec_decode/hpu_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,9 @@ def prepare_attn_metadata(
# block_tables_list is a nested list of shape [num_seq, num_blocks]
# num_blocks should include the slots needed for the current token
# positions are the context lengths, and we need +1 for num_blocks
num_blocks = torch.ceil((positions + 1) / self.block_size).int()
block_size = self.attn_metadata_builder.kv_cache_spec.block_size

num_blocks = torch.ceil((positions + 1) / block_size).int()
num_blocks = num_blocks[:num_seq].tolist()
block_tables_list = []
for i, n in enumerate(num_blocks):
Expand All @@ -198,7 +200,7 @@ def prepare_attn_metadata(

# Compute slot mapping in [batch_size, 1] shape
clamped_positions = clamped_positions.view(-1, 1)
block_numbers = clamped_positions // self.block_size
block_numbers = clamped_positions // block_size

# Limit with num_seq because block_table_cpu_tensor is in the shape [num_seq, x]
block_numbers = block_numbers.to(torch.int64)[:num_seq]
Expand All @@ -208,8 +210,8 @@ def prepare_attn_metadata(
block_ids.apply_(model_runner.defragmenter.resolve)

# Calculate the slot mapping and fill with padding
slot_mapping = block_ids * self.block_size + clamped_positions % self.block_size
dummy_slots = itertools.cycle(range(model_runner._PAD_SLOT_ID, model_runner._PAD_SLOT_ID + self.block_size))
slot_mapping = block_ids * block_size + clamped_positions % block_size
dummy_slots = itertools.cycle(range(model_runner._PAD_SLOT_ID, model_runner._PAD_SLOT_ID + block_size))
slot_mapping[num_seq:].apply_(lambda _, ds=dummy_slots: next(ds))
# Slot mapping needs to be int64 (long) type
slot_mapping = slot_mapping.to(torch.int64)
Expand All @@ -232,7 +234,7 @@ def prepare_attn_metadata(
block_groups=block_groups_device,
input_positions=None,
slot_mapping=slot_mapping_device,
block_size=self.block_size,
block_size=block_size,
window_block_list=None,
window_block_usage=None,
window_block_groups=None,
Expand Down
39 changes: 20 additions & 19 deletions vllm_gaudi/v1/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2841,9 +2841,7 @@ def _pool(

pooling_metadata = self.input_batch.get_pooling_metadata()
seq_lens_cpu = self.seq_lens.cpu[:self.input_batch.num_reqs]
pooling_metadata.build_pooling_cursor(num_scheduled_tokens_np.tolist(),
seq_lens_cpu,
device=hidden_states.device)
pooling_metadata.build_pooling_cursor(num_scheduled_tokens_np, seq_lens_cpu, device=hidden_states.device)

num_reqs = self.input_batch.num_reqs

Expand Down Expand Up @@ -3158,6 +3156,25 @@ def execute_model(
return EMPTY_MODEL_RUNNER_OUTPUT
# For D case, wait until kv finish load here
return self.kv_connector_no_forward(scheduler_output, self.vllm_config)

if self.input_batch.pooling_params:
(input_ids, position_ids, num_scheduled_tokens, attn_metadata,
total_scheduled_tokens) = self._prepare_inputs_for_pooling(scheduler_output)

with set_forward_context(attn_metadata, self.vllm_config):
hidden_states = self.model.forward(
input_ids=input_ids,
positions=position_ids,
)

flattened = hidden_states.view(-1, hidden_states.shape[-1])
pooled_output = self._pool(
flattened,
total_scheduled_tokens,
np.array(num_scheduled_tokens, dtype=np.int32),
)
return pooled_output

self.scheduler_output = scheduler_output
self.warmup_mode = warmup_mode
self.batch_changed = batch_changed
Expand Down Expand Up @@ -3233,23 +3250,7 @@ def sample_tokens(self, grammar_output: "GrammarOutput | None") -> ModelRunnerOu
# Return [tokD0, tokD1, tokD2, tokP0, tokP1, tokP2]

batch_changed = self.batch_changed
if self.input_batch.pooling_params:
(input_ids, position_ids, num_scheduled_tokens, attn_metadata,
total_scheduled_tokens) = self._prepare_inputs_for_pooling(scheduler_output)

with set_forward_context(attn_metadata, self.vllm_config):
hidden_states = self.model.forward(
input_ids=input_ids,
positions=position_ids,
)

flattened = hidden_states.view(-1, hidden_states.shape[-1])
pooled_output = self._pool(
flattened,
total_scheduled_tokens,
np.array(num_scheduled_tokens, dtype=np.int32),
)
return pooled_output
# If necessary, swap decodes/prompts to have all decodes on the start

ensure_decodes_first(self.input_batch)
Expand Down
Loading