Skip to content

[metax-gpu] adapt metax-gpu maca platform #2884

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,15 @@ function copy_ops(){
echo -e "BASE and ROCM ops have been copy to fastdeploy"
return
fi
is_maca=`$python -c "import paddle; print(paddle.device.is_compiled_with_custom_device('metax_gpu'))"`
if [ "$is_maca" = "True" ]; then
DEVICE_TYPE="gpu"
mkdir -p ../fastdeploy/model_executor/ops/base
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/gpu
echo -e "MACA ops have been copy to fastdeploy"
return
fi
mkdir -p ../fastdeploy/model_executor/ops/base
is_cuda=`$python -c "import paddle; print(paddle.is_compiled_with_cuda())"`
if [ "$is_cuda" = "True" ]; then
Expand Down
2 changes: 1 addition & 1 deletion custom_ops/gpu_ops/get_padding_offset.cu
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
const paddle::Tensor &cum_offsets,
const paddle::Tensor &token_num,
const paddle::Tensor &seq_len) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#if defined(PADDLE_WITH_CUSTOM_DEVICE) && !defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些是必须的吗?

auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(input_ids.place()));
auto cu_stream = dev_ctx->stream();
#else
Expand Down
2 changes: 2 additions & 0 deletions custom_ops/gpu_ops/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,7 @@ static void PrintMatrix3(const T *mat_d, int num, std::string name) {
}

#ifndef PADDLE_WITH_HIP
#ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
__forceinline__ __device__ uint32_t ld_flag_acquire(uint32_t *flag_addr,
int mode = 0) {
uint32_t flag;
Expand Down Expand Up @@ -541,6 +542,7 @@ __forceinline__ __device__ void st_flag_release(uint32_t *flag_addr,
"l"(flag_addr));
}
}
#endif

inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
int max_shared_mem_per_block_opt_in = 0;
Expand Down
2 changes: 1 addition & 1 deletion custom_ops/gpu_ops/rebuild_padding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ std::vector<paddle::Tensor> rebuild_padding(
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;

#ifdef PADDLE_WITH_CUSTOM_DEVICE
#if defined(PADDLE_WITH_CUSTOM_DEVICE) && !defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(tmp_out.place()));
auto cu_stream = dev_ctx->stream();
#else
Expand Down
2 changes: 1 addition & 1 deletion custom_ops/gpu_ops/set_value_by_flags.cu
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ void SetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &step_idx,
const paddle::Tensor &stop_flags) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#if defined(PADDLE_WITH_CUSTOM_DEVICE) && !defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(stop_flags.place()));
auto cu_stream = dev_ctx->stream();
#else
Expand Down
2 changes: 1 addition & 1 deletion custom_ops/gpu_ops/step.cu
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ void StepPaddle(const paddle::Tensor &stop_flags,
const paddle::Tensor &first_token_ids,
const int block_size,
const int encoder_decoder_block_num) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#if defined(PADDLE_WITH_CUSTOM_DEVICE) && !defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(seq_lens_this_time.place()));
auto cu_stream = dev_ctx->stream();
#else
Expand Down
2 changes: 1 addition & 1 deletion custom_ops/gpu_ops/stop_generation_multi_ends.cu
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
}
}

#ifdef PADDLE_WITH_CUSTOM_DEVICE
#if defined(PADDLE_WITH_CUSTOM_DEVICE) && !defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(topk_ids.place()));
auto cu_stream = dev_ctx->stream();
#else
Expand Down
2 changes: 1 addition & 1 deletion custom_ops/gpu_ops/stop_generation_multi_stop_seqs.cu
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ void GetStopFlagsMultiSeqs(const paddle::Tensor &topk_ids,
PD_CHECK(topk_ids.dtype() == paddle::DataType::INT64);
PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL);

#ifdef PADDLE_WITH_CUSTOM_DEVICE
#if defined(PADDLE_WITH_CUSTOM_DEVICE) && !defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(topk_ids.place()));
auto cu_stream = dev_ctx->stream();
#else
Expand Down
2 changes: 1 addition & 1 deletion custom_ops/gpu_ops/token_penalty_multi_scores.cu
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids,
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#if defined(PADDLE_WITH_CUSTOM_DEVICE) && !defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(logits.place()));
auto cu_stream = dev_ctx->stream();
#else
Expand Down
2 changes: 1 addition & 1 deletion custom_ops/gpu_ops/update_inputs.cu
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ void UpdateInputes(const paddle::Tensor &stop_flags,
const paddle::Tensor &stop_nums,
const paddle::Tensor &next_tokens,
const paddle::Tensor &is_block_step) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#if defined(PADDLE_WITH_CUSTOM_DEVICE) && !defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(input_ids.place()));
auto cu_stream = dev_ctx->stream();
#else
Expand Down
63 changes: 63 additions & 0 deletions custom_ops/setup_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,69 @@ def find_end_files(directory, end_str):
},
),
)
elif paddle.device.is_compiled_with_custom_device('metax_gpu'):
maca_path = os.getenv("MACA_PATH", "/opt/maca")
json_dir = "third_party/nlohmann_json"
if not os.path.exists(json_dir) or not os.listdir(json_dir):
if not os.path.exists(json_dir):
os.makedirs(json_dir)
clone_git_repo("v3.11.3", "https://bgithub.xyz/nlohmann/json.git", json_dir)
if not os.listdir(json_dir):
raise ValueError("Git clone nlohmann_json failed!")
sources=[
"gpu_ops/save_with_output.cc",
"gpu_ops/set_mask_value.cu",
"gpu_ops/set_value_by_flags.cu",
"gpu_ops/ngram_mask.cu",
"gpu_ops/gather_idx.cu",
"gpu_ops/get_output_ep.cc",
"gpu_ops/token_penalty_multi_scores.cu",
"gpu_ops/token_penalty_only_once.cu",
"gpu_ops/stop_generation.cu",
"gpu_ops/stop_generation_multi_ends.cu",
"gpu_ops/stop_generation_multi_stop_seqs.cu",
"gpu_ops/set_flags.cu",
"gpu_ops/fused_get_rope.cu",
"gpu_ops/get_padding_offset.cu",
"gpu_ops/update_inputs.cu",
"gpu_ops/update_inputs_beam.cu",
"gpu_ops/beam_search_softmax.cu",
"gpu_ops/rebuild_padding.cu",
"gpu_ops/step.cu",
"gpu_ops/step_reschedule.cu",
"gpu_ops/step_system_cache.cu",
"gpu_ops/set_data_ipc.cu",
"gpu_ops/read_data_ipc.cu",
"gpu_ops/dequant_int8.cu",
"gpu_ops/share_external_data.cu",
"gpu_ops/extract_text_token_output.cu",
"gpu_ops/text_image_gather_scatter.cu",
"gpu_ops/text_image_index_out.cu",
"gpu_ops/tune_cublaslt_gemm.cu",
"gpu_ops/moe/tritonmoe_preprocess.cu",
]
sources += find_end_files("gpu_ops/speculate_decoding", ".cu")
sources += find_end_files("gpu_ops/speculate_decoding", ".cc")
setup(
name="fastdeploy_ops",
ext_modules=CUDAExtension(
sources=sources,
extra_compile_args={
"cxx": ["-O3"],
"nvcc": [
"-O3",
"-Ithird_party/nlohmann_json/include",
"-Igpu_ops",
"-DPADDLE_WITH_CUSTOM_DEVICE",
"-DPADDLE_WITH_CUSTOM_DEVICE_METAX_GPU",
"-DPADDLE_DEV",
],
},
library_dirs=[os.path.join(maca_path, "lib")],
extra_link_args=["-lruntime_cu"],
include_dirs=[os.path.join(maca_path, "include"), os.path.join(maca_path, "include/mcr"), os.path.join(maca_path, "include/common")],
),
)
elif paddle.is_compiled_with_cuda():
sources = [
"gpu_ops/set_mask_value.cu",
Expand Down
8 changes: 7 additions & 1 deletion fastdeploy/model_executor/forward_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class ForwardMode(IntEnum):
DECODE = auto()
# Mixed mode
MIXED = auto()
# Native mode
NATIVE = auto()

def is_prefill(self):
""" Is Extend mode """
Expand All @@ -49,6 +51,10 @@ def is_mixed(self):
""" Is Mixed mode """
return self == ForwardMode.MIXED

def is_native(self):
""" Is Native mode """
return self == ForwardMode.NATIVE


@dataclass
class ForwardMeta():
Expand All @@ -68,7 +74,7 @@ class ForwardMeta():
# Attention backend object
attn_backend: AttentionBackend = None
# Forward mode used during attention
forward_mode: ForwardMode = ForwardMode.MIXED
forward_mode: ForwardMode = ForwardMode.MIXED if not paddle.device.is_compiled_with_custom_device('metax_gpu') else ForwardMode.NATIVE
# Attention mask
attn_mask: Optional[paddle.Tensor] = None
# Decoder batch id. Used by attention backend.
Expand Down
2 changes: 1 addition & 1 deletion fastdeploy/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(
super().__init__()

if current_platform.is_cuda() or current_platform.is_xpu(
) or current_platform.is_iluvatar():
) or current_platform.is_iluvatar() or current_platform.is_maca():
self.forward = self.forward_cuda
elif current_platform.is_gcu():
self.forward = self.forward_gcu
Expand Down
4 changes: 3 additions & 1 deletion fastdeploy/model_executor/layers/attention/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@
from .xpu_attn_backend import XPUAttentionBackend
from .iluvatar_attn_backend import IluvatarAttnBackend
from .block_multihead_attn_backend import BlockAttentionBackend
from .flash_attention_interface import flash_attn_func, flash_attn_unpadded_func, flash_attn_kvcache_func

__all__ = [
"AttentionBackend", "PaddleNativeAttnBackend",
"get_attention_backend", "AppendAttentionBackend", "XPUAttentionBackend",
"MLAAttentionBackend", "FlashAttentionBackend", "IluvatarAttnBackend",
"BlockAttentionBackend"
"BlockAttentionBackend",
"flash_attn_func", "flash_attn_unpadded_func", "flash_attn_kvcache_func",
]
19 changes: 19 additions & 0 deletions fastdeploy/model_executor/layers/attention/append_attn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int,

def init_attention_metadata(self, forward_meta: ForwardMeta):
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
if forward_meta.forward_mode.is_native():
return
metadata = AppendAttentionMetadata()
metadata.encoder_block_shape_q = 64
metadata.decoder_block_shape_q = 16
Expand Down Expand Up @@ -260,3 +262,20 @@ def forward_mixed(
self.speculative_method is not None,
)[0]
return res

def forward_native_backend(
self,
q,
k,
v,
qkv,
layer: Attention,
forward_meta: ForwardMeta,
):
"""
forward_mixed
TODO(vivienfanghuagood) WIP
"""
from .native_attention_util import forward_native_backend
out = forward_native_backend(qkv, layer, self, forward_meta)
return out
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,15 @@ def forward(
layer,
forward_meta,
)
elif forward_meta.forward_mode.is_native():
return self.forward_native_backend(
q,
k,
v,
qkv,
layer,
forward_meta,
)
else:
return self.forward_extend(
q,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""
# Copyright (c) 2025 MetaX-tech Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""

import paddle
from paddle import _C_ops
from typing import Optional, Union, Tuple
from paddle import Tensor

import os

for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")):
if lib.endswith(".so"):
paddle.utils.cpp_extension.extension_utils.load_op_meta_info_and_register_op(
lib
)

def flash_attn_func(
q: Tensor,
k: Tensor,
v: Tensor,
fixed_seed_offset: Optional[Tensor] = None,
attn_mask: Optional[Tensor] = None,
dropout_prob: float = 0.0,
causal: bool = False,
return_softmax: bool = False,
is_test: bool = True,
rng_name: str = ""
) -> Union[Tensor, Tuple[Tensor, ...]]:
return paddle._C_ops.flash_attn(
q, k, v,
fixed_seed_offset,
attn_mask,
dropout_prob,
causal,
return_softmax,
is_test,
rng_name
)

def flash_attn_unpadded_func(
q: Tensor,
k: Tensor,
v: Tensor,
cu_seqlens_q: Tensor,
cu_seqlens_k: Tensor,
max_seqlen_q: Union[int, float],
max_seqlen_k: Union[int, float],
fixed_seed_offset: Optional[Tensor] = None,
attn_mask: Optional[Tensor] = None,
softmax_scale: float = 1.0,
dropout: float = 0.0,
causal: bool = False,
return_softmax: bool = False,
is_test: bool = True,
rng_name: str = ""
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
max_seqlen_q_t = paddle.to_tensor(max_seqlen_q, dtype='int64')
max_seqlen_k_t = paddle.to_tensor(max_seqlen_k, dtype='int64')

outputs = paddle._C_ops.flash_attn_unpadded(
q, k, v,
cu_seqlens_q, cu_seqlens_k,
fixed_seed_offset, attn_mask,
max_seqlen_q_t, max_seqlen_k_t,
softmax_scale,
dropout,
causal,
return_softmax,
is_test,
rng_name
)
return outputs

def flash_attn_kvcache_func(
q: Tensor,
k_cache: Tensor,
v_cache: Tensor,
seqlens_k: Tensor,
block_table: Tensor,
k: Optional[Tensor] = None,
v: Optional[Tensor] = None,
rotary_cos: Optional[Tensor] = None,
rotary_sin: Optional[Tensor] = None,
cache_batch_idx: Optional[Tensor] = None,
causal: bool = True,
is_rotary_interleaved: bool = False,
num_splits: int = 1,
dropout: float = 0.0,
return_softmax: bool = False
) -> Tuple[Tensor, Tensor]:
out, softmax_lse = paddle._C_ops._run_custom_op(
"flash_attn_kvcache",
q,
k_cache,
v_cache,
k,
v,
seqlens_k,
rotary_cos,
rotary_sin,
cache_batch_idx,
block_table,
causal,
is_rotary_interleaved,
num_splits,
dropout,
return_softmax
)

return out, softmax_lse
Loading
Loading