Skip to content

Commit 9171130

Browse files
Chamberlain0w0YdrMaster
authored andcommitted
feat: 添加寒武纪平台split/concat/slice/matmul算子,并merge from master
1 parent e329552 commit 9171130

File tree

17 files changed

+1020
-1
lines changed

17 files changed

+1020
-1
lines changed

src/04kernel/src/collectors/concat.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "kernel/collectors/concat.h"
22
#include "../kernels/concat/cpu_kernel.hh"
33
#include "../kernels/concat/cuda_kernel.hh"
4+
#include "../kernels/concat/cnnl_kernel.hh"
45

56
namespace refactor::kernel {
67

@@ -20,6 +21,11 @@ namespace refactor::kernel {
2021
ans.emplace_back(std::move(ptr));
2122
}
2223
break;
24+
case decltype(_target)::Mlu:
25+
if (auto ptr = ConcatCnnl::build(axis, inputs, outputs[0].get()); ptr) {
26+
ans.emplace_back(std::move(ptr));
27+
}
28+
break;
2329
default:
2430
UNREACHABLEX(void, "Unknown target");
2531
}

src/04kernel/src/collectors/mat_mul.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "kernel/collectors/mat_mul.h"
2+
#include "../kernels/mat_mul/cnnl_kernel.hh"
23
#include "../kernels/mat_mul/cpu_kernel.hh"
34
#include "../kernels/mat_mul/cublas_kernel.hh"
45
#include "kernel/attributes/mat_mul_info.h"
@@ -26,6 +27,11 @@ namespace refactor::kernel {
2627
case decltype(_target)::Nvidia:
2728
REGISTER(MatMulCublas)
2829
break;
30+
case decltype(_target)::Mlu:
31+
if (auto ptr = MatMulCnnl::build(inputs, outputs, transA, transB, alpha, beta); ptr) {
32+
ans.emplace_back(std::move(ptr));
33+
}
34+
break;
2935
default:
3036
UNREACHABLEX(void, "Unknown target");
3137
}

src/04kernel/src/collectors/slice.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "kernel/collectors/slice.h"
22
#include "../kernels/slice/cpu_kernel.hh"
33
#include "../kernels/slice/cuda_kernel.hh"
4+
#include "../kernels/slice/cnnl_kernel.hh"
45

56
namespace refactor::kernel {
67

@@ -26,6 +27,11 @@ namespace refactor::kernel {
2627
ans.emplace_back(std::move(ptr));
2728
}
2829
break;
30+
case decltype(_target)::Mlu:
31+
if (auto ptr = SliceCnnl::build(inputs[0].get().dataType, dimentions, inputs[0].get().shape, outputs[0].get().shape); ptr) {
32+
ans.emplace_back(std::move(ptr));
33+
}
34+
break;
2935
default:
3036
UNREACHABLEX(void, "Unknown target");
3137
}

src/04kernel/src/collectors/split.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "kernel/collectors/split.h"
2+
#include "../kernels/split/cnnl_kernel.hh"
23
#include "../kernels/split/cpu_kernel.hh"
34
#include "../kernels/split/cuda_kernel.hh"
45

@@ -20,6 +21,11 @@ namespace refactor::kernel {
2021
ans.emplace_back(std::move(ptr));
2122
}
2223
break;
24+
case decltype(_target)::Mlu:
25+
if (auto ptr = SplitCnnl::build(axis, inputs[0].get(), outputs); ptr) {
26+
ans.emplace_back(std::move(ptr));
27+
}
28+
break;
2329
default:
2430
UNREACHABLEX(void, "Unknown target");
2531
}
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
#include "cnnl_kernel.hh"
2+
3+
#ifdef USE_BANG
4+
#include "../../utilities/bang/cnnl_context.hh"
5+
#include "../../utilities/bang/cnnl_functions.h"
6+
#include <cnnl.h>
7+
#endif
8+
9+
namespace refactor::kernel {
10+
using K = ConcatCnnl;
11+
12+
K::ConcatCnnl(SplitInfoCnnl info_) noexcept
13+
: Kernel(), info(std::move(info_)) {}
14+
15+
auto K::build(int axis, TensorRefs inputs, Tensor output) noexcept -> KernelBox {
16+
#ifndef USE_BANG
17+
return nullptr;
18+
#endif
19+
return std::make_unique<K>(SplitInfoCnnl(axis, output, inputs));
20+
}
21+
auto K::typeId() noexcept -> size_t {
22+
static uint8_t ID = 1;
23+
return reinterpret_cast<size_t>(&ID);
24+
}
25+
26+
auto K::kernelTypeId() const noexcept -> size_t {
27+
return typeId();
28+
}
29+
auto K::description() const noexcept -> std::string_view {
30+
return "Performing split operation using CNNL";
31+
}
32+
33+
#ifdef USE_BANG
34+
auto ConcatCnnl::lower(Resources &res) const -> RoutineWorkspace {
35+
using namespace cnnl;
36+
using namespace runtime;
37+
using DT = DataType;
38+
39+
struct Descriptors {
40+
cnnlTensorDescriptor_t in;
41+
std::vector<cnnlTensorDescriptor_t> out;
42+
bool f32;
43+
44+
explicit Descriptors(int n, decltype(f32) f32_)
45+
: in(nullptr),
46+
out(std::vector<cnnlTensorDescriptor_t>(n, nullptr)),
47+
f32(f32_) {
48+
CNNL_ASSERT(cnnlCreateTensorDescriptor(&in));
49+
for (auto i = 0; i < n; i++) {
50+
CNNL_ASSERT(cnnlCreateTensorDescriptor(&out[i]));
51+
}
52+
}
53+
~Descriptors() noexcept(false) {
54+
CNNL_ASSERT(cnnlDestroyTensorDescriptor(in));
55+
for (auto i = 0; i < out.size(); i++) {
56+
CNNL_ASSERT(cnnlDestroyTensorDescriptor(out[i]));
57+
}
58+
}
59+
60+
Descriptors(const Descriptors &) = delete;
61+
Descriptors(Descriptors &&) = delete;
62+
};
63+
auto d = std::make_shared<Descriptors>(info.num, info.dataType != DT::F64);
64+
setCnnlTensor(d->in, info.dataType, slice(info.inDim.data(), info.inDim.size()));
65+
for (auto i = 0; i < info.outDims.size(); i++) {
66+
setCnnlTensor(d->out[i], info.dataType, slice(info.outDims[i].data(), info.outDims[i].size()));
67+
}
68+
69+
auto handle = res.fetchOrStore<CnnlContext>()->handle;
70+
size_t workspaceSize;
71+
CNNL_ASSERT(cnnlGetSplitWorkspaceSize(handle, info.num, &workspaceSize));
72+
73+
res.fetchOrStore<CnnlContext>();
74+
auto routine = [d = std::move(d), n = info.num, axis = info.axis, workspaceSize](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) {
75+
// fetch cnnl handle from resources
76+
auto handle = res.fetchOrStore<CnnlContext>()->handle;
77+
78+
const void *argv[n];
79+
for (auto i = 0; i < n; i++) {
80+
argv[i] = inputs[i];
81+
}
82+
83+
CNNL_ASSERT(cnnlConcat(
84+
handle, n, axis, d->out.data(), argv,
85+
workspace, workspaceSize, d->in, outputs[0]));
86+
};
87+
88+
return {std::move(routine), workspaceSize};
89+
}
90+
91+
#endif
92+
93+
}// namespace refactor::kernel
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#ifndef KERNEL_CONCAT_CNNL_KERNEL_HH
2+
#define KERNEL_CONCAT_CNNL_KERNEL_HH
3+
4+
#include "../../kernels/split/cnnl_kernel.hh"
5+
#include "kernel/kernel.h"
6+
7+
namespace refactor::kernel {
8+
9+
struct ConcatCnnl final : public Kernel {
10+
SplitInfoCnnl info;
11+
12+
explicit ConcatCnnl(SplitInfoCnnl) noexcept;
13+
14+
static KernelBox build(int, TensorRefs, Tensor) noexcept;
15+
static size_t typeId() noexcept;
16+
17+
size_t kernelTypeId() const noexcept final;
18+
std::string_view description() const noexcept final;
19+
#ifdef USE_BANG
20+
RoutineWorkspace lower(Resources &) const final;
21+
#endif
22+
};
23+
24+
}// namespace refactor::kernel
25+
26+
#endif// KERNEL_CONCAT_CNNL_KERNEL_HH
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
#include "cnnl_kernel.hh"
2+
#include <numeric>
3+
4+
#ifdef USE_BANG
5+
#include "../../utilities/bang/cnnl_context.hh"
6+
#include "../../utilities/bang/cnnl_functions.h"
7+
#include <cnnl.h>
8+
#endif
9+
10+
namespace refactor::kernel {
11+
using K = MatMulCnnl;
12+
using DT = DataType;
13+
14+
K::MatMulCnnl(decltype(info) info_) noexcept
15+
: Kernel(), info(std::move(info_)) {}
16+
17+
auto K::build(TensorRefs inputs_, TensorRefs outputs_, bool transA_, bool transB_, float alpha_, float beta_) noexcept -> KernelBox {
18+
#ifndef USE_BANG
19+
return nullptr;
20+
#endif
21+
auto dt = inputs_[0].get().dataType;
22+
return dt.isIeee754() || dt == DT::I8
23+
? std::make_unique<K>(decltype(info){
24+
dt,
25+
transA_,
26+
transB_,
27+
alpha_,
28+
beta_,
29+
std::vector<int>(inputs_[0].get().shape.begin(), inputs_[0].get().shape.end()),
30+
std::vector<int>(inputs_[1].get().shape.begin(), inputs_[1].get().shape.end()),
31+
std::vector<int>(outputs_[0].get().shape.begin(), outputs_[0].get().shape.end()),
32+
inputs_.size() == 3
33+
? inputs_[2].get().shape.size() == 0 ? std::make_optional(std::vector<int>(1, 1))
34+
: std::make_optional(std::vector<int>(
35+
inputs_[2].get().shape.begin(),
36+
inputs_[2].get().shape.end()))
37+
: std::nullopt,
38+
})
39+
: nullptr;
40+
}
41+
42+
auto K::typeId() noexcept -> size_t {
43+
static uint8_t ID = 1;
44+
return reinterpret_cast<size_t>(&ID);
45+
}
46+
47+
auto K::kernelTypeId() const noexcept -> size_t { return typeId(); }
48+
auto K::description() const noexcept -> std::string_view {
49+
return "Performing MatMul using CNNL";
50+
}
51+
52+
53+
#ifdef USE_BANG
54+
auto K::lower(Resources &res) const noexcept -> RoutineWorkspace {
55+
using namespace cnnl;
56+
using namespace runtime;
57+
using DT = DataType;
58+
59+
// RAII for closure
60+
struct Descriptors {
61+
cnnlTensorDescriptor_t a, b, c;
62+
cnnlMatMulDescriptor_t bmm;
63+
cnnlMatMulAlgo_t algo;
64+
cnnlMatMulHeuristicResult_t heuristic;
65+
cnnlTensorDescriptor_t bias;
66+
bool addBias, f32;
67+
68+
explicit Descriptors(bool addBias_, bool f32_)
69+
: a(nullptr), b(nullptr), c(nullptr),
70+
bmm(nullptr), algo(nullptr), heuristic(nullptr),
71+
bias(nullptr), addBias(addBias_), f32(f32_) {
72+
CNNL_ASSERT(cnnlCreateTensorDescriptor(&a));
73+
CNNL_ASSERT(cnnlCreateTensorDescriptor(&b));
74+
CNNL_ASSERT(cnnlCreateTensorDescriptor(&c));
75+
if (addBias) {
76+
CNNL_ASSERT(cnnlCreateTensorDescriptor(&bias));
77+
}
78+
CNNL_ASSERT(cnnlMatMulDescCreate(&bmm));
79+
CNNL_ASSERT(cnnlMatMulAlgoCreate(&algo));
80+
CNNL_ASSERT(cnnlCreateMatMulHeuristicResult(&heuristic));
81+
}
82+
~Descriptors() noexcept(false) {
83+
CNNL_ASSERT(cnnlDestroyTensorDescriptor(a));
84+
CNNL_ASSERT(cnnlDestroyTensorDescriptor(b));
85+
CNNL_ASSERT(cnnlDestroyTensorDescriptor(c));
86+
if (addBias) {
87+
CNNL_ASSERT(cnnlDestroyTensorDescriptor(bias));
88+
}
89+
CNNL_ASSERT(cnnlMatMulDescDestroy(bmm));
90+
CNNL_ASSERT(cnnlMatMulAlgoDestroy(algo));
91+
CNNL_ASSERT(cnnlDestroyMatMulHeuristicResult(heuristic));
92+
}
93+
94+
Descriptors(const Descriptors &) = delete;
95+
Descriptors(Descriptors &&) = delete;
96+
};
97+
auto d = std::make_shared<Descriptors>(info.biasDim.has_value(), info.dataType != DT::F64);
98+
setCnnlTensor(d->a, info.dataType, slice(info.aDim.data(), info.aDim.size()));
99+
setCnnlTensor(d->b, info.dataType, slice(info.bDim.data(), info.bDim.size()));
100+
setCnnlTensor(d->c, info.dataType, slice(info.cDim.data(), info.cDim.size()));
101+
if (d->addBias) {
102+
CNNL_ASSERT(cnnlSetTensorDescriptor(
103+
d->bias, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(info.dataType),
104+
info.biasDim.value().size(), info.biasDim.value().data()));
105+
}
106+
int32_t tA = info.transA, tB = info.transB;
107+
CNNL_ASSERT(cnnlSetMatMulDescAttr(d->bmm, CNNL_MATMUL_DESC_TRANSA,
108+
&tA, sizeof(int32_t)));
109+
CNNL_ASSERT(cnnlSetMatMulDescAttr(d->bmm, CNNL_MATMUL_DESC_TRANSB,
110+
&tB, sizeof(int32_t)));
111+
auto handle = res.fetchOrStore<CnnlContext>()->handle;
112+
int returnedAlgoCount = 0;
113+
CNNL_ASSERT(cnnlGetBatchMatMulAlgoHeuristic(
114+
handle, d->bmm, d->a, d->b, d->c,
115+
NULL, 1, &(d->heuristic), &returnedAlgoCount));
116+
117+
size_t algoWorkspaceSize;
118+
CNNL_ASSERT(cnnlGetBatchMatMulHeuristicResult(d->heuristic, d->algo, &algoWorkspaceSize));
119+
120+
res.fetchOrStore<CnnlContext>();
121+
auto routine = [d = std::move(d), algoWorkspaceSize,
122+
aa = info.alpha, bb = info.beta](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) {
123+
// fetch cnnl handle from resources
124+
auto handle = res.fetchOrStore<CnnlContext>()->handle;
125+
126+
// build alpha/beta for double
127+
auto alpha = d->f32 ? factor<fp32_t>(aa) : factor<fp64_t>(aa),
128+
beta = d->f32 ? factor<fp32_t>(bb) : factor<fp64_t>(bb),
129+
// one = d->f32 ? factor<fp32_t>(1) : factor<fp64_t>(1),
130+
zero = d->f32 ? factor<fp32_t>(0) : factor<fp64_t>(0);
131+
132+
if (d->addBias) {
133+
CNNL_ASSERT(cnnlExpand(handle, d->bias, inputs[2], d->c, outputs[0]));
134+
}
135+
136+
if (alpha != 0) {
137+
CNNL_ASSERT(cnnlBatchMatMulBCast_v2(
138+
handle, d->bmm, d->algo, &alpha,
139+
d->a, inputs[0], d->b, inputs[1],
140+
d->addBias ? &beta : &zero, d->c, outputs[0],
141+
workspace, algoWorkspaceSize));
142+
}
143+
144+
BANG_ASSERT(cnrtQueueSync(res.fetchOrStore<CnnlContext>()->queue));
145+
};
146+
147+
return {std::move(routine), algoWorkspaceSize};
148+
}
149+
150+
151+
#endif
152+
153+
}// namespace refactor::kernel
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#ifndef KERNEL_MATMUL_CNNL_KERNEL_HH
2+
#define KERNEL_MATMUL_CNNL_KERNEL_HH
3+
4+
#include "kernel/kernel.h"
5+
#include "kernel/tensor.h"
6+
7+
namespace refactor::kernel {
8+
9+
struct MatMulCnnl final : public Kernel {
10+
struct {
11+
DataType dataType;
12+
bool transA, transB;
13+
float alpha, beta;
14+
std::vector<int> aDim, bDim, cDim;
15+
std::optional<std::vector<int>> biasDim;
16+
} info;
17+
18+
explicit MatMulCnnl(decltype(info)) noexcept;
19+
20+
static KernelBox build(TensorRefs, TensorRefs, bool, bool, float, float) noexcept;
21+
static size_t typeId() noexcept;
22+
23+
size_t kernelTypeId() const noexcept final;
24+
std::string_view description() const noexcept final;
25+
#ifdef USE_BANG
26+
RoutineWorkspace lower(Resources &) const noexcept final;
27+
#endif
28+
};
29+
30+
}// namespace refactor::kernel
31+
32+
#endif// KERNEL_MATMUL_CNNL_KERNEL_HH

0 commit comments

Comments
 (0)