From fbd02684d7a95faab51fb6d98006abc6832d281a Mon Sep 17 00:00:00 2001 From: Michal Adamczyk Date: Mon, 22 Sep 2025 15:07:56 +0300 Subject: [PATCH] Fix calculating used blocks Signed-off-by: Michal Adamczyk --- vllm_gaudi/extension/unified.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/vllm_gaudi/extension/unified.py b/vllm_gaudi/extension/unified.py index 9d183812..bb699b7d 100644 --- a/vllm_gaudi/extension/unified.py +++ b/vllm_gaudi/extension/unified.py @@ -349,16 +349,16 @@ class Context: block_usages: torch.tensor @staticmethod - def create(num_computed_tokens: torch.tensor, block_table: torch.tensor, block_size: int) -> Self: + def create(total_tokens: torch.tensor, block_table: torch.tensor, block_size: int) -> Self: """ Create a new Context obj """ - num_ctx_blocks = (num_computed_tokens + block_size - 1) // block_size + num_ctx_blocks = (total_tokens + block_size - 1) // block_size if num_ctx_blocks.sum() <= 0: return None group_ids, group_offsets = indices_and_offsets(num_ctx_blocks) block_ids = fetch_2d(block_table, group_ids, group_offsets) block_usages = torch.clamp( - num_computed_tokens.index_select(0, group_ids) - group_offsets * block_size + 1, 1, block_size) + total_tokens.index_select(0, group_ids) - group_offsets * block_size + 1, 1, block_size) ctx = Context(group_ids, group_offsets, block_ids, block_usages) all_shapes = [v.shape for v in ctx._values() if torch.is_tensor(v)] @@ -408,7 +408,9 @@ def create_unified_batch(req_ids: list[str], all_token_ids: torch.tensor, num_co """ Calculate all necessary tensors needed for batch scheduling """ total_tokens = num_computed_tokens + num_scheduled_tokens query_len = num_scheduled_tokens.sum().item() - contains_prompts = torch.any(total_tokens <= num_prompt_tokens).item() + is_prompt = total_tokens <= num_prompt_tokens + cached_tokens = num_computed_tokens + torch.where(is_prompt, 0, num_scheduled_tokens) + contains_prompts = torch.any(is_prompt).item() num_output_tokens = total_tokens - num_prompt_tokens + 1 num_output_tokens = torch.clamp(num_output_tokens, torch.zeros_like(num_scheduled_tokens), num_scheduled_tokens) group_starts = torch.cumsum(num_scheduled_tokens, dim=0) - num_scheduled_tokens @@ -440,7 +442,7 @@ def first_dim(t: Optional[torch.tensor]) -> int: if contains_prompts: causal_bias = create_causal_bias(token_groups, token_positions, dtype) - ctx = Context.create(num_computed_tokens, block_table, block_size) + ctx = Context.create(cached_tokens, block_table, block_size) if ctx: shared_ctx, unique_ctx = ctx.split(num_scheduled_tokens) if shared_ctx: