Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions vllm_gaudi/v1/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1635,6 +1635,29 @@ def _get_prompt_bucketing_fn(self):
return self._bucketize_2d_prompt

def _can_merge_prefill_contents(self, lhs, rhs):
# --- Logic to handle chunked prefill/prefix caching for HPU ---
# 1. Check basic states of LHS (accumulated batch) and RHS (incoming request).
# lhs_is_not_empty: Check if the accumulated batch actually contains any requests.
# lhs_has_history: Check if any request in the accumulated batch has a non-zero context (history).
lhs_is_not_empty = len(lhs.context_lens) > 0
lhs_has_history = any(length > 0 for length in lhs.context_lens)

# 2. Check if RHS (the incoming request) has context_len > 0 (history).
rhs_has_history = any(length > 0 for length in rhs.context_lens)

# 3. Apply merging restrictions based on history states:

# Condition A: If the accumulated batch is not empty, we cannot append a request that has history.
# This implies that a request with history (e.g., prefix caching hit) must start as a new batch.
if lhs_is_not_empty and rhs_has_history:
return False

# Condition B: If the accumulated batch already contains requests with history,
# we cannot append *any* new request (regardless of whether RHS has history or not).
# This locks the batch once it contains history (likely for decode phase or chunked prefill).
if lhs_has_history:
return False

combined_num_tokens = lhs.get_num_tokens() + rhs.get_num_tokens()
bucketing_fn = self._get_prompt_bucketing_fn()
try:
Expand Down