From 4e6262145f897b5b0ab83753e3a4689e77452b7d Mon Sep 17 00:00:00 2001 From: Wuxun Zhang Date: Tue, 23 Sep 2025 10:27:08 +0300 Subject: [PATCH 1/2] Support DP for unified atten Signed-off-by: Wuxun Zhang --- vllm_gaudi/extension/unified.py | 9 +++++-- vllm_gaudi/v1/worker/hpu_model_runner.py | 31 +++++++++++++++--------- 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/vllm_gaudi/extension/unified.py b/vllm_gaudi/extension/unified.py index 9d1838122..58a03b1c0 100644 --- a/vllm_gaudi/extension/unified.py +++ b/vllm_gaudi/extension/unified.py @@ -403,8 +403,8 @@ def hpu_tensor(tensor: torch.tensor, shape: tuple, pad_value: Union[int, float]) def create_unified_batch(req_ids: list[str], all_token_ids: torch.tensor, num_computed_tokens: torch.tensor, num_scheduled_tokens: torch.tensor, num_prompt_tokens: torch.tensor, block_table: torch.tensor, block_size: int, dtype: torch.dtype, bucketing_fn: Callable[[bool, int, int, int, int], - tuple[int, int, int, - int]]) -> UnifiedBatch: + tuple[int, int, int, int]], + get_dp_padding_fn: Callable[[int], int]) -> UnifiedBatch: """ Calculate all necessary tensors needed for batch scheduling """ total_tokens = num_computed_tokens + num_scheduled_tokens query_len = num_scheduled_tokens.sum().item() @@ -481,6 +481,11 @@ def first_dim(t: Optional[torch.tensor]) -> int: first_dim(logits_indices)) target_qlen, target_shared_blocks, target_unique_blocks, target_logits = bucket + target_qlen += get_dp_padding_fn(target_qlen) + target_shared_blocks += get_dp_padding_fn(target_shared_blocks) + target_unique_blocks += get_dp_padding_fn(target_unique_blocks) + target_logits += get_dp_padding_fn(target_logits) + default_causal_width = 512 fmin = torch.finfo(dtype).min feps = torch.finfo(dtype).tiny diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 1364c39db..6fb596362 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -2,6 +2,7 @@ import collections import contextlib import functools +from functools import partial import itertools import math import os @@ -645,6 +646,15 @@ def round_up(value: int, k: int): return (value + k - 1) // k * k +def get_dp_padding(num_tokens: int, dp_size: int, dp_rank: int, DPMetaData: DPMetadata) -> int: + if dp_size == 1: + return 0 + + num_tokens_across_dp = DPMetaData.num_tokens_across_dp(num_tokens, dp_size, dp_rank) + max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp).item() + return max_tokens_across_dp_cpu - num_tokens + + class HPUModelRunner(KVConnectorModelRunnerMixin): def __init__( @@ -846,6 +856,11 @@ def __init__( self.defragmenter = OnlineDefragmenter() self.debug_fwd = init_debug_logger('fwd') + self.get_dp_padding = partial(get_dp_padding, + dp_size=self.parallel_config.data_parallel_size, + dp_rank=self.parallel_config.data_parallel_rank, + DPMetaData=DPMetadata) + assert not (self.unified_attn and not self.use_contiguous_pa), 'Unified attn requires contiguous_pa!' assert not (self.unified_attn and not self.use_merged_prefill), 'Unified attn requires merged_prefill!' @@ -1779,6 +1794,7 @@ def _form_unified_prefill_batch(self, contents): dtype=self.dtype, contiguous_kv=self.use_contiguous_pa, bucketing_fn=self.unified_bucketing_fn, + get_dp_padding_fn=self.get_dp_padding, ) (token_ids_t, token_positions_t, logits_indices_t, logits_groups, attn_metadata) = batch_data logits_requests = [req_ids[lg] for lg in logits_groups] @@ -2182,6 +2198,7 @@ def _prepare_unified_decode_inputs(self, num_decodes, num_scheduled_tokens) -> D dtype=self.dtype, contiguous_kv=self.use_contiguous_pa, bucketing_fn=self.unified_bucketing_fn, + get_dp_padding_fn=self.get_dp_padding, ) (token_ids_t, token_positions_t, logits_indices_t, logits_groups, attn_metadata) = batch_data return DecodeInputData( @@ -2315,17 +2332,6 @@ def _check_config(self, batch_size, seq_len, num_blocks, attn_metadata, warmup_m if not seen and not warmup_mode: logger.warning("Configuration: %s was not warmed-up!", cfg) - def get_dp_padding(self, num_tokens: int) -> int: - dp_size = self.vllm_config.parallel_config.data_parallel_size - dp_rank = self.vllm_config.parallel_config.data_parallel_rank - - if dp_size == 1: - return 0 - - num_tokens_across_dp = DPMetadata.num_tokens_across_dp(num_tokens, dp_size, dp_rank) - max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp).item() - return max_tokens_across_dp_cpu - num_tokens - def _check_unified_config(self, attn_metadata, logits_indices, warmup_mode): has_causal = 'c' if attn_metadata.causal_bias is not None else '-' has_shared = 's' if attn_metadata.shared_bias is not None else '-' @@ -2716,9 +2722,10 @@ def prepare_unified_batch(self, scheduler_output): block_table = self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs, :max_blocks].clone() if self.defragmenter.enabled: block_table.apply_(self.defragmenter.resolve) + return create_unified_batch(self.input_batch.req_ids, all_token_ids, num_computed_tokens, num_scheduled_tokens, num_prompt_tokens, block_table, self.block_size, self.dtype, - self.unified_bucketing_fn) + self.unified_bucketing_fn, self.get_dp_padding) @torch.inference_mode() def unified_execute_model( From 60a8a9cfc39f0d62fcb86a79f6669c3095922de9 Mon Sep 17 00:00:00 2001 From: Wuxun Zhang Date: Tue, 21 Oct 2025 05:26:49 +0300 Subject: [PATCH 2/2] remove unnecessary .cpu() Signed-off-by: Wuxun Zhang --- vllm_gaudi/v1/worker/hpu_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 543654fb0..8b1d98bbd 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -665,7 +665,7 @@ def get_dp_padding(num_tokens: int, dp_size: int, dp_rank: int) -> int: num_tokens_tensor = torch.tensor(num_tokens_across_dp, device=device, dtype=torch.int32) torch.distributed.all_reduce(num_tokens_tensor, group=group) - max_tokens_across_dp_cpu = torch.max(num_tokens_tensor.cpu()).item() + max_tokens_across_dp_cpu = torch.max(num_tokens_tensor).item() return max_tokens_across_dp_cpu - num_tokens