Skip to content

Commit 8bde8c1

Browse files
Chamberlain0w0YdrMaster
authored andcommitted
feat: mlu上跑通llama/gpt2,结果正确
1 parent 5aa7a1e commit 8bde8c1

File tree

13 files changed

+181
-61
lines changed

13 files changed

+181
-61
lines changed

src/04kernel/src/collectors/global_pool.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "kernel/collectors/global_pool.h"
22
#include "../kernels/pool/cudnn_kernel.hh"
3+
#include "../kernels/pool/cnnl_kernel.hh"
34

45
namespace refactor::kernel {
56

@@ -28,6 +29,11 @@ namespace refactor::kernel {
2829
ans.emplace_back(std::move(ptr));
2930
}
3031
break;
32+
case decltype(_target)::Mlu:
33+
if (auto ptr = PoolCnnl::build(type, false, kernelShape, attributes, x, y); ptr) {
34+
ans.emplace_back(std::move(ptr));
35+
}
36+
break;
3137
default:
3238
UNREACHABLEX(void, "Unknown target");
3339
}

src/04kernel/src/kernels/gather/cnnl_kernel.cc

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "../../utilities/bang/cnnl_context.hh"
55
#include "../../utilities/bang/cnnl_functions.h"
66
#endif
7+
#include <iostream>
78

89
namespace refactor::kernel {
910
using K = GatherCnnl;
@@ -15,11 +16,11 @@ namespace refactor::kernel {
1516
#ifndef USE_BANG
1617
return nullptr;
1718
#endif
18-
19+
1920
return std::make_unique<K>(decltype(info){
2021
input.dataType,
2122
DataType::I32,
22-
axis,
23+
axis ? axis : 0,
2324
std::vector<int>(input.shape.begin(), input.shape.end()),
2425
std::vector<int>(index.shape.begin(), index.shape.end()),
2526
std::vector<int>(output.shape.begin(), output.shape.end()),
@@ -70,15 +71,16 @@ namespace refactor::kernel {
7071

7172
res.fetchOrStore<CnnlContext>();
7273
auto routine = [d = std::move(d),
73-
shape = info.inDim.data(), workspaceSize,
74+
shape = std::vector<int>(info.inDim.begin(), info.inDim.end()),
75+
workspaceSize,
7476
dim = info.axis](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) {
75-
BANG_ASSERT(cnrtMemcpy(workspace, (void*) shape, workspaceSize, CNRT_MEM_TRANS_DIR_HOST2DEV));
77+
res.fetchOrStore<CnnlContext>()->copyFromCPU(workspace, shape.data(), workspaceSize);
7678
CNNL_ASSERT(cnnlGatherV2(res.fetchOrStore<CnnlContext>()->handle, dim,
7779
d->inDesc, inputs[0], reinterpret_cast<const int *>(workspace),
78-
d->indexDesc, reinterpret_cast<const int *>(inputs[1]),
80+
d->indexDesc, reinterpret_cast<const int *>(inputs[1]),
7981
d->outDesc, outputs[0]));
8082
BANG_ASSERT(cnrtQueueSync(res.fetchOrStore<CnnlContext>()->queue));
81-
};
83+
};
8284

8385
return {std::move(routine), workspaceSize};
8486
}

src/04kernel/src/kernels/reduce/cnnl_kernel.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,14 +71,15 @@ namespace refactor::kernel {
7171

7272
std::vector<int>
7373
dimsI(shape.begin(), shape.end()),
74-
dimsO(shape.begin(), shape.end());
74+
dimsO(shape.begin(), shape.end()),
75+
indices(axes.begin(), axes.end());
7576
for (auto axis : axes) {
7677
dimsO[axis] = 1;
7778
}
7879
// setCnnlTensor(d->x, dataType, slice(dimsI.data(), dimsI.size()));
7980
// setCnnlTensor(d->y, dataType, slice(dimsO.data(), dimsO.size()));
80-
CNNL_ASSERT(cnnlSetTensorDescriptor(d->x, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(dataType), dimsI.size(), dimsI.data()));
81-
CNNL_ASSERT(cnnlSetTensorDescriptor(d->y, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(dataType), dimsO.size(), dimsO.data()));
81+
CNNL_ASSERT(cnnlSetTensorDescriptor(d->x, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(dataType), dimsI.size(), dimsI.data()));
82+
CNNL_ASSERT(cnnlSetTensorDescriptor(d->y, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(dataType), dimsO.size(), dimsO.data()));
8283

8384
// clang-format off
8485
auto reduceOp = reduceType == ReduceType::Mean ? CNNL_REDUCE_AVG
@@ -91,12 +92,12 @@ namespace refactor::kernel {
9192
: UNREACHABLEX(cnnlReduceOp_t, "");
9293
// clang-format on
9394
CNNL_ASSERT(cnnlSetReduceDescriptor_v2(
94-
d->reduce, (int *) (axes.data()), axes.size(), reduceOp,
95+
d->reduce, indices.data(), indices.size(), reduceOp,
9596
cnnlDataTypeConvert(d->f32 ? DataType::F32 : DataType::F64),
9697
CNNL_NOT_PROPAGATE_NAN, CNNL_REDUCE_NO_INDICES, CNNL_32BIT_INDICES, 0.0));
9798

9899
auto handler = res.fetchOrStore<CnnlContext>()->handle;
99-
size_t idxWorkspaceSize = axes.size() * sizeof(int);
100+
size_t idxWorkspaceSize = indices.size() * sizeof(int);
100101
// idxWorkspaceSize = hardware::alignBytes(idxWorkspaceSize, 256);
101102
size_t workspaceSize;
102103
// get workspace

src/04kernel/src/kernels/softmax/cnnl_kernel.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,11 @@ namespace refactor::kernel {
5959
static_cast<cnnlSoftmaxAlgorithm_t>(algo),
6060
dataType != DataType::F64);
6161
int dims[]{pre, mid, post};
62-
cnnlSoftmaxMode_t mode = (post == 1) ? CNNL_SOFTMAX_MODE_HIGH_DIMENSION
63-
: (pre == 1) ? CNNL_SOFTMAX_MODE_LOW_DIMENSION
64-
: CNNL_SOFTMAX_MODE_MEDIUM_DIMENSION;
62+
// cnnlSoftmaxMode_t mode = (pre == 1) ? CNNL_SOFTMAX_MODE_HIGH_DIMENSION
63+
// : (post == 1) ? CNNL_SOFTMAX_MODE_LOW_DIMENSION
64+
// : CNNL_SOFTMAX_MODE_MEDIUM_DIMENSION;
65+
// FIXME(bolun): CNNL Softmax mode
66+
cnnlSoftmaxMode_t mode = CNNL_SOFTMAX_MODE_MEDIUM_DIMENSION;
6567

6668
// cnnlSoftmaxForward_v2 is applied to a 3D input tensor only
6769
CNNL_ASSERT(cnnlSetTensorDescriptor(d->t, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(dataType), 3, dims));
@@ -78,6 +80,7 @@ namespace refactor::kernel {
7880
CNNL_COMPUTATION_ULTRAHIGH_PRECISION,
7981
&a, d->t, inputs[0],
8082
&b, d->t, outputs[0]));
83+
res.fetchOrStore<CnnlContext>()->queueSync();
8184
};
8285
}
8386

src/04kernel/src/kernels/where/cnnl_kernel.cc

Lines changed: 39 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,24 @@ namespace refactor::kernel {
1616
#ifndef USE_BANG
1717
return nullptr;
1818
#endif
19-
return std::make_unique<K>(decltype(info) {
20-
inputs[1].get().dataType,
21-
inputs[0].get().shape,
22-
inputs[1].get().shape,
23-
inputs[2].get().shape,
24-
outputs[0].get().shape,
25-
});
19+
std::vector<int> cDim(inputs[0].get().shape.begin(), inputs[0].get().shape.end()),
20+
xDim(inputs[1].get().shape.begin(), inputs[1].get().shape.end()),
21+
yDim(inputs[2].get().shape.begin(), inputs[2].get().shape.end()),
22+
ansDim(outputs[0].get().shape.begin(), outputs[0].get().shape.end());
23+
if (ansDim.size() == 0) {
24+
ansDim.push_back(1);
25+
}
26+
if (xDim.size() == 0) {
27+
xDim.push_back(1);
28+
}
29+
if (yDim.size() == 0) {
30+
yDim.push_back(1);
31+
}
32+
if (cDim.size() == 0) {
33+
cDim.push_back(1);
34+
}
35+
return std::make_unique<K>(decltype(info){
36+
inputs[1].get().dataType, cDim, xDim, yDim, ansDim});
2637
}
2738
auto K::typeId() noexcept -> size_t {
2839
static uint8_t ID = 1;
@@ -44,11 +55,10 @@ namespace refactor::kernel {
4455

4556
struct Descriptors {
4657
cnnlTensorDescriptor_t cond, x, y, ans;
47-
bool f32;
4858

49-
explicit Descriptors(decltype(f32) f32_)
59+
explicit Descriptors()
5060
: cond(nullptr), x(nullptr), y(nullptr),
51-
ans(nullptr), f32(f32_) {
61+
ans(nullptr) {
5262
CNNL_ASSERT(cnnlCreateTensorDescriptor(&cond));
5363
CNNL_ASSERT(cnnlCreateTensorDescriptor(&x));
5464
CNNL_ASSERT(cnnlCreateTensorDescriptor(&y));
@@ -64,49 +74,35 @@ namespace refactor::kernel {
6474
Descriptors(const Descriptors &) = delete;
6575
Descriptors(Descriptors &&) = delete;
6676
};
67-
auto d = std::make_shared<Descriptors>(info.dataType != DT::F64);
68-
69-
std::vector<int> cDim(info.condDim.begin(), info.condDim.end()),
70-
xDim(info.thenDim.begin(), info.thenDim.end()),
71-
yDim(info.elseDim.begin(), info.elseDim.end()),
72-
ansDim(info.outputDim.begin(), info.outputDim.end());
73-
74-
auto rightAlign = [](std::vector<int> &dim, uint32_t targetLength) {
75-
if (dim.size() < targetLength) {
76-
dim.insert(dim.begin(), targetLength - dim.size(), 1);
77-
}
78-
};
79-
if (ansDim.size() == 0) {
80-
ansDim.push_back(1);
81-
}
82-
rightAlign(cDim, ansDim.size());
83-
rightAlign(xDim, ansDim.size());
84-
rightAlign(yDim, ansDim.size());
85-
86-
CNNL_ASSERT(cnnlSetTensorDescriptor(d->cond, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(DT::Bool), cDim.size(), cDim.data()));
87-
CNNL_ASSERT(cnnlSetTensorDescriptor(d->x, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(info.dataType), xDim.size(), xDim.data()));
88-
CNNL_ASSERT(cnnlSetTensorDescriptor(d->y, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(info.dataType), yDim.size(), yDim.data()));
89-
CNNL_ASSERT(cnnlSetTensorDescriptor(d->ans, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(info.dataType), ansDim.size(), ansDim.data()));
77+
auto d = std::make_shared<Descriptors>();
78+
79+
CNNL_ASSERT(cnnlSetTensorDescriptor(
80+
d->cond, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(DT::Bool),
81+
info.condDim.size(), info.condDim.data()));
82+
CNNL_ASSERT(cnnlSetTensorDescriptor(
83+
d->x, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(info.dataType),
84+
info.thenDim.size(), info.thenDim.data()));
85+
CNNL_ASSERT(cnnlSetTensorDescriptor(
86+
d->y, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(info.dataType),
87+
info.elseDim.size(), info.elseDim.data()));
88+
CNNL_ASSERT(cnnlSetTensorDescriptor(
89+
d->ans, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(info.dataType),
90+
info.outputDim.size(), info.outputDim.data()));
9091

9192
auto handle = res.fetchOrStore<CnnlContext>()->handle;
9293
size_t workspaceSize;
9394
CNNL_ASSERT(cnnlGetSelectV2WorkspaceSize(handle, d->cond, d->x, d->y, &workspaceSize));
9495

9596
res.fetchOrStore<CnnlContext>();
9697
auto routine = [d = std::move(d), workspaceSize](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) {
97-
// fetch cnnl handle from resources
98-
auto handle = res.fetchOrStore<CnnlContext>()->handle;
99-
auto cond = inputs[0],
100-
x = inputs[1],
101-
y = inputs[2];
102-
auto ans = outputs[0];
10398

10499
CNNL_ASSERT(cnnlSelectV2(
105-
handle, d->cond, cond, d->x, x,
106-
d->y, y, workspace, workspaceSize,
107-
d->ans, ans));
100+
res.fetchOrStore<CnnlContext>()->handle,
101+
d->cond, inputs[0], d->x, inputs[1],
102+
d->y, inputs[2], workspace, workspaceSize,
103+
d->ans, outputs[0]));
108104

109-
cnrtQueueSync(res.fetchOrStore<CnnlContext>()->queue);
105+
res.fetchOrStore<CnnlContext>()->queueSync();
110106
};
111107

112108
return {std::move(routine), workspaceSize};

src/04kernel/src/kernels/where/cnnl_kernel.hh

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,10 @@
77

88
namespace refactor::kernel {
99

10-
using Shape = absl::InlinedVector<dim_t, 4>;
11-
1210
struct WhereCnnl final : public Kernel {
1311
struct {
1412
DataType dataType;
15-
Shape condDim, thenDim, elseDim, outputDim;
13+
std::vector<int> condDim, thenDim, elseDim, outputDim;
1614
} info;
1715

1816
WhereCnnl(decltype(info)) noexcept;

src/04kernel/src/utilities/bang/cnnl_context.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,15 @@ namespace refactor::kernel::cnnl {
3030
return "CnnlContext";
3131
}
3232

33+
void CnnlContext::copyFromCPU(void *dst, const void *src, size_t size) {
34+
BANG_ASSERT(cnrtMemcpy(dst, const_cast<void *>(src), size,
35+
CNRT_MEM_TRANS_DIR_HOST2DEV));
36+
}
37+
38+
void CnnlContext::queueSync() {
39+
BANG_ASSERT(cnrtQueueSync(queue));
40+
}
41+
3342
}// namespace refactor::kernel::cnnl
3443

3544
#endif

src/04kernel/src/utilities/bang/cnnl_context.hh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ namespace refactor::kernel::cnnl {
2222
size_t resourceTypeId() const noexcept final;
2323
std::string_view description() const noexcept final;
2424

25+
void copyFromCPU(void *dst, const void *src, size_t size);
26+
void queueSync();
2527
};
2628

2729
}// namespace refactor::kernel::cnnl
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#ifdef USE_BANG
2+
#include "cnrt_functions.h"
3+
#include "cnnl_functions.h"
4+
#include <cnrt.h>
5+
#include <cstdio>
6+
7+
namespace refactor::kernel::cnnl {
8+
9+
int currentDevice() {
10+
int device;
11+
BANG_ASSERT(cnrtGetDevice(&device));
12+
return device;
13+
}
14+
15+
void sync() {
16+
BANG_ASSERT(cnrtSyncDevice());
17+
}
18+
19+
void copyOut(void *dst, const void *src, size_t size) {
20+
sync();
21+
BANG_ASSERT(cnrtMemcpy(dst, const_cast<void *>(src), size,
22+
CNRT_MEM_TRANS_DIR_DEV2HOST));
23+
}
24+
25+
}// namespace refactor::kernel::cnnl
26+
27+
#endif
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#ifndef KERNEL_CNRT_FUNCTIONS_H
2+
#define KERNEL_CNRT_FUNCTIONS_H
3+
4+
#include "common.h"
5+
6+
namespace refactor::kernel::cnnl {
7+
8+
int currentDevice();
9+
10+
void sync();
11+
12+
void copyOut(void *dst, const void *src, size_t size);
13+
14+
}// namespace refactor::kernel::cnnl
15+
16+
#endif// KERNEL_CNRT_FUNCTIONS_H

0 commit comments

Comments
 (0)