Skip to content

Gaudi: clean cuda/rocm code in hpu backend, enable flat_hpu #3113

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Apr 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
201dc62
clean cuda/rocm code in hpu backend, enable flat_hpu
sywangyi Mar 14, 2025
b7fea6f
fix TP in pageattn
sywangyi Mar 15, 2025
5d36539
adjust block table in hpu to improve performance
sywangyi Mar 17, 2025
a07e743
enable all the model. not testet yet
sywangyi Mar 17, 2025
6bbe24d
use tensor cache in hpu graph to avoid replay issue
sywangyi Mar 17, 2025
5cd1c93
add moe support, fix qwen/mistral/mixtral crash
sywangyi Mar 18, 2025
073f793
fix phimoe issue
sywangyi Mar 19, 2025
2cde30d
gpt_bigcode could also go pageattn
sywangyi Mar 19, 2025
2074d05
enable dbrx remove some unused code
sywangyi Mar 19, 2025
d5b78ba
Merge branch 'main' into gaudi_backend_pa
sywangyi Mar 20, 2025
f95aa42
multi-modality initial PR
sywangyi Mar 20, 2025
36b6612
adjust warmup and enable vlm
sywangyi Mar 20, 2025
fdf0733
fix incorrect output in qwen2 idefics if hpu graph is used
sywangyi Mar 21, 2025
9914ffe
remove unused quantization code and enable awq/gptq int4
sywangyi Mar 22, 2025
8d221b7
fix gptq issue
sywangyi Mar 23, 2025
6977376
enable fp8
sywangyi Mar 25, 2025
fd70ad7
warmup prefill
sywangyi Mar 26, 2025
ba7a131
add warmup_decode
sywangyi Mar 27, 2025
7900be5
warmup decode
sywangyi Mar 27, 2025
1508ee8
remove block_tables and prefill_cache_indices which will lead to dyna…
sywangyi Mar 28, 2025
7914e98
Merge branch 'main' into gaudi_backend_pa
sywangyi Mar 28, 2025
787dbe9
fix comment
sywangyi Mar 28, 2025
376e050
missing gptj change...
sywangyi Mar 28, 2025
f0e5fae
fix some issue
sywangyi Mar 28, 2025
c55a8ca
remove torch.where to fix incorrect output in hpu graph model
sywangyi Apr 1, 2025
610dd20
Merge branch 'main' into gaudi_backend_pa
sywangyi Apr 11, 2025
4cdc34e
match the latest vllm_extension ops
sywangyi Apr 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Dockerfile_gaudi
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ RUN cd server && \
pip install "git+https://github.com/HabanaAI/DeepSpeed.git@${HABANA_VERSION}" && \
BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \
pip install . --no-cache-dir

RUN pip install git+https://github.com/sywangyi/vllm-hpu-extension.git
# Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router
Expand Down
11 changes: 4 additions & 7 deletions backends/gaudi/server/text_generation_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,9 @@


class Quantization(str, Enum):
bitsandbytes = "bitsandbytes"
bitsandbytes_nf4 = "bitsandbytes-nf4"
bitsandbytes_fp4 = "bitsandbytes-fp4"
gptq = "gptq"
awq = "awq"
eetq = "eetq"
exl2 = "exl2"
fp8 = "fp8"
marlin = "marlin"


class Dtype(str, Enum):
Expand Down Expand Up @@ -105,14 +99,17 @@ def serve(
"bitsandbytes",
"bitsandbytes-nf4",
"bitsandbytes-fp4",
"gptq",
"awq",
"fp8",
}:
raise RuntimeError(
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
)

logger.info("CLI SHARDED = {} DTYPE = {}".format(sharded, dtype))

if sharded:
if sharded and os.getenv("ATTENTION", "default") not in {"paged"}:
tgi_file = Path(__file__).resolve().parent / "tgi_service.py"
num_shard = int(os.getenv("WORLD_SIZE", "1"))
logger.info("CLI SHARDED = {}".format(num_shard))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,43 +1,28 @@
from text_generation_server.utils.import_utils import SYSTEM
import os
from .common import (
Seqlen,
HPUPagedAttentionMetadata,
trim_attn_metadata,
trim_seqlen_metadata,
)

from .common import Seqlen
from .hpu import (
SUPPORTS_WINDOWING,
attention,
paged_attention,
)

if os.getenv("USE_FLASH_ATTENTION", "true").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.")
if SYSTEM == "cuda":
from .cuda import (
attention,
paged_attention,
reshape_and_cache,
SUPPORTS_WINDOWING,
PREFILL_IN_KV_CACHE,
)
elif SYSTEM == "rocm":
from .rocm import (
attention,
paged_attention,
reshape_and_cache,
PREFILL_IN_KV_CACHE,
SUPPORTS_WINDOWING,
)
elif SYSTEM == "ipex":
from .ipex import (
attention,
paged_attention,
reshape_and_cache,
PREFILL_IN_KV_CACHE,
SUPPORTS_WINDOWING,
)
else:
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")

# KVCache needs `reshape_and_cache`, so ensure that it is defined already.
from .kv_cache import KVCache, get_kv_scales

__all__ = [
"attention",
"get_kv_scales",
"paged_attention",
"reshape_and_cache",
"PREFILL_IN_KV_CACHE",
"SUPPORTS_WINDOWING",
"KVCache",
"Seqlen",
"HPUPagedAttentionMetadata",
"trim_seqlen_metadata",
"trim_attn_metadata",
]
215 changes: 145 additions & 70 deletions backends/gaudi/server/text_generation_server/layers/attention/common.py
Original file line number Diff line number Diff line change
@@ -1,72 +1,147 @@
from dataclasses import dataclass
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import ATTENTION
import torch
from typing import Optional


if ATTENTION in {"flashinfer", "flashdecoding"}:

@dataclass
class Seqlen:
input_lengths: torch.Tensor
prefix_lengths: torch.Tensor
cu_seqlen_q: Optional[torch.Tensor]
cu_seqlen_k: Optional[torch.Tensor]
max_q: int
max_k: int

def __init__(
self,
input_lengths,
prefix_lengths,
cu_seqlen_q=None,
max_q=None,
max_k=None,
):
self.input_lengths = input_lengths
self.prefix_lengths = prefix_lengths
device = self.input_lengths.device
shape = self.input_lengths.shape
if cu_seqlen_q is None:
cu_seqlen_q = torch.arange(
shape[0] + 1,
device=device,
dtype=torch.int32,
)
max_q = 1
else:
assert max_q is not None
assert max_k is not None
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)

# cuda graphs don't like this and this is necessary to clamp within mistral
# Although FA2 might not want the clamping
# cu_seqlen_k[0] = 0
total = self.input_lengths + self.prefix_lengths
torch.cumsum(total, -1, out=cu_seqlen_k[1:])

self.cu_seqlen_q = cu_seqlen_q
self.cu_seqlen_k = cu_seqlen_k
self.max_q = max_q
self.max_k = max_k

def clamp(self, max):
# Flash decoding doesn't need to clamp
return self

else:

@dataclass
class Seqlen:
input_lengths: torch.Tensor
prefix_lengths: torch.Tensor
cu_seqlen_q: torch.Tensor
max_q: int
max_k: int

def clamp(self, max):
if SYSTEM == "rocm":
return self
raise NotImplementedError("Not implemented seqlen for paged")
return Seqlen(torch.clamp(self.input_lengths, max=max))
from typing import Optional, List, Dict
import collections

_TYPE_CACHE = {}


@dataclass
class HPUPagedAttentionMetadata:
"""Metadata for PagedAttention."""

block_list: Optional[torch.Tensor]
block_mapping: Optional[torch.Tensor]
block_usage: Optional[torch.Tensor]
block_scales: Optional[torch.Tensor]
block_groups: Optional[torch.Tensor]
attn_bias: Optional[torch.Tensor]


def subtuple(
obj: object,
typename: str,
to_copy: List[str],
to_override: Optional[Dict[str, object]] = None,
):
if obj is None:
return None
if to_override is None:
to_override = {}
fields = set(to_copy) | set(to_override.keys())
if isinstance(obj, dict):
values = {key: obj[key] for key in fields if key in obj}
else:
values = {f: to_override.get(f, getattr(obj, f)) for f in fields}
if typename not in _TYPE_CACHE:
_TYPE_CACHE[typename] = collections.namedtuple(typename, " ".join(fields))
return _TYPE_CACHE[typename](**values)


def trim_attn_metadata(metadata: HPUPagedAttentionMetadata) -> object:
# NOTE(kzawora): To anyone working on this in the future:
# Trimming metadata is required when using HPUGraphs.
# Attention metadata is going to be hashed by PT bridge, and
# appropriate HPUGraphs will be matched based on all inputs' hash.

# Before you put more keys in here, make sure you know their
# value type and make sure you know how it's going to be hashed.
# You can find that information in input_hash function
# in habana_frameworks/torch/hpu/graphs.py. You can also hash
# it manually with torch.hpu.graphs.input_hash(attention_metadata)

# If you use primitive types here - they will get hashed based
# on their value. You *will* get lots of excessive graph captures
# (and an OOM eventually) if you decide to put something like
# seq_len int here.
# If you absolutely need a scalar, put it in a tensor. Tensors
# get hashed using their metadata, not their values:
# input_hash(torch.tensor(123)) == input_hash(torch.tensor(321))
# input_hash(123) != input_hash(321)
# input_hash("abc") != input_hash("cba")
attention_metadata = subtuple(
metadata,
"TrimmedAttentionMetadata",
[
"block_list",
"block_mapping",
"block_usage",
"block_scales",
"block_groups",
"attn_bias",
],
)
return attention_metadata


@dataclass
class Seqlen:
input_lengths: torch.Tensor
cache_lengths: torch.Tensor
cu_seqlen_q: Optional[torch.Tensor]
cu_seqlen_k: Optional[torch.Tensor]

def __init__(
self,
input_lengths,
cache_lengths,
cu_seqlen_q=None,
):
self.input_lengths = input_lengths
self.cache_lengths = cache_lengths
device = self.input_lengths.device
shape = self.input_lengths.shape
if cu_seqlen_q is None:
cu_seqlen_q = torch.arange(
shape[0] + 1,
device=device,
dtype=torch.int32,
)
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)

# cuda graphs don't like this and this is necessary to clamp within mistral
# Although FA2 might not want the clamping
# cu_seqlen_k[0] = 0
total = self.input_lengths + self.cache_lengths
torch.cumsum(total, -1, out=cu_seqlen_k[1:])

self.cu_seqlen_q = cu_seqlen_q
self.cu_seqlen_k = cu_seqlen_k

def clamp(self, max):
# Flash decoding doesn't need to clamp
return self


def trim_seqlen_metadata(metadata: Seqlen) -> object:
# NOTE(kzawora): To anyone working on this in the future:
# Trimming metadata is required when using HPUGraphs.
# Attention metadata is going to be hashed by PT bridge, and
# appropriate HPUGraphs will be matched based on all inputs' hash.

# Before you put more keys in here, make sure you know their
# value type and make sure you know how it's going to be hashed.
# You can find that information in input_hash function
# in habana_frameworks/torch/hpu/graphs.py. You can also hash
# it manually with torch.hpu.graphs.input_hash(attention_metadata)

# If you use primitive types here - they will get hashed based
# on their value. You *will* get lots of excessive graph captures
# (and an OOM eventually) if you decide to put something like
# seq_len int here.
# If you absolutely need a scalar, put it in a tensor. Tensors
# get hashed using their metadata, not their values:
# input_hash(torch.tensor(123)) == input_hash(torch.tensor(321))
# input_hash(123) != input_hash(321)
# input_hash("abc") != input_hash("cba")
attention_metadata = subtuple(
metadata,
"TrimmedSeqlen",
[
"input_lengths",
"cache_lengths",
"cu_seqlen_q",
"cu_seqlen_k",
],
)
return attention_metadata
Loading