diff --git a/tests/unit_tests/worker/test_hpu_model_runner.py b/tests/unit_tests/worker/test_hpu_model_runner.py index e10cff8b5..9c6da2a1d 100644 --- a/tests/unit_tests/worker/test_hpu_model_runner.py +++ b/tests/unit_tests/worker/test_hpu_model_runner.py @@ -390,7 +390,6 @@ def test_reload_weights_before_load_model(model_runner): model_runner.reload_weights() -@pytest.mark.xfail(reason="KV sharing doesn't currently work on HPU") def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order(default_vllm_config: None): torch.set_default_dtype(torch.bfloat16) layer_0 = "model.layers.0.self_attn.attn" @@ -418,7 +417,6 @@ def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order(default_vllm_c assert fwd_context is not None -@pytest.mark.xfail(reason="KV sharing doesn't currently work on HPU") def test_init_kv_cache_with_kv_sharing_target_layer_not_exist(default_vllm_config: None): torch.set_default_dtype(torch.bfloat16) layer_0 = "model.layers.0.self_attn.attn" @@ -448,7 +446,6 @@ def test_init_kv_cache_with_kv_sharing_target_layer_not_exist(default_vllm_confi assert fwd_context is not None -@pytest.mark.xfail(reason="KV sharing doesn't currently work on HPU") def test_init_kv_cache_with_kv_sharing_target_same_as_current(default_vllm_config: None): torch.set_default_dtype(torch.bfloat16) layer_0 = "model.layers.0.self_attn.attn" @@ -544,7 +541,6 @@ def test_init_kv_cache_without_kv_sharing(default_vllm_config: None): assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1 -@pytest.mark.xfail(reason="KV sharing doesn't currently work on HPU") def test_init_kv_cache_with_kv_sharing_valid(default_vllm_config: None): torch.set_default_dtype(torch.bfloat16) layer_0 = "model.layers.0.self_attn.attn" @@ -580,10 +576,11 @@ def test_init_kv_cache_with_kv_sharing_valid(default_vllm_config: None): assert runner.shared_kv_cache_layers[layer_1] == layer_0 available_memory = 20 * GiB_bytes - # page size for layer 0's kv_cache_spec is 32KB + # page size for layer 0's kv_cache_spec is 256KB # with KV sharing, we can allocate (available_mem//page_size//1) blocks # which is twice as many as without KV sharing - num_expected_blocks = 655360 # 20GB / 32KB + page_size = 128 * 8 * 64 * 2 * 2 # 128 for block_size, 2 for K+V, 2 for bfloat16 + num_expected_blocks = available_memory / page_size # 20GB / 256KB kv_cache_config = get_kv_cache_configs(vllm_config, [kv_cache_spec], [available_memory])[0] assert kv_cache_config.num_blocks == num_expected_blocks assert len(kv_cache_config.kv_cache_tensors) == 1 diff --git a/vllm_gaudi/attention/backends/hpu_attn.py b/vllm_gaudi/attention/backends/hpu_attn.py index 5bd6460ed..12923d82e 100644 --- a/vllm_gaudi/attention/backends/hpu_attn.py +++ b/vllm_gaudi/attention/backends/hpu_attn.py @@ -430,8 +430,10 @@ def __init__( use_irope: bool = False, ) -> None: super(AttentionImpl, self).__init__() + self.kv_sharing_target_layer_name = kv_sharing_target_layer_name if kv_sharing_target_layer_name is not None: - raise NotImplementedError("KV sharing is not currently supported on HPU.") + logger.info("[KV sharing] HPUAttentionImpl initialized with kv_sharing_target_layer_name: %s", + self.kv_sharing_target_layer_name) if use_irope: logger.warning_once("Using irope in HPU is not supported yet, it will fall back " "to global attention for long context.") @@ -572,22 +574,22 @@ def forward( if kv_cache is not None and isinstance(kv_cache, tuple): key_cache, value_cache, k_scales, v_scales = \ HPUPagedAttention.split_kv_cache(kv_cache, self.num_kv_heads, self.head_size) - - # Reshape the input keys and values and store them in the cache. - # If kv_cache is not provided, the new key and value tensors are - # not cached. This happens during the initial memory profiling run. - key_cache = self.k_cache(key, - key_cache, - slot_mapping, - scales=k_scales, - block_size=attn_metadata.block_size, - is_prompt=attn_metadata.is_prompt) - value_cache = self.v_cache(value, - value_cache, - slot_mapping, - scales=v_scales, - block_size=attn_metadata.block_size, - is_prompt=attn_metadata.is_prompt) + if self.kv_sharing_target_layer_name is None: + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory profiling run. + key_cache = self.k_cache(key, + key_cache, + slot_mapping, + scales=k_scales, + block_size=attn_metadata.block_size, + is_prompt=attn_metadata.is_prompt) + value_cache = self.v_cache(value, + value_cache, + slot_mapping, + scales=v_scales, + block_size=attn_metadata.block_size, + is_prompt=attn_metadata.is_prompt) if attn_metadata.is_prompt: # Prompt run. diff --git a/vllm_gaudi/extension/ops.py b/vllm_gaudi/extension/ops.py index 96e38f3bc..cd0ef7d1b 100644 --- a/vllm_gaudi/extension/ops.py +++ b/vllm_gaudi/extension/ops.py @@ -366,6 +366,7 @@ def _fsdpa_prompt_attention(query: torch.Tensor, query, key, value, attn_bias, 0.0, is_causal, scale, softmax_mode, recompute_mode, valid_seq_lengths, padding_side ] + args += [window_size] if window_size else [] attn_weights = fsdpa_op(*args) diff --git a/vllm_gaudi/models/interfaces.py b/vllm_gaudi/models/interfaces.py index a4ea2327e..26ead7f83 100644 --- a/vllm_gaudi/models/interfaces.py +++ b/vllm_gaudi/models/interfaces.py @@ -1,7 +1,6 @@ from collections.abc import Callable import torch from torch import Tensor -from vllm.model_executor.models.interfaces import SupportsMultiModal def _embed_text_input_ids( @@ -38,4 +37,4 @@ def _embed_text_input_ids( return embed_input_ids(input_ids) -SupportsMultiModal._embed_text_input_ids = _embed_text_input_ids +#SupportsMultiModal._embed_text_input_ids = _embed_text_input_ids diff --git a/vllm_gaudi/models/qwen3_vl.py b/vllm_gaudi/models/qwen3_vl.py index 9dd551155..82a72d3eb 100644 --- a/vllm_gaudi/models/qwen3_vl.py +++ b/vllm_gaudi/models/qwen3_vl.py @@ -1,6 +1,9 @@ -from torch import nn -from vllm.model_executor.layers.activation import get_act_fn +import torch +from .utils import _merge_multimodal_embeddings from vllm.config import VllmConfig +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.models.interfaces import MultiModalEmbeddings +from vllm.model_executor.models.interfaces import _require_is_multimodal from vllm.model_executor.models.qwen3_vl import ( Qwen3VLForConditionalGeneration, Qwen3_VisionTransformer, @@ -65,9 +68,9 @@ def __init__( ) depth = vision_config.depth - norm_layer = lambda d: nn.LayerNorm(d, eps=norm_eps) + norm_layer = lambda d: torch.nn.LayerNorm(d, eps=norm_eps) - self.blocks = nn.ModuleList([ + self.blocks = torch.nn.ModuleList([ HPUQwen3_VisionBlock( dim=self.hidden_size, num_heads=self.num_heads, @@ -97,3 +100,81 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): multimodal_config=multimodal_config, prefix=maybe_prefix(prefix, "visual"), ) + + def _compute_deepstack_embeds( + self, + inputs_embeds: torch.Tensor, + multimodal_embeddings: MultiModalEmbeddings, + is_multimodal: torch.Tensor, + ) -> tuple[torch.Tensor, MultiModalEmbeddings]: + visual_lens = [len(x) for x in multimodal_embeddings] + multimodal_embeddings_cat = torch.cat(multimodal_embeddings, dim=0) + + ( + multimodal_embeddings_main, + multimodal_embeddings_multiscale, + ) = torch.split( + multimodal_embeddings_cat, + [self.visual_dim, self.multiscale_dim], + dim=-1, + ) + + multimodal_embeddings = torch.split(multimodal_embeddings_main, visual_lens, dim=0) + multimodal_embeddings_multiscale = torch.split(multimodal_embeddings_multiscale, visual_lens, dim=0) + + deepstack_input_embeds = inputs_embeds.new_zeros(inputs_embeds.size(0), + self.deepstack_num_level * inputs_embeds.size(1)) + + deepstack_input_embeds = _merge_multimodal_embeddings( + inputs_embeds=deepstack_input_embeds, + multimodal_embeddings=multimodal_embeddings_multiscale, + is_multimodal=is_multimodal, + ) + deepstack_input_embeds = deepstack_input_embeds.view(inputs_embeds.shape[0], self.deepstack_num_level, + self.visual_dim) + deepstack_input_embeds = deepstack_input_embeds.permute(1, 0, 2) + + return deepstack_input_embeds, multimodal_embeddings + + def embed_input_ids( + self, + input_ids: torch.Tensor, + multimodal_embeddings: MultiModalEmbeddings | None = None, + *, + is_multimodal: torch.Tensor | None = None, + handle_oov_mm_token: bool = False, + ) -> torch.Tensor: + inputs_embeds = self._embed_text_input_ids( + input_ids, + self.language_model.embed_input_ids, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) + + if multimodal_embeddings is None or len(multimodal_embeddings) == 0: + return inputs_embeds + + is_multimodal = _require_is_multimodal(is_multimodal) + + if self.use_deepstack: + ( + deepstack_input_embeds, + multimodal_embeddings, + ) = self._compute_deepstack_embeds( + inputs_embeds=inputs_embeds, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + ) + else: + deepstack_input_embeds = None + + inputs_embeds = _merge_multimodal_embeddings( + inputs_embeds=inputs_embeds, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + ) + + if deepstack_input_embeds is not None: + self._set_deepstack_input_embeds(deepstack_input_embeds) + + return inputs_embeds diff --git a/vllm_gaudi/models/utils.py b/vllm_gaudi/models/utils.py index c1bbfd0e4..e7c090040 100644 --- a/vllm_gaudi/models/utils.py +++ b/vllm_gaudi/models/utils.py @@ -29,18 +29,15 @@ def _merge_multimodal_embeddings( mm_embeds_flat = _flatten_embeddings(multimodal_embeddings) input_dtype = inputs_embeds.dtype + if is_multimodal.dtype == torch.int64: + return inputs_embeds.index_copy_(0, is_multimodal, mm_embeds_flat) try: # For debugging # inputs_embeds[is_multimodal] = mm_embeds_flat.to(dtype=input_dtype) - # htcore.mark_step() + # NOTE: This can avoid D2H sync (#22105), but fails to # raise an error if is_multimodal.sum() < len(mm_embeds_flat) - # inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1), - # mm_embeds_flat.to(dtype=input_dtype)) - - multimodal_positions = torch.where(is_multimodal)[0][:mm_embeds_flat.shape[0]] - inputs_embeds[0, multimodal_positions] = mm_embeds_flat.to(dtype=input_dtype) - + inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1), mm_embeds_flat.to(dtype=input_dtype)) except RuntimeError as e: num_actual_tokens = len(mm_embeds_flat) num_expected_tokens = is_multimodal.sum().item() diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 541334cd9..0e913bc7b 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import collections +import copy import contextlib import functools from functools import partial @@ -42,13 +43,14 @@ from vllm.v1.attention.selector import get_attn_backend from vllm.config import (VllmConfig, update_config) +from vllm.config.multimodal import ImageDummyOptions from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) from vllm.forward_context import set_forward_context from vllm.model_executor.layers.fused_moe.layer import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.vocab_parallel_embedding import (VocabParallelEmbedding) from vllm.model_executor.model_loader import get_model, get_model_loader -from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargsItem) from vllm.multimodal.utils import group_mm_kwargs_by_modality from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding @@ -67,7 +69,7 @@ AsyncModelRunnerOutput, KVConnectorOutput) from vllm.v1.pool.metadata import PoolingMetadata, PoolingStates from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.worker.utils import bind_kv_cache +from vllm.v1.worker.utils import bind_kv_cache, add_kv_sharing_layers_to_kv_cache_groups from vllm.v1.utils import CpuGpuBuffer from vllm_gaudi.v1.worker.hpu_input_batch import InputBatch, CachedRequestState from vllm.distributed.parallel_state import get_pp_group, get_dp_group @@ -75,7 +77,7 @@ from vllm.model_executor.models.interfaces_base import (VllmModelForPooling, is_pooling_model, is_text_generation_model) from vllm.tasks import GenerationTask, PoolingTask, SupportedTask from vllm.transformers_utils.config import is_interleaved -from vllm.v1.worker.utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs, scatter_mm_placeholders) +from vllm.v1.worker.utils import sanity_check_mm_encoder_outputs from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata @@ -691,7 +693,7 @@ def __init__( self.head_size = self.model_config.get_head_size() self.hidden_size = self.model_config.get_hidden_size() self.is_pooling_model = (model_config.runner_type == 'pooling') - logger.debug("model config: ", self.model_config) + logger.debug("model config: %s", self.model_config) self.attn_backend = get_attn_backend( self.head_size, @@ -704,15 +706,18 @@ def __init__( # Mult-modal-related. self.mm_registry = MULTIMODAL_REGISTRY self.uses_mrope = model_config.uses_mrope + self.model_config_copy = None self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(model_config) if self.supports_mm_inputs: self.is_mm_embed = self._make_buffer(self.max_num_tokens, dtype=torch.bool) + self.model_config_copy = copy.deepcopy(self.model_config) self.is_multimodal_raw_input_supported = (model_config.is_multimodal_raw_input_only_model) # Lazy initialization # self.model: nn.Module # set after load_model self.kv_caches: list[torch.Tensor] = [] + self.shared_kv_cache_layers: dict[str, str] = {} self.inc_initialized_successfully = False self._is_inc_finalized = False @@ -950,6 +955,20 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: kv_cache_spec: dict[str, KVCacheSpec] = {} cache_dtype_str = self.vllm_config.cache_config.cache_dtype for layer_name, attn_module in forward_ctx.items(): + kv_sharing_target_layer_name = getattr(attn_module, 'kv_sharing_target_layer_name', None) + if kv_sharing_target_layer_name is not None: + from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target + try: + validate_kv_sharing_target( + layer_name, + kv_sharing_target_layer_name, + forward_ctx, + ) + self.shared_kv_cache_layers[layer_name] = kv_sharing_target_layer_name + except Exception as e: + logger.error("KV sharing validation failed for %s -> %s: %s", layer_name, + kv_sharing_target_layer_name, e) + continue if isinstance(attn_module, FusedMoE): continue @@ -1265,11 +1284,7 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput", req_ids: list if req_id not in self.encoder_cache: self.encoder_cache[req_id] = {} - self.encoder_cache[mm_hash] = scatter_mm_placeholders( - output, - is_embed=pos_info.is_embed.to( - device=output.device) if pos_info.is_embed is not None else pos_info.is_embed, - ) + self.encoder_cache[mm_hash] = output # modified from: vllm/v1/worker/gpu_model_runner.py def _gather_mm_embeddings( @@ -1287,6 +1302,7 @@ def _gather_mm_embeddings( req_start_idx = 0 for req_id in req_ids: + mm_embeds_req: list[torch.Tensor] = [] num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] req_state = self.requests[req_id] num_computed_tokens = \ @@ -1311,6 +1327,11 @@ def _gather_mm_embeddings( start_idx = max(num_computed_tokens - start_pos, 0) end_idx = min(num_computed_tokens - start_pos + num_scheduled_tokens, num_encoder_tokens) assert start_idx < end_idx + curr_embeds_start, curr_embeds_end = (pos_info.get_embeds_indices_in_range(start_idx, end_idx)) + # If there are no embeddings in the current range, we skip + # gathering the embeddings. + if curr_embeds_start == curr_embeds_end: + continue mm_hash = mm_feature.identifier encoder_output = self.encoder_cache.get(mm_hash, None) assert encoder_output is not None,\ @@ -1319,21 +1340,26 @@ def _gather_mm_embeddings( if (is_embed := pos_info.is_embed) is not None: is_embed = is_embed[start_idx:end_idx] + mm_embeds_item = encoder_output[curr_embeds_start:curr_embeds_end] + else: + mm_embeds_item = encoder_output[start_idx:end_idx] - mm_embeds_item = gather_mm_placeholders( - encoder_output[start_idx:end_idx], - is_embed=is_embed, - ) req_start_pos = req_start_idx + start_pos - num_computed_tokens - is_mm_embed[req_start_pos+start_idx:req_start_pos + end_idx] \ - = True - - # Only whole mm items are processed - mm_embeds.append(mm_embeds_item) + is_mm_embed[req_start_pos + start_idx:req_start_pos + + end_idx] = (True if is_embed is None else is_embed) + mm_embeds_req.append(mm_embeds_item) + mm_embeds.extend(mm_embeds_req) req_start_idx += num_scheduled_tokens - is_mm_embed = self.is_mm_embed.copy_to_gpu(total_num_scheduled_tokens) - + # Convert bool tensor to index tensor for merge embedding statically if optimized mm + if self.uses_mrope: + is_mm_embed_index = torch.nonzero(is_mm_embed[:total_num_scheduled_tokens], as_tuple=True)[0] + # Bounds validation on CPU + if len(is_mm_embed_index) > 0 and is_mm_embed_index.max() >= total_num_scheduled_tokens: + raise ValueError(f"Index {is_mm_embed_index.max()} exceeds tensor size {total_num_scheduled_tokens}") + is_mm_embed = is_mm_embed_index.to(self.device) + else: + is_mm_embed = self.is_mm_embed.copy_to_gpu(total_num_scheduled_tokens) return mm_embeds, is_mm_embed def _get_model_mm_inputs( @@ -3752,6 +3778,7 @@ def load_model(self) -> None: self._maybe_compile(self.model) self.model_memory_usage = m.consumed_device_memory logger.info("Compilation took %.4f GB", self.model_memory_usage / float(2**30)) + self.is_mm_optimized = is_mm_optimized(self.model) def _maybe_compile(self, *args, **kwargs): """Entrypoint for a torch.compilation of the model""" @@ -3864,12 +3891,13 @@ def log_warmup(self, phase, i, max_i, first_dim, second_dim, third_dim, causal=F f"free_mem:{free_mem}") tqdm.write(msg) - def log_warmup_multimodal(self, phase, i, max_i, batch_size, seq_len, img_args): + def log_warmup_multimodal(self, phase, i, max_i, batch_size, seq_len, w, h, f): free_mem = format_bytes(HabanaMemoryProfiler.current_free_device_memory()) msg = (f"[Warmup][{phase}][{i+1}/{max_i}] " f"batch_size:{batch_size} " f"seq_len:{seq_len} " - f"img_args:{img_args} " + f"resolution:{w}X{h} " + f"frames:{f} " f"free_mem:{free_mem}") logger.info(msg) @@ -4551,51 +4579,44 @@ def _get_mm_dummy_batch( ) -> BatchedTensorInputs: """Dummy data for profiling and precompiling multimodal models.""" assert self.mm_budget is not None - img_count = 1 - batch = image_args if self.get_model().vision_bucket_manager.is_batch_based else img_count - '''if self.get_model().vision_bucket_manager.is_batch_based: + count = 1 + num_frames = 0 + batch = image_args if self.get_model().vision_bucket_manager.is_batch_based else count + if self.get_model().vision_bucket_manager.is_batch_based: # Create ImageDummyOptions for Gemma3 - #image_options = ImageDummyOptions( - # width=896, # pixels as in gemma3 config - # height=896 # pixels as in gemma3 config - #) + w=896, # pixels as in gemma3 config + h=896 # pixels as in gemma3 config batch = image_args else: - #patch_size = int(self.get_patch_size_from_model()) + patch_size = int(self.get_patch_size_from_model()) # Calculate width and height to maintain aspect ratio and patch count # Total patches = (width/patch_size) * (height/patch_size) # We want: (w/ps) * (h/ps) = num_patch where num_patch is image_args # And: w/h = ratio_w/ratio_h - #grid_w = int(math.sqrt(image_args * ratio_w / ratio_h)) - #grid_h = int(image_args / grid_w) - #w = grid_w * patch_size - #h = grid_h * patch_size - #image_options = ImageDummyOptions( - # width=w, # Custom width in pixels - # height=h # Custom height in pixels - #) - batch = img_count - - processor = self.mm_registry.create_processor(model_config=self.model_config, cache=self.mm_budget.cache) - dummy_data = processor.dummy_inputs.get_decoder_dummy_data(processor, - seq_len=4096, - mm_counts={"image": img_count}, - mm_options={"image": image_options}), - - dummy_mm_data = processor.dummy_inputs.get_dummy_processor_inputs( - seq_len=4096, - mm_counts={"image": img_count}, - ) - ''' - - assert modality == 'image' - # Result in the maximum GPU consumption of the model - dummy_mm_inputs = self.mm_registry.get_dummy_mm_inputs( - self.model_config, - mm_counts={modality: 1}, - cache=self.mm_budget.cache, - ) + grid_w = int(math.sqrt(image_args * ratio_w / ratio_h)) + grid_h = int(image_args / grid_w) + w = grid_w * patch_size + h = grid_h * patch_size + batch = count + self.model_config_copy.max_model_len = 4096 + if modality == 'image': + self.model_config_copy.limit_mm_per_prompt = { + "image": {"count": count, "width": w, "height": h} + } + elif modality == 'video': + video_options = self.model_config_copy.get_multimodal_config().get_dummy_options("video") + num_frames = video_options.num_frames if video_options and hasattr(video_options, 'num_frames') else 100 + w = video_options.width if video_options and hasattr(video_options, 'width') else w + h = video_options.height if video_options and hasattr(video_options, 'height') else h + count = video_options.count if video_options and hasattr(video_options, 'count') else 1 + self.model_config_copy.limit_mm_per_prompt = { + "video": {"count": count, "num_frames": num_frames, "width": w, "height": h} + } + else: + raise NotImplementedError(f"Modality '{modality}' is not supported") + dummy_mm_inputs = MultiModalRegistry().get_dummy_mm_inputs(self.model_config_copy, + mm_counts={modality: count}) dummy_mm_item = dummy_mm_inputs["mm_kwargs"][modality][0] # We use the cache so that the item is saved to the cache, # but not read from the cache @@ -4607,15 +4628,14 @@ def _get_mm_dummy_batch( dummy_mm_items, device=self.device, pin_memory=self.pin_memory, - )) + )), w, h, num_frames def warmup_multimodal_graphs(self, buckets): phase = 'Graph/Multimodal' from vllm.v1.worker.utils import MultiModalBudget self.mm_budget = MultiModalBudget( - self.model_config, - self.scheduler_config, + self.vllm_config, self.mm_registry, ) if self.supports_mm_inputs else None aspect_ratios = [(1, 1)] # 1:1 square @@ -4629,15 +4649,22 @@ def warmup_multimodal_graphs(self, buckets): (9, 16), # 9:16 portrait ] aspect_ratios.extend(aspect_ratio_ext) + is_video_warmup = True if self.model_config.get_multimodal_config() is not None and \ + self.model_config.get_multimodal_config().get_dummy_options("video") is not None \ + and self.mm_budget.mm_limits['video'] != 999 else False + + is_image_warmup = True if self.model_config.get_multimodal_config() is not None and \ + self.model_config.get_multimodal_config().get_dummy_options("image") is not None \ + and self.mm_budget.mm_limits['image'] != 0 else False for modality, max_items in self.mm_budget.mm_limits.items(): - if modality == 'video': - logger.warning_once("Warming up for video is not implemented") + if modality == 'image' and is_image_warmup == False or modality == 'video' \ + and is_video_warmup == False: continue phase = f'Graph/Multimodal({modality})' num_candidates = len(buckets) for idx, img_arg in enumerate(buckets): for (ratio_w, ratio_h) in aspect_ratios: - batched_dummy_mm_inputs = self._get_mm_dummy_batch(modality, img_arg, ratio_w, ratio_h) + batched_dummy_mm_inputs, w, h, f = self._get_mm_dummy_batch(modality, img_arg, ratio_w, ratio_h) dummy_encoder_outputs = \ self.model.embed_multimodal( **batched_dummy_mm_inputs) @@ -4648,7 +4675,7 @@ def warmup_multimodal_graphs(self, buckets): ) self.graphed_buckets.add(img_arg) - self.log_warmup_multimodal(phase, idx, num_candidates, 1, 0, img_arg) + self.log_warmup_multimodal(phase, idx, num_candidates, 1, 0, w, h, f) def _maybe_profile_unified_attn(self): unified_cfg_str = os.environ.get('VLLM_PROFILE_UNIFIED', None) @@ -4872,6 +4899,20 @@ def _dummy_run(self, max_num_batched_tokens: int) -> None: self._prepare_dummy_scenario(prompt_cfg, decode_cfg) return + def maybe_add_kv_sharing_layers_to_kv_cache_groups(self, kv_cache_config: KVCacheConfig) -> None: + """ + Add layers that re-use KV cache to KV cache group of its target layer. + Mapping of KV cache tensors happens in the KV cache initialization. + """ + if not self.shared_kv_cache_layers: + # No cross-layer KV sharing, return + return + + add_kv_sharing_layers_to_kv_cache_groups( + self.shared_kv_cache_layers, + kv_cache_config.kv_cache_groups, + ) + def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize KV cache based on `kv_cache_config`. @@ -4879,6 +4920,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: kv_cache_config: Configuration for the KV cache, including the KV cache size of each layer """ + self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config) if len(kv_cache_config.kv_cache_groups) > 1: raise NotImplementedError("Hybrid models with more than one KV cache type are not " "supported yet.") @@ -4950,6 +4992,13 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: layer_names = set() for group in kv_cache_config.kv_cache_groups: layer_names.update(group.layer_names) + + # Set up cross-layer KV cache sharing + if self.shared_kv_cache_layers: + logger.info("[KV sharing] Setting up tensor sharing for %s layers", len(self.shared_kv_cache_layers)) + for layer_name, target_layer_name in self.shared_kv_cache_layers.items(): + kv_caches[layer_name] = kv_caches[target_layer_name] + assert layer_names == set(kv_caches.keys()), "Some layers are not correctly initialized" bind_kv_cache(kv_caches, self.vllm_config.compilation_config.static_forward_context, self.kv_caches) @@ -5452,7 +5501,7 @@ def __init__( if self.interleaved_sliding_window: self.use_window_sdpa = with_default(get_config().PT_HPU_SDPA_QKV_SLICE_MODE_FWD, False) #os.getenv("PT_HPU_SDPA_QKV_SLICE_MODE_FWD", "false").strip().lower() in ("1", "true") - self.slice_size = with_default(get_config().PT_HPU_SDPA_BC_FACTOR, False) + self.slice_size = int(with_default(get_config().PT_HPU_SDPA_BC_FACTOR, "1024")) # int(os.getenv("PT_HPU_SDPA_BC_FACTOR", "1024")) self.slice_thld = int(os.environ.get('VLLM_FUSEDSDPA_SLIDE_THLD', '8192')) @@ -5654,12 +5703,12 @@ def process_metadata(self, attn_metadata: HPUAttentionMetadataV1, batch_size: in """ if attn_metadata.is_prompt: attn_metadata = self._set_attn_bias(attn_metadata, batch_size, seq_len, device, dtype) - if self.interleaved_sliding_window: + if self.interleaved_sliding_window and self.sliding_window is not None: attn_metadata = self._set_attn_bias_for_sliding_window(attn_metadata, batch_size, seq_len, self.sliding_window, device, dtype) else: attn_metadata = self._set_block_mapping(attn_metadata, batch_size, device, dtype) - if self.interleaved_sliding_window: + if self.interleaved_sliding_window and self.sliding_window is not None: attn_metadata = self._set_block_mapping(attn_metadata, batch_size, device, dtype, True) return attn_metadata