Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions vllm_gaudi/extension/unified.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,8 +407,8 @@ def hpu_tensor(tensor: torch.tensor, shape: tuple, pad_value: Union[int, float])
def create_unified_batch(req_ids: list[str], all_token_ids: torch.tensor, num_computed_tokens: torch.tensor,
num_scheduled_tokens: torch.tensor, num_prompt_tokens: torch.tensor, block_table: torch.tensor,
block_size: int, dtype: torch.dtype, bucketing_fn: Callable[[bool, int, int, int, int],
tuple[int, int, int,
int]]) -> UnifiedBatch:
tuple[int, int, int, int]],
get_dp_padding_fn: Callable[[int], int]) -> UnifiedBatch:
""" Calculate all necessary tensors needed for batch scheduling """
total_tokens = num_computed_tokens + num_scheduled_tokens
query_len = num_scheduled_tokens.sum().item()
Expand Down Expand Up @@ -487,6 +487,11 @@ def first_dim(t: Optional[torch.tensor]) -> int:
first_dim(logits_indices))
target_qlen, target_shared_blocks, target_unique_blocks, target_logits = bucket

target_qlen += get_dp_padding_fn(target_qlen)
target_shared_blocks += get_dp_padding_fn(target_shared_blocks)
target_unique_blocks += get_dp_padding_fn(target_unique_blocks)
target_logits += get_dp_padding_fn(target_logits)

default_causal_width = 512
fmin = torch.finfo(dtype).min
feps = torch.finfo(dtype).tiny
Expand Down
44 changes: 25 additions & 19 deletions vllm_gaudi/v1/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import collections
import contextlib
import functools
from functools import partial
import itertools
import math
import os
Expand Down Expand Up @@ -653,6 +654,22 @@ def round_up(value: int, k: int):
return (value + k - 1) // k * k


def get_dp_padding(num_tokens: int, dp_size: int, dp_rank: int) -> int:
if dp_size == 1:
return 0

device = current_platform.device_type
group = get_dp_group().device_group

num_tokens_across_dp = [0] * dp_size
num_tokens_across_dp[dp_rank] = num_tokens
num_tokens_tensor = torch.tensor(num_tokens_across_dp, device=device, dtype=torch.int32)
torch.distributed.all_reduce(num_tokens_tensor, group=group)

max_tokens_across_dp_cpu = torch.max(num_tokens_tensor).item()
return max_tokens_across_dp_cpu - num_tokens


class HPUModelRunner(KVConnectorModelRunnerMixin):

def __init__(
Expand Down Expand Up @@ -856,6 +873,10 @@ def __init__(
self.defragmenter = OnlineDefragmenter()
self.debug_fwd = init_debug_logger('fwd')

self.get_dp_padding = partial(get_dp_padding,
dp_size=self.parallel_config.data_parallel_size,
dp_rank=self.parallel_config.data_parallel_rank)

assert not (self.unified_attn and not self.use_contiguous_pa), 'Unified attn requires contiguous_pa!'
assert not (self.unified_attn and not self.use_merged_prefill), 'Unified attn requires merged_prefill!'

Expand Down Expand Up @@ -1803,6 +1824,7 @@ def _form_unified_prefill_batch(self, contents):
dtype=self.dtype,
contiguous_kv=self.use_contiguous_pa,
bucketing_fn=self.unified_bucketing_fn,
get_dp_padding_fn=self.get_dp_padding,
)

(token_ids_t, token_positions_t, logits_indices_t, logits_groups, attn_metadata) = batch_data
Expand Down Expand Up @@ -2213,6 +2235,7 @@ def _prepare_unified_decode_inputs(self, num_decodes, num_scheduled_tokens, warm
dtype=self.dtype,
contiguous_kv=self.use_contiguous_pa,
bucketing_fn=self.unified_bucketing_fn,
get_dp_padding_fn=self.get_dp_padding,
)
(token_ids_t, token_positions_t, logits_indices_t, logits_groups, attn_metadata) = batch_data
decode_input_data = DecodeInputData(
Expand Down Expand Up @@ -2361,24 +2384,6 @@ def _check_config(self, batch_size, seq_len, num_blocks, attn_metadata, warmup_m
if not seen and not warmup_mode:
logger.warning("Configuration: %s was not warmed-up!", cfg)

def get_dp_padding(self, num_tokens: int) -> int:
dp_size = self.vllm_config.parallel_config.data_parallel_size
dp_rank = self.vllm_config.parallel_config.data_parallel_rank

if dp_size == 1:
return 0

device = current_platform.device_type
group = get_dp_group().device_group

num_tokens_across_dp = [0] * dp_size
num_tokens_across_dp[dp_rank] = num_tokens
num_tokens_tensor = torch.tensor(num_tokens_across_dp, device=device, dtype=torch.int32)
torch.distributed.all_reduce(num_tokens_tensor, group=group)

max_tokens_across_dp_cpu = torch.max(num_tokens_tensor.cpu()).item()
return max_tokens_across_dp_cpu - num_tokens

def _check_unified_config(self, attn_metadata, logits_indices, warmup_mode):
has_causal = 'c' if attn_metadata.causal_bias is not None else '-'
has_shared = 's' if attn_metadata.shared_bias is not None else '-'
Expand Down Expand Up @@ -2772,9 +2777,10 @@ def prepare_unified_batch(self, scheduler_output):
block_table = self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs, :max_blocks].clone()
if self.defragmenter.enabled:
block_table.apply_(self.defragmenter.resolve)

return create_unified_batch(self.input_batch.req_ids, all_token_ids, num_computed_tokens, num_scheduled_tokens,
num_prompt_tokens, block_table, self.block_size, self.dtype,
self.unified_bucketing_fn)
self.unified_bucketing_fn, self.get_dp_padding)

@torch.inference_mode()
def unified_execute_model(
Expand Down