Skip to content

IPEX support FP8 kvcache/softcap/slidingwindow #3144

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
May 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions Dockerfile_intel
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ ENV HF_HOME=/data \


WORKDIR /usr/src
RUN pip install torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/xpu

RUN pip install torch==2.7.0 torchvision==0.22.0 --index-url https://download.pytorch.org/whl/xpu

# Install server
COPY proto proto
Expand All @@ -116,8 +117,8 @@ ENV TORCH_LLM_ALLREDUCE=1
ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0
ENV TORCH_DEVICE_BACKEND_AUTOLOAD=0

RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/oneccl_bind_pt-2.6.0%2Bxpu-cp311-cp311-linux_x86_64.whl
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/intel_extension_for_pytorch-2.6.10%2Bxpu-cp311-cp311-linux_x86_64.whl
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/oneccl_bind_pt-2.7.0%2Bxpu-cp311-cp311-linux_x86_64.whl
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/intel_extension_for_pytorch-2.7.10%2Bxpu-cp311-cp311-linux_x86_64.whl
# Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router
Expand Down Expand Up @@ -180,13 +181,13 @@ RUN case ${TARGETPLATFORM} in \

RUN conda install -c conda-forge gperftools mkl

RUN pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cpu
RUN pip install triton==3.1.0 py-libnuma
RUN pip install torch==2.7.0 torchvision==0.22.0 torchaudio==2.7.0 --index-url https://download.pytorch.org/whl/cpu
RUN pip install triton==3.2.0 py-libnuma

WORKDIR /usr/src

RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/cpu/intel_extension_for_pytorch-2.6.0%2Bcpu-cp311-cp311-linux_x86_64.whl
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/cpu/oneccl_bind_pt-2.6.0%2Bcpu-cp311-cp311-linux_x86_64.whl
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/cpu/intel_extension_for_pytorch-2.7.0%2Bcpu-cp311-cp311-linux_x86_64.whl
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/cpu/oneccl_bind_pt-2.7.0%2Bcpu-cp311-cp311-linux_x86_64.whl


ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so
Expand Down
49 changes: 42 additions & 7 deletions server/text_generation_server/layers/attention/ipex.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
BLOCK_SIZE,
)

SUPPORTS_WINDOWING = False
if ATTENTION == "flashdecoding-ipex":
SUPPORTS_WINDOWING = True
else:
SUPPORTS_WINDOWING = False


def attention(
Expand All @@ -25,13 +28,19 @@ def attention(
causal: bool = True,
softcap: Optional[float] = None,
):
if softcap is not None:
raise NotImplementedError("softcap is not available in IPEX")

out = torch.empty_like(query)
kv_cache_dtype = "auto"
if kv_cache.key.dtype == torch.float8_e5m2:
kv_cache_dtype = "fp8_e5m2"
if kv_cache.key.dtype == torch.float8_e4m3fn:
kv_cache_dtype = "fp8_e4m3"

# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
if ATTENTION == "flashdecoding-ipex":
window_size_right = -1 if window_size_left == -1 else 0
if softcap is None:
softcap = -1.0
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
out,
query.contiguous() if query.device.type == "xpu" else query,
Expand All @@ -45,8 +54,18 @@ def attention(
causal,
block_tables,
None,
window_size_left=window_size_left,
window_size_right=window_size_right,
kv_cache_dtype=kv_cache_dtype,
k_scale=kv_scales.key_scale_cpu,
v_scale=kv_scales.value_scale_cpu,
softcap=softcap,
)
else:
if softcap is not None:
raise NotImplementedError(
"softcap is not available in IPEX paged attention"
)
ipex.llm.functional.varlen_attention(
query.contiguous() if query.device.type == "xpu" else query,
key.contiguous() if key.device.type == "xpu" else key,
Expand Down Expand Up @@ -80,12 +99,16 @@ def paged_attention(
softcap: Optional[float] = None,
window_size_left: Optional[int] = -1,
):
if softcap is not None:
raise NotImplementedError("softcap is not available in IPEX")

out = torch.empty_like(query)

kv_cache_dtype = "auto"
if kv_cache.key.dtype == torch.float8_e5m2:
kv_cache_dtype = "fp8_e5m2"
if kv_cache.key.dtype == torch.float8_e4m3fn:
kv_cache_dtype = "fp8_e4m3"
if ATTENTION == "flashdecoding-ipex":
window_size_right = -1 if window_size_left == -1 else 0
if softcap is None:
softcap = -1.0
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
out,
query.contiguous() if query.device.type == "xpu" else query,
Expand All @@ -99,9 +122,19 @@ def paged_attention(
True,
block_tables,
None,
window_size_left=window_size_left,
window_size_right=window_size_right,
kv_cache_dtype=kv_cache_dtype,
k_scale=kv_scales.key_scale_cpu,
v_scale=kv_scales.value_scale_cpu,
softcap=softcap,
)
else:
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
if softcap is not None:
raise NotImplementedError(
"softcap is not available in IPEX paged attention"
)
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
out,
query,
Expand All @@ -114,6 +147,8 @@ def paged_attention(
BLOCK_SIZE,
max_s,
None,
k_scale=kv_scales.key_scale_cpu,
v_scale=kv_scales.value_scale_cpu,
)
return out

Expand Down
43 changes: 37 additions & 6 deletions server/text_generation_server/layers/attention/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,20 @@ def __init__(
if dtype in {torch.float8_e5m2, torch.float8_e4m3fn}:
if not (
(ATTENTION == "flashinfer" and SYSTEM == "cuda")
or (ATTENTION == "paged" and SYSTEM in ("cuda", "rocm"))
or (ATTENTION == "paged" and SYSTEM in ("cuda", "rocm", "ipex"))
or (ATTENTION == "flashdecoding-ipex")
):
raise ValueError(
"FP8 KV cache is currently only supported for flashinfer on CUDA and paged attention on CUDA and ROCm. "
"FP8 KV cache is currently only supported for flashinfer on CUDA and paged attention on CUDA, ROCm and INTEL IPEX and flashdecoding in Intel IPEX "
)
if SYSTEM == "rocm" and dtype == torch.float8_e5m2:
raise ValueError(
"float8_e5m2 FP8 KV cache is not supported on AMD ROCm"
)
if device.type == "cpu" and dtype == torch.float8_e4m3fn:
raise ValueError(
"float8_e4m3fn FP8 KV cache is not supported on Intel IPEX CPU"
)

element_size = torch.tensor([], dtype=dtype).element_size()
if SYSTEM == "ipex" and device.type == "xpu":
Expand Down Expand Up @@ -133,15 +138,16 @@ def can_scale(self, kv_scales: KVScales) -> bool:
return False
elif self.dtype == torch.float8_e4m3fn and (
(ATTENTION in ("paged", "flashinfer") and SYSTEM == "cuda")
or (ATTENTION == "paged" and SYSTEM == "rocm")
or (ATTENTION == "paged" and SYSTEM in ["rocm", "ipex"])
or (ATTENTION == "flashdecoding-ipex")
):
log_once(logger.info, "Using FP8 KV cache scales")
return True
else:
# We have scales, but not the correct FP8 cache type, so warn once.
log_once(
logger.info,
"Ignoring FP8 KV cache scales, supported only for float8_e4m3fn KV cache with flashinfer on CUDA and paged attention on ROCm",
"Ignoring FP8 KV cache scales, supported only for float8_e4m3fn KV cache with flashinfer on CUDA and paged attention on ROCm/IPEX and flashdecoding on IPEX",
)
return False

Expand Down Expand Up @@ -207,8 +213,20 @@ def store(
elif ATTENTION == "flashdecoding-ipex" and key.device.type == "xpu":
import intel_extension_for_pytorch as ipex

kv_cache_dtype = "auto"
if key_cache.dtype == torch.float8_e5m2:
kv_cache_dtype = "fp8_e5m2"
if key_cache.dtype == torch.float8_e4m3fn:
kv_cache_dtype = "fp8_e4m3"
ipex.llm.modules.PagedAttention.reshape_and_cache_flash(
key, value, key_cache, value_cache, slots
key,
value,
key_cache,
value_cache,
slots,
kv_cache_dtype=kv_cache_dtype,
k_scale=kv_scales.key_scale_cpu,
v_scale=kv_scales.value_scale_cpu,
)
else:
paged_reshape_and_cache(
Expand Down Expand Up @@ -267,8 +285,21 @@ def paged_reshape_and_cache(
elif SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex

kv_cache_dtype = "auto"
if key_cache.dtype == torch.float8_e5m2:
kv_cache_dtype = "fp8_e5m2"
if key_cache.dtype == torch.float8_e4m3fn:
kv_cache_dtype = "fp8_e4m3"

ipex.llm.modules.PagedAttention.reshape_and_cache(
key, value, key_cache, value_cache, slots
key,
value,
key_cache,
value_cache,
slots,
kv_cache_dtype=kv_cache_dtype,
k_scale=k_scale,
v_scale=v_scale,
)
else:
raise NotImplementedError(
Expand Down