diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 71cb93f46..24488d078 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -159,7 +159,7 @@ def __init__( def get_output(self) -> ModelRunnerOutput: """Copy the device tensors to the host and return a ModelRunnerOutput. - + This function blocks until the copy is finished. """ @@ -2411,7 +2411,7 @@ def _prepare_input_ids(self, scheduler_output: "SchedulerOutput", return_index: bool = False) -> Optional[torch.Tensor]: """Prepare the input IDs for the current batch. - + Carefully handles the `prev_sampled_token_ids` which can be cached from the previous engine iteration, in which case those tokens on the GPU need to be copied into the corresponding slots into input_ids.""" @@ -4293,7 +4293,7 @@ def distribute_sum_evenly(self, total_sum, max_length): def get_merged_prefill_seq_lens(self, query_len, ctx_blocks): ''' - Get seperate sequence lengths from merged layout to individual + Get seperate sequence lengths from merged layout to individual samples. Returns list of sequence length (including query and context) and context lengths. @@ -4827,7 +4827,30 @@ def __del__(self): @torch.inference_mode() def profile_run(self) -> None: - return + # Skip profile run on decode instances + if (self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_consumer): + return + + max_batch_size = max(1, min(self.max_num_seqs, self.max_num_tokens // self.max_model_len)) + if self.supports_mm_inputs: + # Using batch_size 1 for profiling multimodal models + max_batch_size = 1 + + # Run a simple profile scenario using the existing dummy run infrastructure + if self.unified_attn: + # (query_len, shared_ctx_len, unique_ctx_len, is_causal) for unified attention + unified_cfg = (self.max_model_len * max_batch_size, 0, 0, True) + self._prepare_dummy_unified_scenario(unified_cfg) + else: + if self.max_model_len < self.max_num_batched_tokens: + prompt_cfg = (max_batch_size, self.max_model_len, 0) + else: + # Assume bs=1 with max context for profile run + prompt_cfg = (1, self.max_num_batched_tokens, + (self.max_model_len - self.max_num_batched_tokens + self.block_size - 1) // + self.block_size) + decode_cfg = None + self._prepare_dummy_scenario(prompt_cfg, decode_cfg) def _dummy_run(self, max_num_batched_tokens: int) -> None: assert max_num_batched_tokens == 1 @@ -5344,7 +5367,7 @@ class TensorTuple(tuple): """ A tuple subclass designed to hold nested torch.Tensors, providing .shape and .device properties. - + It ensures that the nested structure is not ragged and that all contained tensors reside on the same device. """ @@ -5388,7 +5411,7 @@ def dtype(self): class HPUAttentionMetadataProcessor: """ Processor class for post-processing HPU attention metadata. - + This class takes already-built attention metadata and augments it with additional tensors such as attention bias masks and block mappings that are required for efficient attention computation on HPU. It does NOT build diff --git a/vllm_gaudi/v1/worker/hpu_worker.py b/vllm_gaudi/v1/worker/hpu_worker.py index aaa9bb9ea..1d37289af 100644 --- a/vllm_gaudi/v1/worker/hpu_worker.py +++ b/vllm_gaudi/v1/worker/hpu_worker.py @@ -179,13 +179,34 @@ def determine_available_memory(self) -> int: for layer_name, layer_spec in kv_cache_spec.items(): if isinstance(layer_spec, FullAttentionSpec): dtype = layer_spec.dtype - - # Use an empty tensor instead of `None`` to force Dynamo to pass - # it by reference, rather by specializing on the value ``None``. - hpu_k_cache = torch.tensor([], dtype=dtype, device='hpu') - hpu_v_cache = torch.tensor([], dtype=dtype, device='hpu') - hpu_k_scales = torch.tensor([], dtype=dtype, device='hpu') - hpu_v_scales = torch.tensor([], dtype=dtype, device='hpu') + if dtype == torch.float8_e4m3fn and os.environ.get('QUANT_CONFIG', None) is not None and \ + os.environ.get('VLLM_DYNAMIC_KV_QUANT', None) is not None and not self.model_config.use_mla: + create_dynamic_scales = True + else: + create_dynamic_scales = False + + # Create dummy KV cache tensors with proper shapes for profiling + num_blocks = 1 # Use single block for profiling + block_size = layer_spec.block_size + num_kv_heads = layer_spec.num_kv_heads + head_size = layer_spec.head_size + + kv_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(num_blocks, block_size, num_kv_heads, + head_size) + kv_scales_shape = kv_cache_shape[:-1] + (1, ) + + hpu_k_cache = torch.zeros(kv_cache_shape, dtype=dtype, device='hpu') + hpu_v_cache = None if self.model_config.use_mla else torch.zeros( + kv_cache_shape, dtype=dtype, device='hpu') + + hpu_k_scales = torch.ones(kv_scales_shape, dtype=torch.bfloat16, + device='hpu') if create_dynamic_scales else None + if hpu_v_cache is None: + hpu_v_scales = None + elif create_dynamic_scales: + hpu_v_scales = torch.ones(kv_scales_shape, dtype=torch.bfloat16, device='hpu') + else: + hpu_v_scales = None kv_caches[layer_name] = (hpu_k_cache, hpu_v_cache, hpu_k_scales, hpu_v_scales) @@ -196,6 +217,14 @@ def determine_available_memory(self) -> int: runner_kv_caches: list[torch.Tensor] = [] bind_kv_cache(kv_caches, self.vllm_config.compilation_config.static_forward_context, runner_kv_caches) + + if self.model_runner.unified_attn: + # Create unified attention persistent context for profiling + from vllm_gaudi.extension.unified_batch import UnifiedBatchPersistentContext + self.model_runner.unified_attn_persistent_ctx = UnifiedBatchPersistentContext( + self.model_runner.max_num_batched_tokens, 0, 0, self.model_runner.block_size, dtype, + self.model_runner.profiler) + if is_fake_hpu(): fake_hpu_cache_alloc = 4 * 2**30 # take 4 GiB flat on fake hpu return fake_hpu_cache_alloc @@ -229,7 +258,15 @@ def determine_available_memory(self) -> int: "reserved for usable KV cache") logger.info(msg) + + # Clear the dummy KV cache to free up memory + kv_caches = {} + forward_context = self.vllm_config.compilation_config.static_forward_context + for layer_name in forward_context: + forward_context[layer_name].kv_cache = None + runner_kv_caches = [] gc.collect() + return cache_size_bytes - dummy_block_headroom def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: @@ -372,7 +409,7 @@ def sleep(self, level: int = 1) -> None: def wake_up(self, tags: list[str] | None = None) -> None: """Wake up the worker from sleep mode. It can move the model back to HPU and/or reinitialize KV cache. - + Args: tags: Optional list of tags (kept for interface compatibility) """