From 9f3823682a42def62e5c2a5339c7a6bf3b799a65 Mon Sep 17 00:00:00 2001 From: tangzhiyi11 Date: Wed, 4 Jun 2025 08:48:29 +0000 Subject: [PATCH] [ix] support ix device --- CMakeLists.txt | 2 +- dlinfer/vendor/ix/CMakeLists.txt | 2 + dlinfer/vendor/ix/__init__.py | 3 + dlinfer/vendor/ix/ix_ops.py | 333 +++++++++++++++++++++++++++++++ requirements/ix/build.txt | 6 + requirements/ix/full.txt | 2 + requirements/ix/runtime.txt | 2 + requirements/ix/torch.txt | 2 + setup.py | 1 + 9 files changed, 352 insertions(+), 1 deletion(-) create mode 100644 dlinfer/vendor/ix/CMakeLists.txt create mode 100644 dlinfer/vendor/ix/__init__.py create mode 100644 dlinfer/vendor/ix/ix_ops.py create mode 100644 requirements/ix/build.txt create mode 100644 requirements/ix/full.txt create mode 100644 requirements/ix/runtime.txt create mode 100644 requirements/ix/torch.txt diff --git a/CMakeLists.txt b/CMakeLists.txt index 723cc844..154d2d3f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -16,7 +16,7 @@ endif() set(DEVICE "" CACHE STRING "device string, default empty string") string(TOLOWER "${DEVICE}" DEVICE) -list(APPEND SUPPORTED_DEVICE "ascend" "maca" "camb") +list(APPEND SUPPORTED_DEVICE "ascend" "maca" "camb" "ix") if(NOT DEVICE) message(FATAL_ERROR "Please specify variable DEVICE of dlinfer!") diff --git a/dlinfer/vendor/ix/CMakeLists.txt b/dlinfer/vendor/ix/CMakeLists.txt new file mode 100644 index 00000000..7cf657d8 --- /dev/null +++ b/dlinfer/vendor/ix/CMakeLists.txt @@ -0,0 +1,2 @@ +# Empty install target for ix device +install(TARGETS) diff --git a/dlinfer/vendor/ix/__init__.py b/dlinfer/vendor/ix/__init__.py new file mode 100644 index 00000000..b83cd52f --- /dev/null +++ b/dlinfer/vendor/ix/__init__.py @@ -0,0 +1,3 @@ +from .ix_ops import * + +device_str = "cuda" diff --git a/dlinfer/vendor/ix/ix_ops.py b/dlinfer/vendor/ix/ix_ops.py new file mode 100644 index 00000000..a5bb9e4b --- /dev/null +++ b/dlinfer/vendor/ix/ix_ops.py @@ -0,0 +1,333 @@ +import os +import math +import torch +import lmdeploy.pytorch.distributed as dist + +from dlinfer.vendor import vendor_ops_registry +from dlinfer.utils.registry import register_ops +from dlinfer.utils.type_annotation import Tensor, Optional, Sequence, Tuple + +import ixformer.inference.functions as ops +import ixformer.functions as ix_func +from ixformer.contrib.vllm_flash_attn import ( + flash_attn_varlen_func as _flash_attn_varlen_func, +) +from ixformer.contrib.vllm_flash_attn import ( + flash_attn_with_kvcache as _flash_attn_with_kvcache, +) + +__all__ = [ + "add_rms_norm", + "apply_rotary_pos_emb", + "prefill_attention", + "fused_moe", + "fill_kv_cache", + "paged_decode_attention", + "paged_prefill_attention", + "rms_norm", + "silu_and_mul", + "moe_gating_topk_softmax", + "linear", + "weight_quant_matmul", + "dynamic_quant", + "linear_w8a8", + "rms_norm_w8a8", + "add_rms_norm_w8a8", +] + + +@register_ops(vendor_ops_registry) +def add_rms_norm( + hidden_states: Tensor, + residual: Tensor, + weight: Tensor, + epsilon: float, +) -> Tuple[Tensor, Tensor]: + return ix_func.residual_rms_norm( + input=hidden_states, + residual=residual, + weight=weight, + eps=epsilon, + residual_alpha=1, + ) + + +@register_ops(vendor_ops_registry) +def apply_rotary_pos_emb( + query: Tensor, + key: Tensor, + cos: Optional[Tensor], + sin: Optional[Tensor], +) -> Tuple[Tensor, Tensor]: + query = query.contiguous().unsqueeze(0) + key = key.contiguous().unsqueeze(0) + position_ids_1d = torch.arange(0, query.size(1), device=query.device) + query = query.flatten(-2, -1) + key = key.flatten(-2, -1) + cos = cos[..., : cos.shape[-1] // 2] + sin = sin[..., : sin.shape[-1] // 2 :] + cos_sin_cache = torch.cat((cos, sin), dim=-1) + + ops.vllm_rotary_embedding( + position_ids_1d, query, key, cos_sin_cache.size(-1), cos_sin_cache, True + ) + return query, key + + +@register_ops(vendor_ops_registry) +def prefill_attention( + query: Tensor, + key: Tensor, + value: Tensor, + q_start_loc: Tensor, + q_seq_len: Tensor, + max_q_seq_len: int, + num_q_heads: int, + num_kv_heads: int, + attn_mask: Sequence[Optional[Tensor]], + softmax_scale: Optional[float], + alibi_slopes: Optional[Sequence[float]], + attn_output: Optional[Tensor], +) -> Tensor: + + if q_seq_len is None: + q_seq_len = max_q_seq_len + kv_seq_len = q_seq_len + max_kv_seq_len = max_q_seq_len + + causal = True + if softmax_scale is None: + softmax_scale = float(1 / math.sqrt(key.size(-1))) + _flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=q_start_loc, + cu_seqlens_k=q_start_loc, + max_seqlen_q=max_q_seq_len, + max_seqlen_k=max_kv_seq_len, + softmax_scale=softmax_scale, + causal=causal, + out=attn_output, + ) + + return attn_output + + +@register_ops(vendor_ops_registry) +def fill_kv_cache( + key: Tensor, + value: Tensor, + key_cache: Tensor, + value_cache: Tensor, + kv_indices: Tensor, + k_scales_zeros: Sequence[Optional[Tensor]], + v_scales_zeros: Sequence[Optional[Tensor]], + quant_bits: int, +) -> Tuple[Tensor, Tensor]: + kv_indices = kv_indices.squeeze(-1) + ops.reshape_and_cache_flash( + key, value, key_cache, value_cache, kv_indices, "auto", 1.0, 1.0 + ) + return key_cache, value_cache + + +@register_ops(vendor_ops_registry) +def paged_decode_attention( + query: Tensor, + key_cache: Tensor, + value_cache: Tensor, + block_table: Optional[Tensor], + block_size: int, + kv_seq_len: Tensor, + max_kv_seq_len: int, + num_q_heads: int, + num_kv_heads: int, + softmax_scale: Optional[float], + alibi_slopes: Optional[Sequence[float]], + attn_output: Optional[Tensor], + kv_scales: Optional[Tensor], + kv_zeros: Optional[Tensor], + quant_bits: Optional[int], +) -> Tensor: + if alibi_slopes is not None: + raise RuntimeError("paged_decode_attention does not support alibi_slopes yet") + + dim = query.size(-1) + num_kv_heads = value_cache.size(1) + block_size = value_cache.size(2) + batch_size = block_table.size(0) + + if softmax_scale is None: + softmax_scale = float(1 / math.sqrt(query.size(-1))) + + block_table = block_table.to(torch.int32) + kv_seq_len = kv_seq_len.to(torch.int32) + + output = torch.empty_like(query) + + ix_func.vllm_paged_attention( + output, + query, + key_cache, + value_cache, + num_kv_heads, + softmax_scale, + block_table, + kv_seq_len.cpu(), + kv_seq_len, + block_size, + max_kv_seq_len, + None, + False, + need_view=False, + ) + return output + + +@register_ops(vendor_ops_registry) +def paged_prefill_attention( + query: Tensor, + key: Tensor, + value: Tensor, + key_cache: Tensor, + value_cache: Tensor, + block_table: Tensor, + block_size: int, + q_start_loc: Tensor, + q_seq_len: Tensor, + kv_seq_len: Tensor, + cu_seq_lens_kv: Tensor, + max_q_seq_len: int, + max_kv_seq_len: int, + num_q_heads: int, + num_kv_heads: int, + attn_mask: Sequence[Optional[Tensor]], + softmax_scale: Optional[float], + alibi_slopes: Optional[Sequence[float]], + attn_output: Optional[Tensor], + kv_scales: Optional[Tensor], + kv_zeros: Optional[Tensor], + quant_bits: Optional[int], +) -> Tensor: + raise NotImplementedError("Not implemented on ix.") + + +@register_ops(vendor_ops_registry) +def rms_norm( + hidden_states: Tensor, + weight: Tensor, + epsilon: float, +) -> Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + weight = weight.to(torch.float32) + output = torch.empty_like(hidden_states) + + ops.rms_norm(hidden_states, weight, epsilon, output) + + return output.to(input_dtype) + + +@register_ops(vendor_ops_registry) +def moe_gating_topk_softmax( + router_logits: Tensor, topk: int, renormalize: bool = False +) -> Tuple[Tensor, Tensor]: + raise NotImplementedError("Not implemented on ix.") + + +@register_ops(vendor_ops_registry) +def silu_and_mul(x: Tensor, dim: int = -1) -> Tensor: + d = x.shape[-1] // 2 + output_shape = x.shape[:-1] + (d,) + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) + + ops.silu_and_mul(x, out) + return out + + +@register_ops(vendor_ops_registry) +def fused_moe( + hidden_states: Tensor, + gate_up_weights: Tensor, + down_weights: Tensor, + topk_weights: Tensor, + topk_ids: Tensor, + top_k: int, + renormalize: bool, +) -> Tensor: + raise NotImplementedError("Not implemented on ix.") + + +@register_ops(vendor_ops_registry) +def linear( + x: Tensor, + weight: Tensor, + bias: Optional[Tensor], + all_reduce: Optional[bool], + group: Optional[str], +) -> Tensor: + if os.getenv("DLINER_LINEAR_USE_NN_LAYOUT", "0") == "1": + out = torch.matmul(x, weight) + if bias is not None: + out += bias + else: + out = torch.nn.functional.linear(x, weight, bias) + if all_reduce: + dist.all_reduce(out) + return out + + +# Quantification of W4A16 is currently supported and tested. +@register_ops(vendor_ops_registry) +def weight_quant_matmul( + x: Tensor, + qweight: Tensor, + scale: Tensor, + offset: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + all_reduce: Optional[bool] = False, + group_size: Optional[int] = 0, +): + raise NotImplementedError("Not implemented on ix.") + + +@register_ops(vendor_ops_registry) +def dynamic_quant( + x: Tensor, quant_dtype: torch.dtype, quant_granularity: str = "PER_TOKEN" +): + raise NotImplementedError("Not implemented on ix.") + + +@register_ops(vendor_ops_registry) +def linear_w8a8( + a: Tensor, + b: Tensor, + rms_scale: float, + linear_scale: float, + out_dtype: torch.dtype, + quant_dtype: torch.dtype = torch.int8, + bias: Tensor = None, +): + raise NotImplementedError("Not implemented on ix.") + + +@register_ops(vendor_ops_registry) +def rms_norm_w8a8( + hidden_states: Tensor, + weight: Tensor, + epsilon: float, + quant_dtype: torch.dtype = torch.int8, +): + raise NotImplementedError("Not implemented on ix.") + + +@register_ops(vendor_ops_registry) +def add_rms_norm_w8a8( + hidden_states: Tensor, + residual: Tensor, + weight: Tensor, + epsilon: float, + quant_dtype: torch.dtype = torch.int8, +): + raise NotImplementedError("Not implemented on ix.") diff --git a/requirements/ix/build.txt b/requirements/ix/build.txt new file mode 100644 index 00000000..1a572028 --- /dev/null +++ b/requirements/ix/build.txt @@ -0,0 +1,6 @@ +ninja +setuptools +wheel +scikit-build +cmake>=3.18 +-r torch.txt diff --git a/requirements/ix/full.txt b/requirements/ix/full.txt new file mode 100644 index 00000000..00e009da --- /dev/null +++ b/requirements/ix/full.txt @@ -0,0 +1,2 @@ +-r build.txt +-r runtime.txt diff --git a/requirements/ix/runtime.txt b/requirements/ix/runtime.txt new file mode 100644 index 00000000..2646956e --- /dev/null +++ b/requirements/ix/runtime.txt @@ -0,0 +1,2 @@ +transformers +-r torch.txt diff --git a/requirements/ix/torch.txt b/requirements/ix/torch.txt new file mode 100644 index 00000000..ac988bdf --- /dev/null +++ b/requirements/ix/torch.txt @@ -0,0 +1,2 @@ +torch +torchvision diff --git a/setup.py b/setup.py index 70f5b4bc..a5086c80 100644 --- a/setup.py +++ b/setup.py @@ -11,6 +11,7 @@ "ascend": "PrivateUse1", "maca": "CUDA", "camb": "PrivateUse1", + "ix": "CUDA", }