Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions vllm_gaudi/extension/unified.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,8 +407,8 @@ def hpu_tensor(tensor: torch.tensor, shape: tuple, pad_value: Union[int, float])
def create_unified_batch(req_ids: list[str], all_token_ids: torch.tensor, num_computed_tokens: torch.tensor,
num_scheduled_tokens: torch.tensor, num_prompt_tokens: torch.tensor, block_table: torch.tensor,
block_size: int, dtype: torch.dtype, bucketing_fn: Callable[[bool, int, int, int, int],
tuple[int, int, int,
int]]) -> UnifiedBatch:
tuple[int, int, int, int]],
get_dp_padding_fn: Callable[[int], int]) -> UnifiedBatch:
""" Calculate all necessary tensors needed for batch scheduling """
total_tokens = num_computed_tokens + num_scheduled_tokens
query_len = num_scheduled_tokens.sum().item()
Expand Down Expand Up @@ -487,6 +487,11 @@ def first_dim(t: Optional[torch.tensor]) -> int:
first_dim(logits_indices))
target_qlen, target_shared_blocks, target_unique_blocks, target_logits = bucket

target_qlen += get_dp_padding_fn(target_qlen)
target_shared_blocks += get_dp_padding_fn(target_shared_blocks)
target_unique_blocks += get_dp_padding_fn(target_unique_blocks)
target_logits += get_dp_padding_fn(target_logits)

default_causal_width = 512
fmin = torch.finfo(dtype).min
feps = torch.finfo(dtype).tiny
Expand Down
44 changes: 25 additions & 19 deletions vllm_gaudi/v1/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import collections
import contextlib
import functools
from functools import partial
import itertools
import math
import os
Expand Down Expand Up @@ -652,6 +653,22 @@ def round_up(value: int, k: int):
return (value + k - 1) // k * k


def get_dp_padding(num_tokens: int, dp_size: int, dp_rank: int) -> int:
if dp_size == 1:
return 0

device = current_platform.device_type
group = get_dp_group().device_group

num_tokens_across_dp = [0] * dp_size
num_tokens_across_dp[dp_rank] = num_tokens
num_tokens_tensor = torch.tensor(num_tokens_across_dp, device=device, dtype=torch.int32)
torch.distributed.all_reduce(num_tokens_tensor, group=group)

max_tokens_across_dp_cpu = torch.max(num_tokens_tensor.cpu()).item()
return max_tokens_across_dp_cpu - num_tokens


class HPUModelRunner(KVConnectorModelRunnerMixin):

def __init__(
Expand Down Expand Up @@ -855,6 +872,10 @@ def __init__(
self.defragmenter = OnlineDefragmenter()
self.debug_fwd = init_debug_logger('fwd')

self.get_dp_padding = partial(get_dp_padding,
dp_size=self.parallel_config.data_parallel_size,
dp_rank=self.parallel_config.data_parallel_rank)

assert not (self.unified_attn and not self.use_contiguous_pa), 'Unified attn requires contiguous_pa!'
assert not (self.unified_attn and not self.use_merged_prefill), 'Unified attn requires merged_prefill!'

Expand Down Expand Up @@ -1802,6 +1823,7 @@ def _form_unified_prefill_batch(self, contents):
dtype=self.dtype,
contiguous_kv=self.use_contiguous_pa,
bucketing_fn=self.unified_bucketing_fn,
get_dp_padding_fn=self.get_dp_padding,
)

(token_ids_t, token_positions_t, logits_indices_t, logits_groups, attn_metadata) = batch_data
Expand Down Expand Up @@ -2212,6 +2234,7 @@ def _prepare_unified_decode_inputs(self, num_decodes, num_scheduled_tokens, warm
dtype=self.dtype,
contiguous_kv=self.use_contiguous_pa,
bucketing_fn=self.unified_bucketing_fn,
get_dp_padding_fn=self.get_dp_padding,
)
(token_ids_t, token_positions_t, logits_indices_t, logits_groups, attn_metadata) = batch_data
decode_input_data = DecodeInputData(
Expand Down Expand Up @@ -2360,24 +2383,6 @@ def _check_config(self, batch_size, seq_len, num_blocks, attn_metadata, warmup_m
if not seen and not warmup_mode:
logger.warning("Configuration: %s was not warmed-up!", cfg)

def get_dp_padding(self, num_tokens: int) -> int:
dp_size = self.vllm_config.parallel_config.data_parallel_size
dp_rank = self.vllm_config.parallel_config.data_parallel_rank

if dp_size == 1:
return 0

device = current_platform.device_type
group = get_dp_group().device_group

num_tokens_across_dp = [0] * dp_size
num_tokens_across_dp[dp_rank] = num_tokens
num_tokens_tensor = torch.tensor(num_tokens_across_dp, device=device, dtype=torch.int32)
torch.distributed.all_reduce(num_tokens_tensor, group=group)

max_tokens_across_dp_cpu = torch.max(num_tokens_tensor.cpu()).item()
return max_tokens_across_dp_cpu - num_tokens

def _check_unified_config(self, attn_metadata, logits_indices, warmup_mode):
has_causal = 'c' if attn_metadata.causal_bias is not None else '-'
has_shared = 's' if attn_metadata.shared_bias is not None else '-'
Expand Down Expand Up @@ -2771,9 +2776,10 @@ def prepare_unified_batch(self, scheduler_output):
block_table = self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs, :max_blocks].clone()
if self.defragmenter.enabled:
block_table.apply_(self.defragmenter.resolve)

return create_unified_batch(self.input_batch.req_ids, all_token_ids, num_computed_tokens, num_scheduled_tokens,
num_prompt_tokens, block_table, self.block_size, self.dtype,
self.unified_bucketing_fn)
self.unified_bucketing_fn, self.get_dp_padding)

@torch.inference_mode()
def unified_execute_model(
Expand Down