Skip to content

Commit 5ffd00e

Browse files
committed
feat: implement GEMM with MUBLAS and MUDNN backends in moore gpu
1 parent c7373fe commit 5ffd00e

File tree

6 files changed

+320
-7
lines changed

6 files changed

+320
-7
lines changed
Lines changed: 98 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,104 @@
11
#ifndef __GEMM_MOORE_H__
22
#define __GEMM_MOORE_H__
33

4-
#include "../gemm.h"
4+
#include "mublas/gemm_mublas.h"
5+
#include "mudnn/gemm_mudnn.h"
56

6-
DESCRIPTOR(moore)
7+
namespace op::gemm::moore {
8+
9+
// Descriptor class for GEMM operations on Moore devices.
10+
// This class acts as a wrapper to select either mublas or mudnn backend.
11+
// It encapsulates the backend-specific Descriptor implementation and provides
12+
// a unified interface for workspace query and GEMM calculation.
13+
class Descriptor final : public InfiniopDescriptor {
14+
public:
15+
// Destructor: deletes the backend-specific descriptor.
16+
~Descriptor() {
17+
if (_backend == Backend::MUBLAS) {
18+
delete reinterpret_cast<mublas::Descriptor *>(_impl);
19+
} else {
20+
delete reinterpret_cast<mudnn::Descriptor *>(_impl);
21+
}
22+
}
23+
24+
// Returns the required workspace size for the GEMM operation.
25+
size_t workspaceSize() const {
26+
if (_backend == Backend::MUBLAS) {
27+
return reinterpret_cast<mublas::Descriptor *>(_impl)->workspaceSize();
28+
} else {
29+
return reinterpret_cast<mudnn::Descriptor *>(_impl)->workspaceSize();
30+
}
31+
}
32+
33+
// Static factory method to create a Descriptor instance.
34+
// This method chooses the backend (mublas or mudnn) and constructs
35+
// the corresponding implementation internally.
36+
static infiniStatus_t create(
37+
infiniopHandle_t handle,
38+
Descriptor **desc_ptr,
39+
infiniopTensorDescriptor_t c_desc,
40+
infiniopTensorDescriptor_t a_desc,
41+
infiniopTensorDescriptor_t b_desc) {
42+
auto desc = new Descriptor(handle->device, handle->device_id);
43+
44+
// Backend selection strategy:
45+
// Currently defaulting to MUDNN.
46+
// Can be modified to choose based on environment variables or runtime parameters.
47+
desc->_backend = Backend::MUDNN;
48+
49+
if (desc->_backend == Backend::MUBLAS) {
50+
mublas::Descriptor *impl;
51+
auto status = mublas::Descriptor::create(handle, &impl, c_desc, a_desc, b_desc);
52+
if (status != INFINI_STATUS_SUCCESS) {
53+
delete desc;
54+
return status;
55+
}
56+
desc->_impl = impl;
57+
} else {
58+
mudnn::Descriptor *impl;
59+
auto status = mudnn::Descriptor::create(handle, &impl, c_desc, a_desc, b_desc);
60+
if (status != INFINI_STATUS_SUCCESS) {
61+
delete desc;
62+
return status;
63+
}
64+
desc->_impl = impl;
65+
}
66+
67+
*desc_ptr = desc;
68+
return INFINI_STATUS_SUCCESS;
69+
}
70+
71+
// Unified GEMM calculation interface.
72+
// Calls the corresponding backend's calculate function internally.
73+
infiniStatus_t calculate(
74+
void *workspace, size_t workspace_size,
75+
void *c, float beta,
76+
const void *a, const void *b,
77+
float alpha,
78+
void *stream) const {
79+
if (_backend == Backend::MUBLAS) {
80+
return reinterpret_cast<mublas::Descriptor *>(_impl)
81+
->calculate(workspace, workspace_size, c, beta, a, b, alpha, stream);
82+
} else {
83+
return reinterpret_cast<mudnn::Descriptor *>(_impl)
84+
->calculate(workspace, workspace_size, c, beta, a, b, alpha, stream);
85+
}
86+
}
87+
88+
private:
89+
// Private constructor: ensures users cannot directly instantiate Descriptor.
90+
// Instances must be created via the static create() factory method.
91+
Descriptor(infiniDevice_t device_type, int device_id)
92+
: InfiniopDescriptor{device_type, device_id}, _impl(nullptr) {}
93+
94+
// Enum to indicate which backend is being used internally.
95+
enum class Backend { MUBLAS,
96+
MUDNN };
97+
98+
Backend _backend; // Currently selected MUBLAS/MUDNN backend
99+
void *_impl; // Pointer to backend-specific descriptor (mublas::Descriptor* or mudnn::Descriptor*)
100+
};
101+
102+
} // namespace op::gemm::moore
7103

8104
#endif // __GEMM_MOORE_H__
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#ifndef __GEMM_MUBLAS_H__
2+
#define __GEMM_MUBLAS_H__
3+
4+
#include "../../gemm.h"
5+
6+
DESCRIPTOR(mublas)
7+
8+
#endif // __GEMM_MUBLAS_H__

src/infiniop/ops/gemm/moore/gemm_moore.mu renamed to src/infiniop/ops/gemm/moore/mublas/gemm_mublas.mu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
#include "../../../devices/moore/moore_common.h"
2-
#include "../../../devices/moore/moore_handle.h"
3-
#include "gemm_moore.h"
1+
#include "../../../../devices/moore/moore_common.h"
2+
#include "../../../../devices/moore/moore_handle.h"
3+
#include "gemm_mublas.h"
44

5-
namespace op::gemm::moore {
5+
namespace op::gemm::mublas {
66

77
struct Descriptor::Opaque {
88
std::shared_ptr<device::moore::Handle::Internal> internal;
@@ -122,4 +122,4 @@ infiniStatus_t Descriptor::calculate(
122122
return INFINI_STATUS_SUCCESS;
123123
}
124124

125-
} // namespace op::gemm::moore
125+
} // namespace op::gemm::mublas
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#ifndef __GEMM_MUDNN_H__
2+
#define __GEMM_MUDNN_H__
3+
4+
#include "../../gemm.h"
5+
6+
DESCRIPTOR(mudnn)
7+
8+
#endif // __GEMM_MUDNN_H__
Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
#include "../../../../devices/moore/moore_common.h"
2+
#include "../../../../devices/moore/moore_handle.h"
3+
#include "gemm_mudnn.h"
4+
5+
#include <musa_bf16.h>
6+
7+
namespace op::gemm::mudnn {
8+
9+
struct Descriptor::Opaque {
10+
std::shared_ptr<device::moore::Handle::Internal> internal;
11+
};
12+
13+
Descriptor::~Descriptor() {
14+
delete _opaque;
15+
}
16+
17+
infiniStatus_t Descriptor::create(
18+
infiniopHandle_t handle_,
19+
Descriptor **desc_ptr,
20+
infiniopTensorDescriptor_t c_desc,
21+
infiniopTensorDescriptor_t a_desc,
22+
infiniopTensorDescriptor_t b_desc) {
23+
auto handle = reinterpret_cast<device::moore::Handle *>(handle_);
24+
auto dtype = c_desc->dtype();
25+
26+
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16);
27+
28+
auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::ROW_MAJOR);
29+
CHECK_RESULT(result);
30+
31+
*desc_ptr = new Descriptor(
32+
dtype, result.take(), 0,
33+
new Opaque{handle->internal()},
34+
handle->device, handle->device_id);
35+
return INFINI_STATUS_SUCCESS;
36+
}
37+
38+
template <typename Tdata>
39+
infiniStatus_t calculate(
40+
const MatmulInfo &info,
41+
std::shared_ptr<device::moore::Handle::Internal> &_internal,
42+
void *c,
43+
float beta,
44+
const void *a,
45+
const void *b,
46+
float alpha,
47+
void *stream)
48+
{
49+
// 0. For muDNN development, refer to the official documentation and the following headers:
50+
// - /usr/local/musa/include/mudnn_base.h
51+
// - /usr/local/musa/include/mudnn_math.h
52+
// - /usr/local/musa/include/mudnn.h
53+
54+
// 1. Create BatchMatMul operator
55+
auto matmul_operator = std::make_unique<::musa::dnn::BatchMatMul>();
56+
matmul_operator->SetComputeMode(::musa::dnn::BatchMatMul::ComputeMode::TENSOR);
57+
58+
// 2. Use _internal->useMudnn to manage muDNN handle
59+
return _internal->useMudnn((musaStream_t)stream, [&](::musa::dnn::Handle &mudnn_handle) -> infiniStatus_t {
60+
61+
// 3. Create BatchMatMul Tensor
62+
::musa::dnn::Tensor out, left, right;
63+
64+
if constexpr (std::is_same<Tdata, half>::value) {
65+
out.SetType(::musa::dnn::Tensor::Type::HALF);
66+
left.SetType(::musa::dnn::Tensor::Type::HALF);
67+
right.SetType(::musa::dnn::Tensor::Type::HALF);
68+
}
69+
else if constexpr (std::is_same<Tdata, __mt_bfloat16>::value){
70+
out.SetType(::musa::dnn::Tensor::Type::BFLOAT16);
71+
left.SetType(::musa::dnn::Tensor::Type::BFLOAT16);
72+
right.SetType(::musa::dnn::Tensor::Type::BFLOAT16);
73+
}
74+
else{
75+
out.SetType(::musa::dnn::Tensor::Type::FLOAT);
76+
left.SetType(::musa::dnn::Tensor::Type::FLOAT);
77+
right.SetType(::musa::dnn::Tensor::Type::FLOAT);
78+
}
79+
80+
// 4. Bind BatchMatMul Tensor addr
81+
out.SetAddr(c);
82+
left.SetAddr(a);
83+
right.SetAddr(b);
84+
85+
// 5. Config Tensor left
86+
std::array<int64_t, 3> a_dims_array;
87+
std::array<int64_t, 3> a_stride_array;
88+
if (info.a_matrix.col_stride != 1) {
89+
a_dims_array = { static_cast<int64_t>(info.batch),
90+
static_cast<int64_t>(info.k),
91+
static_cast<int64_t>(info.m) };
92+
} else {
93+
a_dims_array = { static_cast<int64_t>(info.batch),
94+
static_cast<int64_t>(info.m),
95+
static_cast<int64_t>(info.k) };
96+
}
97+
a_stride_array = { static_cast<int64_t>(info.a_matrix.stride),
98+
static_cast<int64_t>(info.a_matrix.ld()),
99+
1 };
100+
left.SetNdInfo(static_cast<int>(a_dims_array.size()), a_dims_array.data(), a_stride_array.data());
101+
102+
// 6. Config Tensor right
103+
std::array<int64_t, 3> b_dims_array;
104+
std::array<int64_t, 3> b_stride_array;
105+
if (info.b_matrix.col_stride != 1) {
106+
b_dims_array = { static_cast<int64_t>(info.batch),
107+
static_cast<int64_t>(info.n),
108+
static_cast<int64_t>(info.k) };
109+
} else {
110+
b_dims_array = { static_cast<int64_t>(info.batch),
111+
static_cast<int64_t>(info.k),
112+
static_cast<int64_t>(info.n) };
113+
}
114+
b_stride_array = { static_cast<int64_t>(info.b_matrix.stride),
115+
static_cast<int64_t>(info.b_matrix.ld()),
116+
1 };
117+
right.SetNdInfo(static_cast<int>(b_dims_array.size()), b_dims_array.data(), b_stride_array.data());
118+
119+
// 7. Confit Tensor out, muDNN BatchMatMul output only support row-major tensor
120+
std::array<int64_t, 3> c_dims_array = { static_cast<int64_t>(info.batch),
121+
static_cast<int64_t>(info.m),
122+
static_cast<int64_t>(info.n) };
123+
std::array<int64_t, 3> c_stride_array = { static_cast<int64_t>(info.c_matrix.stride),
124+
static_cast<int64_t>(info.c_matrix.ld()),
125+
1 };
126+
out.SetNdInfo(static_cast<int>(c_dims_array.size()), c_dims_array.data(), c_stride_array.data());
127+
128+
// 8. Workspace Memory Handler
129+
::musa::dnn::MemoryMaintainer maintainer = [](size_t size) -> ::musa::dnn::MemoryHandler {
130+
void* ptr = nullptr;
131+
musaMalloc(&ptr, size);
132+
return ::musa::dnn::MemoryHandler(ptr, [](void* p) { if(p) musaFree(p); });
133+
};
134+
135+
// 9. Tensor left and Tensor right transpose config
136+
if (info.a_matrix.col_stride == 1 && info.b_matrix.col_stride != 1)
137+
matmul_operator->SetTranspose(false, true);
138+
else if (info.a_matrix.col_stride != 1 && info.b_matrix.col_stride == 1)
139+
matmul_operator->SetTranspose(true, false);
140+
else if (info.a_matrix.col_stride != 1 && info.b_matrix.col_stride != 1)
141+
matmul_operator->SetTranspose(true, true);
142+
else
143+
matmul_operator->SetTranspose(false, false);
144+
145+
// 10. BatchMatMul workspace config
146+
size_t workspace_size_in_bytes = 0;
147+
matmul_operator->GetWorkspaceSize(mudnn_handle, workspace_size_in_bytes, out, left, right);
148+
149+
// 11. Alpha Beta Gamma
150+
matmul_operator->SetAlpha(static_cast<double>(alpha));
151+
matmul_operator->SetBeta(static_cast<double>(beta));
152+
matmul_operator->SetGamma(0.0);
153+
154+
// 12. Run
155+
matmul_operator->Run(
156+
mudnn_handle,
157+
out,
158+
left,
159+
right,
160+
static_cast<int64_t>(info.batch),
161+
static_cast<int64_t>(info.m),
162+
static_cast<int64_t>(info.n),
163+
static_cast<int64_t>(info.k),
164+
static_cast<int64_t>(info.a_matrix.ld()),
165+
static_cast<int64_t>(info.b_matrix.ld()),
166+
static_cast<int64_t>(info.c_matrix.ld()),
167+
static_cast<int64_t>(info.a_matrix.stride),
168+
static_cast<int64_t>(info.b_matrix.stride),
169+
static_cast<int64_t>(info.c_matrix.stride),
170+
maintainer
171+
);
172+
173+
return INFINI_STATUS_SUCCESS;
174+
});
175+
}
176+
177+
178+
infiniStatus_t Descriptor::calculate(void *workspace,
179+
size_t workspace_size,
180+
void *c,
181+
float beta,
182+
const void *a,
183+
const void *b,
184+
float alpha,
185+
void *stream) const {
186+
switch (_dtype) {
187+
case INFINI_DTYPE_F16:
188+
return mudnn::calculate<half>(_info, _opaque->internal, c, beta, a, b, alpha, stream);
189+
case INFINI_DTYPE_F32:
190+
return mudnn::calculate<float>(_info,_opaque->internal, c, beta, a, b, alpha, stream);
191+
case INFINI_DTYPE_BF16:
192+
return mudnn::calculate<__mt_bfloat16>(_info,_opaque->internal, c, beta, a, b, alpha, stream);
193+
default:
194+
return INFINI_STATUS_BAD_TENSOR_DTYPE;
195+
}
196+
}
197+
198+
} // namespace op::gemm::mudnn

xmake/moore.lua

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ target("infiniop-moore")
4444
add_cxflags("-lstdc++", "-fPIC", "-Wno-comment")
4545
add_files("../src/infiniop/devices/moore/*.cc")
4646
add_files("../src/infiniop/ops/*/moore/*.mu", {rule = "mu"})
47+
48+
-- Add source files for Moore muBLAS/muDNN GEMM backends.
49+
add_files("../src/infiniop/ops/gemm/moore/*/*.mu", {rule = "mu"})
4750
target_end()
4851

4952
target("infinirt-moore")

0 commit comments

Comments
 (0)