Skip to content

IPEX support FP8 kvcache/softcap/slidingwindow #3144

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions Dockerfile_intel
Original file line number Diff line number Diff line change
@@ -87,7 +87,7 @@ RUN echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https:/

RUN mv /tmp/intel-for-pytorch-gpu-dev.list /etc/apt/sources.list.d

RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y xpu-smi cmake ninja-build pciutils intel-ocloc
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y xpu-smi cmake ninja-build pciutils intel-ocloc libnl-genl-3-200

# Text Generation Inference base env
ENV HF_HOME=/data \
@@ -100,8 +100,6 @@ ENV HF_HOME=/data \
WORKDIR /usr/src
RUN pip install torch==2.6.0 --index-url https://download.pytorch.org/whl/test/xpu

RUN pip install triton-xpu==3.2.0b1 --no-cache-dir

# Install server
COPY proto proto
COPY server server
@@ -119,7 +117,9 @@ ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0
ENV TORCH_DEVICE_BACKEND_AUTOLOAD=0

RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/oneccl_bind_pt-2.6.0%2Bxpu-cp311-cp311-linux_x86_64.whl
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/intel_extension_for_pytorch-2.6.10%2Bxpu-cp311-cp311-linux_x86_64.whl
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout d5a7036316a01ea8220eb4da78a2207c423a1166
RUN sed -i 's/VERSION_MINOR 7/VERSION_MINOR 6/' intel-extension-for-pytorch/version.txt
RUN cd intel-extension-for-pytorch && git submodule update --init --recursive && USE_AOT_DEVLIST='pvc,ats-m150' BUILD_SEPARATE_OPS=OFF BUILD_WITH_CPU=OFF USE_XETLA=ON python setup.py install && rm -rf /usr/src/intel-extension-for-pytorch
# Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router
49 changes: 42 additions & 7 deletions server/text_generation_server/layers/attention/ipex.py
Original file line number Diff line number Diff line change
@@ -8,7 +8,10 @@
BLOCK_SIZE,
)

SUPPORTS_WINDOWING = False
if ATTENTION == "flashdecoding-ipex":
SUPPORTS_WINDOWING = True
else:
SUPPORTS_WINDOWING = False


def attention(
@@ -25,13 +28,19 @@ def attention(
causal: bool = True,
softcap: Optional[float] = None,
):
if softcap is not None:
raise NotImplementedError("softcap is not available in IPEX")

out = torch.empty_like(query)
kv_cache_dtype = "auto"
if kv_cache.key.dtype == torch.float8_e5m2:
kv_cache_dtype = "fp8_e5m2"
if kv_cache.key.dtype == torch.float8_e4m3fn:
kv_cache_dtype = "fp8_e4m3"

# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
if ATTENTION == "flashdecoding-ipex":
window_size_right = -1 if window_size_left == -1 else 0
if softcap is None:
softcap = -1.0
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
out,
query.contiguous() if query.device.type == "xpu" else query,
@@ -45,8 +54,18 @@ def attention(
causal,
block_tables,
None,
window_size_left=window_size_left,
window_size_right=window_size_right,
kv_cache_dtype=kv_cache_dtype,
k_scale=kv_scales.key_scale_cpu,
v_scale=kv_scales.value_scale_cpu,
softcap=softcap,
)
else:
if softcap is not None:
raise NotImplementedError(
"softcap is not available in IPEX paged attention"
)
ipex.llm.functional.varlen_attention(
query.contiguous() if query.device.type == "xpu" else query,
key.contiguous() if key.device.type == "xpu" else key,
@@ -80,12 +99,16 @@ def paged_attention(
softcap: Optional[float] = None,
window_size_left: Optional[int] = -1,
):
if softcap is not None:
raise NotImplementedError("softcap is not available in IPEX")

out = torch.empty_like(query)

kv_cache_dtype = "auto"
if kv_cache.key.dtype == torch.float8_e5m2:
kv_cache_dtype = "fp8_e5m2"
if kv_cache.key.dtype == torch.float8_e4m3fn:
kv_cache_dtype = "fp8_e4m3"
if ATTENTION == "flashdecoding-ipex":
window_size_right = -1 if window_size_left == -1 else 0
if softcap is None:
softcap = -1.0
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
out,
query.contiguous() if query.device.type == "xpu" else query,
@@ -99,9 +122,19 @@ def paged_attention(
True,
block_tables,
None,
window_size_left=window_size_left,
window_size_right=window_size_right,
kv_cache_dtype=kv_cache_dtype,
k_scale=kv_scales.key_scale_cpu,
v_scale=kv_scales.value_scale_cpu,
softcap=softcap,
)
else:
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
if softcap is not None:
raise NotImplementedError(
"softcap is not available in IPEX paged attention"
)
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
out,
query,
@@ -114,6 +147,8 @@ def paged_attention(
BLOCK_SIZE,
max_s,
None,
k_scale=kv_scales.key_scale_cpu,
v_scale=kv_scales.value_scale_cpu,
)
return out

43 changes: 37 additions & 6 deletions server/text_generation_server/layers/attention/kv_cache.py
Original file line number Diff line number Diff line change
@@ -68,15 +68,20 @@ def __init__(
if dtype in {torch.float8_e5m2, torch.float8_e4m3fn}:
if not (
(ATTENTION == "flashinfer" and SYSTEM == "cuda")
or (ATTENTION == "paged" and SYSTEM in ("cuda", "rocm"))
or (ATTENTION == "paged" and SYSTEM in ("cuda", "rocm", "ipex"))
or (ATTENTION == "flashdecoding-ipex")
):
raise ValueError(
"FP8 KV cache is currently only supported for flashinfer on CUDA and paged attention on CUDA and ROCm. "
"FP8 KV cache is currently only supported for flashinfer on CUDA and paged attention on CUDA, ROCm and INTEL IPEX and flashdecoding in Intel IPEX "
)
if SYSTEM == "rocm" and dtype == torch.float8_e5m2:
raise ValueError(
"float8_e5m2 FP8 KV cache is not supported on AMD ROCm"
)
if device.type == "cpu" and dtype == torch.float8_e4m3fn:
raise ValueError(
"float8_e4m3fn FP8 KV cache is not supported on Intel IPEX CPU"
)

element_size = torch.tensor([], dtype=dtype).element_size()
if SYSTEM == "ipex" and device.type == "xpu":
@@ -133,15 +138,16 @@ def can_scale(self, kv_scales: KVScales) -> bool:
return False
elif self.dtype == torch.float8_e4m3fn and (
(ATTENTION in ("paged", "flashinfer") and SYSTEM == "cuda")
or (ATTENTION == "paged" and SYSTEM == "rocm")
or (ATTENTION == "paged" and SYSTEM in ["rocm", "ipex"])
or (ATTENTION == "flashdecoding-ipex")
):
log_once(logger.info, "Using FP8 KV cache scales")
return True
else:
# We have scales, but not the correct FP8 cache type, so warn once.
log_once(
logger.info,
"Ignoring FP8 KV cache scales, supported only for float8_e4m3fn KV cache with flashinfer on CUDA and paged attention on ROCm",
"Ignoring FP8 KV cache scales, supported only for float8_e4m3fn KV cache with flashinfer on CUDA and paged attention on ROCm/IPEX and flashdecoding on IPEX",
)
return False

@@ -207,8 +213,20 @@ def store(
elif ATTENTION == "flashdecoding-ipex" and key.device.type == "xpu":
import intel_extension_for_pytorch as ipex

kv_cache_dtype = "auto"
if key_cache.dtype == torch.float8_e5m2:
kv_cache_dtype = "fp8_e5m2"
if key_cache.dtype == torch.float8_e4m3fn:
kv_cache_dtype = "fp8_e4m3"
ipex.llm.modules.PagedAttention.reshape_and_cache_flash(
key, value, key_cache, value_cache, slots
key,
value,
key_cache,
value_cache,
slots,
kv_cache_dtype=kv_cache_dtype,
k_scale=kv_scales.key_scale_cpu,
v_scale=kv_scales.value_scale_cpu,
)
else:
paged_reshape_and_cache(
@@ -267,8 +285,21 @@ def paged_reshape_and_cache(
elif SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex

kv_cache_dtype = "auto"
if key_cache.dtype == torch.float8_e5m2:
kv_cache_dtype = "fp8_e5m2"
if key_cache.dtype == torch.float8_e4m3fn:
kv_cache_dtype = "fp8_e4m3"

ipex.llm.modules.PagedAttention.reshape_and_cache(
key, value, key_cache, value_cache, slots
key,
value,
key_cache,
value_cache,
slots,
kv_cache_dtype=kv_cache_dtype,
k_scale=k_scale,
v_scale=v_scale,
)
else:
raise NotImplementedError(