diff --git a/vllm_gaudi/attention/backends/hpu_attn.py b/vllm_gaudi/attention/backends/hpu_attn.py index 1aa7faab4..ddaaf42ad 100644 --- a/vllm_gaudi/attention/backends/hpu_attn.py +++ b/vllm_gaudi/attention/backends/hpu_attn.py @@ -1060,7 +1060,7 @@ def __init__( self.qk_head_dim = qk_head_dim self.v_head_dim = v_head_dim self.kv_b_proj = kv_b_proj # Used to expand latent → full KV in causal path - + self.use_online_merge = get_config().unified_attn_online_merge assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.latent_cache_k = VLLMKVCache() if not self.enable_fp8_attn \ diff --git a/vllm_gaudi/extension/features.py b/vllm_gaudi/extension/features.py index 004ffff5f..5f0480c2c 100644 --- a/vllm_gaudi/extension/features.py +++ b/vllm_gaudi/extension/features.py @@ -61,6 +61,21 @@ def get_experimental_flags(): return to_dict(flags) +def unified_attn_dev_flags(): + flags = [ + Value('unified_attn_dense_shared_bias', True), + Value('unified_attn_chunked_shared_attn', True), + Value('unified_attn_online_merge', True), + Value('unified_attn_shared_attn_chunk_size', 64), + Value('unified_attn_split_graphs', Enabled('unified_attn_online_merge')), + Value( + 'unified_attn_softmax_fa2', + All(VersionRange(">=1.24.0.279"), Enabled('unified_attn'), Kernel(softmax_fa2), Hardware('gaudi3'), + Not(Enabled('unified_attn_chunked_shared_attn')))), + ] + return flags + + def get_features(): supported_attn_impls = ['flex_impl', 'fsdpa_impl', 'naive_impl'] bucketing_strategies = ['exponential_bucketing', 'linear_bucketing'] @@ -90,12 +105,11 @@ def get_features(): Value('dynamic_shapes_compilation', True, env_var='VLLM_T_COMPILE_DYNAMIC_SHAPES', env_var_type=boolean), Value('fullgraph_compilation', False, env_var='VLLM_T_COMPILE_FULLGRAPH', env_var_type=boolean), Value('unified_attn', False), - Value('unified_attn_softmax_fa2', - All(VersionRange(">=1.24.0.279"), Enabled('unified_attn'), Kernel(softmax_fa2), Hardware('gaudi3'))), + *unified_attn_dev_flags(), Value('scale_adjustment', True, env_var='VLLM_SCALE_ADJUSTMENT', env_var_type=boolean), Value('flatten_input', Any(ModelType('qwen3_moe'), ModelType('granitemoe'), ModelType('glm4_moe'))), Value('unified_attn_shared_cache_ratio', - 0.8, + 1, env_var='VLLM_UNIFIED_ATTENTION_SHARED_CACHE_RATIO', env_var_type=float), Value('high_level_profiler_enabled', False, env_var='VLLM_PROFILER_ENABLED', env_var_type=boolean), diff --git a/vllm_gaudi/extension/unified.py b/vllm_gaudi/extension/unified.py index ee3e4a74a..d7525ce88 100644 --- a/vllm_gaudi/extension/unified.py +++ b/vllm_gaudi/extension/unified.py @@ -149,6 +149,9 @@ def create_softmax_fa2_input_tensors( retained_shape[-1] = get_last_dim_size(retained_shape[-1], vec_size, pack_size) t_retained_shape = tuple(retained_shape) + # Convert fmin to scalar once + fmin_val = fmin.item() if isinstance(fmin, torch.Tensor) else fmin + if t_retained_shape not in inputM_hpu_tensors: print("Allocating new input tensors for shape:", t_retained_shape, "for attn shape:", attn.shape) return torch.full(retained_shape, fmin, dtype=attn.dtype, device='hpu'), torch.zeros(retained_shape, @@ -173,6 +176,77 @@ def convert_cl_aligned_tensor(input_hpu, reference_size) -> torch.tensor: return input_hpu +def online_merge_step( + acc_attn: Optional[torch.tensor], + acc_max: Optional[torch.tensor], + acc_sum: Optional[torch.tensor], + new_attn: Optional[torch.tensor], + new_max: Optional[torch.tensor], + new_sum: Optional[torch.tensor], +) -> tuple[Optional[torch.tensor], Optional[torch.tensor], Optional[torch.tensor]]: + """Incrementally merge attention results using flash-attention style rescaling. + + This implements the online softmax algorithm where we maintain running + unnormalized weighted values, max, and sum. The final normalization + (dividing by sum) is done at the end. + + Args: + acc_attn: Accumulated unnormalized weighted V [tokens, heads, head_dim] or None + acc_max: Accumulated max values [tokens, heads] or None + acc_sum: Accumulated sum of exp values [tokens, heads] or None + new_attn: New unnormalized weighted V to merge + new_max: New max values to merge + new_sum: New sum of exp values to merge + + Returns: + Tuple of (merged_attn, merged_max, merged_sum) + """ + if new_attn is None: + return acc_attn, acc_max, acc_sum + if acc_attn is None: + return new_attn, new_max, new_sum + + # Flash-attention style merge + merged_max = torch.maximum(acc_max, new_max) + old_scale = torch.exp(acc_max - merged_max) + new_scale = torch.exp(new_max - merged_max) + + merged_attn = acc_attn * old_scale.unsqueeze(-1) + new_attn * new_scale.unsqueeze(-1) + merged_sum = acc_sum * old_scale + new_sum * new_scale + + return merged_attn, merged_max, merged_sum + + +def online_merge(*attn_results: tuple[torch.tensor, torch.tensor, torch.tensor], + feps: torch.tensor) -> Optional[torch.tensor]: + """Merge partial attention values using online (incremental) algorithm. + + Alternative to merge() that uses online_merge_step for incremental merging. + This approach is more memory efficient as it doesn't need to keep all + intermediate results simultaneously. + + Args: + attn_results: Variable number of (attn, max, sum) tuples + feps: Small epsilon for numerical stability + + Returns: + Final normalized attention output, or None if all inputs are None + """ + acc_attn = None + acc_max = None + acc_sum = None + + for attn, max_val, sum_val in attn_results: + acc_attn, acc_max, acc_sum = online_merge_step(acc_attn, acc_max, acc_sum, attn, max_val, sum_val) + + if acc_attn is None: + return None + + # Final normalization + acc_sum = torch.maximum(acc_sum, feps) + return acc_attn / acc_sum.unsqueeze(-1) + + def merge(*attn_results: torch.tensor, feps: torch.tensor) -> torch.tensor: """Merge partial attention values into final attn score""" all_attn, all_max, all_sum = zip(*attn_results) @@ -261,6 +335,82 @@ def combine(slices): return combine(attn_slices), combine(max_slices), combine(sum_slices) +@dataclass +class SharedBlockChunkedBiasData: + """Data needed to compute shared block bias per-chunk during chunked attention. + + This avoids materializing the full [query_len, num_shared_blocks, block_size] + bias tensor which can be prohibitively large with many shared blocks. + + Contains dense block_usages of shape (num_query_tokens, num_shared_blocks). + During chunked attention, we slice block_usages[:, chunk_start:chunk_end] and + generate bias for each chunk on-the-fly. + """ + block_usages: torch.tensor # Dense: [num_query_tokens, num_shared_blocks] + num_query_tokens: int # Total query length (padded) + num_shared_blocks: int # Total number of shared blocks (padded) + split_chunked_graphs: bool + + +def _partial_attn_shared_core(query: torch.tensor, + key: torch.tensor, + value: torch.tensor, + bias: torch.tensor, + fmin: torch.tensor, + inputL_hpu_tensors: Dict[tuple, torch.Tensor], + inputM_hpu_tensors: Dict[tuple, torch.Tensor], + kv_heads: int, + is_mla: bool, + w_uv: Optional[torch.tensor] = None) -> tuple[torch.tensor, torch.tensor, torch.tensor]: + """Core shared attention computation. + + This is the inner loop extracted for reuse between full and chunked paths. + + Args: + query: Query tensor, already transposed [kv_heads, q_heads_per_kv, tokens, head_dim] or similar + key: Key tensor from cache [kv_heads, q_heads_per_kv, kv_len, head_dim] + value: Value tensor from cache + bias: Attention bias [1, kv_len] (already flattened from [num_blocks, block_size]) + fmin: Minimum float for softmax stability + inputL_hpu_tensors: Cache for FA2 tensors + inputM_hpu_tensors: Cache for FA2 tensors + kv_heads: Number of KV heads + is_mla: Whether using MLA attention + w_uv: Optional MLA projection matrix + + Returns: + Tuple of (unnormalized_weighted_V, local_max, local_sum) + """ + num_heads = query.size(0) * query.size(1) if not is_mla else query.size(0) + + attn = torch.matmul(query, key.transpose(-1, -2)) + attn = attn.flatten(0, 1) + attn = attn + bias + + # TODO: remove dtype check once full support is added for fp8 in unified attention + if get_config().unified_attn_softmax_fa2 and attn.dtype == torch.bfloat16: + inputM_hpu, inputL_hpu = create_softmax_fa2_input_tensors(attn, fmin, inputL_hpu_tensors, inputM_hpu_tensors) + attn, local_max, local_sum, _exp_max_fixup_hpu = torch.ops.hpu.softmax_fa2(attn, + inputM=inputM_hpu, + inputL=inputL_hpu) + local_max = convert_cl_aligned_tensor(local_max, list(attn.shape[:-1])) + local_sum = convert_cl_aligned_tensor(local_sum, list(attn.shape[:-1])) + else: + local_max = torch.maximum(attn.amax(-1), fmin) + attn = torch.exp(attn - local_max.unsqueeze(-1)) + local_sum = attn.sum(-1) + + attn = torch.matmul(attn.unflatten(0, (kv_heads if not is_mla else num_heads, -1)), value).flatten(0, 1) + + # MLA: Extract latent part and project to full V + if is_mla and w_uv is not None: + latent_dim = w_uv.size(1) + attn_latent = attn[..., :latent_dim] + attn = torch.bmm(attn_latent, w_uv) + + return attn.transpose(0, 1), local_max.transpose(0, 1), local_sum.transpose(0, 1) + + def partial_attn_shared(query: torch.tensor, blocks: torch.tensor, bias: Optional[torch.tensor], @@ -268,59 +418,209 @@ def partial_attn_shared(query: torch.tensor, inputL_hpu_tensors: Dict[tuple, torch.Tensor], inputM_hpu_tensors: Dict[tuple, torch.Tensor], cache_utils: CacheUtils, - w_uv: Optional[torch.tensor] = None) -> tuple[torch.tensor, torch.tensor, torch.tensor]: - """Partial attention where all shared blocks are compared with whole query + dtype: torch.dtype, + w_uv: Optional[torch.tensor] = None, + chunked_data: Optional[SharedBlockChunkedBiasData] = None, + chunk_size: int = 0) -> tuple[torch.tensor, torch.tensor, torch.tensor]: + """Partial attention where all shared blocks are compared with whole query. + + Supports two modes: + 1. Full bias mode (default): bias tensor is provided, process all blocks at once + 2. Chunked mode: chunk_size > 0, process blocks in chunks + - If bias is provided, slice from it + - If bias is None but chunked_data is provided, generate bias per chunk from dense block_usages Args: - w_uv: Optional MLA projection matrix [num_heads, latent_dim, v_head_dim]. - If provided, assumes MLA mode where query/key/value are in latent space. + query: Query tensor [tokens, num_heads, head_dim] + blocks: Shared block indices [num_shared_blocks] + bias: Pre-computed bias tensor [query_len, num_blocks, block_size]. Can be None for chunked generation. + fmin: Minimum float value for softmax stability + inputL_hpu_tensors: Cache for softmax input tensors + inputM_hpu_tensors: Cache for softmax input tensors + cache_utils: Cache utilities for fetching KV + dtype: Output dtype for bias generation + w_uv: Optional MLA projection matrix [num_heads, latent_dim, v_head_dim] + chunked_data: Metadata for chunked processing (contains dense block_usages) + chunk_size: Number of blocks per chunk (0 = full mode, >0 = chunked mode) + + Returns: + Tuple of (unnormalized_weighted_V, local_max, local_sum) """ - if bias is None: - return (None, None, None) - + # Determine mode + use_chunked = chunk_size > 0 and chunked_data is not None + + if not use_chunked: + # Full bias mode - original implementation + if bias is None: + return (None, None, None) + return _partial_attn_shared_full(query, blocks, bias, fmin, inputL_hpu_tensors, inputM_hpu_tensors, cache_utils, + w_uv) + else: + # Chunked mode - process blocks in chunks + # bias can be None for chunked generation (will generate from chunked_data.block_usages per chunk) + if blocks is None: + return (None, None, None) + return _partial_attn_shared_chunked(query, blocks, bias, chunked_data, chunk_size, fmin, inputL_hpu_tensors, + inputM_hpu_tensors, cache_utils, dtype, w_uv) + + +def _partial_attn_shared_full(query: torch.tensor, + blocks: torch.tensor, + bias: torch.tensor, + fmin: torch.tensor, + inputL_hpu_tensors: Dict[tuple, torch.Tensor], + inputM_hpu_tensors: Dict[tuple, torch.Tensor], + cache_utils: CacheUtils, + w_uv: Optional[torch.tensor] = None) -> tuple[torch.tensor, torch.tensor, torch.tensor]: + """Full bias implementation of partial_attn_shared.""" is_mla = w_uv is not None if is_mla: # MLA: Single latent cache contains both K and V latent_kv = cache_utils.fetch_shared(blocks) num_heads = query.size(1) - query = query.transpose(0, 1).unsqueeze(1) # [num_heads, 1, tokens, latent_dim + rope_dim] + query_t = query.transpose(0, 1).unsqueeze(1) # [num_heads, 1, tokens, latent_dim + rope_dim] key = latent_kv.unsqueeze(0).unsqueeze(0).expand(num_heads, 1, -1, -1) value = latent_kv.unsqueeze(0).unsqueeze(0).expand(num_heads, 1, -1, -1) kv_heads = 1 else: # Standard attention: Separate K and V caches kv_heads = cache_utils.kv_heads - query = query.transpose(0, 1).unflatten(0, (kv_heads, -1)) + query_t = query.transpose(0, 1).unflatten(0, (kv_heads, -1)) key, value = cache_utils.fetch_shared(blocks) - bias = bias.flatten(-2, -1).unsqueeze(0) + bias_flat = bias.flatten(-2, -1).unsqueeze(0) + + return _partial_attn_shared_core(query_t, key, value, bias_flat, fmin, inputL_hpu_tensors, inputM_hpu_tensors, + kv_heads, is_mla, w_uv) + + +def _partial_attn_shared_chunked( + query: torch.tensor, + blocks: torch.tensor, + bias: Optional[torch.tensor], + chunked_data: SharedBlockChunkedBiasData, + chunk_size: int, + fmin: torch.tensor, + inputL_hpu_tensors: Dict[tuple, torch.Tensor], + inputM_hpu_tensors: Dict[tuple, torch.Tensor], + cache_utils: CacheUtils, + dtype: torch.dtype, + w_uv: Optional[torch.tensor] = None) -> tuple[torch.tensor, torch.tensor, torch.tensor]: + """Chunked implementation of partial_attn_shared with per-chunk bias generation. + + Generates bias per chunk from dense block_usages to save memory. + Avoids materializing the full (query_len, num_blocks, block_size) bias tensor. + + Strategy: + 1. Process blocks in chunks of chunk_size + 2. For each chunk, slice block_usages and generate chunk bias on-the-fly + 3. Compute attention for the chunk using _partial_attn_shared_core + 4. Merge chunk results using flash-attention style online softmax + """ + num_blocks = chunked_data.num_shared_blocks + block_size = cache_utils.block_size + num_query_tokens = chunked_data.num_query_tokens - attn = torch.matmul(query, key.transpose(-1, -2)) - attn = attn.flatten(0, 1) - attn = attn + bias - # TODO: remove dtype check once full support is added for fp8 in unified attention - if get_config().unified_attn_softmax_fa2 and attn.dtype == torch.bfloat16: - inputM_hpu, inputL_hpu = create_softmax_fa2_input_tensors(attn, fmin, inputL_hpu_tensors, inputM_hpu_tensors) - attn, local_max, local_sum, _exp_max_fixup_hpu = torch.ops.hpu.softmax_fa2(attn, - inputM=inputM_hpu, - inputL=inputL_hpu) - local_max = convert_cl_aligned_tensor(local_max, list(attn.shape[:-1])) - local_sum = convert_cl_aligned_tensor(local_sum, list(attn.shape[:-1])) - else: - local_max = torch.maximum(attn.amax(-1), fmin) - attn = torch.exp(attn - local_max.unsqueeze(-1)) - local_sum = attn.sum(-1) + is_mla = w_uv is not None + kv_heads = 1 if is_mla else cache_utils.kv_heads + + # Calculate number of chunks + num_chunks = math.ceil(num_blocks / chunk_size) + + # Check if we have pre-computed bias or need to generate per-chunk + generate_bias_per_chunk = (bias is None) + + # Pre-allocate reusable tensors outside the loop (avoid allocations per iteration) + if generate_bias_per_chunk: + block_len_range = torch.arange(1, + block_size + 1, + dtype=chunked_data.block_usages.dtype, + device=chunked_data.block_usages.device) + # Pre-allocate chunk_bias buffer - will be overwritten each iteration + chunk_bias_buffer = torch.empty((num_query_tokens, chunk_size, block_size), + dtype=dtype, + device=chunked_data.block_usages.device) + + # Accumulators for online softmax-style merging + accumulated_attn = None + global_max = None + global_sum = None + split_graphs = chunked_data.split_chunked_graphs + for chunk_idx in range(num_chunks): + if split_graphs: + htorch.core.mark_step() + chunk_start = chunk_idx * chunk_size + chunk_end = min(chunk_start + chunk_size, num_blocks) + actual_chunk_len = chunk_end - chunk_start + + # Slice blocks for this chunk + chunk_blocks = blocks[chunk_start:chunk_end] + + if generate_bias_per_chunk: + # Generate bias for this chunk from dense block_usages + # chunked_data.block_usages is (num_query_tokens, num_shared_blocks) + # Slice to get (num_query_tokens, actual_chunk_len) for this chunk + chunk_block_usages = chunked_data.block_usages[:, chunk_start:chunk_end] + + # Generate chunk bias using dense broadcast into pre-allocated buffer + # chunk_block_usages.unsqueeze(-1): (num_query_tokens, actual_chunk_len, 1) + # broadcast comparison: (num_query_tokens, actual_chunk_len, block_size) + chunk_mask = block_len_range > chunk_block_usages.unsqueeze(-1) + + # Use view of pre-allocated buffer for actual chunk size + chunk_bias = chunk_bias_buffer[:, :actual_chunk_len, :] + chunk_bias.zero_() + chunk_bias.masked_fill_(chunk_mask, -math.inf) + else: + # Pre-computed: slice from full bias tensor + chunk_bias = bias[:, chunk_start:chunk_end, :] - attn = torch.matmul(attn.unflatten(0, (kv_heads if not is_mla else num_heads, -1)), value).flatten(0, 1) + # Fetch KV for this chunk + if is_mla: + latent_kv = cache_utils.fetch_shared(chunk_blocks) + num_heads = query.size(1) + query_t = query.transpose(0, 1).unsqueeze(1) + key = latent_kv.unsqueeze(0).unsqueeze(0).expand(num_heads, 1, -1, -1) + value = latent_kv.unsqueeze(0).unsqueeze(0).expand(num_heads, 1, -1, -1) + else: + query_t = query.transpose(0, 1).unflatten(0, (kv_heads, -1)) + key, value = cache_utils.fetch_shared(chunk_blocks) + + # Flatten bias for attention: [1, query_len, chunk_len * block_size] + chunk_bias_flat = chunk_bias.flatten(-2, -1).unsqueeze(0) + + # Compute attention for this chunk + chunk_attn, chunk_max, chunk_sum = _partial_attn_shared_core(query_t, key, value, chunk_bias_flat, fmin, + inputL_hpu_tensors, inputM_hpu_tensors, kv_heads, + is_mla, w_uv) + + # Online merge: combine this chunk with accumulated results + if accumulated_attn is None: + # First chunk - just store + accumulated_attn = chunk_attn + global_max = chunk_max + global_sum = chunk_sum + else: + # Merge with existing - use flash-attention style rescaling + new_max = torch.maximum(global_max, chunk_max) - # MLA: Extract latent part and project to full V - if is_mla: - latent_dim = w_uv.size(1) - attn_latent = attn[..., :latent_dim] # Extract only latent dimension (exclude rope_dim) - attn = torch.bmm(attn_latent, w_uv) # [num_heads, tokens, v_head_dim] + # Rescale factors + old_scale = torch.exp(global_max - new_max) + new_scale = torch.exp(chunk_max - new_max) - return attn.transpose(0, 1), local_max.transpose(0, 1), local_sum.transpose(0, 1) + # Rescale accumulated values and sums + accumulated_attn = accumulated_attn * old_scale.unsqueeze(-1) + chunk_attn * new_scale.unsqueeze(-1) + global_sum = global_sum * old_scale + chunk_sum * new_scale + global_max = new_max + + if split_graphs: + htorch.core.mark_step() + + if accumulated_attn is None: + return (None, None, None) + + return accumulated_attn, global_max, global_sum def partial_attn_unique(query: torch.tensor, @@ -391,6 +691,9 @@ class HPUUnifiedAttentionMetadata: causal_width: int shared_blocks: Optional[torch.tensor] shared_bias: Optional[torch.tensor] + # Chunked bias data for chunk-wise computation (used when shared_bias is None but shared_blocks exists) + shared_bias_chunked: Optional[SharedBlockChunkedBiasData] + shared_chunk_size: int # Number of blocks to process per chunk (0 = use full bias) unique_blocks: Optional[torch.tensor] | Optional[int] unique_block_mapping: Optional[torch.tensor] unique_bias: Optional[torch.tensor] @@ -398,6 +701,8 @@ class HPUUnifiedAttentionMetadata: feps: torch.tensor inputL_hpu_tensors: Optional[Dict[tuple, torch.Tensor]] inputM_hpu_tensors: Optional[Dict[tuple, torch.Tensor]] + online_merge: bool + split_graphs: bool def seq_len(self): # TODO: This needs to be changed in case of mixed batches @@ -426,6 +731,17 @@ def unified_attn(query: torch.tensor, key: torch.tensor, value: torch.tensor, ke scaled_query = query * scale cache_utils = CacheUtils(key_cache, value_cache, metadata.block_size) + use_online_merge = metadata.online_merge + split_graphs = metadata.split_graphs + + if use_online_merge: + # Online merge: compute and merge incrementally to avoid large intermediate buffers + acc_attn, acc_max, acc_sum = None, None, None + + if split_graphs: + htorch.core.mark_step() + + # 1. Causal attention causal = partial_attn_causal(query=scaled_query, key=key, value=value, @@ -435,6 +751,13 @@ def unified_attn(query: torch.tensor, key: torch.tensor, value: torch.tensor, ke inputL_hpu_tensors=metadata.inputL_hpu_tensors, inputM_hpu_tensors=metadata.inputM_hpu_tensors, w_uv=None) + if use_online_merge: + acc_attn, acc_max, acc_sum = online_merge_step(acc_attn, acc_max, acc_sum, *causal) + + if split_graphs: + htorch.core.mark_step() + + # 2. Shared attention shared = partial_attn_shared(query=scaled_query, blocks=metadata.shared_blocks, bias=metadata.shared_bias, @@ -442,7 +765,17 @@ def unified_attn(query: torch.tensor, key: torch.tensor, value: torch.tensor, ke inputL_hpu_tensors=metadata.inputL_hpu_tensors, inputM_hpu_tensors=metadata.inputM_hpu_tensors, cache_utils=cache_utils, - w_uv=None) + dtype=query.dtype, + w_uv=None, + chunked_data=metadata.shared_bias_chunked, + chunk_size=metadata.shared_chunk_size) + if use_online_merge: + acc_attn, acc_max, acc_sum = online_merge_step(acc_attn, acc_max, acc_sum, *shared) + + if split_graphs: + htorch.core.mark_step() + + # 3. Unique attention unique = partial_attn_unique(query=scaled_query, blocks=metadata.unique_blocks, block_mapping=metadata.unique_block_mapping, @@ -450,9 +783,22 @@ def unified_attn(query: torch.tensor, key: torch.tensor, value: torch.tensor, ke fmin=metadata.fmin, cache_utils=cache_utils, w_uv=None) - attn = merge(causal, shared, unique, feps=metadata.feps) - if attn is None: - return query + if use_online_merge: + acc_attn, acc_max, acc_sum = online_merge_step(acc_attn, acc_max, acc_sum, *unique) + + if split_graphs: + htorch.core.mark_step() + + # Final normalization + if use_online_merge: + if acc_attn is None: + return query + acc_sum = torch.maximum(acc_sum, metadata.feps) + attn = acc_attn / acc_sum.unsqueeze(-1) + else: + attn = merge(causal, shared, unique, feps=metadata.feps) + if attn is None: + return query return attn @@ -477,6 +823,9 @@ def unified_mla(query: Optional[torch.tensor], w_uv: Projection matrix from latent to full V [num_heads, latent_dim, v_head_dim] query_latent: Query tensor for cached path (in latent space) [tokens, num_heads, latent_dim + rope_dim] None if only causal attention is needed. + use_online_merge: If True, use online (incremental) merge algorithm. + Merges after each partial attention to avoid large intermediate buffers. + If False, use offline (single-pass) merge algorithm. Returns: Attention output [tokens, num_heads * v_head_dim] @@ -495,6 +844,15 @@ def unified_mla(query: Optional[torch.tensor], # MLA: latent cache has no head dimension, value_cache is None (stored in same cache) cache_utils = CacheUtils(latent_cache, value_cache=None, block_size=metadata.block_size, is_mla=True) + use_online_merge = metadata.online_merge + split_graphs = metadata.split_graphs + + if use_online_merge: + # Online merge: compute and merge incrementally to avoid large intermediate buffers + acc_attn, acc_max, acc_sum = None, None, None + + if split_graphs: + htorch.core.mark_step() # Causal: compute-friendly path (expand K/V from latent) # key and value already expanded by caller @@ -508,18 +866,36 @@ def unified_mla(query: Optional[torch.tensor], inputL_hpu_tensors=metadata.inputL_hpu_tensors, inputM_hpu_tensors=metadata.inputM_hpu_tensors, w_uv=w_uv) if scaled_query_causal is not None else (None, None, None) + if use_online_merge: + acc_attn, acc_max, acc_sum = online_merge_step(acc_attn, acc_max, acc_sum, *causal) + + if split_graphs: + htorch.core.mark_step() # Shared/Unique: memory-friendly path (Q in latent space, fetch cached latent KV) # query_latent is already transformed to latent space by caller # For these paths, we need to expand K/V from cached latent vectors - shared = partial_attn_shared(query=scaled_query_latent, - blocks=metadata.shared_blocks, - bias=metadata.shared_bias, - fmin=metadata.fmin, - inputL_hpu_tensors=metadata.inputL_hpu_tensors, - inputM_hpu_tensors=metadata.inputM_hpu_tensors, - cache_utils=cache_utils, - w_uv=w_uv) if scaled_query_latent is not None else (None, None, None) + + # Single call handles both full and chunked modes + if scaled_query_latent is not None: + shared = partial_attn_shared(query=scaled_query_latent, + blocks=metadata.shared_blocks, + bias=metadata.shared_bias, + fmin=metadata.fmin, + inputL_hpu_tensors=metadata.inputL_hpu_tensors, + inputM_hpu_tensors=metadata.inputM_hpu_tensors, + cache_utils=cache_utils, + dtype=scaled_query_latent.dtype, + w_uv=w_uv, + chunked_data=metadata.shared_bias_chunked, + chunk_size=metadata.shared_chunk_size) + if use_online_merge: + acc_attn, acc_max, acc_sum = online_merge_step(acc_attn, acc_max, acc_sum, *shared) + else: + shared = (None, None, None) + + if split_graphs: + htorch.core.mark_step() unique = partial_attn_unique(query=scaled_query_latent, blocks=metadata.unique_blocks, @@ -528,14 +904,28 @@ def unified_mla(query: Optional[torch.tensor], fmin=metadata.fmin, cache_utils=cache_utils, w_uv=w_uv) if scaled_query_latent is not None else (None, None, None) + if use_online_merge: + acc_attn, acc_max, acc_sum = online_merge_step(acc_attn, acc_max, acc_sum, *unique) - attn = merge(causal, shared, unique, feps=metadata.feps) - if attn is None: - # No attention computed, return original query - # Use whichever query was provided - # FIXME(kzawora): I'm not quite sure if that's correct, needs verification - if query is not None: - return query.flatten(-2, -1) # [tokens, num_heads * head_dim] - else: - return query_latent.flatten(-2, -1) # [tokens, num_heads * head_dim] + if split_graphs: + htorch.core.mark_step() + + if use_online_merge: + if acc_attn is None: + if query is not None: + return query.flatten(-2, -1) # [tokens, num_heads * head_dim] + else: + return query_latent.flatten(-2, -1) # [tokens, num_heads * head_dim] + acc_sum = torch.maximum(acc_sum, metadata.feps) + attn = acc_attn / acc_sum.unsqueeze(-1) + else: + attn = merge(causal, shared, unique, feps=metadata.feps) + if attn is None: + # No attention computed, return original query + # Use whichever query was provided + # FIXME(kzawora): I'm not quite sure if that's correct, needs verification + if query is not None: + return query.flatten(-2, -1) # [tokens, num_heads * head_dim] + else: + return query_latent.flatten(-2, -1) # [tokens, num_heads * head_dim] return attn diff --git a/vllm_gaudi/extension/unified_batch.py b/vllm_gaudi/extension/unified_batch.py index 07ccc3911..e1b8ade5a 100644 --- a/vllm_gaudi/extension/unified_batch.py +++ b/vllm_gaudi/extension/unified_batch.py @@ -2,7 +2,7 @@ import numpy as np import habana_frameworks.torch as htorch from dataclasses import dataclass -from vllm_gaudi.extension.unified import HPUUnifiedAttentionMetadata, get_vecsize_packsize, get_last_dim_size +from vllm_gaudi.extension.unified import HPUUnifiedAttentionMetadata, SharedBlockChunkedBiasData, get_vecsize_packsize, get_last_dim_size from vllm.v1.spec_decode.metadata import SpecDecodeMetadata import math from typing import Optional, Callable, Union @@ -312,19 +312,28 @@ def __init__(self, max_num_batched_tokens, max_shared_blocks, max_unique_blocks, estimated_shared_bias_mem = (max_num_batched_tokens * max_shared_blocks * block_size * np.dtype(np_dtype).itemsize) + (max_shared_blocks * block_size * block_size * np.dtype(np_dtype).itemsize) - # NOTE(kzawora): 64GiB is an arbitrary threshold to avoid OOMs when allocating large shared bias buffers - shared_bias_mem_threshold = 64 * 2**30 - self.use_persistent_shared_biases = estimated_shared_bias_mem <= shared_bias_mem_threshold - if self.use_persistent_shared_biases: - self.shared_bias = np.full((max_num_batched_tokens, max_shared_blocks, block_size), - -math.inf, - dtype=np_dtype) - # NOTE(kzawora): shared block bias is a weird entity - it maps block usage to each individual token in the context - - # so the upper bound should be max_shared_blocks*block_size (max_num_shared_tokens) by block_size - self.shared_block_bias = np.full((max_shared_blocks * block_size, block_size), -math.inf, dtype=np_dtype) + + self.use_dense_shared_bias = get_config().unified_attn_dense_shared_bias + if self.use_dense_shared_bias: + # Dense block_usages for chunked shared attention - shape (max_qlen, max_shared_blocks) + # Value 0 means "masked out entirely" (will produce all -inf bias) + self.block_usages_dense = np.zeros((max_num_batched_tokens, max_shared_blocks), dtype=np.int32) else: - self.shared_bias = None - self.shared_block_bias = None + # NOTE(kzawora): 64GiB is an arbitrary threshold to avoid OOMs when allocating large shared bias buffers + shared_bias_mem_threshold = 64 * 2**30 + self.use_persistent_shared_biases = estimated_shared_bias_mem <= shared_bias_mem_threshold + if self.use_persistent_shared_biases: + self.shared_bias = np.full((max_num_batched_tokens, max_shared_blocks, block_size), + -math.inf, + dtype=np_dtype) + # NOTE(kzawora): shared block bias is a weird entity - it maps block usage to each individual token in the context - + # so the upper bound should be max_shared_blocks*block_size (max_num_shared_tokens) by block_size + self.shared_block_bias = np.full((max_shared_blocks * block_size, block_size), + -math.inf, + dtype=np_dtype) + else: + self.shared_bias = None + self.shared_block_bias = None self.unique_bias = np.full((max_unique_blocks, block_size), -math.inf, dtype=np_dtype) self.unique_block_bias = np.full((max_unique_blocks, block_size), -math.inf, dtype=np_dtype) @@ -334,6 +343,7 @@ def __init__(self, max_num_batched_tokens, max_shared_blocks, max_unique_blocks, self.causal_bias_generator = HPUCausalBiasGenerator() self.shared_bias_generator = HPUSharedBiasGenerator() + self.shared_bias_generator_dense = HPUSharedBiasGeneratorDense() self.graphed = True if self.graphed: config = get_config() @@ -342,6 +352,8 @@ def __init__(self, max_num_batched_tokens, max_shared_blocks, max_unique_blocks, disable_tensor_cache=True) self.shared_bias_generator = htorch.hpu.wrap_in_hpu_graph(self.shared_bias_generator, disable_tensor_cache=True) + self.shared_bias_generator_dense = htorch.hpu.wrap_in_hpu_graph(self.shared_bias_generator_dense, + disable_tensor_cache=True) elif config.bridge_mode == 'eager': self.causal_bias_generator = torch.compile(self.causal_bias_generator, backend='hpu_backend', @@ -351,6 +363,10 @@ def __init__(self, max_num_batched_tokens, max_shared_blocks, max_unique_blocks, backend='hpu_backend', fullgraph=True, dynamic=False) + self.shared_bias_generator_dense = torch.compile(self.shared_bias_generator_dense, + backend='hpu_backend', + fullgraph=True, + dynamic=False) self.hpu_tensor_online_padding = False if not self.hpu_tensor_online_padding: # NOTE(kzawora): Dynamic mempool caches - store largest placeholders needed for each (pad_value, dtype) @@ -486,7 +502,7 @@ class HPUSharedBiasGenerator(HPUBiasGenerator): def forward(self, block_usages: torch.tensor, hpu_shared_token_idx: torch.tensor, hpu_shared_block_idx: torch.tensor, block_size: torch.tensor, dtype: torch.dtype, target_qlen, target_shared_blocks) -> torch.tensor: - """ Generate block bias based on block_usage """ + """ Generate block bias based on block_usage (sparse scatter version) """ block_len_range = torch.arange(1, block_size + 1, dtype=block_usages.dtype, device=block_usages.device) block_mask = block_len_range.unsqueeze(0) > block_usages.unsqueeze(-1) hpu_shared_block_bias = self.mask_to_bias_torch(block_mask, dtype=dtype) @@ -498,6 +514,147 @@ def forward(self, block_usages: torch.tensor, hpu_shared_token_idx: torch.tensor return hpu_shared_bias +class HPUSharedBiasGeneratorDense(HPUBiasGenerator): + """ + Dense version of shared bias generator - takes pre-scattered block_usages + of shape (target_qlen, target_shared_blocks) instead of sparse coordinates. + + This avoids dynamic-length coordinate arrays on HPU by doing the scatter on CPU. + """ + + def forward(self, block_usages_dense: torch.tensor, block_size: int, dtype: torch.dtype) -> torch.tensor: + """ + Generate block bias from dense block_usages. + + Args: + block_usages_dense: Shape (target_qlen, target_shared_blocks), values are block usage counts (0 = masked out) + block_size: Size of each block + dtype: Output dtype + + Returns: + Shape (target_qlen, target_shared_blocks, block_size) bias tensor + """ + # block_usages_dense: (target_qlen, target_shared_blocks) + # We want: block_mask[q, b, k] = True if k >= block_usages_dense[q, b] + # Which means: mask out positions k where k+1 > block_usages_dense[q, b] + block_len_range = torch.arange(1, + block_size + 1, + dtype=block_usages_dense.dtype, + device=block_usages_dense.device) + # block_len_range: (block_size,) + # block_usages_dense.unsqueeze(-1): (target_qlen, target_shared_blocks, 1) + # broadcast comparison: (target_qlen, target_shared_blocks, block_size) + block_mask = block_len_range > block_usages_dense.unsqueeze(-1) + return self.mask_to_bias_torch(block_mask, dtype=dtype) + + +def _prepare_shared_bias_hpu( + persistent_ctx: UnifiedBatchPersistentContext, + attn_metadata: 'HPUUnifiedAttentionMetadata', + shared_token_idx: np.ndarray, + shared_block_idx: np.ndarray, + shared_block_usage: np.ndarray, + shared_blocks: np.ndarray, + target_qlen: int, + target_shared_blocks: int, + query_len: int, + block_size: int, + dtype: torch.dtype, + np_dtype: np.dtype, + slot_mapping_dtype: torch.dtype, + use_chunked_processing: bool, + use_dense_bias_generation: bool, +) -> None: + """ + Prepare shared bias tensors on HPU. + + This function handles three approaches for shared bias generation: + 1. Chunked dense: For large shared blocks, generate bias per-chunk during attention + 2. Non-chunked dense: Scatter on CPU, broadcast on HPU (static shapes) + 3. Sparse (legacy): Dynamic scatter on HPU with fallback to CPU in case of too many shared tokens. + + Modifies attn_metadata.shared_bias and attn_metadata.shared_bias_chunked in place. + """ + if use_chunked_processing: + with persistent_ctx.profiler.record_event('internal', 'shared_bias_chunked_prep'): + # CHUNKED DENSE APPROACH: + # - Scatter block_usages into dense (target_qlen, target_shared_blocks) on CPU + # - Don't generate full bias - just pass the dense block_usages + # - Attention code will generate bias per chunk by slicing block_usages + + # Use persistent buffer - get view of required size and zero it + block_usages_dense = persistent_ctx.block_usages_dense[:target_qlen, :target_shared_blocks] + block_usages_dense.fill(0) # Reset to 0 (fully masked) + + # Scatter: block_usages_dense[token_idx, block_idx] = block_usage value + block_usages_dense[shared_token_idx, shared_block_idx] = shared_block_usage + + # Transfer dense tensor to HPU - shape is fully static (target_qlen, target_shared_blocks) + hpu_block_usages_dense = persistent_ctx.hpu_tensor(block_usages_dense, (target_qlen, target_shared_blocks), + 0, torch.int32) + + # DON'T generate full bias - attention code will generate per chunk + attn_metadata.shared_bias = None + attn_metadata.shared_bias_chunked = SharedBlockChunkedBiasData( + block_usages=hpu_block_usages_dense, + num_query_tokens=target_qlen, + num_shared_blocks=target_shared_blocks, + split_chunked_graphs=get_config().unified_attn_split_graphs, + ) + return + + # Non-chunked paths + if use_dense_bias_generation: + with persistent_ctx.profiler.record_event('internal', 'shared_bias_dense_prep'): + # DENSE APPROACH: Scatter on CPU (any shape), broadcast on HPU (static shape) + block_usages_dense = persistent_ctx.block_usages_dense[:target_qlen, :target_shared_blocks] + block_usages_dense.fill(0) + block_usages_dense[shared_token_idx, shared_block_idx] = shared_block_usage + + hpu_block_usages_dense = persistent_ctx.hpu_tensor(block_usages_dense, (target_qlen, target_shared_blocks), + 0, torch.int32) + + attn_metadata.shared_bias = persistent_ctx.shared_bias_generator_dense(hpu_block_usages_dense, block_size, + dtype) + return + + # SPARSE APPROACH (legacy): Dynamic scatter on HPU with CPU fallback + actual_num_shared_tokens = shared_block_usage.shape[0] + padded_num_shared_tokens = target_shared_blocks * block_size + + if padded_num_shared_tokens < actual_num_shared_tokens: + # Too many shared tokens - fall back to CPU generation + with persistent_ctx.profiler.record_event('internal', 'shared_bias_cpu_fallback'): + shared_block_bias = generate_bias(shared_block_usage, block_size, np_dtype, persistent_ctx.block_len_range, + persistent_ctx.shared_block_bias) + + if persistent_ctx.use_persistent_shared_biases: + shared_bias = persistent_ctx.shared_bias[:query_len, :shared_blocks.shape[0], :block_size] + else: + shared_bias = np.full((query_len, shared_blocks.shape[0], block_size), -math.inf, dtype=np_dtype) + + shared_bias.fill(-math.inf) + shared_bias[shared_token_idx, shared_block_idx] = shared_block_bias + attn_metadata.shared_bias = persistent_ctx.hpu_tensor(shared_bias, + (target_qlen, target_shared_blocks, block_size), + -math.inf, dtype) + else: + # HPU-accelerated sparse generation + with persistent_ctx.profiler.record_event('internal', 'shared_bias_hpu_prep'): + shared_tokens_shape = (padded_num_shared_tokens, ) + hpu_shared_block_usage = persistent_ctx.hpu_tensor(shared_block_usage, shared_tokens_shape, -1, + slot_mapping_dtype) + hpu_shared_token_idx = persistent_ctx.hpu_tensor(shared_token_idx, shared_tokens_shape, -1, + slot_mapping_dtype) + hpu_shared_block_idx = persistent_ctx.hpu_tensor(shared_block_idx, shared_tokens_shape, -1, + slot_mapping_dtype) + + attn_metadata.shared_bias = persistent_ctx.shared_bias_generator(hpu_shared_block_usage, + hpu_shared_token_idx, hpu_shared_block_idx, + block_size, dtype, target_qlen, + target_shared_blocks) + + def create_unified_batch( req_ids: list[str], all_token_ids: torch.Tensor, @@ -665,6 +822,26 @@ def first_dim(t: Optional[np.ndarray]) -> int: fmin = torch.finfo(dtype).min feps = torch.finfo(dtype).tiny + # Determine if we should use chunked computation for shared blocks + # NOTE(kzawora): Chunked processing computes attention in chunks to save memory. + # With chunked dense generation, we only allocate (target_qlen, target_shared_blocks) for block_usages + # instead of the full (target_qlen, target_shared_blocks, block_size) bias tensor. + # Bias is generated per chunk: (target_qlen, chunk_size, block_size) + default_chunk_size = get_config( + ).unified_attn_shared_attn_chunk_size # Process up to 64 blocks at a time for shared attention + use_chunked_processing = get_config().unified_attn_chunked_shared_attn and bool( + target_shared_blocks > default_chunk_size) # Chunked dense processing - generates bias per chunk + + # Pad target_shared_blocks to be a multiple of chunk_size for chunked processing + # This ensures all chunks have exactly chunk_size blocks (static shapes in the kernel) + if use_chunked_processing and target_shared_blocks % default_chunk_size != 0: + target_shared_blocks = ( + (target_shared_blocks + default_chunk_size - 1) // default_chunk_size) * default_chunk_size + + # Dense bias generation: scatter on CPU (any shape), then broadcast on HPU (static shape) + # This avoids dynamic-length coordinate arrays on HPU entirely + use_dense_bias_generation = persistent_ctx.use_dense_shared_bias + with persistent_ctx.profiler.record_event('internal', 'attn_metadata_prep'): attn_metadata = HPUUnifiedAttentionMetadata( block_size=block_size, @@ -674,9 +851,11 @@ def first_dim(t: Optional[np.ndarray]) -> int: target_qlen), -math.inf, dtype) if causal_bias is not None else None, causal_width=default_causal_width, shared_blocks=persistent_ctx.hpu_tensor(shared_blocks, (target_shared_blocks, ), -1, slot_mapping_dtype), - shared_bias=persistent_ctx.hpu_tensor(shared_bias, - (target_qlen, target_shared_blocks, - block_size), -math.inf, dtype) if shared_bias is not None else None, + # For chunked processing: still allocate full bias for now (stepping stone to verify correctness) + # shared_bias will be set below after HPU acceleration + shared_bias=None, # Will be set below + shared_bias_chunked=None, # Will be set below if chunked processing is enabled + shared_chunk_size=default_chunk_size if use_chunked_processing else 0, unique_blocks=target_unique_blocks, unique_block_mapping=persistent_ctx.hpu_tensor(unique_block_mapping, (target_unique_blocks, ), -1, slot_mapping_dtype), @@ -685,6 +864,8 @@ def first_dim(t: Optional[np.ndarray]) -> int: feps=to_hpu(feps, dtype=dtype), inputL_hpu_tensors=dict(), inputM_hpu_tensors=dict(), + split_graphs=get_config().unified_attn_split_graphs, + online_merge=get_config().unified_attn_online_merge, ) if hpu_bias_acceleration: @@ -700,47 +881,23 @@ def first_dim(t: Optional[np.ndarray]) -> int: attn_metadata.causal_bias = persistent_ctx.causal_bias_generator(hpu_token_groups, hpu_token_positions, hpu_padding_mask, dtype) if do_shared: - # NOTE(kzawora): this is kinda janky, but for a good reason - the number of shared tokens can vary significantly, - # and it impacts whether it's even worth running on HPU. - # On HPU, we need to avoid dynamic shapes == we need to pad number of shared tokens. - # It's currently padded to target_shared_blocks * block_size - # We set some simple heuristics to decide whether to run on HPU or fallback to CPU: - # 1. Check the number of shared tokens. If it's greater than the padded number of shared tokens, we fallback to CPU. - # 2. Check the padding ratio - if the padding exceeds 50% of actual size, we fallback to CPU. - - actual_num_shared_tokens = shared_block_usage.shape[0] - padded_num_shared_tokens = target_shared_blocks * block_size - # NOTE(kzawora): Initially we checked for padding ratio as well, but ultimately I've found no cases in - # which generating mask on padded_num_shared_tokens was slower than CPU + copying to HPU - # in case we have too many or too little shared tokens, we fall back to cpu generation - if padded_num_shared_tokens < actual_num_shared_tokens: - with persistent_ctx.profiler.record_event('internal', 'shared_bias_cpu_fallback'): - shared_block_bias = generate_bias(shared_block_usage, block_size, np_dtype, - persistent_ctx.block_len_range, persistent_ctx.shared_block_bias) - if persistent_ctx.use_persistent_shared_biases: - shared_bias = persistent_ctx.shared_bias[:query_len, :shared_blocks.shape[0], :block_size] - else: - shared_bias = np.full((query_len, shared_blocks.shape[0], block_size), - -math.inf, - dtype=np_dtype) - shared_bias.fill(-math.inf) - shared_bias[shared_token_idx, shared_block_idx] = shared_block_bias - attn_metadata.shared_bias = persistent_ctx.hpu_tensor( - shared_bias, (target_qlen, target_shared_blocks, block_size), -math.inf, dtype) - else: - with persistent_ctx.profiler.record_event('internal', 'shared_bias_hpu_prep'): - # do HPU-accelerated shared mask generation - shared_tokens_shape = (padded_num_shared_tokens, ) - hpu_shared_block_usage = persistent_ctx.hpu_tensor(shared_block_usage, shared_tokens_shape, -1, - slot_mapping_dtype) - hpu_shared_token_idx = persistent_ctx.hpu_tensor(shared_token_idx, shared_tokens_shape, -1, - slot_mapping_dtype) - hpu_shared_block_idx = persistent_ctx.hpu_tensor(shared_block_idx, shared_tokens_shape, -1, - slot_mapping_dtype) - hpu_shared_bias = persistent_ctx.shared_bias_generator(hpu_shared_block_usage, hpu_shared_token_idx, - hpu_shared_block_idx, block_size, dtype, - target_qlen, target_shared_blocks) - attn_metadata.shared_bias = hpu_shared_bias + _prepare_shared_bias_hpu( + persistent_ctx=persistent_ctx, + attn_metadata=attn_metadata, + shared_token_idx=shared_token_idx, + shared_block_idx=shared_block_idx, + shared_block_usage=shared_block_usage, + shared_blocks=shared_blocks, + target_qlen=target_qlen, + target_shared_blocks=target_shared_blocks, + query_len=query_len, + block_size=block_size, + dtype=dtype, + np_dtype=np_dtype, + slot_mapping_dtype=slot_mapping_dtype, + use_chunked_processing=use_chunked_processing, + use_dense_bias_generation=use_dense_bias_generation, + ) token_ids_device = persistent_ctx.hpu_tensor(token_ids, (target_qlen, ), -1, token_ids_dtype) logits_indices_device = persistent_ctx.hpu_tensor(logits_indices, (target_logits, ), -1, logits_indices_dtype) diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 1683ac426..400319205 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -2599,7 +2599,9 @@ def _check_config(self, batch_size, seq_len, num_blocks, attn_metadata, warmup_m def _get_unified_config(self, attn_metadata, logits_indices): has_causal = 'c' if attn_metadata.causal_bias is not None else '-' - has_shared = 's' if attn_metadata.shared_bias is not None else '-' + has_shared_bias = attn_metadata.shared_bias is not None + has_chunked_bias = attn_metadata.shared_bias_chunked is not None + has_shared = 's' if has_shared_bias or has_chunked_bias else '-' has_unique = 'u' if attn_metadata.unique_bias is not None else '-' phase = has_causal + has_shared + has_unique qlen = attn_metadata.slot_mapping.size(0) @@ -4436,7 +4438,6 @@ def _prepare_dummy_unified_scenario(self, unified_cfg): for request_blocks in split_shared_blocks_ids: self._add_dummy_unified_request(requests, False, False, request_blocks, num_computed_tokens, 1, scheduled_tokens) - self._execute_dummy_scenario(requests, scheduled_tokens) def _prepare_dummy_scenario(self, prompt_cfg, decode_cfg):