Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 32 additions & 15 deletions vllm/_ipex_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,10 @@ def paged_attention_v1(
# todo: ipex will refactor namespace
import vllm._C.ops
vllm._C.ops.paged_attention_v1(out, query,
key_cache.view_as(value_cache),
value_cache, num_kv_heads, scale,
block_tables, context_lens, block_size,
max_context_len, alibi_slopes, kv_cache_dtype, k_scale, logits_soft_cap)
key_cache.view_as(value_cache),
value_cache, num_kv_heads, scale,
block_tables, context_lens, block_size,
max_context_len, alibi_slopes, kv_cache_dtype, k_scale, logits_soft_cap)

@staticmethod
def paged_attention_v2(
Expand Down Expand Up @@ -365,29 +365,46 @@ def chunked_prefill(
p_dropout: float,
softmax_scale: float,
zero_tensors: bool,
is_caual: bool,
is_casual: bool,
return_softmax: bool,
gen_: Optional[torch.Generator],
):
return torch.ops.torch_ipex.chunked_prefill(
return ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
output,
query.contiguous(),
key_cache,
value_cache,
output,
cu_seqlens_q,
cu_seqlens_k,
seq_used_k,
block_table,
alibi_slopes,
max_seqlen_q,
max_seqlen_k,
p_dropout,
softmax_scale,
zero_tensors,
is_caual,
return_softmax,
gen_,
is_casual,
block_table,
alibi_slopes,
k_scale=1.0,
v_scale=1.0,
)
# return torch.ops.torch_ipex.chunked_prefill(
# query.contiguous(),
# key_cache,
# value_cache,
# output,
# cu_seqlens_q,
# cu_seqlens_k,
# seq_used_k,
# block_table,
# alibi_slopes,
# max_seqlen_q,
# max_seqlen_k,
# p_dropout,
# softmax_scale,
# zero_tensors,
# is_caual,
# return_softmax,
# gen_,
# )


@staticmethod
def copy_blocks(key_caches: List[torch.Tensor],
Expand Down
13 changes: 13 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@
"STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER": STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER,
}

BMG_TARGET_IDS = ["0xe20b", "0xe210"]

# Constants related to forcing the attention backend selection

# String name of register which may be set in order to
Expand Down Expand Up @@ -2564,3 +2566,14 @@ def sha256(input) -> int:
input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
return int.from_bytes(hashlib.sha256(input_bytes).digest(),
byteorder="big")

@cache
def is_bmg_platform():
if not torch.xpu.is_available():
raise ValueError("Cannot detect the usage of XPU!")
device_index = torch.xpu.current_device()
device_name = torch.xpu.get_device_name(device_index)
for target_id in BMG_TARGET_IDS:
if target_id in device_name:
return True
return False
112 changes: 73 additions & 39 deletions vllm/v1/attention/backends/ipex_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
from vllm.attention.backends.ipex_attn import use_gqa_kernel
from vllm.utils import is_bmg_platform
import os

@dataclass
Expand Down Expand Up @@ -46,9 +47,9 @@ def get_kv_cache_shape(
# if block_size % 16 != 0:
# raise ValueError("Block size must be a multiple of 16.")
# This needs to be changed...
# return (2, num_blocks, block_size, num_kv_heads, head_size)
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
num_kv_heads, head_size)
return (2, num_blocks, block_size, num_kv_heads, head_size)
# return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
# num_kv_heads, head_size)



Expand Down Expand Up @@ -94,6 +95,8 @@ def __init__(
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

support_head_sizes = IPEXAttentionBackend.get_supported_head_sizes()
self.using_gqa_kernel = use_gqa_kernel(num_heads, num_kv_heads, head_size, logits_soft_cap)
self.is_bmg_platform = is_bmg_platform()
if head_size not in support_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by FlashAttention. "
Expand All @@ -104,7 +107,6 @@ def __init__(
"are not implemented for "
"IpexAttnBackendImpl")

# TODO(gc): Refine this logic..., because of bad performance...
def forward(
self,
layer: AttentionLayer,
Expand Down Expand Up @@ -147,6 +149,8 @@ def forward(
k_scale,
v_scale,
self.scale,
self.using_gqa_kernel,
self.is_bmg_platform,
self.sliding_window,
self.alibi_slopes,
self.logits_soft_cap,
Expand Down Expand Up @@ -219,6 +223,8 @@ def ipex_llm_chunked_prefill(
k_scale: float,
v_scale: float,
scale: float,
using_gqa_kernel: bool,
is_bmg_platform: bool,
sliding_window: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
Expand All @@ -237,54 +243,82 @@ def ipex_llm_chunked_prefill(
key = key.view(-1, num_kv_heads, head_size)
value = value.view(-1, num_kv_heads, head_size)

using_gqa_kernel = use_gqa_kernel(num_heads, num_kv_heads, head_size, logits_soft_cap)


if using_gqa_kernel:
key_cache, value_cache = split_kv_cache_ipexllm(
kv_cache, num_kv_heads, head_size)
ipex_ops.reshape_and_cache_ipexllm(
if is_bmg_platform:
key_cache, value_cache = kv_cache.unbind(0)
ipex_ops.reshape_and_cache_flash(
key[:num_actual_tokens],
value[:num_actual_tokens],
key_cache,
value_cache,
attn_metadata.slot_mapping.flatten(),
attn_metadata.slot_mapping,
kv_cache_dtype,
k_scale,
v_scale,
)
else:
key_cache, value_cache = split_kv_cache(
kv_cache, num_kv_heads, head_size)
ipex_ops.reshape_and_cache(
key[:num_actual_tokens],
value[:num_actual_tokens],
ipex_ops.chunked_prefill(
query[:num_actual_tokens].contiguous(),
key_cache,
value_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache_dtype,
k_scale,
v_scale,
output[:num_actual_tokens],
attn_metadata.query_start_loc,
attn_metadata.seq_start_loc,
None,
attn_metadata.block_table,
alibi_slopes,
attn_metadata.max_query_len,
attn_metadata.max_seq_len,
0.0,
scale,
False,
True,
False,
None,
)
# Invoke chunked prefill method...
import vllm._C.ops
assert head_size == 128 or head_size == 64
value = os.environ.get('USE_CONTEXT_V1')
query_len = attn_metadata.query_start_loc[1:] - attn_metadata.query_start_loc[:-1]
seq_len = attn_metadata.seq_start_loc[1:] - attn_metadata.seq_start_loc[:-1]
context_len = seq_len - query_len
if using_gqa_kernel:
# if using_gqa_kernel, then only the v1 kernel can be used
out = vllm._C.ops.context_attention_forward_v1(query[:num_actual_tokens], key_cache, value_cache, attn_metadata.block_table, attn_metadata.query_start_loc, seq_len, context_len, attn_metadata.max_seq_len, torch.amax(context_len).item())
elif value is None:
# Otherwise, by default use v2 attention forward kernel...
out = vllm._C.ops.context_attention_forward_v2(query[:num_actual_tokens], key_cache, value_cache, attn_metadata.block_table, attn_metadata.query_start_loc, seq_len, context_len, attn_metadata.max_seq_len, torch.amax(context_len).item(), torch.amax(query_len).item())
else:
out = vllm._C.ops.context_attention_forward_v1(query[:num_actual_tokens], key_cache, value_cache, attn_metadata.block_table, attn_metadata.query_start_loc, seq_len, context_len, attn_metadata.max_seq_len, torch.amax(context_len).item())

# output[:num_actual_tokens] = out
output[:num_actual_tokens] = out.view(out.shape[0], -1)
if using_gqa_kernel:
key_cache, value_cache = split_kv_cache_ipexllm(
kv_cache, num_kv_heads, head_size)
ipex_ops.reshape_and_cache_ipexllm(
key[:num_actual_tokens],
value[:num_actual_tokens],
key_cache,
value_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache_dtype,
k_scale,
v_scale,
)
else:
key_cache, value_cache = split_kv_cache(
kv_cache, num_kv_heads, head_size)
ipex_ops.reshape_and_cache(
key[:num_actual_tokens],
value[:num_actual_tokens],
key_cache,
value_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache_dtype,
k_scale,
v_scale,
)
# Invoke chunked prefill method...
import vllm._C.ops
assert head_size == 128 or head_size == 64
value = os.environ.get('USE_CONTEXT_V1')
query_len = attn_metadata.query_start_loc[1:] - attn_metadata.query_start_loc[:-1]
seq_len = attn_metadata.seq_start_loc[1:] - attn_metadata.seq_start_loc[:-1]
context_len = seq_len - query_len
if using_gqa_kernel:
# if using_gqa_kernel, then only the v1 kernel can be used
out = vllm._C.ops.context_attention_forward_v1(query[:num_actual_tokens], key_cache, value_cache, attn_metadata.block_table, attn_metadata.query_start_loc, seq_len, context_len, attn_metadata.max_seq_len, torch.amax(context_len).item())
elif value is None:
# Otherwise, by default use v2 attention forward kernel...
out = vllm._C.ops.context_attention_forward_v2(query[:num_actual_tokens], key_cache, value_cache, attn_metadata.block_table, attn_metadata.query_start_loc, seq_len, context_len, attn_metadata.max_seq_len, torch.amax(context_len).item(), torch.amax(query_len).item())
else:
out = vllm._C.ops.context_attention_forward_v1(query[:num_actual_tokens], key_cache, value_cache, attn_metadata.block_table, attn_metadata.query_start_loc, seq_len, context_len, attn_metadata.max_seq_len, torch.amax(context_len).item())

# output[:num_actual_tokens] = out
output[:num_actual_tokens] = out.view(out.shape[0], -1)



Expand Down