Skip to content

Commit 5aa7a1e

Browse files
Chamberlain0w0YdrMaster
authored andcommitted
feat: mlu跑通llama,但未得到正确结果
1 parent beda029 commit 5aa7a1e

File tree

14 files changed

+47
-42
lines changed

14 files changed

+47
-42
lines changed

src/04kernel/src/collectors/concat.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ namespace refactor::kernel {
99
ConcatCollector::filter(TensorRefs inputs, TensorRefs outputs) const {
1010
SplitInfo info(axis, inputs);
1111

12+
auto const &b = outputs[0];
13+
1214
std::vector<KernelBox> ans;
1315
switch (_target) {
1416
case decltype(_target)::Cpu:
@@ -22,7 +24,7 @@ namespace refactor::kernel {
2224
}
2325
break;
2426
case decltype(_target)::Mlu:
25-
if (auto ptr = ConcatCnnl::build(axis, inputs, outputs[0].get()); ptr) {
27+
if (auto ptr = ConcatCnnl::build(axis, inputs, b); ptr) {
2628
ans.emplace_back(std::move(ptr));
2729
}
2830
break;

src/04kernel/src/collectors/gather.cc

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,12 @@ namespace refactor::kernel {
99
GatherCollector::filter(TensorRefs inputs, TensorRefs outputs) const {
1010
GatherInfo info(axis, inputs[0], inputs[1]);
1111

12-
std::vector<KernelBox> ans;
12+
auto const &a = inputs[0];
13+
auto const &b = inputs[1];
14+
auto const &c = outputs[0];
15+
16+
std::vector<KernelBox>
17+
ans;
1318
switch (_target) {
1419
case decltype(_target)::Cpu:
1520
if (auto ptr = GatherCpu::build(info); ptr != nullptr) {
@@ -22,7 +27,7 @@ namespace refactor::kernel {
2227
}
2328
break;
2429
case decltype(_target)::Mlu:
25-
if (auto ptr = GatherCnnl::build(axis, inputs[0].get(), inputs[1].get(), outputs[0].get()); ptr != nullptr) {
30+
if (auto ptr = GatherCnnl::build(axis, a, b, c); ptr != nullptr) {
2631
ans.emplace_back(std::move(ptr));
2732
}
2833
break;

src/04kernel/src/collectors/split.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ namespace refactor::kernel {
99
SplitCollector::filter(TensorRefs inputs, TensorRefs outputs) const {
1010
SplitInfo info(axis, outputs);
1111

12+
auto const &a = inputs[0];
13+
1214
std::vector<KernelBox> ans;
1315
switch (_target) {
1416
case decltype(_target)::Cpu:
@@ -22,7 +24,7 @@ namespace refactor::kernel {
2224
}
2325
break;
2426
case decltype(_target)::Mlu:
25-
if (auto ptr = SplitCnnl::build(axis, inputs[0].get(), outputs); ptr) {
27+
if (auto ptr = SplitCnnl::build(axis, a, outputs); ptr) {
2628
ans.emplace_back(std::move(ptr));
2729
}
2830
break;

src/04kernel/src/kernels/concat/cnnl_kernel.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace refactor::kernel {
1212
K::ConcatCnnl(SplitInfoCnnl info_) noexcept
1313
: Kernel(), info(std::move(info_)) {}
1414

15-
auto K::build(int axis, TensorRefs inputs, Tensor output) noexcept -> KernelBox {
15+
auto K::build(int axis, TensorRefs inputs, Tensor const &output) noexcept -> KernelBox {
1616
#ifndef USE_BANG
1717
return nullptr;
1818
#endif

src/04kernel/src/kernels/concat/cnnl_kernel.hh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ namespace refactor::kernel {
1111

1212
explicit ConcatCnnl(SplitInfoCnnl) noexcept;
1313

14-
static KernelBox build(int, TensorRefs, Tensor) noexcept;
14+
static KernelBox build(int, TensorRefs, Tensor const &) noexcept;
1515
static size_t typeId() noexcept;
1616

1717
size_t kernelTypeId() const noexcept final;

src/04kernel/src/kernels/gather/cnnl_kernel.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,14 @@ namespace refactor::kernel {
1111
K::GatherCnnl(decltype(info) info_) noexcept
1212
: Kernel(), info(std::move(info_)) {}
1313

14-
auto K::build(int axis, Tensor input, Tensor index, Tensor output) noexcept -> KernelBox {
14+
auto K::build(int axis, Tensor const &input, Tensor const &index, Tensor const &output) noexcept -> KernelBox {
1515
#ifndef USE_BANG
1616
return nullptr;
1717
#endif
18+
1819
return std::make_unique<K>(decltype(info){
1920
input.dataType,
20-
index.dataType,
21+
DataType::I32,
2122
axis,
2223
std::vector<int>(input.shape.begin(), input.shape.end()),
2324
std::vector<int>(index.shape.begin(), index.shape.end()),

src/04kernel/src/kernels/gather/cnnl_kernel.hh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ namespace refactor::kernel {
1515

1616
explicit GatherCnnl(decltype(info)) noexcept;
1717

18-
static KernelBox build(int, Tensor, Tensor, Tensor) noexcept;
18+
static KernelBox build(int, Tensor const &, Tensor const &, Tensor const &) noexcept;
1919
static size_t typeId() noexcept;
2020

2121
size_t kernelTypeId() const noexcept final;

src/04kernel/src/kernels/reduce/cnnl_kernel.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,10 @@ namespace refactor::kernel {
7575
for (auto axis : axes) {
7676
dimsO[axis] = 1;
7777
}
78-
setCnnlTensor(d->x, dataType, slice(dimsI.data(), dimsI.size()));
79-
setCnnlTensor(d->y, dataType, slice(dimsO.data(), dimsO.size()));
78+
// setCnnlTensor(d->x, dataType, slice(dimsI.data(), dimsI.size()));
79+
// setCnnlTensor(d->y, dataType, slice(dimsO.data(), dimsO.size()));
80+
CNNL_ASSERT(cnnlSetTensorDescriptor(d->x, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(dataType), dimsI.size(), dimsI.data()));
81+
CNNL_ASSERT(cnnlSetTensorDescriptor(d->y, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(dataType), dimsO.size(), dimsO.data()));
8082

8183
// clang-format off
8284
auto reduceOp = reduceType == ReduceType::Mean ? CNNL_REDUCE_AVG

src/04kernel/src/kernels/simple_binary/binary_cnnl.cc

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,10 @@ namespace refactor::kernel {
2626
// !a.dataType.isFloat() ||
2727
!ARTHIMETIC.contains(op) ||
2828
// At least one of a,b should have the same shape as c
29-
(a.shape != c.shape && b.shape != c.shape) ||
29+
(a.shape != c.shape && b.shape != c.shape)
3030
// Sub only supports brocasting b
31-
(a.shape != c.shape && op == Op::Sub)) {
31+
// (a.shape != c.shape && op == Op::Sub)
32+
) {
3233
return nullptr;
3334
}
3435

@@ -122,18 +123,13 @@ namespace refactor::kernel {
122123

123124
auto handle = res.fetchOrStore<CnnlContext>()->handle;
124125
size_t workspaceSize;
125-
if (aDims != cDims) {
126-
CNNL_ASSERT(cnnlGetBinaryWorkspaceSize(handle, d->bDesc,
127-
d->aDesc, d->cDesc,
128-
&workspaceSize));
129-
} else {
130-
CNNL_ASSERT(cnnlGetBinaryWorkspaceSize(handle, d->aDesc,
126+
CNNL_ASSERT(cnnlGetBinaryWorkspaceSize(handle, d->aDesc,
131127
d->bDesc, d->cDesc,
132128
&workspaceSize));
133-
}
129+
134130

135131
res.fetchOrStore<CnnlContext>();
136-
auto routine = [swap = aDims != cDims, d,
132+
auto routine = [d = std::move(d),
137133
workspaceSize, cnnlLogicOP,
138134
op = this->opType](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) {
139135
auto handle = res.fetchOrStore<CnnlContext>()->handle;
@@ -151,20 +147,11 @@ namespace refactor::kernel {
151147
beta = d->f32
152148
? factor<fp32_t>(0)
153149
: factor<fp64_t>(0);
154-
155-
if (swap) {
156-
CNNL_ASSERT(cnnlOpTensor(handle, d->opDesc,
157-
&alphaB, d->bDesc, b,
158-
&alphaA, d->aDesc, a,
159-
workspace, workspaceSize,
160-
&beta, d->cDesc, c));
161-
} else {
162150
CNNL_ASSERT(cnnlOpTensor(handle, d->opDesc,
163151
&alphaA, d->aDesc, a,
164152
&alphaB, d->bDesc, b,
165153
workspace, workspaceSize,
166154
&beta, d->cDesc, c));
167-
}
168155
} else if (op == SimpleBinaryType::Div) {
169156
CNNL_ASSERT(cnnlDiv_v2(handle,
170157
CNNL_COMPUTATION_HIGH_PRECISION,

src/04kernel/src/kernels/simple_unary/cnnl_simple_unary_kernel.cc

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,12 @@ namespace refactor::kernel {
6262

6363
setCnnlTensor(d->tensor, dataType, slice(&size, 1));
6464

65-
auto cnnlUnaryForward = [this](cnnlHandle_t handle,
66-
const cnnlTensorDescriptor_t x_desc,
67-
const void *x,
68-
const cnnlTensorDescriptor_t y_desc,
69-
void *y) -> cnnlStatus_t {
70-
switch (this->type) {
65+
auto cnnlUnaryForward = [t = this->type](cnnlHandle_t handle,
66+
const cnnlTensorDescriptor_t x_desc,
67+
const void *x,
68+
const cnnlTensorDescriptor_t y_desc,
69+
void *y) -> cnnlStatus_t {
70+
switch (t) {
7171
case Ty::Abs:
7272
return cnnlAbs(handle, x_desc, x, y_desc, y);
7373
case Ty::Neg:
@@ -77,6 +77,7 @@ namespace refactor::kernel {
7777
case Ty::Erf:
7878
return cnnlErf_v2(handle, CNNL_COMPUTATION_HIGH_PRECISION, x_desc, x, y_desc, y);
7979
default:
80+
// fmt::println("{}", unaryName(t));
8081
UNREACHABLE();
8182
}
8283
};

0 commit comments

Comments
 (0)