diff --git a/vllm/_ipex_ops.py b/vllm/_ipex_ops.py index c3b6ca7ebe48..8dd101608f52 100644 --- a/vllm/_ipex_ops.py +++ b/vllm/_ipex_ops.py @@ -84,10 +84,10 @@ def paged_attention_v1( # todo: ipex will refactor namespace import vllm._C.ops vllm._C.ops.paged_attention_v1(out, query, - key_cache.view_as(value_cache), - value_cache, num_kv_heads, scale, - block_tables, context_lens, block_size, - max_context_len, alibi_slopes, kv_cache_dtype, k_scale, logits_soft_cap) + key_cache.view_as(value_cache), + value_cache, num_kv_heads, scale, + block_tables, context_lens, block_size, + max_context_len, alibi_slopes, kv_cache_dtype, k_scale, logits_soft_cap) @staticmethod def paged_attention_v2( @@ -365,29 +365,46 @@ def chunked_prefill( p_dropout: float, softmax_scale: float, zero_tensors: bool, - is_caual: bool, + is_casual: bool, return_softmax: bool, gen_: Optional[torch.Generator], ): - return torch.ops.torch_ipex.chunked_prefill( + return ipex.llm.modules.PagedAttention.flash_attn_varlen_func( + output, query.contiguous(), key_cache, value_cache, - output, cu_seqlens_q, cu_seqlens_k, - seq_used_k, - block_table, - alibi_slopes, max_seqlen_q, max_seqlen_k, - p_dropout, softmax_scale, - zero_tensors, - is_caual, - return_softmax, - gen_, + is_casual, + block_table, + alibi_slopes, + k_scale=1.0, + v_scale=1.0, ) + # return torch.ops.torch_ipex.chunked_prefill( + # query.contiguous(), + # key_cache, + # value_cache, + # output, + # cu_seqlens_q, + # cu_seqlens_k, + # seq_used_k, + # block_table, + # alibi_slopes, + # max_seqlen_q, + # max_seqlen_k, + # p_dropout, + # softmax_scale, + # zero_tensors, + # is_caual, + # return_softmax, + # gen_, + # ) + @staticmethod def copy_blocks(key_caches: List[torch.Tensor], diff --git a/vllm/utils.py b/vllm/utils.py index 5f32f8cb66a5..2ee0c1906790 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -128,6 +128,8 @@ "STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER": STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER, } +BMG_TARGET_IDS = ["0xe20b", "0xe210"] + # Constants related to forcing the attention backend selection # String name of register which may be set in order to @@ -2564,3 +2566,14 @@ def sha256(input) -> int: input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL) return int.from_bytes(hashlib.sha256(input_bytes).digest(), byteorder="big") + +@cache +def is_bmg_platform(): + if not torch.xpu.is_available(): + raise ValueError("Cannot detect the usage of XPU!") + device_index = torch.xpu.current_device() + device_name = torch.xpu.get_device_name(device_index) + for target_id in BMG_TARGET_IDS: + if target_id in device_name: + return True + return False \ No newline at end of file diff --git a/vllm/v1/attention/backends/ipex_attn.py b/vllm/v1/attention/backends/ipex_attn.py index 29cde02f3007..f4a435eaa1a8 100644 --- a/vllm/v1/attention/backends/ipex_attn.py +++ b/vllm/v1/attention/backends/ipex_attn.py @@ -11,6 +11,7 @@ from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.attention.backends.ipex_attn import use_gqa_kernel +from vllm.utils import is_bmg_platform import os @dataclass @@ -46,9 +47,9 @@ def get_kv_cache_shape( # if block_size % 16 != 0: # raise ValueError("Block size must be a multiple of 16.") # This needs to be changed... - # return (2, num_blocks, block_size, num_kv_heads, head_size) - return PagedAttention.get_kv_cache_shape(num_blocks, block_size, - num_kv_heads, head_size) + return (2, num_blocks, block_size, num_kv_heads, head_size) + # return PagedAttention.get_kv_cache_shape(num_blocks, block_size, + # num_kv_heads, head_size) @@ -94,6 +95,8 @@ def __init__( self.num_queries_per_kv = self.num_heads // self.num_kv_heads support_head_sizes = IPEXAttentionBackend.get_supported_head_sizes() + self.using_gqa_kernel = use_gqa_kernel(num_heads, num_kv_heads, head_size, logits_soft_cap) + self.is_bmg_platform = is_bmg_platform() if head_size not in support_head_sizes: raise ValueError( f"Head size {head_size} is not supported by FlashAttention. " @@ -104,7 +107,6 @@ def __init__( "are not implemented for " "IpexAttnBackendImpl") - # TODO(gc): Refine this logic..., because of bad performance... def forward( self, layer: AttentionLayer, @@ -147,6 +149,8 @@ def forward( k_scale, v_scale, self.scale, + self.using_gqa_kernel, + self.is_bmg_platform, self.sliding_window, self.alibi_slopes, self.logits_soft_cap, @@ -219,6 +223,8 @@ def ipex_llm_chunked_prefill( k_scale: float, v_scale: float, scale: float, + using_gqa_kernel: bool, + is_bmg_platform: bool, sliding_window: Optional[List[int]] = None, alibi_slopes: Optional[torch.Tensor] = None, logits_soft_cap: Optional[float] = None, @@ -237,54 +243,82 @@ def ipex_llm_chunked_prefill( key = key.view(-1, num_kv_heads, head_size) value = value.view(-1, num_kv_heads, head_size) - using_gqa_kernel = use_gqa_kernel(num_heads, num_kv_heads, head_size, logits_soft_cap) - - - if using_gqa_kernel: - key_cache, value_cache = split_kv_cache_ipexllm( - kv_cache, num_kv_heads, head_size) - ipex_ops.reshape_and_cache_ipexllm( + if is_bmg_platform: + key_cache, value_cache = kv_cache.unbind(0) + ipex_ops.reshape_and_cache_flash( key[:num_actual_tokens], value[:num_actual_tokens], key_cache, value_cache, - attn_metadata.slot_mapping.flatten(), + attn_metadata.slot_mapping, kv_cache_dtype, k_scale, v_scale, ) - else: - key_cache, value_cache = split_kv_cache( - kv_cache, num_kv_heads, head_size) - ipex_ops.reshape_and_cache( - key[:num_actual_tokens], - value[:num_actual_tokens], + ipex_ops.chunked_prefill( + query[:num_actual_tokens].contiguous(), key_cache, value_cache, - attn_metadata.slot_mapping.flatten(), - kv_cache_dtype, - k_scale, - v_scale, + output[:num_actual_tokens], + attn_metadata.query_start_loc, + attn_metadata.seq_start_loc, + None, + attn_metadata.block_table, + alibi_slopes, + attn_metadata.max_query_len, + attn_metadata.max_seq_len, + 0.0, + scale, + False, + True, + False, + None, ) - # Invoke chunked prefill method... - import vllm._C.ops - assert head_size == 128 or head_size == 64 - value = os.environ.get('USE_CONTEXT_V1') - query_len = attn_metadata.query_start_loc[1:] - attn_metadata.query_start_loc[:-1] - seq_len = attn_metadata.seq_start_loc[1:] - attn_metadata.seq_start_loc[:-1] - context_len = seq_len - query_len - if using_gqa_kernel: - # if using_gqa_kernel, then only the v1 kernel can be used - out = vllm._C.ops.context_attention_forward_v1(query[:num_actual_tokens], key_cache, value_cache, attn_metadata.block_table, attn_metadata.query_start_loc, seq_len, context_len, attn_metadata.max_seq_len, torch.amax(context_len).item()) - elif value is None: - # Otherwise, by default use v2 attention forward kernel... - out = vllm._C.ops.context_attention_forward_v2(query[:num_actual_tokens], key_cache, value_cache, attn_metadata.block_table, attn_metadata.query_start_loc, seq_len, context_len, attn_metadata.max_seq_len, torch.amax(context_len).item(), torch.amax(query_len).item()) else: - out = vllm._C.ops.context_attention_forward_v1(query[:num_actual_tokens], key_cache, value_cache, attn_metadata.block_table, attn_metadata.query_start_loc, seq_len, context_len, attn_metadata.max_seq_len, torch.amax(context_len).item()) - - # output[:num_actual_tokens] = out - output[:num_actual_tokens] = out.view(out.shape[0], -1) + if using_gqa_kernel: + key_cache, value_cache = split_kv_cache_ipexllm( + kv_cache, num_kv_heads, head_size) + ipex_ops.reshape_and_cache_ipexllm( + key[:num_actual_tokens], + value[:num_actual_tokens], + key_cache, + value_cache, + attn_metadata.slot_mapping.flatten(), + kv_cache_dtype, + k_scale, + v_scale, + ) + else: + key_cache, value_cache = split_kv_cache( + kv_cache, num_kv_heads, head_size) + ipex_ops.reshape_and_cache( + key[:num_actual_tokens], + value[:num_actual_tokens], + key_cache, + value_cache, + attn_metadata.slot_mapping.flatten(), + kv_cache_dtype, + k_scale, + v_scale, + ) + # Invoke chunked prefill method... + import vllm._C.ops + assert head_size == 128 or head_size == 64 + value = os.environ.get('USE_CONTEXT_V1') + query_len = attn_metadata.query_start_loc[1:] - attn_metadata.query_start_loc[:-1] + seq_len = attn_metadata.seq_start_loc[1:] - attn_metadata.seq_start_loc[:-1] + context_len = seq_len - query_len + if using_gqa_kernel: + # if using_gqa_kernel, then only the v1 kernel can be used + out = vllm._C.ops.context_attention_forward_v1(query[:num_actual_tokens], key_cache, value_cache, attn_metadata.block_table, attn_metadata.query_start_loc, seq_len, context_len, attn_metadata.max_seq_len, torch.amax(context_len).item()) + elif value is None: + # Otherwise, by default use v2 attention forward kernel... + out = vllm._C.ops.context_attention_forward_v2(query[:num_actual_tokens], key_cache, value_cache, attn_metadata.block_table, attn_metadata.query_start_loc, seq_len, context_len, attn_metadata.max_seq_len, torch.amax(context_len).item(), torch.amax(query_len).item()) + else: + out = vllm._C.ops.context_attention_forward_v1(query[:num_actual_tokens], key_cache, value_cache, attn_metadata.block_table, attn_metadata.query_start_loc, seq_len, context_len, attn_metadata.max_seq_len, torch.amax(context_len).item()) + # output[:num_actual_tokens] = out + output[:num_actual_tokens] = out.view(out.shape[0], -1)