-
Notifications
You must be signed in to change notification settings - Fork 101
Implement profile_run method in HPU model runner #775
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
51610c9
5f0b0e3
fea3bc5
03c7a6e
6c5ed62
d6e74c9
6817354
db825c5
b5fbc19
c23cf37
8f933b2
a90ee8f
a296980
d57aff3
8cc4ca1
990327e
09591e6
fb301c9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -179,13 +179,34 @@ def determine_available_memory(self) -> int: | |||||||||||||||
| for layer_name, layer_spec in kv_cache_spec.items(): | ||||||||||||||||
| if isinstance(layer_spec, FullAttentionSpec): | ||||||||||||||||
| dtype = layer_spec.dtype | ||||||||||||||||
|
|
||||||||||||||||
| # Use an empty tensor instead of `None`` to force Dynamo to pass | ||||||||||||||||
| # it by reference, rather by specializing on the value ``None``. | ||||||||||||||||
| hpu_k_cache = torch.tensor([], dtype=dtype, device='hpu') | ||||||||||||||||
| hpu_v_cache = torch.tensor([], dtype=dtype, device='hpu') | ||||||||||||||||
| hpu_k_scales = torch.tensor([], dtype=dtype, device='hpu') | ||||||||||||||||
| hpu_v_scales = torch.tensor([], dtype=dtype, device='hpu') | ||||||||||||||||
| if dtype == torch.float8_e4m3fn and os.environ.get('QUANT_CONFIG', None) is not None and \ | ||||||||||||||||
| os.environ.get('VLLM_DYNAMIC_KV_QUANT', None) is not None and not self.model_config.use_mla: | ||||||||||||||||
| create_dynamic_scales = True | ||||||||||||||||
| else: | ||||||||||||||||
| create_dynamic_scales = False | ||||||||||||||||
|
|
||||||||||||||||
| # Create dummy KV cache tensors with proper shapes for profiling | ||||||||||||||||
| num_blocks = 1 # Use single block for profiling | ||||||||||||||||
| block_size = layer_spec.block_size | ||||||||||||||||
| num_kv_heads = layer_spec.num_kv_heads | ||||||||||||||||
| head_size = layer_spec.head_size | ||||||||||||||||
|
|
||||||||||||||||
| kv_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(num_blocks, block_size, num_kv_heads, | ||||||||||||||||
| head_size) | ||||||||||||||||
| kv_scales_shape = kv_cache_shape[:-1] + (1, ) | ||||||||||||||||
|
|
||||||||||||||||
| hpu_k_cache = torch.zeros(kv_cache_shape, dtype=dtype, device='hpu') | ||||||||||||||||
| hpu_v_cache = None if self.model_config.use_mla else torch.zeros( | ||||||||||||||||
| kv_cache_shape, dtype=dtype, device='hpu') | ||||||||||||||||
xwu-intel marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||
|
|
||||||||||||||||
| hpu_k_scales = torch.ones(kv_scales_shape, dtype=torch.bfloat16, | ||||||||||||||||
| device='hpu') if create_dynamic_scales else None | ||||||||||||||||
| if hpu_v_cache is None: | ||||||||||||||||
| hpu_v_scales = None | ||||||||||||||||
| elif create_dynamic_scales: | ||||||||||||||||
| hpu_v_scales = torch.ones(kv_scales_shape, dtype=torch.bfloat16, device='hpu') | ||||||||||||||||
| else: | ||||||||||||||||
| hpu_v_scales = None | ||||||||||||||||
|
Comment on lines
+204
to
+209
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @wuxun-zhang Copilot asked me to change previous your suggested style code to current one ...
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay, anyway I don't think these |
||||||||||||||||
|
|
||||||||||||||||
| kv_caches[layer_name] = (hpu_k_cache, hpu_v_cache, hpu_k_scales, hpu_v_scales) | ||||||||||||||||
|
|
||||||||||||||||
|
|
@@ -196,6 +217,14 @@ def determine_available_memory(self) -> int: | |||||||||||||||
|
|
||||||||||||||||
| runner_kv_caches: list[torch.Tensor] = [] | ||||||||||||||||
| bind_kv_cache(kv_caches, self.vllm_config.compilation_config.static_forward_context, runner_kv_caches) | ||||||||||||||||
|
|
||||||||||||||||
| if self.model_runner.unified_attn: | ||||||||||||||||
| # Create unified attention persistent context for profiling | ||||||||||||||||
| from vllm_gaudi.extension.unified_batch import UnifiedBatchPersistentContext | ||||||||||||||||
| self.model_runner.unified_attn_persistent_ctx = UnifiedBatchPersistentContext( | ||||||||||||||||
| self.model_runner.max_num_batched_tokens, 0, 0, self.model_runner.block_size, dtype, | ||||||||||||||||
| self.model_runner.profiler) | ||||||||||||||||
|
|
||||||||||||||||
| if is_fake_hpu(): | ||||||||||||||||
| fake_hpu_cache_alloc = 4 * 2**30 # take 4 GiB flat on fake hpu | ||||||||||||||||
| return fake_hpu_cache_alloc | ||||||||||||||||
|
|
@@ -229,7 +258,15 @@ def determine_available_memory(self) -> int: | |||||||||||||||
| "reserved for usable KV cache") | ||||||||||||||||
|
|
||||||||||||||||
| logger.info(msg) | ||||||||||||||||
|
|
||||||||||||||||
| # Clear the dummy KV cache to free up memory | ||||||||||||||||
| kv_caches = {} | ||||||||||||||||
| forward_context = self.vllm_config.compilation_config.static_forward_context | ||||||||||||||||
| for layer_name in forward_context: | ||||||||||||||||
| forward_context[layer_name].kv_cache = None | ||||||||||||||||
| runner_kv_caches = [] | ||||||||||||||||
| gc.collect() | ||||||||||||||||
|
|
||||||||||||||||
| return cache_size_bytes - dummy_block_headroom | ||||||||||||||||
|
|
||||||||||||||||
| def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: | ||||||||||||||||
|
|
@@ -372,7 +409,7 @@ def sleep(self, level: int = 1) -> None: | |||||||||||||||
| def wake_up(self, tags: list[str] | None = None) -> None: | ||||||||||||||||
| """Wake up the worker from sleep mode. | ||||||||||||||||
| It can move the model back to HPU and/or reinitialize KV cache. | ||||||||||||||||
|
|
||||||||||||||||
| Args: | ||||||||||||||||
| tags: Optional list of tags (kept for interface compatibility) | ||||||||||||||||
| """ | ||||||||||||||||
|
|
||||||||||||||||
Uh oh!
There was an error while loading. Please reload this page.