Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions vllm_gaudi/v1/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def __init__(

def get_output(self) -> ModelRunnerOutput:
"""Copy the device tensors to the host and return a ModelRunnerOutput.

This function blocks until the copy is finished.
"""

Expand Down Expand Up @@ -2411,7 +2411,7 @@ def _prepare_input_ids(self,
scheduler_output: "SchedulerOutput",
return_index: bool = False) -> Optional[torch.Tensor]:
"""Prepare the input IDs for the current batch.

Carefully handles the `prev_sampled_token_ids` which can be cached
from the previous engine iteration, in which case those tokens on the
GPU need to be copied into the corresponding slots into input_ids."""
Expand Down Expand Up @@ -4293,7 +4293,7 @@ def distribute_sum_evenly(self, total_sum, max_length):

def get_merged_prefill_seq_lens(self, query_len, ctx_blocks):
'''
Get seperate sequence lengths from merged layout to individual
Get seperate sequence lengths from merged layout to individual
samples.
Returns list of sequence length (including query and context) and
context lengths.
Expand Down Expand Up @@ -4827,7 +4827,30 @@ def __del__(self):

@torch.inference_mode()
def profile_run(self) -> None:
return
# Skip profile run on decode instances
if (self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_consumer):
return

max_batch_size = max(1, min(self.max_num_seqs, self.max_num_tokens // self.max_model_len))
if self.supports_mm_inputs:
# Using batch_size 1 for profiling multimodal models
max_batch_size = 1

# Run a simple profile scenario using the existing dummy run infrastructure
if self.unified_attn:
# (query_len, shared_ctx_len, unique_ctx_len, is_causal) for unified attention
unified_cfg = (self.max_model_len * max_batch_size, 0, 0, True)
self._prepare_dummy_unified_scenario(unified_cfg)
else:
if self.max_model_len < self.max_num_batched_tokens:
prompt_cfg = (max_batch_size, self.max_model_len, 0)
else:
# Assume bs=1 with max context for profile run
prompt_cfg = (1, self.max_num_batched_tokens,
(self.max_model_len - self.max_num_batched_tokens + self.block_size - 1) //
self.block_size)
decode_cfg = None
self._prepare_dummy_scenario(prompt_cfg, decode_cfg)

def _dummy_run(self, max_num_batched_tokens: int) -> None:
assert max_num_batched_tokens == 1
Expand Down Expand Up @@ -5344,7 +5367,7 @@ class TensorTuple(tuple):
"""
A tuple subclass designed to hold nested torch.Tensors, providing
.shape and .device properties.

It ensures that the nested structure is not ragged and that all
contained tensors reside on the same device.
"""
Expand Down Expand Up @@ -5388,7 +5411,7 @@ def dtype(self):
class HPUAttentionMetadataProcessor:
"""
Processor class for post-processing HPU attention metadata.

This class takes already-built attention metadata and augments it with
additional tensors such as attention bias masks and block mappings that
are required for efficient attention computation on HPU. It does NOT build
Expand Down
53 changes: 45 additions & 8 deletions vllm_gaudi/v1/worker/hpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,13 +179,34 @@ def determine_available_memory(self) -> int:
for layer_name, layer_spec in kv_cache_spec.items():
if isinstance(layer_spec, FullAttentionSpec):
dtype = layer_spec.dtype

# Use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
hpu_k_cache = torch.tensor([], dtype=dtype, device='hpu')
hpu_v_cache = torch.tensor([], dtype=dtype, device='hpu')
hpu_k_scales = torch.tensor([], dtype=dtype, device='hpu')
hpu_v_scales = torch.tensor([], dtype=dtype, device='hpu')
if dtype == torch.float8_e4m3fn and os.environ.get('QUANT_CONFIG', None) is not None and \
os.environ.get('VLLM_DYNAMIC_KV_QUANT', None) is not None and not self.model_config.use_mla:
create_dynamic_scales = True
else:
create_dynamic_scales = False

# Create dummy KV cache tensors with proper shapes for profiling
num_blocks = 1 # Use single block for profiling
block_size = layer_spec.block_size
num_kv_heads = layer_spec.num_kv_heads
head_size = layer_spec.head_size

kv_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(num_blocks, block_size, num_kv_heads,
head_size)
kv_scales_shape = kv_cache_shape[:-1] + (1, )

hpu_k_cache = torch.zeros(kv_cache_shape, dtype=dtype, device='hpu')
hpu_v_cache = None if self.model_config.use_mla else torch.zeros(
kv_cache_shape, dtype=dtype, device='hpu')

hpu_k_scales = torch.ones(kv_scales_shape, dtype=torch.bfloat16,
device='hpu') if create_dynamic_scales else None
if hpu_v_cache is None:
hpu_v_scales = None
elif create_dynamic_scales:
hpu_v_scales = torch.ones(kv_scales_shape, dtype=torch.bfloat16, device='hpu')
else:
hpu_v_scales = None
Comment on lines +204 to +209
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if hpu_v_cache is None:
hpu_v_scales = None
elif create_dynamic_scales:
hpu_v_scales = torch.ones(kv_scales_shape, dtype=torch.bfloat16, device='hpu')
else:
hpu_v_scales = None
hpu_v_scales = torch.ones(kv_scales_shape, dtype=torch.bfloat16, device='hpu') if (not self.model_config.use_mla and create_dynamic_scales) else None

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wuxun-zhang Copilot asked me to change previous your suggested style code to current one ...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, anyway I don't think these if-else are good practice...


kv_caches[layer_name] = (hpu_k_cache, hpu_v_cache, hpu_k_scales, hpu_v_scales)

Expand All @@ -196,6 +217,14 @@ def determine_available_memory(self) -> int:

runner_kv_caches: list[torch.Tensor] = []
bind_kv_cache(kv_caches, self.vllm_config.compilation_config.static_forward_context, runner_kv_caches)

if self.model_runner.unified_attn:
# Create unified attention persistent context for profiling
from vllm_gaudi.extension.unified_batch import UnifiedBatchPersistentContext
self.model_runner.unified_attn_persistent_ctx = UnifiedBatchPersistentContext(
self.model_runner.max_num_batched_tokens, 0, 0, self.model_runner.block_size, dtype,
self.model_runner.profiler)

if is_fake_hpu():
fake_hpu_cache_alloc = 4 * 2**30 # take 4 GiB flat on fake hpu
return fake_hpu_cache_alloc
Expand Down Expand Up @@ -229,7 +258,15 @@ def determine_available_memory(self) -> int:
"reserved for usable KV cache")

logger.info(msg)

# Clear the dummy KV cache to free up memory
kv_caches = {}
forward_context = self.vllm_config.compilation_config.static_forward_context
for layer_name in forward_context:
forward_context[layer_name].kv_cache = None
runner_kv_caches = []
gc.collect()

return cache_size_bytes - dummy_block_headroom

def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
Expand Down Expand Up @@ -372,7 +409,7 @@ def sleep(self, level: int = 1) -> None:
def wake_up(self, tags: list[str] | None = None) -> None:
"""Wake up the worker from sleep mode.
It can move the model back to HPU and/or reinitialize KV cache.

Args:
tags: Optional list of tags (kept for interface compatibility)
"""
Expand Down