Skip to content

Commit beda029

Browse files
Chamberlain0w0YdrMaster
authored andcommitted
feat: 添加寒武纪平台erf/mod/cast/clip/gather/scatternd算子
1 parent 9171130 commit beda029

File tree

20 files changed

+951
-35
lines changed

20 files changed

+951
-35
lines changed

src/04kernel/src/collectors/cast.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "kernel/collectors/cast.h"
22
#include "../kernels/cast/cpu_kernel.hh"
33
#include "../kernels/cast/cuda_kernel.hh"
4+
#include "../kernels/cast/cnnl_kernel.hh"
45

56
namespace refactor::kernel {
67

@@ -24,6 +25,11 @@ namespace refactor::kernel {
2425
ans.emplace_back(std::move(ptr));
2526
}
2627
break;
28+
case decltype(_target)::Mlu:
29+
if (auto ptr = CastCnnl::build(from, to); ptr) {
30+
ans.emplace_back(std::move(ptr));
31+
}
32+
break;
2733
default:
2834
UNREACHABLEX(void, "Unknown target");
2935
}

src/04kernel/src/collectors/clip.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "kernel/collectors/clip.h"
22
#include "../kernels/clip/cpu_kernel.hh"
33
#include "../kernels/clip/cuda_kernel.hh"
4+
#include "../kernels/clip/cnnl_kernel.hh"
45

56
namespace refactor::kernel {
67

@@ -24,6 +25,11 @@ namespace refactor::kernel {
2425
ans.emplace_back(std::move(ptr));
2526
}
2627
break;
28+
case decltype(_target)::Mlu:
29+
if (auto ptr = ClipCnnl::build(data, hasMax); ptr) {
30+
ans.emplace_back(std::move(ptr));
31+
}
32+
break;
2733
default:
2834
UNREACHABLEX(void, "Unknown target");
2935
}

src/04kernel/src/collectors/gather.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "kernel/collectors/gather.h"
2+
#include "../kernels/gather/cnnl_kernel.hh"
23
#include "../kernels/gather/cpu_kernel.hh"
34
#include "../kernels/gather/cuda_kernel.hh"
45

@@ -20,6 +21,11 @@ namespace refactor::kernel {
2021
ans.emplace_back(std::move(ptr));
2122
}
2223
break;
24+
case decltype(_target)::Mlu:
25+
if (auto ptr = GatherCnnl::build(axis, inputs[0].get(), inputs[1].get(), outputs[0].get()); ptr != nullptr) {
26+
ans.emplace_back(std::move(ptr));
27+
}
28+
break;
2329
default:
2430
UNREACHABLEX(void, "Unknown target");
2531
}

src/04kernel/src/collectors/scatter_nd.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "kernel/collectors/scatter_nd.h"
22
#include "../kernels/scatter_nd/cpu_kernel.hh"
33
#include "../kernels/scatter_nd/cuda_kernel.hh"
4+
#include "../kernels/scatter_nd/cnnl_kernel.hh"
45

56
namespace refactor::kernel {
67

@@ -23,6 +24,11 @@ namespace refactor::kernel {
2324
ans.emplace_back(std::move(ptr));
2425
}
2526
break;
27+
case decltype(_target)::Mlu:
28+
if (auto ptr = ScatterNDCnnl::build(inputs, outputs); ptr) {
29+
ans.emplace_back(std::move(ptr));
30+
}
31+
break;
2632
default:
2733
UNREACHABLEX(void, "Unknown target");
2834
}
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
#include "cnnl_kernel.hh"
2+
3+
#ifdef USE_BANG
4+
#include "../../utilities/bang/cnnl_context.hh"
5+
#include "../../utilities/bang/cnnl_functions.h"
6+
#endif
7+
8+
9+
namespace refactor::kernel {
10+
using K = CastCnnl;
11+
using DT = DataType;
12+
13+
K::CastCnnl(decltype(from) from_,
14+
decltype(to) to_,
15+
decltype(shape) shape_) noexcept
16+
: from(from_), to(to_), shape(shape_) {}
17+
18+
auto K::build(Tensor const &from, Tensor const &to) noexcept -> KernelBox {
19+
#ifndef USE_BANG
20+
return nullptr;
21+
#endif
22+
23+
return std::make_unique<K>(from.dataType, to.dataType,
24+
std::vector<int>(from.shape.begin(), from.shape.end()));
25+
}
26+
auto K::typeId() noexcept -> size_t {
27+
static uint8_t ID = 1;
28+
return reinterpret_cast<size_t>(&ID);
29+
}
30+
31+
auto K::kernelTypeId() const noexcept -> size_t {
32+
return typeId();
33+
}
34+
auto K::description() const noexcept -> std::string_view {
35+
return "Performing cast operation using CNNL";
36+
}
37+
38+
#ifdef USE_BANG
39+
40+
static cnnlCastDataType_t castType(DataType from, DataType to);
41+
42+
auto K::lower(Resources &res) const -> RoutineWorkspace {
43+
using namespace cnnl;
44+
using namespace runtime;
45+
46+
struct Descriptors {
47+
cnnlTensorDescriptor_t inDesc, outDesc;
48+
cnnlCastDataType_t cast;
49+
50+
Descriptors() : inDesc(nullptr), outDesc(nullptr) {
51+
CNNL_ASSERT(cnnlCreateTensorDescriptor(&inDesc));
52+
CNNL_ASSERT(cnnlCreateTensorDescriptor(&outDesc));
53+
}
54+
~Descriptors() noexcept(false) {
55+
CNNL_ASSERT(cnnlDestroyTensorDescriptor(inDesc));
56+
CNNL_ASSERT(cnnlDestroyTensorDescriptor(outDesc));
57+
}
58+
};
59+
auto d = std::make_shared<Descriptors>();
60+
d->cast = castType(from, to);
61+
setCnnlTensor(d->inDesc, from, slice(shape.data(), shape.size()));
62+
setCnnlTensor(d->outDesc, to, slice(shape.data(), shape.size()));
63+
64+
res.fetchOrStore<CnnlContext>();
65+
return [d = std::move(d)](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) {
66+
CNNL_ASSERT(cnnlCastDataType(res.fetchOrStore<CnnlContext>()->handle,
67+
d->inDesc, inputs[0], d->cast, d->outDesc, outputs[0]));
68+
// BANG_ASSERT(cnrtQueueSync(res.fetchOrStore<CnnlContext>()->queue));
69+
};
70+
}
71+
72+
static cnnlCastDataType_t castType(DataType from, DataType to) {
73+
switch (from) {
74+
case DT::F32:
75+
switch (to) {
76+
case DT::F64:
77+
return CNNL_CAST_FLOAT_TO_DOUBLE;
78+
case DT::FP16:
79+
return CNNL_CAST_FLOAT_TO_HALF;
80+
case DT::I64:
81+
return CNNL_CAST_FLOAT_TO_INT64;
82+
case DT::I32:
83+
return CNNL_CAST_FLOAT_TO_INT32;
84+
case DT::I16:
85+
return CNNL_CAST_FLOAT_TO_INT16;
86+
case DT::I8:
87+
return CNNL_CAST_FLOAT_TO_INT8;
88+
case DT::U8:
89+
return CNNL_CAST_FLOAT_TO_UINT8;
90+
// case DT::BF16:
91+
// return CNNL_CAST_FLOAT_TO_BFLOAT16;
92+
case DT::Bool:
93+
return CNNL_CAST_FLOAT_TO_BOOL;
94+
default:
95+
UNREACHABLE();
96+
}
97+
case DT::FP16:
98+
switch (to) {
99+
case DT::F32:
100+
return CNNL_CAST_HALF_TO_FLOAT;
101+
case DT::I64:
102+
return CNNL_CAST_HALF_TO_INT64;
103+
case DT::I32:
104+
return CNNL_CAST_HALF_TO_INT32;
105+
case DT::I16:
106+
return CNNL_CAST_HALF_TO_INT16;
107+
case DT::I8:
108+
return CNNL_CAST_HALF_TO_INT8;
109+
case DT::U8:
110+
return CNNL_CAST_HALF_TO_UINT8;
111+
case DT::Bool:
112+
return CNNL_CAST_HALF_TO_BOOL;
113+
default:
114+
UNREACHABLE();
115+
}
116+
case DT::I32:
117+
switch (to) {
118+
case DT::F32:
119+
return CNNL_CAST_INT32_TO_FLOAT;
120+
case DT::FP16:
121+
return CNNL_CAST_INT32_TO_HALF;
122+
case DT::I64:
123+
return CNNL_CAST_INT32_TO_INT64;
124+
case DT::I16:
125+
return CNNL_CAST_INT32_TO_INT16;
126+
case DT::I8:
127+
return CNNL_CAST_INT32_TO_INT8;
128+
case DT::Bool:
129+
return CNNL_CAST_INT32_TO_BOOL;
130+
default:
131+
UNREACHABLE();
132+
}
133+
case DT::I16:
134+
switch (to) {
135+
case DT::F32:
136+
return CNNL_CAST_INT16_TO_FLOAT;
137+
case DT::FP16:
138+
return CNNL_CAST_INT16_TO_HALF;
139+
case DT::I32:
140+
return CNNL_CAST_INT16_TO_INT32;
141+
// case DT::I8:
142+
// return CNNL_CAST_INT16_TO_INT8;
143+
default:
144+
UNREACHABLE();
145+
}
146+
case DT::I8:
147+
switch (to) {
148+
case DT::F32:
149+
return CNNL_CAST_INT8_TO_FLOAT;
150+
case DT::FP16:
151+
return CNNL_CAST_INT8_TO_HALF;
152+
case DT::I32:
153+
return CNNL_CAST_INT8_TO_INT32;
154+
case DT::I16:
155+
return CNNL_CAST_INT8_TO_INT16;
156+
default:
157+
UNREACHABLE();
158+
}
159+
case DT::U8:
160+
switch (to) {
161+
case DT::F32:
162+
return CNNL_CAST_UINT8_TO_FLOAT;
163+
case DT::FP16:
164+
return CNNL_CAST_UINT8_TO_HALF;
165+
case DT::I64:
166+
return CNNL_CAST_UINT8_TO_INT64;
167+
case DT::I32:
168+
return CNNL_CAST_UINT8_TO_INT32;
169+
default:
170+
UNREACHABLE();
171+
}
172+
case DT::Bool:
173+
switch (to) {
174+
case DT::F32:
175+
return CNNL_CAST_BOOL_TO_FLOAT;
176+
case DT::FP16:
177+
return CNNL_CAST_BOOL_TO_HALF;
178+
case DT::I32:
179+
return CNNL_CAST_BOOL_TO_INT32;
180+
default:
181+
UNREACHABLE();
182+
}
183+
case DT::I64:
184+
switch (to) {
185+
case DT::F32:
186+
return CNNL_CAST_INT64_TO_FLOAT;
187+
case DT::FP16:
188+
return CNNL_CAST_INT64_TO_HALF;
189+
case DT::I32:
190+
return CNNL_CAST_INT64_TO_INT32;
191+
case DT::U32:
192+
return CNNL_CAST_INT64_TO_UINT32;
193+
default:
194+
UNREACHABLE();
195+
}
196+
case DT::U32:
197+
switch (to) {
198+
case DT::I64:
199+
return CNNL_CAST_UINT32_TO_INT64;
200+
case DT::U64:
201+
return CNNL_CAST_UINT32_TO_UINT64;
202+
default:
203+
UNREACHABLE();
204+
}
205+
case DT::F64:
206+
switch (to) {
207+
case DT::F32:
208+
return CNNL_CAST_DOUBLE_TO_FLOAT;
209+
default:
210+
UNREACHABLE();
211+
}
212+
case DT::BF16:
213+
switch (to) {
214+
// case DT::F32:
215+
// return CNNL_CAST_BF16_TO_FLOAT;
216+
default:
217+
UNREACHABLE();
218+
}
219+
default:
220+
UNREACHABLE();
221+
}
222+
}
223+
224+
#endif
225+
226+
}// namespace refactor::kernel
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#ifndef KERNEL_CAST_CNNL_KERNEL_HH
2+
#define KERNEL_CAST_CNNL_KERNEL_HH
3+
4+
#include "kernel/kernel.h"
5+
#include "kernel/tensor.h"
6+
7+
namespace refactor::kernel {
8+
9+
struct CastCnnl final : public Kernel {
10+
DataType from, to;
11+
std::vector<int> shape;
12+
13+
CastCnnl(decltype(from), decltype(to), decltype(shape)) noexcept;
14+
15+
static KernelBox build(Tensor const &, Tensor const &) noexcept;
16+
static size_t typeId() noexcept;
17+
18+
size_t kernelTypeId() const noexcept final;
19+
std::string_view description() const noexcept final;
20+
#ifdef USE_BANG
21+
RoutineWorkspace lower(Resources &) const final;
22+
#endif
23+
};
24+
25+
}// namespace refactor::kernel
26+
27+
#endif// KERNEL_CAST_CNNL_KERNEL_HH
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
#include "cnnl_kernel.hh"
2+
3+
#ifdef USE_BANG
4+
#include "../../utilities/bang/cnnl_context.hh"
5+
#include "../../utilities/bang/cnnl_functions.h"
6+
#endif
7+
8+
namespace refactor::kernel {
9+
using K = ClipCnnl;
10+
11+
K::ClipCnnl(decltype(dataType) dt,
12+
decltype(shape) shape_,
13+
decltype(hasMax) hasMax_) noexcept
14+
: dataType(dt), shape(shape_), hasMax(hasMax_) {
15+
}
16+
17+
auto K::build(Tensor const &data, bool hasMax) noexcept -> KernelBox {
18+
return data.dataType.isCpuNumberic()
19+
? std::make_unique<K>(data.dataType,
20+
std::vector<int>(data.shape.begin(), data.shape.end()),
21+
hasMax)
22+
: nullptr;
23+
}
24+
auto K::typeId() noexcept -> size_t {
25+
static uint8_t ID = 1;
26+
return reinterpret_cast<size_t>(&ID);
27+
}
28+
29+
auto K::kernelTypeId() const noexcept -> size_t {
30+
return typeId();
31+
}
32+
auto K::description() const noexcept -> std::string_view {
33+
return "Performing clip operation using CNNL";
34+
}
35+
36+
#ifdef USE_BANG
37+
auto K::lower(Resources &res) const noexcept -> RoutineWorkspace {
38+
using namespace cnnl;
39+
using namespace runtime;
40+
41+
struct Descriptors {
42+
cnnlTensorDescriptor_t t;
43+
44+
Descriptors() : t(nullptr) {
45+
CNNL_ASSERT(cnnlCreateTensorDescriptor(&t));
46+
}
47+
~Descriptors() noexcept(false) {
48+
CNNL_ASSERT(cnnlDestroyTensorDescriptor(t));
49+
}
50+
};
51+
auto d = std::make_shared<Descriptors>();
52+
setCnnlTensor(d->t, dataType, slice(shape.data(), shape.size()));
53+
54+
res.fetchOrStore<CnnlContext>();
55+
return [d = std::move(d), hasMax = this->hasMax](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) {
56+
CNNL_ASSERT(cnnlClip_v2(res.fetchOrStore<CnnlContext>()->handle,
57+
CNNL_POINTER_MODE_DEVICE, d->t,
58+
inputs[0], inputs[1], hasMax ? inputs[2] : nullptr,
59+
d->t, outputs[0]));
60+
BANG_ASSERT(cnrtQueueSync(res.fetchOrStore<CnnlContext>()->queue));
61+
};
62+
}
63+
64+
#endif
65+
66+
}// namespace refactor::kernel

0 commit comments

Comments
 (0)