diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 21c729e71..6b1f27c65 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -1635,6 +1635,29 @@ def _get_prompt_bucketing_fn(self): return self._bucketize_2d_prompt def _can_merge_prefill_contents(self, lhs, rhs): + # --- Logic to handle chunked prefill/prefix caching for HPU --- + # 1. Check basic states of LHS (accumulated batch) and RHS (incoming request). + # lhs_is_not_empty: Check if the accumulated batch actually contains any requests. + # lhs_has_history: Check if any request in the accumulated batch has a non-zero context (history). + lhs_is_not_empty = len(lhs.context_lens) > 0 + lhs_has_history = any(length > 0 for length in lhs.context_lens) + + # 2. Check if RHS (the incoming request) has context_len > 0 (history). + rhs_has_history = any(length > 0 for length in rhs.context_lens) + + # 3. Apply merging restrictions based on history states: + + # Condition A: If the accumulated batch is not empty, we cannot append a request that has history. + # This implies that a request with history (e.g., prefix caching hit) must start as a new batch. + if lhs_is_not_empty and rhs_has_history: + return False + + # Condition B: If the accumulated batch already contains requests with history, + # we cannot append *any* new request (regardless of whether RHS has history or not). + # This locks the batch once it contains history (likely for decode phase or chunked prefill). + if lhs_has_history: + return False + combined_num_tokens = lhs.get_num_tokens() + rhs.get_num_tokens() bucketing_fn = self._get_prompt_bucketing_fn() try: