Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ endif()
set(DEVICE "" CACHE STRING "device string, default empty string")
string(TOLOWER "${DEVICE}" DEVICE)

list(APPEND SUPPORTED_DEVICE "ascend" "maca" "camb")
list(APPEND SUPPORTED_DEVICE "ascend" "maca" "camb" "ix")

if(NOT DEVICE)
message(FATAL_ERROR "Please specify variable DEVICE of dlinfer!")
Expand Down
2 changes: 2 additions & 0 deletions dlinfer/vendor/ix/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Empty install target for ix device
install(TARGETS)
3 changes: 3 additions & 0 deletions dlinfer/vendor/ix/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .ix_ops import *

device_str = "cuda"
333 changes: 333 additions & 0 deletions dlinfer/vendor/ix/ix_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,333 @@
import os
import math
import torch
import lmdeploy.pytorch.distributed as dist

from dlinfer.vendor import vendor_ops_registry
from dlinfer.utils.registry import register_ops
from dlinfer.utils.type_annotation import Tensor, Optional, Sequence, Tuple

import ixformer.inference.functions as ops
import ixformer.functions as ix_func
from ixformer.contrib.vllm_flash_attn import (
flash_attn_varlen_func as _flash_attn_varlen_func,
)
from ixformer.contrib.vllm_flash_attn import (
flash_attn_with_kvcache as _flash_attn_with_kvcache,
)

__all__ = [
"add_rms_norm",
"apply_rotary_pos_emb",
"prefill_attention",
"fused_moe",
"fill_kv_cache",
"paged_decode_attention",
"paged_prefill_attention",
"rms_norm",
"silu_and_mul",
"moe_gating_topk_softmax",
"linear",
"weight_quant_matmul",
"dynamic_quant",
"linear_w8a8",
"rms_norm_w8a8",
"add_rms_norm_w8a8",
]


@register_ops(vendor_ops_registry)
def add_rms_norm(
hidden_states: Tensor,
residual: Tensor,
weight: Tensor,
epsilon: float,
) -> Tuple[Tensor, Tensor]:
return ix_func.residual_rms_norm(
input=hidden_states,
residual=residual,
weight=weight,
eps=epsilon,
residual_alpha=1,
)


@register_ops(vendor_ops_registry)
def apply_rotary_pos_emb(
query: Tensor,
key: Tensor,
cos: Optional[Tensor],
sin: Optional[Tensor],
) -> Tuple[Tensor, Tensor]:
query = query.contiguous().unsqueeze(0)
key = key.contiguous().unsqueeze(0)
position_ids_1d = torch.arange(0, query.size(1), device=query.device)
query = query.flatten(-2, -1)
key = key.flatten(-2, -1)
cos = cos[..., : cos.shape[-1] // 2]
sin = sin[..., : sin.shape[-1] // 2 :]
cos_sin_cache = torch.cat((cos, sin), dim=-1)

ops.vllm_rotary_embedding(
position_ids_1d, query, key, cos_sin_cache.size(-1), cos_sin_cache, True
)
return query, key


@register_ops(vendor_ops_registry)
def prefill_attention(
query: Tensor,
key: Tensor,
value: Tensor,
q_start_loc: Tensor,
q_seq_len: Tensor,
max_q_seq_len: int,
num_q_heads: int,
num_kv_heads: int,
attn_mask: Sequence[Optional[Tensor]],
softmax_scale: Optional[float],
alibi_slopes: Optional[Sequence[float]],
attn_output: Optional[Tensor],
) -> Tensor:

if q_seq_len is None:
q_seq_len = max_q_seq_len
kv_seq_len = q_seq_len
max_kv_seq_len = max_q_seq_len

causal = True
if softmax_scale is None:
softmax_scale = float(1 / math.sqrt(key.size(-1)))
_flash_attn_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=q_start_loc,
cu_seqlens_k=q_start_loc,
max_seqlen_q=max_q_seq_len,
max_seqlen_k=max_kv_seq_len,
softmax_scale=softmax_scale,
causal=causal,
out=attn_output,
)

return attn_output


@register_ops(vendor_ops_registry)
def fill_kv_cache(
key: Tensor,
value: Tensor,
key_cache: Tensor,
value_cache: Tensor,
kv_indices: Tensor,
k_scales_zeros: Sequence[Optional[Tensor]],
v_scales_zeros: Sequence[Optional[Tensor]],
quant_bits: int,
) -> Tuple[Tensor, Tensor]:
kv_indices = kv_indices.squeeze(-1)
ops.reshape_and_cache_flash(
key, value, key_cache, value_cache, kv_indices, "auto", 1.0, 1.0
)
return key_cache, value_cache


@register_ops(vendor_ops_registry)
def paged_decode_attention(
query: Tensor,
key_cache: Tensor,
value_cache: Tensor,
block_table: Optional[Tensor],
block_size: int,
kv_seq_len: Tensor,
max_kv_seq_len: int,
num_q_heads: int,
num_kv_heads: int,
softmax_scale: Optional[float],
alibi_slopes: Optional[Sequence[float]],
attn_output: Optional[Tensor],
kv_scales: Optional[Tensor],
kv_zeros: Optional[Tensor],
quant_bits: Optional[int],
) -> Tensor:
if alibi_slopes is not None:
raise RuntimeError("paged_decode_attention does not support alibi_slopes yet")

dim = query.size(-1)
num_kv_heads = value_cache.size(1)
block_size = value_cache.size(2)
batch_size = block_table.size(0)

if softmax_scale is None:
softmax_scale = float(1 / math.sqrt(query.size(-1)))

block_table = block_table.to(torch.int32)
kv_seq_len = kv_seq_len.to(torch.int32)

output = torch.empty_like(query)

ix_func.vllm_paged_attention(
output,
query,
key_cache,
value_cache,
num_kv_heads,
softmax_scale,
block_table,
kv_seq_len.cpu(),
kv_seq_len,
block_size,
max_kv_seq_len,
None,
False,
need_view=False,
)
return output


@register_ops(vendor_ops_registry)
def paged_prefill_attention(
query: Tensor,
key: Tensor,
value: Tensor,
key_cache: Tensor,
value_cache: Tensor,
block_table: Tensor,
block_size: int,
q_start_loc: Tensor,
q_seq_len: Tensor,
kv_seq_len: Tensor,
cu_seq_lens_kv: Tensor,
max_q_seq_len: int,
max_kv_seq_len: int,
num_q_heads: int,
num_kv_heads: int,
attn_mask: Sequence[Optional[Tensor]],
softmax_scale: Optional[float],
alibi_slopes: Optional[Sequence[float]],
attn_output: Optional[Tensor],
kv_scales: Optional[Tensor],
kv_zeros: Optional[Tensor],
quant_bits: Optional[int],
) -> Tensor:
raise NotImplementedError("Not implemented on ix.")


@register_ops(vendor_ops_registry)
def rms_norm(
hidden_states: Tensor,
weight: Tensor,
epsilon: float,
) -> Tensor:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
weight = weight.to(torch.float32)
output = torch.empty_like(hidden_states)

ops.rms_norm(hidden_states, weight, epsilon, output)

return output.to(input_dtype)


@register_ops(vendor_ops_registry)
def moe_gating_topk_softmax(
router_logits: Tensor, topk: int, renormalize: bool = False
) -> Tuple[Tensor, Tensor]:
raise NotImplementedError("Not implemented on ix.")


@register_ops(vendor_ops_registry)
def silu_and_mul(x: Tensor, dim: int = -1) -> Tensor:
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)

ops.silu_and_mul(x, out)
return out


@register_ops(vendor_ops_registry)
def fused_moe(
hidden_states: Tensor,
gate_up_weights: Tensor,
down_weights: Tensor,
topk_weights: Tensor,
topk_ids: Tensor,
top_k: int,
renormalize: bool,
) -> Tensor:
raise NotImplementedError("Not implemented on ix.")


@register_ops(vendor_ops_registry)
def linear(
x: Tensor,
weight: Tensor,
bias: Optional[Tensor],
all_reduce: Optional[bool],
group: Optional[str],
) -> Tensor:
if os.getenv("DLINER_LINEAR_USE_NN_LAYOUT", "0") == "1":
out = torch.matmul(x, weight)
if bias is not None:
out += bias
else:
out = torch.nn.functional.linear(x, weight, bias)
if all_reduce:
dist.all_reduce(out)
return out


# Quantification of W4A16 is currently supported and tested.
@register_ops(vendor_ops_registry)
def weight_quant_matmul(
x: Tensor,
qweight: Tensor,
scale: Tensor,
offset: Optional[Tensor] = None,
bias: Optional[Tensor] = None,
all_reduce: Optional[bool] = False,
group_size: Optional[int] = 0,
):
raise NotImplementedError("Not implemented on ix.")


@register_ops(vendor_ops_registry)
def dynamic_quant(
x: Tensor, quant_dtype: torch.dtype, quant_granularity: str = "PER_TOKEN"
):
raise NotImplementedError("Not implemented on ix.")


@register_ops(vendor_ops_registry)
def linear_w8a8(
a: Tensor,
b: Tensor,
rms_scale: float,
linear_scale: float,
out_dtype: torch.dtype,
quant_dtype: torch.dtype = torch.int8,
bias: Tensor = None,
):
raise NotImplementedError("Not implemented on ix.")


@register_ops(vendor_ops_registry)
def rms_norm_w8a8(
hidden_states: Tensor,
weight: Tensor,
epsilon: float,
quant_dtype: torch.dtype = torch.int8,
):
raise NotImplementedError("Not implemented on ix.")


@register_ops(vendor_ops_registry)
def add_rms_norm_w8a8(
hidden_states: Tensor,
residual: Tensor,
weight: Tensor,
epsilon: float,
quant_dtype: torch.dtype = torch.int8,
):
raise NotImplementedError("Not implemented on ix.")
6 changes: 6 additions & 0 deletions requirements/ix/build.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
ninja
setuptools
wheel
scikit-build
cmake>=3.18
-r torch.txt
2 changes: 2 additions & 0 deletions requirements/ix/full.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
-r build.txt
-r runtime.txt
2 changes: 2 additions & 0 deletions requirements/ix/runtime.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
transformers
-r torch.txt
2 changes: 2 additions & 0 deletions requirements/ix/torch.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
torch
torchvision
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"ascend": "PrivateUse1",
"maca": "CUDA",
"camb": "PrivateUse1",
"ix": "CUDA",
}


Expand Down
Loading