From 689431e825b707c8b2dafabf97ff7442a214b939 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Wed, 7 Jan 2026 12:35:47 +0200 Subject: [PATCH 1/9] Handle spec decode optionals in unified batch Signed-off-by: Konrad Zawora --- vllm_gaudi/extension/unified_batch.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/vllm_gaudi/extension/unified_batch.py b/vllm_gaudi/extension/unified_batch.py index ec9147eee..07ccc3911 100644 --- a/vllm_gaudi/extension/unified_batch.py +++ b/vllm_gaudi/extension/unified_batch.py @@ -535,12 +535,13 @@ def create_unified_batch( num_scheduled_tokens = num_scheduled_tokens.tolist() # NOTE(Chendi): In spec decode case, we will return -1 as dummy draft token # while we need to exclude them when counting num_scheduled_tokens - for idx, req_id in enumerate(req_ids): - spec_tokens = scheduled_spec_decode_tokens.get(req_id, None) - if spec_tokens is None: - continue - num_spec_tokens = len([i for i in spec_tokens if i != -1]) - num_scheduled_tokens[idx] = num_spec_tokens + 1 + if scheduled_spec_decode_tokens is not None: + for idx, req_id in enumerate(req_ids): + spec_tokens = scheduled_spec_decode_tokens.get(req_id, None) + if spec_tokens is None: + continue + num_spec_tokens = len([i for i in spec_tokens if i != -1]) + num_scheduled_tokens[idx] = num_spec_tokens + 1 num_scheduled_tokens = np.asarray(num_scheduled_tokens, dtype=np.int32) # Convert torch dtype to numpy dtype for internal operations @@ -582,11 +583,12 @@ def create_unified_batch( # Used by spec decode draft model num_reqs = len(req_ids) - cu_num_tokens, _ = get_cumsum_and_arange(num_scheduled_tokens) query_start_loc_cpu = torch.zeros((num_reqs + 1, ), dtype=torch.int32) - query_start_loc_np = query_start_loc_cpu.numpy() - query_start_loc_np[0] = 0 - query_start_loc_np[1:num_reqs + 1] = cu_num_tokens + if get_cumsum_and_arange is not None: + cu_num_tokens, _ = get_cumsum_and_arange(num_scheduled_tokens) + query_start_loc_np = query_start_loc_cpu.numpy() + query_start_loc_np[0] = 0 + query_start_loc_np[1:num_reqs + 1] = cu_num_tokens def first_dim(t: Optional[np.ndarray]) -> int: """ Takes first dim size or 0 if tensor is None""" @@ -767,7 +769,7 @@ def first_dim(t: Optional[np.ndarray]) -> int: invalid_req_indices.append(len(req_ids) - 1) # call prepare_spec_decode_inputs to prepare spec decode inputs - if max(num_output_tokens) > 1: + if max(num_output_tokens) > 1 and prepare_spec_decode_inputs_fn is not None: with persistent_ctx.profiler.record_event('internal', 'spec_decode_metadata_prep'): _, spec_decode_metadata = prepare_spec_decode_inputs_fn(all_token_ids.shape[0], scheduled_spec_decode_tokens, From 0068dd0be5d781cc7ff74962a6a5d8758a4162bc Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Wed, 7 Jan 2026 18:11:30 +0200 Subject: [PATCH 2/9] [WIP] Chunked Shared Attention with Dense Biases Signed-off-by: Konrad Zawora --- vllm_gaudi/extension/unified.py | 331 ++++++++++++++++++++--- vllm_gaudi/extension/unified_batch.py | 209 +++++++++++--- vllm_gaudi/v1/worker/hpu_model_runner.py | 5 +- 3 files changed, 453 insertions(+), 92 deletions(-) diff --git a/vllm_gaudi/extension/unified.py b/vllm_gaudi/extension/unified.py index ee3e4a74a..37298e2b2 100644 --- a/vllm_gaudi/extension/unified.py +++ b/vllm_gaudi/extension/unified.py @@ -149,14 +149,18 @@ def create_softmax_fa2_input_tensors( retained_shape[-1] = get_last_dim_size(retained_shape[-1], vec_size, pack_size) t_retained_shape = tuple(retained_shape) + # Convert fmin to scalar once + fmin_val = fmin.item() if isinstance(fmin, torch.Tensor) else fmin + if t_retained_shape not in inputM_hpu_tensors: print("Allocating new input tensors for shape:", t_retained_shape, "for attn shape:", attn.shape) - return torch.full(retained_shape, fmin, dtype=attn.dtype, device='hpu'), torch.zeros(retained_shape, - dtype=attn.dtype, - device="hpu") + inputM_hpu_tensors[t_retained_shape] = torch.full(retained_shape, fmin_val, dtype=attn.dtype, device='hpu') + inputL_hpu_tensors[t_retained_shape] = torch.zeros(retained_shape, dtype=attn.dtype, device="hpu") + return inputM_hpu_tensors[t_retained_shape], inputL_hpu_tensors[t_retained_shape] + torch.hpu.synchronize() inputL_hpu_tensors[t_retained_shape].zero_() - inputM_hpu_tensors[t_retained_shape].fill_(fmin) + inputM_hpu_tensors[t_retained_shape].fill_(fmin_val) return inputM_hpu_tensors[t_retained_shape], inputL_hpu_tensors[t_retained_shape] @@ -261,6 +265,81 @@ def combine(slices): return combine(attn_slices), combine(max_slices), combine(sum_slices) +@dataclass +class SharedBlockChunkedBiasData: + """Data needed to compute shared block bias per-chunk during chunked attention. + + This avoids materializing the full [query_len, num_shared_blocks, block_size] + bias tensor which can be prohibitively large with many shared blocks. + + Contains dense block_usages of shape (num_query_tokens, num_shared_blocks). + During chunked attention, we slice block_usages[:, chunk_start:chunk_end] and + generate bias for each chunk on-the-fly. + """ + block_usages: torch.tensor # Dense: [num_query_tokens, num_shared_blocks] + num_query_tokens: int # Total query length (padded) + num_shared_blocks: int # Total number of shared blocks (padded) + + +def _partial_attn_shared_core(query: torch.tensor, + key: torch.tensor, + value: torch.tensor, + bias: torch.tensor, + fmin: torch.tensor, + inputL_hpu_tensors: Dict[tuple, torch.Tensor], + inputM_hpu_tensors: Dict[tuple, torch.Tensor], + kv_heads: int, + is_mla: bool, + w_uv: Optional[torch.tensor] = None) -> tuple[torch.tensor, torch.tensor, torch.tensor]: + """Core shared attention computation. + + This is the inner loop extracted for reuse between full and chunked paths. + + Args: + query: Query tensor, already transposed [kv_heads, q_heads_per_kv, tokens, head_dim] or similar + key: Key tensor from cache [kv_heads, q_heads_per_kv, kv_len, head_dim] + value: Value tensor from cache + bias: Attention bias [1, kv_len] (already flattened from [num_blocks, block_size]) + fmin: Minimum float for softmax stability + inputL_hpu_tensors: Cache for FA2 tensors + inputM_hpu_tensors: Cache for FA2 tensors + kv_heads: Number of KV heads + is_mla: Whether using MLA attention + w_uv: Optional MLA projection matrix + + Returns: + Tuple of (unnormalized_weighted_V, local_max, local_sum) + """ + num_heads = query.size(0) * query.size(1) if not is_mla else query.size(0) + + attn = torch.matmul(query, key.transpose(-1, -2)) + attn = attn.flatten(0, 1) + attn = attn + bias + + # TODO: remove dtype check once full support is added for fp8 in unified attention + if get_config().unified_attn_softmax_fa2 and attn.dtype == torch.bfloat16: + inputM_hpu, inputL_hpu = create_softmax_fa2_input_tensors(attn, fmin, inputL_hpu_tensors, inputM_hpu_tensors) + attn, local_max, local_sum, _exp_max_fixup_hpu = torch.ops.hpu.softmax_fa2(attn, + inputM=inputM_hpu, + inputL=inputL_hpu) + local_max = convert_cl_aligned_tensor(local_max, list(attn.shape[:-1])) + local_sum = convert_cl_aligned_tensor(local_sum, list(attn.shape[:-1])) + else: + local_max = torch.maximum(attn.amax(-1), fmin) + attn = torch.exp(attn - local_max.unsqueeze(-1)) + local_sum = attn.sum(-1) + + attn = torch.matmul(attn.unflatten(0, (kv_heads if not is_mla else num_heads, -1)), value).flatten(0, 1) + + # MLA: Extract latent part and project to full V + if is_mla and w_uv is not None: + latent_dim = w_uv.size(1) + attn_latent = attn[..., :latent_dim] + attn = torch.bmm(attn_latent, w_uv) + + return attn.transpose(0, 1), local_max.transpose(0, 1), local_sum.transpose(0, 1) + + def partial_attn_shared(query: torch.tensor, blocks: torch.tensor, bias: Optional[torch.tensor], @@ -268,59 +347,204 @@ def partial_attn_shared(query: torch.tensor, inputL_hpu_tensors: Dict[tuple, torch.Tensor], inputM_hpu_tensors: Dict[tuple, torch.Tensor], cache_utils: CacheUtils, - w_uv: Optional[torch.tensor] = None) -> tuple[torch.tensor, torch.tensor, torch.tensor]: - """Partial attention where all shared blocks are compared with whole query + dtype: torch.dtype, + w_uv: Optional[torch.tensor] = None, + chunked_data: Optional[SharedBlockChunkedBiasData] = None, + chunk_size: int = 0) -> tuple[torch.tensor, torch.tensor, torch.tensor]: + """Partial attention where all shared blocks are compared with whole query. + + Supports two modes: + 1. Full bias mode (default): bias tensor is provided, process all blocks at once + 2. Chunked mode: chunk_size > 0, process blocks in chunks + - If bias is provided, slice from it + - If bias is None but chunked_data is provided, generate bias per chunk from dense block_usages Args: - w_uv: Optional MLA projection matrix [num_heads, latent_dim, v_head_dim]. - If provided, assumes MLA mode where query/key/value are in latent space. + query: Query tensor [tokens, num_heads, head_dim] + blocks: Shared block indices [num_shared_blocks] + bias: Pre-computed bias tensor [query_len, num_blocks, block_size]. Can be None for chunked generation. + fmin: Minimum float value for softmax stability + inputL_hpu_tensors: Cache for softmax input tensors + inputM_hpu_tensors: Cache for softmax input tensors + cache_utils: Cache utilities for fetching KV + dtype: Output dtype for bias generation + w_uv: Optional MLA projection matrix [num_heads, latent_dim, v_head_dim] + chunked_data: Metadata for chunked processing (contains dense block_usages) + chunk_size: Number of blocks per chunk (0 = full mode, >0 = chunked mode) + + Returns: + Tuple of (unnormalized_weighted_V, local_max, local_sum) """ - if bias is None: - return (None, None, None) - + # Determine mode + use_chunked = chunk_size > 0 and chunked_data is not None + + if not use_chunked: + # Full bias mode - original implementation + if bias is None: + return (None, None, None) + return _partial_attn_shared_full(query, blocks, bias, fmin, inputL_hpu_tensors, inputM_hpu_tensors, cache_utils, + w_uv) + else: + # Chunked mode - process blocks in chunks + # bias can be None for chunked generation (will generate from chunked_data.block_usages per chunk) + if blocks is None: + return (None, None, None) + return _partial_attn_shared_chunked(query, blocks, bias, chunked_data, chunk_size, fmin, inputL_hpu_tensors, + inputM_hpu_tensors, cache_utils, dtype, w_uv) + + +def _partial_attn_shared_full(query: torch.tensor, + blocks: torch.tensor, + bias: torch.tensor, + fmin: torch.tensor, + inputL_hpu_tensors: Dict[tuple, torch.Tensor], + inputM_hpu_tensors: Dict[tuple, torch.Tensor], + cache_utils: CacheUtils, + w_uv: Optional[torch.tensor] = None) -> tuple[torch.tensor, torch.tensor, torch.tensor]: + """Full bias implementation of partial_attn_shared.""" is_mla = w_uv is not None if is_mla: # MLA: Single latent cache contains both K and V latent_kv = cache_utils.fetch_shared(blocks) num_heads = query.size(1) - query = query.transpose(0, 1).unsqueeze(1) # [num_heads, 1, tokens, latent_dim + rope_dim] + query_t = query.transpose(0, 1).unsqueeze(1) # [num_heads, 1, tokens, latent_dim + rope_dim] key = latent_kv.unsqueeze(0).unsqueeze(0).expand(num_heads, 1, -1, -1) value = latent_kv.unsqueeze(0).unsqueeze(0).expand(num_heads, 1, -1, -1) kv_heads = 1 else: # Standard attention: Separate K and V caches kv_heads = cache_utils.kv_heads - query = query.transpose(0, 1).unflatten(0, (kv_heads, -1)) + query_t = query.transpose(0, 1).unflatten(0, (kv_heads, -1)) key, value = cache_utils.fetch_shared(blocks) - bias = bias.flatten(-2, -1).unsqueeze(0) + bias_flat = bias.flatten(-2, -1).unsqueeze(0) + + return _partial_attn_shared_core(query_t, key, value, bias_flat, fmin, inputL_hpu_tensors, inputM_hpu_tensors, + kv_heads, is_mla, w_uv) + + +def _partial_attn_shared_chunked( + query: torch.tensor, + blocks: torch.tensor, + bias: Optional[torch.tensor], + chunked_data: SharedBlockChunkedBiasData, + chunk_size: int, + fmin: torch.tensor, + inputL_hpu_tensors: Dict[tuple, torch.Tensor], + inputM_hpu_tensors: Dict[tuple, torch.Tensor], + cache_utils: CacheUtils, + dtype: torch.dtype, + w_uv: Optional[torch.tensor] = None) -> tuple[torch.tensor, torch.tensor, torch.tensor]: + """Chunked implementation of partial_attn_shared with per-chunk bias generation. + + Generates bias per chunk from dense block_usages to save memory. + Avoids materializing the full (query_len, num_blocks, block_size) bias tensor. + + Strategy: + 1. Process blocks in chunks of chunk_size + 2. For each chunk, slice block_usages and generate chunk bias on-the-fly + 3. Compute attention for the chunk using _partial_attn_shared_core + 4. Merge chunk results using flash-attention style online softmax + """ + num_blocks = chunked_data.num_shared_blocks + block_size = cache_utils.block_size + num_query_tokens = chunked_data.num_query_tokens - attn = torch.matmul(query, key.transpose(-1, -2)) - attn = attn.flatten(0, 1) - attn = attn + bias - # TODO: remove dtype check once full support is added for fp8 in unified attention - if get_config().unified_attn_softmax_fa2 and attn.dtype == torch.bfloat16: - inputM_hpu, inputL_hpu = create_softmax_fa2_input_tensors(attn, fmin, inputL_hpu_tensors, inputM_hpu_tensors) - attn, local_max, local_sum, _exp_max_fixup_hpu = torch.ops.hpu.softmax_fa2(attn, - inputM=inputM_hpu, - inputL=inputL_hpu) - local_max = convert_cl_aligned_tensor(local_max, list(attn.shape[:-1])) - local_sum = convert_cl_aligned_tensor(local_sum, list(attn.shape[:-1])) - else: - local_max = torch.maximum(attn.amax(-1), fmin) - attn = torch.exp(attn - local_max.unsqueeze(-1)) - local_sum = attn.sum(-1) + is_mla = w_uv is not None + kv_heads = 1 if is_mla else cache_utils.kv_heads + + # Calculate number of chunks + num_chunks = math.ceil(num_blocks / chunk_size) + + # Check if we have pre-computed bias or need to generate per-chunk + generate_bias_per_chunk = (bias is None) + + # Pre-allocate reusable tensors outside the loop (avoid allocations per iteration) + if generate_bias_per_chunk: + block_len_range = torch.arange(1, + block_size + 1, + dtype=chunked_data.block_usages.dtype, + device=chunked_data.block_usages.device) + # Pre-allocate chunk_bias buffer - will be overwritten each iteration + chunk_bias_buffer = torch.empty((num_query_tokens, chunk_size, block_size), + dtype=dtype, + device=chunked_data.block_usages.device) + + # Accumulators for online softmax-style merging + accumulated_attn = None + global_max = None + global_sum = None + + for chunk_idx in range(num_chunks): + chunk_start = chunk_idx * chunk_size + chunk_end = min(chunk_start + chunk_size, num_blocks) + actual_chunk_len = chunk_end - chunk_start + + # Slice blocks for this chunk + chunk_blocks = blocks[chunk_start:chunk_end] + + if generate_bias_per_chunk: + # Generate bias for this chunk from dense block_usages + # chunked_data.block_usages is (num_query_tokens, num_shared_blocks) + # Slice to get (num_query_tokens, actual_chunk_len) for this chunk + chunk_block_usages = chunked_data.block_usages[:, chunk_start:chunk_end] + + # Generate chunk bias using dense broadcast into pre-allocated buffer + # chunk_block_usages.unsqueeze(-1): (num_query_tokens, actual_chunk_len, 1) + # broadcast comparison: (num_query_tokens, actual_chunk_len, block_size) + chunk_mask = block_len_range > chunk_block_usages.unsqueeze(-1) + + # Use view of pre-allocated buffer for actual chunk size + chunk_bias = chunk_bias_buffer[:, :actual_chunk_len, :] + chunk_bias.zero_() + chunk_bias.masked_fill_(chunk_mask, -math.inf) + else: + # Pre-computed: slice from full bias tensor + chunk_bias = bias[:, chunk_start:chunk_end, :] - attn = torch.matmul(attn.unflatten(0, (kv_heads if not is_mla else num_heads, -1)), value).flatten(0, 1) + # Fetch KV for this chunk + if is_mla: + latent_kv = cache_utils.fetch_shared(chunk_blocks) + num_heads = query.size(1) + query_t = query.transpose(0, 1).unsqueeze(1) + key = latent_kv.unsqueeze(0).unsqueeze(0).expand(num_heads, 1, -1, -1) + value = latent_kv.unsqueeze(0).unsqueeze(0).expand(num_heads, 1, -1, -1) + else: + query_t = query.transpose(0, 1).unflatten(0, (kv_heads, -1)) + key, value = cache_utils.fetch_shared(chunk_blocks) + + # Flatten bias for attention: [1, query_len, chunk_len * block_size] + chunk_bias_flat = chunk_bias.flatten(-2, -1).unsqueeze(0) + + # Compute attention for this chunk + chunk_attn, chunk_max, chunk_sum = _partial_attn_shared_core(query_t, key, value, chunk_bias_flat, fmin, + inputL_hpu_tensors, inputM_hpu_tensors, kv_heads, + is_mla, w_uv) + + # Online merge: combine this chunk with accumulated results + if accumulated_attn is None: + # First chunk - just store + accumulated_attn = chunk_attn + global_max = chunk_max + global_sum = chunk_sum + else: + # Merge with existing - use flash-attention style rescaling + new_max = torch.maximum(global_max, chunk_max) - # MLA: Extract latent part and project to full V - if is_mla: - latent_dim = w_uv.size(1) - attn_latent = attn[..., :latent_dim] # Extract only latent dimension (exclude rope_dim) - attn = torch.bmm(attn_latent, w_uv) # [num_heads, tokens, v_head_dim] + # Rescale factors + old_scale = torch.exp(global_max - new_max) + new_scale = torch.exp(chunk_max - new_max) - return attn.transpose(0, 1), local_max.transpose(0, 1), local_sum.transpose(0, 1) + # Rescale accumulated values and sums + accumulated_attn = accumulated_attn * old_scale.unsqueeze(-1) + chunk_attn * new_scale.unsqueeze(-1) + global_sum = global_sum * old_scale + chunk_sum * new_scale + global_max = new_max + + if accumulated_attn is None: + return (None, None, None) + + return accumulated_attn, global_max, global_sum def partial_attn_unique(query: torch.tensor, @@ -391,6 +615,9 @@ class HPUUnifiedAttentionMetadata: causal_width: int shared_blocks: Optional[torch.tensor] shared_bias: Optional[torch.tensor] + # Chunked bias data for chunk-wise computation (used when shared_bias is None but shared_blocks exists) + shared_bias_chunked: Optional[SharedBlockChunkedBiasData] + shared_chunk_size: int # Number of blocks to process per chunk (0 = use full bias) unique_blocks: Optional[torch.tensor] | Optional[int] unique_block_mapping: Optional[torch.tensor] unique_bias: Optional[torch.tensor] @@ -435,6 +662,8 @@ def unified_attn(query: torch.tensor, key: torch.tensor, value: torch.tensor, ke inputL_hpu_tensors=metadata.inputL_hpu_tensors, inputM_hpu_tensors=metadata.inputM_hpu_tensors, w_uv=None) + + # Single call handles both full and chunked modes shared = partial_attn_shared(query=scaled_query, blocks=metadata.shared_blocks, bias=metadata.shared_bias, @@ -442,7 +671,11 @@ def unified_attn(query: torch.tensor, key: torch.tensor, value: torch.tensor, ke inputL_hpu_tensors=metadata.inputL_hpu_tensors, inputM_hpu_tensors=metadata.inputM_hpu_tensors, cache_utils=cache_utils, - w_uv=None) + dtype=query.dtype, + w_uv=None, + chunked_data=metadata.shared_bias_chunked, + chunk_size=metadata.shared_chunk_size) + unique = partial_attn_unique(query=scaled_query, blocks=metadata.unique_blocks, block_mapping=metadata.unique_block_mapping, @@ -512,14 +745,22 @@ def unified_mla(query: Optional[torch.tensor], # Shared/Unique: memory-friendly path (Q in latent space, fetch cached latent KV) # query_latent is already transformed to latent space by caller # For these paths, we need to expand K/V from cached latent vectors - shared = partial_attn_shared(query=scaled_query_latent, - blocks=metadata.shared_blocks, - bias=metadata.shared_bias, - fmin=metadata.fmin, - inputL_hpu_tensors=metadata.inputL_hpu_tensors, - inputM_hpu_tensors=metadata.inputM_hpu_tensors, - cache_utils=cache_utils, - w_uv=w_uv) if scaled_query_latent is not None else (None, None, None) + + # Single call handles both full and chunked modes + if scaled_query_latent is not None: + shared = partial_attn_shared(query=scaled_query_latent, + blocks=metadata.shared_blocks, + bias=metadata.shared_bias, + fmin=metadata.fmin, + inputL_hpu_tensors=metadata.inputL_hpu_tensors, + inputM_hpu_tensors=metadata.inputM_hpu_tensors, + cache_utils=cache_utils, + dtype=scaled_query_latent.dtype, + w_uv=w_uv, + chunked_data=metadata.shared_bias_chunked, + chunk_size=metadata.shared_chunk_size) + else: + shared = (None, None, None) unique = partial_attn_unique(query=scaled_query_latent, blocks=metadata.unique_blocks, diff --git a/vllm_gaudi/extension/unified_batch.py b/vllm_gaudi/extension/unified_batch.py index 07ccc3911..d10aff685 100644 --- a/vllm_gaudi/extension/unified_batch.py +++ b/vllm_gaudi/extension/unified_batch.py @@ -2,7 +2,7 @@ import numpy as np import habana_frameworks.torch as htorch from dataclasses import dataclass -from vllm_gaudi.extension.unified import HPUUnifiedAttentionMetadata, get_vecsize_packsize, get_last_dim_size +from vllm_gaudi.extension.unified import HPUUnifiedAttentionMetadata, SharedBlockChunkedBiasData, get_vecsize_packsize, get_last_dim_size from vllm.v1.spec_decode.metadata import SpecDecodeMetadata import math from typing import Optional, Callable, Union @@ -334,6 +334,7 @@ def __init__(self, max_num_batched_tokens, max_shared_blocks, max_unique_blocks, self.causal_bias_generator = HPUCausalBiasGenerator() self.shared_bias_generator = HPUSharedBiasGenerator() + self.shared_bias_generator_dense = HPUSharedBiasGeneratorDense() self.graphed = True if self.graphed: config = get_config() @@ -342,6 +343,8 @@ def __init__(self, max_num_batched_tokens, max_shared_blocks, max_unique_blocks, disable_tensor_cache=True) self.shared_bias_generator = htorch.hpu.wrap_in_hpu_graph(self.shared_bias_generator, disable_tensor_cache=True) + self.shared_bias_generator_dense = htorch.hpu.wrap_in_hpu_graph(self.shared_bias_generator_dense, + disable_tensor_cache=True) elif config.bridge_mode == 'eager': self.causal_bias_generator = torch.compile(self.causal_bias_generator, backend='hpu_backend', @@ -351,6 +354,10 @@ def __init__(self, max_num_batched_tokens, max_shared_blocks, max_unique_blocks, backend='hpu_backend', fullgraph=True, dynamic=False) + self.shared_bias_generator_dense = torch.compile(self.shared_bias_generator_dense, + backend='hpu_backend', + fullgraph=True, + dynamic=False) self.hpu_tensor_online_padding = False if not self.hpu_tensor_online_padding: # NOTE(kzawora): Dynamic mempool caches - store largest placeholders needed for each (pad_value, dtype) @@ -486,7 +493,7 @@ class HPUSharedBiasGenerator(HPUBiasGenerator): def forward(self, block_usages: torch.tensor, hpu_shared_token_idx: torch.tensor, hpu_shared_block_idx: torch.tensor, block_size: torch.tensor, dtype: torch.dtype, target_qlen, target_shared_blocks) -> torch.tensor: - """ Generate block bias based on block_usage """ + """ Generate block bias based on block_usage (sparse scatter version) """ block_len_range = torch.arange(1, block_size + 1, dtype=block_usages.dtype, device=block_usages.device) block_mask = block_len_range.unsqueeze(0) > block_usages.unsqueeze(-1) hpu_shared_block_bias = self.mask_to_bias_torch(block_mask, dtype=dtype) @@ -498,6 +505,40 @@ def forward(self, block_usages: torch.tensor, hpu_shared_token_idx: torch.tensor return hpu_shared_bias +class HPUSharedBiasGeneratorDense(HPUBiasGenerator): + """ + Dense version of shared bias generator - takes pre-scattered block_usages + of shape (target_qlen, target_shared_blocks) instead of sparse coordinates. + + This avoids dynamic-length coordinate arrays on HPU by doing the scatter on CPU. + """ + + def forward(self, block_usages_dense: torch.tensor, block_size: int, dtype: torch.dtype) -> torch.tensor: + """ + Generate block bias from dense block_usages. + + Args: + block_usages_dense: Shape (target_qlen, target_shared_blocks), values are block usage counts (0 = masked out) + block_size: Size of each block + dtype: Output dtype + + Returns: + Shape (target_qlen, target_shared_blocks, block_size) bias tensor + """ + # block_usages_dense: (target_qlen, target_shared_blocks) + # We want: block_mask[q, b, k] = True if k >= block_usages_dense[q, b] + # Which means: mask out positions k where k+1 > block_usages_dense[q, b] + block_len_range = torch.arange(1, + block_size + 1, + dtype=block_usages_dense.dtype, + device=block_usages_dense.device) + # block_len_range: (block_size,) + # block_usages_dense.unsqueeze(-1): (target_qlen, target_shared_blocks, 1) + # broadcast comparison: (target_qlen, target_shared_blocks, block_size) + block_mask = block_len_range > block_usages_dense.unsqueeze(-1) + return self.mask_to_bias_torch(block_mask, dtype=dtype) + + def create_unified_batch( req_ids: list[str], all_token_ids: torch.Tensor, @@ -665,6 +706,19 @@ def first_dim(t: Optional[np.ndarray]) -> int: fmin = torch.finfo(dtype).min feps = torch.finfo(dtype).tiny + # Determine if we should use chunked computation for shared blocks + # NOTE(kzawora): Chunked processing computes attention in chunks to save memory. + # With chunked dense generation, we only allocate (target_qlen, target_shared_blocks) for block_usages + # instead of the full (target_qlen, target_shared_blocks, block_size) bias tensor. + # Bias is generated per chunk: (target_qlen, chunk_size, block_size) + default_chunk_size = 32 # Process up to 64 blocks at a time for shared attention + use_chunked_processing = bool(target_shared_blocks + > default_chunk_size) # Chunked dense processing - generates bias per chunk + + # Dense bias generation: scatter on CPU (any shape), then broadcast on HPU (static shape) + # This avoids dynamic-length coordinate arrays on HPU entirely + use_dense_bias_generation = True # Use dense scatter approach for shared bias (non-chunked fallback) + with persistent_ctx.profiler.record_event('internal', 'attn_metadata_prep'): attn_metadata = HPUUnifiedAttentionMetadata( block_size=block_size, @@ -674,9 +728,11 @@ def first_dim(t: Optional[np.ndarray]) -> int: target_qlen), -math.inf, dtype) if causal_bias is not None else None, causal_width=default_causal_width, shared_blocks=persistent_ctx.hpu_tensor(shared_blocks, (target_shared_blocks, ), -1, slot_mapping_dtype), - shared_bias=persistent_ctx.hpu_tensor(shared_bias, - (target_qlen, target_shared_blocks, - block_size), -math.inf, dtype) if shared_bias is not None else None, + # For chunked processing: still allocate full bias for now (stepping stone to verify correctness) + # shared_bias will be set below after HPU acceleration + shared_bias=None, # Will be set below + shared_bias_chunked=None, # Will be set below if chunked processing is enabled + shared_chunk_size=default_chunk_size if use_chunked_processing else 0, unique_blocks=target_unique_blocks, unique_block_mapping=persistent_ctx.hpu_tensor(unique_block_mapping, (target_unique_blocks, ), -1, slot_mapping_dtype), @@ -700,47 +756,110 @@ def first_dim(t: Optional[np.ndarray]) -> int: attn_metadata.causal_bias = persistent_ctx.causal_bias_generator(hpu_token_groups, hpu_token_positions, hpu_padding_mask, dtype) if do_shared: - # NOTE(kzawora): this is kinda janky, but for a good reason - the number of shared tokens can vary significantly, - # and it impacts whether it's even worth running on HPU. - # On HPU, we need to avoid dynamic shapes == we need to pad number of shared tokens. - # It's currently padded to target_shared_blocks * block_size - # We set some simple heuristics to decide whether to run on HPU or fallback to CPU: - # 1. Check the number of shared tokens. If it's greater than the padded number of shared tokens, we fallback to CPU. - # 2. Check the padding ratio - if the padding exceeds 50% of actual size, we fallback to CPU. - - actual_num_shared_tokens = shared_block_usage.shape[0] - padded_num_shared_tokens = target_shared_blocks * block_size - # NOTE(kzawora): Initially we checked for padding ratio as well, but ultimately I've found no cases in - # which generating mask on padded_num_shared_tokens was slower than CPU + copying to HPU - # in case we have too many or too little shared tokens, we fall back to cpu generation - if padded_num_shared_tokens < actual_num_shared_tokens: - with persistent_ctx.profiler.record_event('internal', 'shared_bias_cpu_fallback'): - shared_block_bias = generate_bias(shared_block_usage, block_size, np_dtype, - persistent_ctx.block_len_range, persistent_ctx.shared_block_bias) - if persistent_ctx.use_persistent_shared_biases: - shared_bias = persistent_ctx.shared_bias[:query_len, :shared_blocks.shape[0], :block_size] + # Check if we should use chunked processing to save memory + if use_chunked_processing: + with persistent_ctx.profiler.record_event('internal', 'shared_bias_chunked_prep'): + # CHUNKED DENSE APPROACH: + # - Scatter block_usages into dense (target_qlen, target_shared_blocks) on CPU + # - Don't generate full bias - just pass the dense block_usages + # - Attention code will generate bias per chunk by slicing block_usages + + # Create dense block_usages on CPU with shape (target_qlen, target_shared_blocks) + # Value 0 means "masked out entirely" (will produce all -inf bias) + block_usages_dense = np.zeros((target_qlen, target_shared_blocks), dtype=np.int32) + + # Scatter: block_usages_dense[token_idx, block_idx] = block_usage value + block_usages_dense[shared_token_idx, shared_block_idx] = shared_block_usage + + # Transfer dense tensor to HPU - shape is fully static (target_qlen, target_shared_blocks) + hpu_block_usages_dense = persistent_ctx.hpu_tensor( + block_usages_dense, + (target_qlen, target_shared_blocks), + 0, # pad with 0 (fully masked) + torch.int32) + + # DON'T generate full bias - set shared_bias to None + # Attention code will generate bias per chunk using shared_bias_chunked + attn_metadata.shared_bias = None + + # Set up chunked data for chunk-wise bias generation + attn_metadata.shared_bias_chunked = SharedBlockChunkedBiasData( + block_usages=hpu_block_usages_dense, # Dense (target_qlen, target_shared_blocks) + num_query_tokens=target_qlen, + num_shared_blocks=target_shared_blocks, + ) + + # Handle non-chunked case + if not use_chunked_processing: + if use_dense_bias_generation: + # DENSE APPROACH: Scatter on CPU (any shape), broadcast on HPU (static shape) + # This avoids dynamic-length coordinate arrays on HPU entirely + with persistent_ctx.profiler.record_event('internal', 'shared_bias_dense_prep'): + # Create dense block_usages on CPU with shape (target_qlen, target_shared_blocks) + # Value 0 means "masked out entirely" (will produce all -inf bias) + block_usages_dense = np.zeros((target_qlen, target_shared_blocks), dtype=np.int32) + + # Scatter: block_usages_dense[token_idx, block_idx] = block_usage value + block_usages_dense[shared_token_idx, shared_block_idx] = shared_block_usage + + # Transfer dense tensor to HPU - shape is fully static (target_qlen, target_shared_blocks) + hpu_block_usages_dense = persistent_ctx.hpu_tensor( + block_usages_dense, + (target_qlen, target_shared_blocks), + 0, # pad with 0 (fully masked) + torch.int32) + + # Generate bias using dense generator - no dynamic scatter on HPU + hpu_shared_bias = persistent_ctx.shared_bias_generator_dense( + hpu_block_usages_dense, block_size, dtype) + attn_metadata.shared_bias = hpu_shared_bias + else: + # SPARSE APPROACH: Original implementation with dynamic-length coordinate arrays + # NOTE(kzawora): this is kinda janky, but for a good reason - the number of shared tokens can vary significantly, + # and it impacts whether it's even worth running on HPU. + # On HPU, we need to avoid dynamic shapes == we need to pad number of shared tokens. + # It's currently padded to target_shared_blocks * block_size + # We set some simple heuristics to decide whether to run on HPU or fallback to CPU: + # 1. Check the number of shared tokens. If it's greater than the padded number of shared tokens, we fallback to CPU. + # 2. Check the padding ratio - if the padding exceeds 50% of actual size, we fallback to CPU. + + actual_num_shared_tokens = shared_block_usage.shape[0] + padded_num_shared_tokens = target_shared_blocks * block_size + # NOTE(kzawora): Initially we checked for padding ratio as well, but ultimately I've found no cases in + # which generating mask on padded_num_shared_tokens was slower than CPU + copying to HPU + # in case we have too many or too little shared tokens, we fall back to cpu generation + if padded_num_shared_tokens < actual_num_shared_tokens: + with persistent_ctx.profiler.record_event('internal', 'shared_bias_cpu_fallback'): + shared_block_bias = generate_bias(shared_block_usage, block_size, np_dtype, + persistent_ctx.block_len_range, + persistent_ctx.shared_block_bias) + if persistent_ctx.use_persistent_shared_biases: + shared_bias = persistent_ctx.shared_bias[:query_len, :shared_blocks. + shape[0], :block_size] + else: + shared_bias = np.full((query_len, shared_blocks.shape[0], block_size), + -math.inf, + dtype=np_dtype) + shared_bias.fill(-math.inf) + shared_bias[shared_token_idx, shared_block_idx] = shared_block_bias + attn_metadata.shared_bias = persistent_ctx.hpu_tensor( + shared_bias, (target_qlen, target_shared_blocks, block_size), -math.inf, dtype) else: - shared_bias = np.full((query_len, shared_blocks.shape[0], block_size), - -math.inf, - dtype=np_dtype) - shared_bias.fill(-math.inf) - shared_bias[shared_token_idx, shared_block_idx] = shared_block_bias - attn_metadata.shared_bias = persistent_ctx.hpu_tensor( - shared_bias, (target_qlen, target_shared_blocks, block_size), -math.inf, dtype) - else: - with persistent_ctx.profiler.record_event('internal', 'shared_bias_hpu_prep'): - # do HPU-accelerated shared mask generation - shared_tokens_shape = (padded_num_shared_tokens, ) - hpu_shared_block_usage = persistent_ctx.hpu_tensor(shared_block_usage, shared_tokens_shape, -1, - slot_mapping_dtype) - hpu_shared_token_idx = persistent_ctx.hpu_tensor(shared_token_idx, shared_tokens_shape, -1, - slot_mapping_dtype) - hpu_shared_block_idx = persistent_ctx.hpu_tensor(shared_block_idx, shared_tokens_shape, -1, - slot_mapping_dtype) - hpu_shared_bias = persistent_ctx.shared_bias_generator(hpu_shared_block_usage, hpu_shared_token_idx, - hpu_shared_block_idx, block_size, dtype, - target_qlen, target_shared_blocks) - attn_metadata.shared_bias = hpu_shared_bias + with persistent_ctx.profiler.record_event('internal', 'shared_bias_hpu_prep'): + # do HPU-accelerated shared mask generation + shared_tokens_shape = (padded_num_shared_tokens, ) + hpu_shared_block_usage = persistent_ctx.hpu_tensor(shared_block_usage, shared_tokens_shape, + -1, slot_mapping_dtype) + hpu_shared_token_idx = persistent_ctx.hpu_tensor(shared_token_idx, shared_tokens_shape, -1, + slot_mapping_dtype) + hpu_shared_block_idx = persistent_ctx.hpu_tensor(shared_block_idx, shared_tokens_shape, -1, + slot_mapping_dtype) + hpu_shared_bias = persistent_ctx.shared_bias_generator(hpu_shared_block_usage, + hpu_shared_token_idx, + hpu_shared_block_idx, block_size, + dtype, target_qlen, + target_shared_blocks) + attn_metadata.shared_bias = hpu_shared_bias token_ids_device = persistent_ctx.hpu_tensor(token_ids, (target_qlen, ), -1, token_ids_dtype) logits_indices_device = persistent_ctx.hpu_tensor(logits_indices, (target_logits, ), -1, logits_indices_dtype) diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 7d3d392d3..b105525d2 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -2532,7 +2532,9 @@ def _check_config(self, batch_size, seq_len, num_blocks, attn_metadata, warmup_m def _get_unified_config(self, attn_metadata, logits_indices): has_causal = 'c' if attn_metadata.causal_bias is not None else '-' - has_shared = 's' if attn_metadata.shared_bias is not None else '-' + has_shared_bias = attn_metadata.shared_bias is not None + has_chunked_bias = attn_metadata.shared_bias_chunked is not None + has_shared = 's' if has_shared_bias or has_chunked_bias else '-' has_unique = 'u' if attn_metadata.unique_bias is not None else '-' phase = has_causal + has_shared + has_unique qlen = attn_metadata.slot_mapping.size(0) @@ -4365,7 +4367,6 @@ def _prepare_dummy_unified_scenario(self, unified_cfg): for request_blocks in split_shared_blocks_ids: self._add_dummy_unified_request(requests, False, False, request_blocks, num_computed_tokens, 1, scheduled_tokens) - self._execute_dummy_scenario(requests, scheduled_tokens) def _prepare_dummy_scenario(self, prompt_cfg, decode_cfg): From b4751ceb7b170a5a50f6256c93bfbf6f82521192 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Wed, 7 Jan 2026 19:00:23 +0200 Subject: [PATCH 3/9] cleanup Signed-off-by: Konrad Zawora --- vllm_gaudi/extension/features.py | 1 + vllm_gaudi/extension/unified_batch.py | 264 ++++++++++++++------------ 2 files changed, 147 insertions(+), 118 deletions(-) diff --git a/vllm_gaudi/extension/features.py b/vllm_gaudi/extension/features.py index 719e9abd0..380586bd0 100644 --- a/vllm_gaudi/extension/features.py +++ b/vllm_gaudi/extension/features.py @@ -90,6 +90,7 @@ def get_features(): Value('dynamic_shapes_compilation', True, env_var='VLLM_T_COMPILE_DYNAMIC_SHAPES', env_var_type=boolean), Value('fullgraph_compilation', False, env_var='VLLM_T_COMPILE_FULLGRAPH', env_var_type=boolean), Value('unified_attn', False), + Value('unified_attn_dense_shared_bias', True), Value('unified_attn_softmax_fa2', All(VersionRange(">=1.24.0.279"), Enabled('unified_attn'), Kernel(softmax_fa2), Hardware('gaudi3'))), Value('scale_adjustment', True, env_var='VLLM_SCALE_ADJUSTMENT', env_var_type=boolean), diff --git a/vllm_gaudi/extension/unified_batch.py b/vllm_gaudi/extension/unified_batch.py index d10aff685..d024b6fb9 100644 --- a/vllm_gaudi/extension/unified_batch.py +++ b/vllm_gaudi/extension/unified_batch.py @@ -312,19 +312,28 @@ def __init__(self, max_num_batched_tokens, max_shared_blocks, max_unique_blocks, estimated_shared_bias_mem = (max_num_batched_tokens * max_shared_blocks * block_size * np.dtype(np_dtype).itemsize) + (max_shared_blocks * block_size * block_size * np.dtype(np_dtype).itemsize) - # NOTE(kzawora): 64GiB is an arbitrary threshold to avoid OOMs when allocating large shared bias buffers - shared_bias_mem_threshold = 64 * 2**30 - self.use_persistent_shared_biases = estimated_shared_bias_mem <= shared_bias_mem_threshold - if self.use_persistent_shared_biases: - self.shared_bias = np.full((max_num_batched_tokens, max_shared_blocks, block_size), - -math.inf, - dtype=np_dtype) - # NOTE(kzawora): shared block bias is a weird entity - it maps block usage to each individual token in the context - - # so the upper bound should be max_shared_blocks*block_size (max_num_shared_tokens) by block_size - self.shared_block_bias = np.full((max_shared_blocks * block_size, block_size), -math.inf, dtype=np_dtype) + + self.use_dense_shared_bias = get_config().unified_attn_dense_shared_bias + if self.use_dense_shared_bias: + # Dense block_usages for chunked shared attention - shape (max_qlen, max_shared_blocks) + # Value 0 means "masked out entirely" (will produce all -inf bias) + self.block_usages_dense = np.zeros((max_num_batched_tokens, max_shared_blocks), dtype=np.int32) else: - self.shared_bias = None - self.shared_block_bias = None + # NOTE(kzawora): 64GiB is an arbitrary threshold to avoid OOMs when allocating large shared bias buffers + shared_bias_mem_threshold = 64 * 2**30 + self.use_persistent_shared_biases = estimated_shared_bias_mem <= shared_bias_mem_threshold + if self.use_persistent_shared_biases: + self.shared_bias = np.full((max_num_batched_tokens, max_shared_blocks, block_size), + -math.inf, + dtype=np_dtype) + # NOTE(kzawora): shared block bias is a weird entity - it maps block usage to each individual token in the context - + # so the upper bound should be max_shared_blocks*block_size (max_num_shared_tokens) by block_size + self.shared_block_bias = np.full((max_shared_blocks * block_size, block_size), + -math.inf, + dtype=np_dtype) + else: + self.shared_bias = None + self.shared_block_bias = None self.unique_bias = np.full((max_unique_blocks, block_size), -math.inf, dtype=np_dtype) self.unique_block_bias = np.full((max_unique_blocks, block_size), -math.inf, dtype=np_dtype) @@ -539,6 +548,112 @@ def forward(self, block_usages_dense: torch.tensor, block_size: int, dtype: torc return self.mask_to_bias_torch(block_mask, dtype=dtype) +def _prepare_shared_bias_hpu( + persistent_ctx: UnifiedBatchPersistentContext, + attn_metadata: 'HPUUnifiedAttentionMetadata', + shared_token_idx: np.ndarray, + shared_block_idx: np.ndarray, + shared_block_usage: np.ndarray, + shared_blocks: np.ndarray, + target_qlen: int, + target_shared_blocks: int, + query_len: int, + block_size: int, + dtype: torch.dtype, + np_dtype: np.dtype, + slot_mapping_dtype: torch.dtype, + use_chunked_processing: bool, + use_dense_bias_generation: bool, +) -> None: + """ + Prepare shared bias tensors on HPU. + + This function handles three approaches for shared bias generation: + 1. Chunked dense: For large shared blocks, generate bias per-chunk during attention + 2. Non-chunked dense: Scatter on CPU, broadcast on HPU (static shapes) + 3. Sparse (legacy): Dynamic scatter on HPU with fallback to CPU in case of too many shared tokens. + + Modifies attn_metadata.shared_bias and attn_metadata.shared_bias_chunked in place. + """ + if use_chunked_processing: + with persistent_ctx.profiler.record_event('internal', 'shared_bias_chunked_prep'): + # CHUNKED DENSE APPROACH: + # - Scatter block_usages into dense (target_qlen, target_shared_blocks) on CPU + # - Don't generate full bias - just pass the dense block_usages + # - Attention code will generate bias per chunk by slicing block_usages + + # Use persistent buffer - get view of required size and zero it + block_usages_dense = persistent_ctx.block_usages_dense[:target_qlen, :target_shared_blocks] + block_usages_dense.fill(0) # Reset to 0 (fully masked) + + # Scatter: block_usages_dense[token_idx, block_idx] = block_usage value + block_usages_dense[shared_token_idx, shared_block_idx] = shared_block_usage + + # Transfer dense tensor to HPU - shape is fully static (target_qlen, target_shared_blocks) + hpu_block_usages_dense = persistent_ctx.hpu_tensor(block_usages_dense, (target_qlen, target_shared_blocks), + 0, torch.int32) + + # DON'T generate full bias - attention code will generate per chunk + attn_metadata.shared_bias = None + attn_metadata.shared_bias_chunked = SharedBlockChunkedBiasData( + block_usages=hpu_block_usages_dense, + num_query_tokens=target_qlen, + num_shared_blocks=target_shared_blocks, + ) + return + + # Non-chunked paths + if use_dense_bias_generation: + with persistent_ctx.profiler.record_event('internal', 'shared_bias_dense_prep'): + # DENSE APPROACH: Scatter on CPU (any shape), broadcast on HPU (static shape) + block_usages_dense = persistent_ctx.block_usages_dense[:target_qlen, :target_shared_blocks] + block_usages_dense.fill(0) + block_usages_dense[shared_token_idx, shared_block_idx] = shared_block_usage + + hpu_block_usages_dense = persistent_ctx.hpu_tensor(block_usages_dense, (target_qlen, target_shared_blocks), + 0, torch.int32) + + attn_metadata.shared_bias = persistent_ctx.shared_bias_generator_dense(hpu_block_usages_dense, block_size, + dtype) + return + + # SPARSE APPROACH (legacy): Dynamic scatter on HPU with CPU fallback + actual_num_shared_tokens = shared_block_usage.shape[0] + padded_num_shared_tokens = target_shared_blocks * block_size + + if padded_num_shared_tokens < actual_num_shared_tokens: + # Too many shared tokens - fall back to CPU generation + with persistent_ctx.profiler.record_event('internal', 'shared_bias_cpu_fallback'): + shared_block_bias = generate_bias(shared_block_usage, block_size, np_dtype, persistent_ctx.block_len_range, + persistent_ctx.shared_block_bias) + + if persistent_ctx.use_persistent_shared_biases: + shared_bias = persistent_ctx.shared_bias[:query_len, :shared_blocks.shape[0], :block_size] + else: + shared_bias = np.full((query_len, shared_blocks.shape[0], block_size), -math.inf, dtype=np_dtype) + + shared_bias.fill(-math.inf) + shared_bias[shared_token_idx, shared_block_idx] = shared_block_bias + attn_metadata.shared_bias = persistent_ctx.hpu_tensor(shared_bias, + (target_qlen, target_shared_blocks, block_size), + -math.inf, dtype) + else: + # HPU-accelerated sparse generation + with persistent_ctx.profiler.record_event('internal', 'shared_bias_hpu_prep'): + shared_tokens_shape = (padded_num_shared_tokens, ) + hpu_shared_block_usage = persistent_ctx.hpu_tensor(shared_block_usage, shared_tokens_shape, -1, + slot_mapping_dtype) + hpu_shared_token_idx = persistent_ctx.hpu_tensor(shared_token_idx, shared_tokens_shape, -1, + slot_mapping_dtype) + hpu_shared_block_idx = persistent_ctx.hpu_tensor(shared_block_idx, shared_tokens_shape, -1, + slot_mapping_dtype) + + attn_metadata.shared_bias = persistent_ctx.shared_bias_generator(hpu_shared_block_usage, + hpu_shared_token_idx, hpu_shared_block_idx, + block_size, dtype, target_qlen, + target_shared_blocks) + + def create_unified_batch( req_ids: list[str], all_token_ids: torch.Tensor, @@ -711,13 +826,13 @@ def first_dim(t: Optional[np.ndarray]) -> int: # With chunked dense generation, we only allocate (target_qlen, target_shared_blocks) for block_usages # instead of the full (target_qlen, target_shared_blocks, block_size) bias tensor. # Bias is generated per chunk: (target_qlen, chunk_size, block_size) - default_chunk_size = 32 # Process up to 64 blocks at a time for shared attention + default_chunk_size = 64 # Process up to 64 blocks at a time for shared attention use_chunked_processing = bool(target_shared_blocks > default_chunk_size) # Chunked dense processing - generates bias per chunk # Dense bias generation: scatter on CPU (any shape), then broadcast on HPU (static shape) # This avoids dynamic-length coordinate arrays on HPU entirely - use_dense_bias_generation = True # Use dense scatter approach for shared bias (non-chunked fallback) + use_dense_bias_generation = persistent_ctx.use_dense_shared_bias with persistent_ctx.profiler.record_event('internal', 'attn_metadata_prep'): attn_metadata = HPUUnifiedAttentionMetadata( @@ -756,110 +871,23 @@ def first_dim(t: Optional[np.ndarray]) -> int: attn_metadata.causal_bias = persistent_ctx.causal_bias_generator(hpu_token_groups, hpu_token_positions, hpu_padding_mask, dtype) if do_shared: - # Check if we should use chunked processing to save memory - if use_chunked_processing: - with persistent_ctx.profiler.record_event('internal', 'shared_bias_chunked_prep'): - # CHUNKED DENSE APPROACH: - # - Scatter block_usages into dense (target_qlen, target_shared_blocks) on CPU - # - Don't generate full bias - just pass the dense block_usages - # - Attention code will generate bias per chunk by slicing block_usages - - # Create dense block_usages on CPU with shape (target_qlen, target_shared_blocks) - # Value 0 means "masked out entirely" (will produce all -inf bias) - block_usages_dense = np.zeros((target_qlen, target_shared_blocks), dtype=np.int32) - - # Scatter: block_usages_dense[token_idx, block_idx] = block_usage value - block_usages_dense[shared_token_idx, shared_block_idx] = shared_block_usage - - # Transfer dense tensor to HPU - shape is fully static (target_qlen, target_shared_blocks) - hpu_block_usages_dense = persistent_ctx.hpu_tensor( - block_usages_dense, - (target_qlen, target_shared_blocks), - 0, # pad with 0 (fully masked) - torch.int32) - - # DON'T generate full bias - set shared_bias to None - # Attention code will generate bias per chunk using shared_bias_chunked - attn_metadata.shared_bias = None - - # Set up chunked data for chunk-wise bias generation - attn_metadata.shared_bias_chunked = SharedBlockChunkedBiasData( - block_usages=hpu_block_usages_dense, # Dense (target_qlen, target_shared_blocks) - num_query_tokens=target_qlen, - num_shared_blocks=target_shared_blocks, - ) - - # Handle non-chunked case - if not use_chunked_processing: - if use_dense_bias_generation: - # DENSE APPROACH: Scatter on CPU (any shape), broadcast on HPU (static shape) - # This avoids dynamic-length coordinate arrays on HPU entirely - with persistent_ctx.profiler.record_event('internal', 'shared_bias_dense_prep'): - # Create dense block_usages on CPU with shape (target_qlen, target_shared_blocks) - # Value 0 means "masked out entirely" (will produce all -inf bias) - block_usages_dense = np.zeros((target_qlen, target_shared_blocks), dtype=np.int32) - - # Scatter: block_usages_dense[token_idx, block_idx] = block_usage value - block_usages_dense[shared_token_idx, shared_block_idx] = shared_block_usage - - # Transfer dense tensor to HPU - shape is fully static (target_qlen, target_shared_blocks) - hpu_block_usages_dense = persistent_ctx.hpu_tensor( - block_usages_dense, - (target_qlen, target_shared_blocks), - 0, # pad with 0 (fully masked) - torch.int32) - - # Generate bias using dense generator - no dynamic scatter on HPU - hpu_shared_bias = persistent_ctx.shared_bias_generator_dense( - hpu_block_usages_dense, block_size, dtype) - attn_metadata.shared_bias = hpu_shared_bias - else: - # SPARSE APPROACH: Original implementation with dynamic-length coordinate arrays - # NOTE(kzawora): this is kinda janky, but for a good reason - the number of shared tokens can vary significantly, - # and it impacts whether it's even worth running on HPU. - # On HPU, we need to avoid dynamic shapes == we need to pad number of shared tokens. - # It's currently padded to target_shared_blocks * block_size - # We set some simple heuristics to decide whether to run on HPU or fallback to CPU: - # 1. Check the number of shared tokens. If it's greater than the padded number of shared tokens, we fallback to CPU. - # 2. Check the padding ratio - if the padding exceeds 50% of actual size, we fallback to CPU. - - actual_num_shared_tokens = shared_block_usage.shape[0] - padded_num_shared_tokens = target_shared_blocks * block_size - # NOTE(kzawora): Initially we checked for padding ratio as well, but ultimately I've found no cases in - # which generating mask on padded_num_shared_tokens was slower than CPU + copying to HPU - # in case we have too many or too little shared tokens, we fall back to cpu generation - if padded_num_shared_tokens < actual_num_shared_tokens: - with persistent_ctx.profiler.record_event('internal', 'shared_bias_cpu_fallback'): - shared_block_bias = generate_bias(shared_block_usage, block_size, np_dtype, - persistent_ctx.block_len_range, - persistent_ctx.shared_block_bias) - if persistent_ctx.use_persistent_shared_biases: - shared_bias = persistent_ctx.shared_bias[:query_len, :shared_blocks. - shape[0], :block_size] - else: - shared_bias = np.full((query_len, shared_blocks.shape[0], block_size), - -math.inf, - dtype=np_dtype) - shared_bias.fill(-math.inf) - shared_bias[shared_token_idx, shared_block_idx] = shared_block_bias - attn_metadata.shared_bias = persistent_ctx.hpu_tensor( - shared_bias, (target_qlen, target_shared_blocks, block_size), -math.inf, dtype) - else: - with persistent_ctx.profiler.record_event('internal', 'shared_bias_hpu_prep'): - # do HPU-accelerated shared mask generation - shared_tokens_shape = (padded_num_shared_tokens, ) - hpu_shared_block_usage = persistent_ctx.hpu_tensor(shared_block_usage, shared_tokens_shape, - -1, slot_mapping_dtype) - hpu_shared_token_idx = persistent_ctx.hpu_tensor(shared_token_idx, shared_tokens_shape, -1, - slot_mapping_dtype) - hpu_shared_block_idx = persistent_ctx.hpu_tensor(shared_block_idx, shared_tokens_shape, -1, - slot_mapping_dtype) - hpu_shared_bias = persistent_ctx.shared_bias_generator(hpu_shared_block_usage, - hpu_shared_token_idx, - hpu_shared_block_idx, block_size, - dtype, target_qlen, - target_shared_blocks) - attn_metadata.shared_bias = hpu_shared_bias + _prepare_shared_bias_hpu( + persistent_ctx=persistent_ctx, + attn_metadata=attn_metadata, + shared_token_idx=shared_token_idx, + shared_block_idx=shared_block_idx, + shared_block_usage=shared_block_usage, + shared_blocks=shared_blocks, + target_qlen=target_qlen, + target_shared_blocks=target_shared_blocks, + query_len=query_len, + block_size=block_size, + dtype=dtype, + np_dtype=np_dtype, + slot_mapping_dtype=slot_mapping_dtype, + use_chunked_processing=use_chunked_processing, + use_dense_bias_generation=use_dense_bias_generation, + ) token_ids_device = persistent_ctx.hpu_tensor(token_ids, (target_qlen, ), -1, token_ids_dtype) logits_indices_device = persistent_ctx.hpu_tensor(logits_indices, (target_logits, ), -1, logits_indices_dtype) From 1f0f17a43335f8d618b1454649ecbcbf6fd7a188 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Wed, 7 Jan 2026 19:03:24 +0200 Subject: [PATCH 4/9] remove FA changes Signed-off-by: Konrad Zawora --- vllm_gaudi/extension/unified.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/vllm_gaudi/extension/unified.py b/vllm_gaudi/extension/unified.py index 37298e2b2..9d09168e5 100644 --- a/vllm_gaudi/extension/unified.py +++ b/vllm_gaudi/extension/unified.py @@ -154,13 +154,12 @@ def create_softmax_fa2_input_tensors( if t_retained_shape not in inputM_hpu_tensors: print("Allocating new input tensors for shape:", t_retained_shape, "for attn shape:", attn.shape) - inputM_hpu_tensors[t_retained_shape] = torch.full(retained_shape, fmin_val, dtype=attn.dtype, device='hpu') - inputL_hpu_tensors[t_retained_shape] = torch.zeros(retained_shape, dtype=attn.dtype, device="hpu") - return inputM_hpu_tensors[t_retained_shape], inputL_hpu_tensors[t_retained_shape] - + return torch.full(retained_shape, fmin, dtype=attn.dtype, device='hpu'), torch.zeros(retained_shape, + dtype=attn.dtype, + device="hpu") torch.hpu.synchronize() inputL_hpu_tensors[t_retained_shape].zero_() - inputM_hpu_tensors[t_retained_shape].fill_(fmin_val) + inputM_hpu_tensors[t_retained_shape].fill_(fmin) return inputM_hpu_tensors[t_retained_shape], inputL_hpu_tensors[t_retained_shape] From 019a1cf5da6e1626508bab4c5c59be5920ccd7ec Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Wed, 7 Jan 2026 19:13:22 +0200 Subject: [PATCH 5/9] pad shared blocks to chunk size Signed-off-by: Konrad Zawora --- vllm_gaudi/extension/unified_batch.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm_gaudi/extension/unified_batch.py b/vllm_gaudi/extension/unified_batch.py index d024b6fb9..d102362e1 100644 --- a/vllm_gaudi/extension/unified_batch.py +++ b/vllm_gaudi/extension/unified_batch.py @@ -830,6 +830,12 @@ def first_dim(t: Optional[np.ndarray]) -> int: use_chunked_processing = bool(target_shared_blocks > default_chunk_size) # Chunked dense processing - generates bias per chunk + # Pad target_shared_blocks to be a multiple of chunk_size for chunked processing + # This ensures all chunks have exactly chunk_size blocks (static shapes in the kernel) + if use_chunked_processing and target_shared_blocks % default_chunk_size != 0: + target_shared_blocks = ( + (target_shared_blocks + default_chunk_size - 1) // default_chunk_size) * default_chunk_size + # Dense bias generation: scatter on CPU (any shape), then broadcast on HPU (static shape) # This avoids dynamic-length coordinate arrays on HPU entirely use_dense_bias_generation = persistent_ctx.use_dense_shared_bias From 2d708e489d8c025d359ed5e269b01d0ced883a47 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Wed, 7 Jan 2026 19:35:00 +0200 Subject: [PATCH 6/9] Add FA online merge to UA Signed-off-by: Konrad Zawora --- vllm_gaudi/extension/unified.py | 160 ++++++++++++++++++++++++++++---- 1 file changed, 142 insertions(+), 18 deletions(-) diff --git a/vllm_gaudi/extension/unified.py b/vllm_gaudi/extension/unified.py index 9d09168e5..a8d8218d8 100644 --- a/vllm_gaudi/extension/unified.py +++ b/vllm_gaudi/extension/unified.py @@ -176,6 +176,77 @@ def convert_cl_aligned_tensor(input_hpu, reference_size) -> torch.tensor: return input_hpu +def online_merge_step( + acc_attn: Optional[torch.tensor], + acc_max: Optional[torch.tensor], + acc_sum: Optional[torch.tensor], + new_attn: Optional[torch.tensor], + new_max: Optional[torch.tensor], + new_sum: Optional[torch.tensor], +) -> tuple[Optional[torch.tensor], Optional[torch.tensor], Optional[torch.tensor]]: + """Incrementally merge attention results using flash-attention style rescaling. + + This implements the online softmax algorithm where we maintain running + unnormalized weighted values, max, and sum. The final normalization + (dividing by sum) is done at the end. + + Args: + acc_attn: Accumulated unnormalized weighted V [tokens, heads, head_dim] or None + acc_max: Accumulated max values [tokens, heads] or None + acc_sum: Accumulated sum of exp values [tokens, heads] or None + new_attn: New unnormalized weighted V to merge + new_max: New max values to merge + new_sum: New sum of exp values to merge + + Returns: + Tuple of (merged_attn, merged_max, merged_sum) + """ + if new_attn is None: + return acc_attn, acc_max, acc_sum + if acc_attn is None: + return new_attn, new_max, new_sum + + # Flash-attention style merge + merged_max = torch.maximum(acc_max, new_max) + old_scale = torch.exp(acc_max - merged_max) + new_scale = torch.exp(new_max - merged_max) + + merged_attn = acc_attn * old_scale.unsqueeze(-1) + new_attn * new_scale.unsqueeze(-1) + merged_sum = acc_sum * old_scale + new_sum * new_scale + + return merged_attn, merged_max, merged_sum + + +def online_merge(*attn_results: tuple[torch.tensor, torch.tensor, torch.tensor], + feps: torch.tensor) -> Optional[torch.tensor]: + """Merge partial attention values using online (incremental) algorithm. + + Alternative to merge() that uses online_merge_step for incremental merging. + This approach is more memory efficient as it doesn't need to keep all + intermediate results simultaneously. + + Args: + attn_results: Variable number of (attn, max, sum) tuples + feps: Small epsilon for numerical stability + + Returns: + Final normalized attention output, or None if all inputs are None + """ + acc_attn = None + acc_max = None + acc_sum = None + + for attn, max_val, sum_val in attn_results: + acc_attn, acc_max, acc_sum = online_merge_step(acc_attn, acc_max, acc_sum, attn, max_val, sum_val) + + if acc_attn is None: + return None + + # Final normalization + acc_sum = torch.maximum(acc_sum, feps) + return acc_attn / acc_sum.unsqueeze(-1) + + def merge(*attn_results: torch.tensor, feps: torch.tensor) -> torch.tensor: """Merge partial attention values into final attn score""" all_attn, all_max, all_sum = zip(*attn_results) @@ -645,13 +716,30 @@ def is_prompt(self): return self.causal_bias is not None -def unified_attn(query: torch.tensor, key: torch.tensor, value: torch.tensor, key_cache: torch.tensor, - value_cache: torch.tensor, scale: float, metadata: HPUUnifiedAttentionMetadata) -> torch.tensor: - """Main entry point for unified attention""" +def unified_attn(query: torch.tensor, + key: torch.tensor, + value: torch.tensor, + key_cache: torch.tensor, + value_cache: torch.tensor, + scale: float, + metadata: HPUUnifiedAttentionMetadata, + use_online_merge: bool = True) -> torch.tensor: + """Main entry point for unified attention + + Args: + use_online_merge: If True, use online (incremental) merge algorithm. + Merges after each partial attention to avoid large intermediate buffers. + If False, use offline (single-pass) merge algorithm. + """ scaled_query = query * scale cache_utils = CacheUtils(key_cache, value_cache, metadata.block_size) + if use_online_merge: + # Online merge: compute and merge incrementally to avoid large intermediate buffers + acc_attn, acc_max, acc_sum = None, None, None + + # 1. Causal attention causal = partial_attn_causal(query=scaled_query, key=key, value=value, @@ -661,8 +749,10 @@ def unified_attn(query: torch.tensor, key: torch.tensor, value: torch.tensor, ke inputL_hpu_tensors=metadata.inputL_hpu_tensors, inputM_hpu_tensors=metadata.inputM_hpu_tensors, w_uv=None) + if use_online_merge: + acc_attn, acc_max, acc_sum = online_merge_step(acc_attn, acc_max, acc_sum, *causal) - # Single call handles both full and chunked modes + # 2. Shared attention shared = partial_attn_shared(query=scaled_query, blocks=metadata.shared_blocks, bias=metadata.shared_bias, @@ -674,7 +764,10 @@ def unified_attn(query: torch.tensor, key: torch.tensor, value: torch.tensor, ke w_uv=None, chunked_data=metadata.shared_bias_chunked, chunk_size=metadata.shared_chunk_size) + if use_online_merge: + acc_attn, acc_max, acc_sum = online_merge_step(acc_attn, acc_max, acc_sum, *shared) + # 3. Unique attention unique = partial_attn_unique(query=scaled_query, blocks=metadata.unique_blocks, block_mapping=metadata.unique_block_mapping, @@ -682,9 +775,19 @@ def unified_attn(query: torch.tensor, key: torch.tensor, value: torch.tensor, ke fmin=metadata.fmin, cache_utils=cache_utils, w_uv=None) - attn = merge(causal, shared, unique, feps=metadata.feps) - if attn is None: - return query + if use_online_merge: + acc_attn, acc_max, acc_sum = online_merge_step(acc_attn, acc_max, acc_sum, *unique) + + # Final normalization + if use_online_merge: + if acc_attn is None: + return query + acc_sum = torch.maximum(acc_sum, metadata.feps) + attn = acc_attn / acc_sum.unsqueeze(-1) + else: + attn = merge(causal, shared, unique, feps=metadata.feps) + if attn is None: + return query return attn @@ -695,7 +798,8 @@ def unified_mla(query: Optional[torch.tensor], scale: float, metadata: HPUUnifiedAttentionMetadata, w_uv: torch.tensor, - query_latent: Optional[torch.tensor] = None) -> torch.tensor: + query_latent: Optional[torch.tensor] = None, + use_online_merge: bool = False) -> torch.tensor: """Main entry point for Unified MLA Args: @@ -709,6 +813,9 @@ def unified_mla(query: Optional[torch.tensor], w_uv: Projection matrix from latent to full V [num_heads, latent_dim, v_head_dim] query_latent: Query tensor for cached path (in latent space) [tokens, num_heads, latent_dim + rope_dim] None if only causal attention is needed. + use_online_merge: If True, use online (incremental) merge algorithm. + Merges after each partial attention to avoid large intermediate buffers. + If False, use offline (single-pass) merge algorithm. Returns: Attention output [tokens, num_heads * v_head_dim] @@ -727,6 +834,9 @@ def unified_mla(query: Optional[torch.tensor], # MLA: latent cache has no head dimension, value_cache is None (stored in same cache) cache_utils = CacheUtils(latent_cache, value_cache=None, block_size=metadata.block_size, is_mla=True) + if use_online_merge: + # Online merge: compute and merge incrementally to avoid large intermediate buffers + acc_attn, acc_max, acc_sum = None, None, None # Causal: compute-friendly path (expand K/V from latent) # key and value already expanded by caller @@ -740,6 +850,8 @@ def unified_mla(query: Optional[torch.tensor], inputL_hpu_tensors=metadata.inputL_hpu_tensors, inputM_hpu_tensors=metadata.inputM_hpu_tensors, w_uv=w_uv) if scaled_query_causal is not None else (None, None, None) + if use_online_merge: + acc_attn, acc_max, acc_sum = online_merge_step(acc_attn, acc_max, acc_sum, *causal) # Shared/Unique: memory-friendly path (Q in latent space, fetch cached latent KV) # query_latent is already transformed to latent space by caller @@ -758,6 +870,8 @@ def unified_mla(query: Optional[torch.tensor], w_uv=w_uv, chunked_data=metadata.shared_bias_chunked, chunk_size=metadata.shared_chunk_size) + if use_online_merge: + acc_attn, acc_max, acc_sum = online_merge_step(acc_attn, acc_max, acc_sum, *shared) else: shared = (None, None, None) @@ -768,14 +882,24 @@ def unified_mla(query: Optional[torch.tensor], fmin=metadata.fmin, cache_utils=cache_utils, w_uv=w_uv) if scaled_query_latent is not None else (None, None, None) - - attn = merge(causal, shared, unique, feps=metadata.feps) - if attn is None: - # No attention computed, return original query - # Use whichever query was provided - # FIXME(kzawora): I'm not quite sure if that's correct, needs verification - if query is not None: - return query.flatten(-2, -1) # [tokens, num_heads * head_dim] - else: - return query_latent.flatten(-2, -1) # [tokens, num_heads * head_dim] + if use_online_merge: + acc_attn, acc_max, acc_sum = online_merge_step(acc_attn, acc_max, acc_sum, *unique) + if use_online_merge: + if acc_attn is None: + if query is not None: + return query.flatten(-2, -1) # [tokens, num_heads * head_dim] + else: + return query_latent.flatten(-2, -1) # [tokens, num_heads * head_dim] + acc_sum = torch.maximum(acc_sum, metadata.feps) + attn = acc_attn / acc_sum.unsqueeze(-1) + else: + attn = merge(causal, shared, unique, feps=metadata.feps) + if attn is None: + # No attention computed, return original query + # Use whichever query was provided + # FIXME(kzawora): I'm not quite sure if that's correct, needs verification + if query is not None: + return query.flatten(-2, -1) # [tokens, num_heads * head_dim] + else: + return query_latent.flatten(-2, -1) # [tokens, num_heads * head_dim] return attn From ef3c73fe51cf68e813950b1e6a91cb51947e084c Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Thu, 8 Jan 2026 17:20:15 +0200 Subject: [PATCH 7/9] Unified Attention - multi-step low-level profiling Signed-off-by: Konrad Zawora --- vllm_gaudi/v1/worker/hpu_model_runner.py | 53 +++++++++++++++++++++--- 1 file changed, 48 insertions(+), 5 deletions(-) diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index b105525d2..0ebfba039 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -4450,6 +4450,22 @@ def _execute_dummy_scenario(self, requests, scheduled_tokens): self.sample_tokens(None) self.execute_model(cleanup, warmup_mode=True) + def _generate_unified_profiling(self, unified_cfgs): + steps = 3 + profiler = setup_profiler(warmup=steps - 1, active=1) + profiler.start() + for _ in range(steps): + for unified_cfg in unified_cfgs: + if unified_cfg not in self.bucketing_manager.unified_buckets: + self.bucketing_manager.unified_buckets.insert(0, unified_cfg) + torch.hpu.synchronize() + with torch.autograd.profiler.record_function(str(unified_cfg)): + self._prepare_dummy_unified_scenario(unified_cfg) + torch.hpu.synchronize() + profiler.step() + profiler.stop() + return profiler + def _generate_profiling(self, prompt_cfg, decode_cfg): steps = 3 profiler = setup_profiler(warmup=steps - 1, active=1) @@ -4595,6 +4611,30 @@ def warmup_multimodal_graphs(self, buckets): self.graphed_buckets.add(img_arg) self.log_warmup_multimodal(phase, idx, num_candidates, 1, 0, img_arg) + def _maybe_profile_unified_attn(self): + unified_cfg_str = os.environ.get('VLLM_PROFILE_UNIFIED', None) + if unified_cfg_str: + # NOTE(kzawora): VLLM_PROFILE_UNIFIED can pass either a single tuple + # or a list of tuples. Examples: + # VLLM_PROFILE_UNIFIED="(8,16,16,1)" or VLLM_PROFILE_UNIFIED=8,16,16,1 + # VLLM_PROFILE_UNIFIED="[(8,16,16,0), (4,8,8,1)]" + # If a list of tuples is passed, we profile each one sequentially. + # We're using ast.literal_eval to safely parse the string representation of the tuple/list + import ast + cfg = ast.literal_eval(unified_cfg_str) + cfg_list = [] + if isinstance(cfg, tuple): + # Single cfg passed as tuple, e.g. 512,32,128,1 or (512,32,128,1) + cfg_list = [cfg] + elif isinstance(cfg, list): + # Multiple cfgs passed as a list of tuples, e.g. [(512,32,128,0),(512,32,128,1)] + cfg_list = cfg + else: + raise AssertionError("VLLM_PROFILE_UNIFIED value must be a tuple or a list of tuples") + prof = self._generate_unified_profiling(cfg_list) + msg = f"Finished profiling. Key averages:\n{prof.key_averages()}" + raise AssertionError(msg) + @torch.inference_mode() def warmup_model(self) -> None: if not self.enable_bucketing: @@ -4640,11 +4680,14 @@ def warmup_model(self) -> None: ) self.defragmenter.initialize(self.kv_caches, self.block_size) - - prompt_profile_cfg, decode_profile_cfg = self._read_profiling_cfg() - if prompt_profile_cfg or decode_profile_cfg: - self._generate_profiling(prompt_profile_cfg, decode_profile_cfg) - raise AssertionError("Finished profiling") + # Profiling + if self.unified_attn: + self._maybe_profile_unified_attn() + else: + prompt_profile_cfg, decode_profile_cfg = self._read_profiling_cfg() + if prompt_profile_cfg or decode_profile_cfg: + self._generate_profiling(prompt_profile_cfg, decode_profile_cfg, None) + raise AssertionError("Finished profiling") kv_caches = self.kv_caches if not htorch.utils.internal.is_lazy() and not self.model_config.enforce_eager: From 527304e6fd9122b9f3fde67982f8480390e44c48 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Wed, 14 Jan 2026 15:33:48 +0200 Subject: [PATCH 8/9] cleanup flags + add graph splitting Signed-off-by: Konrad Zawora --- vllm_gaudi/attention/backends/hpu_attn.py | 2 +- vllm_gaudi/extension/features.py | 21 +++++++++-- vllm_gaudi/extension/unified.py | 46 ++++++++++++++--------- vllm_gaudi/extension/unified_batch.py | 10 +++-- 4 files changed, 53 insertions(+), 26 deletions(-) diff --git a/vllm_gaudi/attention/backends/hpu_attn.py b/vllm_gaudi/attention/backends/hpu_attn.py index f619be467..1093af548 100644 --- a/vllm_gaudi/attention/backends/hpu_attn.py +++ b/vllm_gaudi/attention/backends/hpu_attn.py @@ -1030,7 +1030,7 @@ def __init__( self.qk_head_dim = qk_head_dim self.v_head_dim = v_head_dim self.kv_b_proj = kv_b_proj # Used to expand latent → full KV in causal path - + self.use_online_merge = get_config().unified_attn_online_merge assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.latent_cache_k = VLLMKVCache() if not self.enable_fp8_attn \ diff --git a/vllm_gaudi/extension/features.py b/vllm_gaudi/extension/features.py index fab4184b7..df0f4098b 100644 --- a/vllm_gaudi/extension/features.py +++ b/vllm_gaudi/extension/features.py @@ -61,6 +61,21 @@ def get_experimental_flags(): return to_dict(flags) +def unified_attn_dev_flags(): + flags = [ + Value('unified_attn_dense_shared_bias', True), + Value('unified_attn_chunked_shared_attn', True), + Value('unified_attn_online_merge', True), + Value('unified_attn_shared_attn_chunk_size', 128), + Value('unified_attn_split_graphs', Enabled('unified_attn_online_merge')), + Value( + 'unified_attn_softmax_fa2', + All(VersionRange(">=1.24.0.279"), Enabled('unified_attn'), Kernel(softmax_fa2), Hardware('gaudi3'), + Not(Enabled('unified_attn_chunked_shared_attn')))), + ] + return flags + + def get_features(): supported_attn_impls = ['flex_impl', 'fsdpa_impl', 'naive_impl'] bucketing_strategies = ['exponential_bucketing', 'linear_bucketing'] @@ -90,13 +105,11 @@ def get_features(): Value('dynamic_shapes_compilation', True, env_var='VLLM_T_COMPILE_DYNAMIC_SHAPES', env_var_type=boolean), Value('fullgraph_compilation', False, env_var='VLLM_T_COMPILE_FULLGRAPH', env_var_type=boolean), Value('unified_attn', False), - Value('unified_attn_dense_shared_bias', True), - Value('unified_attn_softmax_fa2', - All(VersionRange(">=1.24.0.279"), Enabled('unified_attn'), Kernel(softmax_fa2), Hardware('gaudi3'))), + *unified_attn_dev_flags(), Value('scale_adjustment', True, env_var='VLLM_SCALE_ADJUSTMENT', env_var_type=boolean), Value('flatten_input', Any(ModelType('qwen3_moe'), ModelType('granitemoe'), ModelType('glm4_moe'))), Value('unified_attn_shared_cache_ratio', - 0.8, + 1, env_var='VLLM_UNIFIED_ATTENTION_SHARED_CACHE_RATIO', env_var_type=float), Value('high_level_profiler_enabled', False, env_var='VLLM_PROFILER_ENABLED', env_var_type=boolean), diff --git a/vllm_gaudi/extension/unified.py b/vllm_gaudi/extension/unified.py index a8d8218d8..7f2f67abe 100644 --- a/vllm_gaudi/extension/unified.py +++ b/vllm_gaudi/extension/unified.py @@ -349,6 +349,7 @@ class SharedBlockChunkedBiasData: block_usages: torch.tensor # Dense: [num_query_tokens, num_shared_blocks] num_query_tokens: int # Total query length (padded) num_shared_blocks: int # Total number of shared blocks (padded) + split_chunked_graphs: bool def _partial_attn_shared_core(query: torch.tensor, @@ -545,8 +546,10 @@ def _partial_attn_shared_chunked( accumulated_attn = None global_max = None global_sum = None - + split_graphs = chunked_data.split_chunked_graphs for chunk_idx in range(num_chunks): + if split_graphs: + htorch.core.mark_step() chunk_start = chunk_idx * chunk_size chunk_end = min(chunk_start + chunk_size, num_blocks) actual_chunk_len = chunk_end - chunk_start @@ -611,6 +614,9 @@ def _partial_attn_shared_chunked( global_sum = global_sum * old_scale + chunk_sum * new_scale global_max = new_max + if split_graphs: + htorch.core.mark_step() + if accumulated_attn is None: return (None, None, None) @@ -695,6 +701,8 @@ class HPUUnifiedAttentionMetadata: feps: torch.tensor inputL_hpu_tensors: Optional[Dict[tuple, torch.Tensor]] inputM_hpu_tensors: Optional[Dict[tuple, torch.Tensor]] + online_merge: bool + split_graphs: bool def seq_len(self): # TODO: This needs to be changed in case of mixed batches @@ -716,29 +724,23 @@ def is_prompt(self): return self.causal_bias is not None -def unified_attn(query: torch.tensor, - key: torch.tensor, - value: torch.tensor, - key_cache: torch.tensor, - value_cache: torch.tensor, - scale: float, - metadata: HPUUnifiedAttentionMetadata, - use_online_merge: bool = True) -> torch.tensor: - """Main entry point for unified attention - - Args: - use_online_merge: If True, use online (incremental) merge algorithm. - Merges after each partial attention to avoid large intermediate buffers. - If False, use offline (single-pass) merge algorithm. - """ +def unified_attn(query: torch.tensor, key: torch.tensor, value: torch.tensor, key_cache: torch.tensor, + value_cache: torch.tensor, scale: float, metadata: HPUUnifiedAttentionMetadata) -> torch.tensor: + """Main entry point for unified attention""" scaled_query = query * scale cache_utils = CacheUtils(key_cache, value_cache, metadata.block_size) + use_online_merge = metadata.online_merge + split_graphs = metadata.split_graphs + if use_online_merge: # Online merge: compute and merge incrementally to avoid large intermediate buffers acc_attn, acc_max, acc_sum = None, None, None + if split_graphs: + htorch.core.mark_step() + # 1. Causal attention causal = partial_attn_causal(query=scaled_query, key=key, @@ -752,6 +754,9 @@ def unified_attn(query: torch.tensor, if use_online_merge: acc_attn, acc_max, acc_sum = online_merge_step(acc_attn, acc_max, acc_sum, *causal) + if split_graphs: + htorch.core.mark_step() + # 2. Shared attention shared = partial_attn_shared(query=scaled_query, blocks=metadata.shared_blocks, @@ -767,6 +772,9 @@ def unified_attn(query: torch.tensor, if use_online_merge: acc_attn, acc_max, acc_sum = online_merge_step(acc_attn, acc_max, acc_sum, *shared) + if split_graphs: + htorch.core.mark_step() + # 3. Unique attention unique = partial_attn_unique(query=scaled_query, blocks=metadata.unique_blocks, @@ -778,6 +786,9 @@ def unified_attn(query: torch.tensor, if use_online_merge: acc_attn, acc_max, acc_sum = online_merge_step(acc_attn, acc_max, acc_sum, *unique) + if split_graphs: + htorch.core.mark_step() + # Final normalization if use_online_merge: if acc_attn is None: @@ -798,8 +809,7 @@ def unified_mla(query: Optional[torch.tensor], scale: float, metadata: HPUUnifiedAttentionMetadata, w_uv: torch.tensor, - query_latent: Optional[torch.tensor] = None, - use_online_merge: bool = False) -> torch.tensor: + query_latent: Optional[torch.tensor] = None) -> torch.tensor: """Main entry point for Unified MLA Args: diff --git a/vllm_gaudi/extension/unified_batch.py b/vllm_gaudi/extension/unified_batch.py index d102362e1..e1b8ade5a 100644 --- a/vllm_gaudi/extension/unified_batch.py +++ b/vllm_gaudi/extension/unified_batch.py @@ -599,6 +599,7 @@ def _prepare_shared_bias_hpu( block_usages=hpu_block_usages_dense, num_query_tokens=target_qlen, num_shared_blocks=target_shared_blocks, + split_chunked_graphs=get_config().unified_attn_split_graphs, ) return @@ -826,9 +827,10 @@ def first_dim(t: Optional[np.ndarray]) -> int: # With chunked dense generation, we only allocate (target_qlen, target_shared_blocks) for block_usages # instead of the full (target_qlen, target_shared_blocks, block_size) bias tensor. # Bias is generated per chunk: (target_qlen, chunk_size, block_size) - default_chunk_size = 64 # Process up to 64 blocks at a time for shared attention - use_chunked_processing = bool(target_shared_blocks - > default_chunk_size) # Chunked dense processing - generates bias per chunk + default_chunk_size = get_config( + ).unified_attn_shared_attn_chunk_size # Process up to 64 blocks at a time for shared attention + use_chunked_processing = get_config().unified_attn_chunked_shared_attn and bool( + target_shared_blocks > default_chunk_size) # Chunked dense processing - generates bias per chunk # Pad target_shared_blocks to be a multiple of chunk_size for chunked processing # This ensures all chunks have exactly chunk_size blocks (static shapes in the kernel) @@ -862,6 +864,8 @@ def first_dim(t: Optional[np.ndarray]) -> int: feps=to_hpu(feps, dtype=dtype), inputL_hpu_tensors=dict(), inputM_hpu_tensors=dict(), + split_graphs=get_config().unified_attn_split_graphs, + online_merge=get_config().unified_attn_online_merge, ) if hpu_bias_acceleration: From 8a5eac06286df0ae9c553c6f45841ef4ee807c75 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Wed, 14 Jan 2026 16:06:55 +0200 Subject: [PATCH 9/9] reduce chunk size, fix mla Signed-off-by: Konrad Zawora --- vllm_gaudi/extension/features.py | 2 +- vllm_gaudi/extension/unified.py | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/vllm_gaudi/extension/features.py b/vllm_gaudi/extension/features.py index df0f4098b..3e2e71ebb 100644 --- a/vllm_gaudi/extension/features.py +++ b/vllm_gaudi/extension/features.py @@ -66,7 +66,7 @@ def unified_attn_dev_flags(): Value('unified_attn_dense_shared_bias', True), Value('unified_attn_chunked_shared_attn', True), Value('unified_attn_online_merge', True), - Value('unified_attn_shared_attn_chunk_size', 128), + Value('unified_attn_shared_attn_chunk_size', 64), Value('unified_attn_split_graphs', Enabled('unified_attn_online_merge')), Value( 'unified_attn_softmax_fa2', diff --git a/vllm_gaudi/extension/unified.py b/vllm_gaudi/extension/unified.py index 7f2f67abe..d7525ce88 100644 --- a/vllm_gaudi/extension/unified.py +++ b/vllm_gaudi/extension/unified.py @@ -844,10 +844,16 @@ def unified_mla(query: Optional[torch.tensor], # MLA: latent cache has no head dimension, value_cache is None (stored in same cache) cache_utils = CacheUtils(latent_cache, value_cache=None, block_size=metadata.block_size, is_mla=True) + use_online_merge = metadata.online_merge + split_graphs = metadata.split_graphs + if use_online_merge: # Online merge: compute and merge incrementally to avoid large intermediate buffers acc_attn, acc_max, acc_sum = None, None, None + if split_graphs: + htorch.core.mark_step() + # Causal: compute-friendly path (expand K/V from latent) # key and value already expanded by caller # w_uv projection applied by unified function @@ -863,6 +869,9 @@ def unified_mla(query: Optional[torch.tensor], if use_online_merge: acc_attn, acc_max, acc_sum = online_merge_step(acc_attn, acc_max, acc_sum, *causal) + if split_graphs: + htorch.core.mark_step() + # Shared/Unique: memory-friendly path (Q in latent space, fetch cached latent KV) # query_latent is already transformed to latent space by caller # For these paths, we need to expand K/V from cached latent vectors @@ -885,6 +894,9 @@ def unified_mla(query: Optional[torch.tensor], else: shared = (None, None, None) + if split_graphs: + htorch.core.mark_step() + unique = partial_attn_unique(query=scaled_query_latent, blocks=metadata.unique_blocks, block_mapping=metadata.unique_block_mapping, @@ -894,6 +906,10 @@ def unified_mla(query: Optional[torch.tensor], w_uv=w_uv) if scaled_query_latent is not None else (None, None, None) if use_online_merge: acc_attn, acc_max, acc_sum = online_merge_step(acc_attn, acc_max, acc_sum, *unique) + + if split_graphs: + htorch.core.mark_step() + if use_online_merge: if acc_attn is None: if query is not None: