-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Gaudi: clean cuda/rocm code in hpu backend, enable flat_hpu #3113
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
27 commits
Select commit
Hold shift + click to select a range
201dc62
clean cuda/rocm code in hpu backend, enable flat_hpu
sywangyi b7fea6f
fix TP in pageattn
sywangyi 5d36539
adjust block table in hpu to improve performance
sywangyi a07e743
enable all the model. not testet yet
sywangyi 6bbe24d
use tensor cache in hpu graph to avoid replay issue
sywangyi 5cd1c93
add moe support, fix qwen/mistral/mixtral crash
sywangyi 073f793
fix phimoe issue
sywangyi 2cde30d
gpt_bigcode could also go pageattn
sywangyi 2074d05
enable dbrx remove some unused code
sywangyi d5b78ba
Merge branch 'main' into gaudi_backend_pa
sywangyi f95aa42
multi-modality initial PR
sywangyi 36b6612
adjust warmup and enable vlm
sywangyi fdf0733
fix incorrect output in qwen2 idefics if hpu graph is used
sywangyi 9914ffe
remove unused quantization code and enable awq/gptq int4
sywangyi 8d221b7
fix gptq issue
sywangyi 6977376
enable fp8
sywangyi fd70ad7
warmup prefill
sywangyi ba7a131
add warmup_decode
sywangyi 7900be5
warmup decode
sywangyi 1508ee8
remove block_tables and prefill_cache_indices which will lead to dyna…
sywangyi 7914e98
Merge branch 'main' into gaudi_backend_pa
sywangyi 787dbe9
fix comment
sywangyi 376e050
missing gptj change...
sywangyi f0e5fae
fix some issue
sywangyi c55a8ca
remove torch.where to fix incorrect output in hpu graph model
sywangyi 610dd20
Merge branch 'main' into gaudi_backend_pa
sywangyi 4cdc34e
match the latest vllm_extension ops
sywangyi File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
51 changes: 18 additions & 33 deletions
51
backends/gaudi/server/text_generation_server/layers/attention/__init__.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,43 +1,28 @@ | ||
from text_generation_server.utils.import_utils import SYSTEM | ||
import os | ||
from .common import ( | ||
Seqlen, | ||
HPUPagedAttentionMetadata, | ||
trim_attn_metadata, | ||
trim_seqlen_metadata, | ||
) | ||
|
||
from .common import Seqlen | ||
from .hpu import ( | ||
SUPPORTS_WINDOWING, | ||
attention, | ||
paged_attention, | ||
) | ||
|
||
if os.getenv("USE_FLASH_ATTENTION", "true").lower() == "false": | ||
raise ImportError("`USE_FLASH_ATTENTION` is false.") | ||
if SYSTEM == "cuda": | ||
from .cuda import ( | ||
attention, | ||
paged_attention, | ||
reshape_and_cache, | ||
SUPPORTS_WINDOWING, | ||
PREFILL_IN_KV_CACHE, | ||
) | ||
elif SYSTEM == "rocm": | ||
from .rocm import ( | ||
attention, | ||
paged_attention, | ||
reshape_and_cache, | ||
PREFILL_IN_KV_CACHE, | ||
SUPPORTS_WINDOWING, | ||
) | ||
elif SYSTEM == "ipex": | ||
from .ipex import ( | ||
attention, | ||
paged_attention, | ||
reshape_and_cache, | ||
PREFILL_IN_KV_CACHE, | ||
SUPPORTS_WINDOWING, | ||
) | ||
else: | ||
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention") | ||
|
||
# KVCache needs `reshape_and_cache`, so ensure that it is defined already. | ||
from .kv_cache import KVCache, get_kv_scales | ||
|
||
__all__ = [ | ||
"attention", | ||
"get_kv_scales", | ||
"paged_attention", | ||
"reshape_and_cache", | ||
"PREFILL_IN_KV_CACHE", | ||
"SUPPORTS_WINDOWING", | ||
"KVCache", | ||
"Seqlen", | ||
"HPUPagedAttentionMetadata", | ||
"trim_seqlen_metadata", | ||
"trim_attn_metadata", | ||
] |
215 changes: 145 additions & 70 deletions
215
backends/gaudi/server/text_generation_server/layers/attention/common.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,72 +1,147 @@ | ||
from dataclasses import dataclass | ||
from text_generation_server.utils.import_utils import SYSTEM | ||
from text_generation_server.models.globals import ATTENTION | ||
import torch | ||
from typing import Optional | ||
|
||
|
||
if ATTENTION in {"flashinfer", "flashdecoding"}: | ||
|
||
@dataclass | ||
class Seqlen: | ||
input_lengths: torch.Tensor | ||
prefix_lengths: torch.Tensor | ||
cu_seqlen_q: Optional[torch.Tensor] | ||
cu_seqlen_k: Optional[torch.Tensor] | ||
max_q: int | ||
max_k: int | ||
|
||
def __init__( | ||
self, | ||
input_lengths, | ||
prefix_lengths, | ||
cu_seqlen_q=None, | ||
max_q=None, | ||
max_k=None, | ||
): | ||
self.input_lengths = input_lengths | ||
self.prefix_lengths = prefix_lengths | ||
device = self.input_lengths.device | ||
shape = self.input_lengths.shape | ||
if cu_seqlen_q is None: | ||
cu_seqlen_q = torch.arange( | ||
shape[0] + 1, | ||
device=device, | ||
dtype=torch.int32, | ||
) | ||
max_q = 1 | ||
else: | ||
assert max_q is not None | ||
assert max_k is not None | ||
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32) | ||
|
||
# cuda graphs don't like this and this is necessary to clamp within mistral | ||
# Although FA2 might not want the clamping | ||
# cu_seqlen_k[0] = 0 | ||
total = self.input_lengths + self.prefix_lengths | ||
torch.cumsum(total, -1, out=cu_seqlen_k[1:]) | ||
|
||
self.cu_seqlen_q = cu_seqlen_q | ||
self.cu_seqlen_k = cu_seqlen_k | ||
self.max_q = max_q | ||
self.max_k = max_k | ||
|
||
def clamp(self, max): | ||
# Flash decoding doesn't need to clamp | ||
return self | ||
|
||
else: | ||
|
||
@dataclass | ||
class Seqlen: | ||
input_lengths: torch.Tensor | ||
prefix_lengths: torch.Tensor | ||
cu_seqlen_q: torch.Tensor | ||
max_q: int | ||
max_k: int | ||
|
||
def clamp(self, max): | ||
if SYSTEM == "rocm": | ||
return self | ||
raise NotImplementedError("Not implemented seqlen for paged") | ||
return Seqlen(torch.clamp(self.input_lengths, max=max)) | ||
from typing import Optional, List, Dict | ||
import collections | ||
|
||
_TYPE_CACHE = {} | ||
|
||
|
||
@dataclass | ||
class HPUPagedAttentionMetadata: | ||
"""Metadata for PagedAttention.""" | ||
|
||
block_list: Optional[torch.Tensor] | ||
block_mapping: Optional[torch.Tensor] | ||
block_usage: Optional[torch.Tensor] | ||
block_scales: Optional[torch.Tensor] | ||
block_groups: Optional[torch.Tensor] | ||
attn_bias: Optional[torch.Tensor] | ||
|
||
|
||
def subtuple( | ||
obj: object, | ||
typename: str, | ||
to_copy: List[str], | ||
to_override: Optional[Dict[str, object]] = None, | ||
): | ||
if obj is None: | ||
return None | ||
if to_override is None: | ||
to_override = {} | ||
fields = set(to_copy) | set(to_override.keys()) | ||
if isinstance(obj, dict): | ||
values = {key: obj[key] for key in fields if key in obj} | ||
else: | ||
values = {f: to_override.get(f, getattr(obj, f)) for f in fields} | ||
if typename not in _TYPE_CACHE: | ||
_TYPE_CACHE[typename] = collections.namedtuple(typename, " ".join(fields)) | ||
return _TYPE_CACHE[typename](**values) | ||
|
||
|
||
def trim_attn_metadata(metadata: HPUPagedAttentionMetadata) -> object: | ||
# NOTE(kzawora): To anyone working on this in the future: | ||
# Trimming metadata is required when using HPUGraphs. | ||
# Attention metadata is going to be hashed by PT bridge, and | ||
# appropriate HPUGraphs will be matched based on all inputs' hash. | ||
|
||
# Before you put more keys in here, make sure you know their | ||
# value type and make sure you know how it's going to be hashed. | ||
# You can find that information in input_hash function | ||
# in habana_frameworks/torch/hpu/graphs.py. You can also hash | ||
# it manually with torch.hpu.graphs.input_hash(attention_metadata) | ||
|
||
# If you use primitive types here - they will get hashed based | ||
# on their value. You *will* get lots of excessive graph captures | ||
# (and an OOM eventually) if you decide to put something like | ||
# seq_len int here. | ||
# If you absolutely need a scalar, put it in a tensor. Tensors | ||
# get hashed using their metadata, not their values: | ||
# input_hash(torch.tensor(123)) == input_hash(torch.tensor(321)) | ||
# input_hash(123) != input_hash(321) | ||
# input_hash("abc") != input_hash("cba") | ||
attention_metadata = subtuple( | ||
metadata, | ||
"TrimmedAttentionMetadata", | ||
[ | ||
"block_list", | ||
"block_mapping", | ||
"block_usage", | ||
"block_scales", | ||
"block_groups", | ||
"attn_bias", | ||
], | ||
) | ||
return attention_metadata | ||
|
||
|
||
@dataclass | ||
class Seqlen: | ||
input_lengths: torch.Tensor | ||
cache_lengths: torch.Tensor | ||
cu_seqlen_q: Optional[torch.Tensor] | ||
cu_seqlen_k: Optional[torch.Tensor] | ||
|
||
def __init__( | ||
self, | ||
input_lengths, | ||
cache_lengths, | ||
cu_seqlen_q=None, | ||
): | ||
self.input_lengths = input_lengths | ||
self.cache_lengths = cache_lengths | ||
device = self.input_lengths.device | ||
shape = self.input_lengths.shape | ||
if cu_seqlen_q is None: | ||
cu_seqlen_q = torch.arange( | ||
shape[0] + 1, | ||
device=device, | ||
dtype=torch.int32, | ||
) | ||
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32) | ||
|
||
# cuda graphs don't like this and this is necessary to clamp within mistral | ||
# Although FA2 might not want the clamping | ||
# cu_seqlen_k[0] = 0 | ||
total = self.input_lengths + self.cache_lengths | ||
torch.cumsum(total, -1, out=cu_seqlen_k[1:]) | ||
|
||
self.cu_seqlen_q = cu_seqlen_q | ||
self.cu_seqlen_k = cu_seqlen_k | ||
|
||
def clamp(self, max): | ||
# Flash decoding doesn't need to clamp | ||
return self | ||
|
||
|
||
def trim_seqlen_metadata(metadata: Seqlen) -> object: | ||
# NOTE(kzawora): To anyone working on this in the future: | ||
# Trimming metadata is required when using HPUGraphs. | ||
# Attention metadata is going to be hashed by PT bridge, and | ||
# appropriate HPUGraphs will be matched based on all inputs' hash. | ||
|
||
# Before you put more keys in here, make sure you know their | ||
# value type and make sure you know how it's going to be hashed. | ||
# You can find that information in input_hash function | ||
# in habana_frameworks/torch/hpu/graphs.py. You can also hash | ||
# it manually with torch.hpu.graphs.input_hash(attention_metadata) | ||
|
||
# If you use primitive types here - they will get hashed based | ||
# on their value. You *will* get lots of excessive graph captures | ||
# (and an OOM eventually) if you decide to put something like | ||
# seq_len int here. | ||
# If you absolutely need a scalar, put it in a tensor. Tensors | ||
# get hashed using their metadata, not their values: | ||
# input_hash(torch.tensor(123)) == input_hash(torch.tensor(321)) | ||
# input_hash(123) != input_hash(321) | ||
# input_hash("abc") != input_hash("cba") | ||
attention_metadata = subtuple( | ||
metadata, | ||
"TrimmedSeqlen", | ||
[ | ||
"input_lengths", | ||
"cache_lengths", | ||
"cu_seqlen_q", | ||
"cu_seqlen_k", | ||
], | ||
) | ||
return attention_metadata |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.