Skip to content
2 changes: 1 addition & 1 deletion vllm_gaudi/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,7 +1030,7 @@ def __init__(
self.qk_head_dim = qk_head_dim
self.v_head_dim = v_head_dim
self.kv_b_proj = kv_b_proj # Used to expand latent → full KV in causal path

self.use_online_merge = get_config().unified_attn_online_merge
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.latent_cache_k = VLLMKVCache() if not self.enable_fp8_attn \
Expand Down
20 changes: 17 additions & 3 deletions vllm_gaudi/extension/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,21 @@ def get_experimental_flags():
return to_dict(flags)


def unified_attn_dev_flags():
flags = [
Value('unified_attn_dense_shared_bias', True),
Value('unified_attn_chunked_shared_attn', True),
Value('unified_attn_online_merge', True),
Value('unified_attn_shared_attn_chunk_size', 64),
Value('unified_attn_split_graphs', Enabled('unified_attn_online_merge')),
Value(
'unified_attn_softmax_fa2',
All(VersionRange(">=1.24.0.279"), Enabled('unified_attn'), Kernel(softmax_fa2), Hardware('gaudi3'),
Not(Enabled('unified_attn_chunked_shared_attn')))),
]
return flags


def get_features():
supported_attn_impls = ['flex_impl', 'fsdpa_impl', 'naive_impl']
bucketing_strategies = ['exponential_bucketing', 'linear_bucketing']
Expand Down Expand Up @@ -90,12 +105,11 @@ def get_features():
Value('dynamic_shapes_compilation', True, env_var='VLLM_T_COMPILE_DYNAMIC_SHAPES', env_var_type=boolean),
Value('fullgraph_compilation', False, env_var='VLLM_T_COMPILE_FULLGRAPH', env_var_type=boolean),
Value('unified_attn', False),
Value('unified_attn_softmax_fa2',
All(VersionRange(">=1.24.0.279"), Enabled('unified_attn'), Kernel(softmax_fa2), Hardware('gaudi3'))),
*unified_attn_dev_flags(),
Value('scale_adjustment', True, env_var='VLLM_SCALE_ADJUSTMENT', env_var_type=boolean),
Value('flatten_input', Any(ModelType('qwen3_moe'), ModelType('granitemoe'), ModelType('glm4_moe'))),
Value('unified_attn_shared_cache_ratio',
0.8,
1,
env_var='VLLM_UNIFIED_ATTENTION_SHARED_CACHE_RATIO',
env_var_type=float),
Value('high_level_profiler_enabled', False, env_var='VLLM_PROFILER_ENABLED', env_var_type=boolean),
Expand Down
Loading
Loading