forked from alibaba/MNN
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathSpatialProductExecution.cpp
More file actions
92 lines (77 loc) · 3.09 KB
/
Copy pathSpatialProductExecution.cpp
File metadata and controls
92 lines (77 loc) · 3.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
//
// SpatialProductExecution.cpp
// MNN
//
// Created by MNN on 2019/02/28.
// Copyright © 2018, Alibaba Group Holding Limited
//
#include "SpatialProductExecution.hpp"
#include <Macro.h>
#include "TensorUtils.hpp"
namespace MNN {
namespace OpenCL {
SpatialProductExecution::SpatialProductExecution(const std::vector<Tensor *> &inputs, const MNN::Op *op,
Backend *backend)
: Execution(backend) {
#ifdef LOG_VERBOSE
MNN_PRINT("start SpatialProductExecution init !\n");
#endif
mOpenCLBackend = static_cast<OpenCLBackend *>(backend);
mAreadySetArg = false;
#ifdef LOG_VERBOSE
MNN_PRINT("end SpatialProductExecution init !\n");
#endif
}
ErrorCode SpatialProductExecution::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
auto runtime = mOpenCLBackend->getOpenCLRuntime();
if (mKernel.get() == nullptr) {
std::set<std::string> buildOptions;
std::string kernelName = "spatial_product";
mKernel = runtime->buildKernel("spatial_product", kernelName, buildOptions);
mMaxWorkGroupSize = static_cast<uint32_t>(runtime->getMaxWorkGroupSize(mKernel));
}
return NO_ERROR;
}
ErrorCode SpatialProductExecution::onExecute(const std::vector<Tensor *> &inputs,
const std::vector<Tensor *> &outputs) {
#ifdef LOG_VERBOSE
MNN_PRINT("start SpatialProductExecution onExecute !\n");
#endif
Tensor *input = inputs[0];
Tensor *input1 = inputs[1];
Tensor *output = outputs[0];
std::vector<int> inputShape = tensorShapeFormat(input);
std::vector<int> input1Shape = tensorShapeFormat(input1);
std::vector<int> outputShape = tensorShapeFormat(output);
if (!mAreadySetArg) {
int batch = outputShape.at(0);
int outputHeight = outputShape.at(1);
int outputWidth = outputShape.at(2);
int channels = outputShape.at(3);
int channelBlocks = (channels + 3) / 4;
mGlobalWorkSize = {
static_cast<uint32_t>(channelBlocks),
static_cast<uint32_t>(outputWidth),
static_cast<uint32_t>(batch * outputHeight),
};
uint32_t idx = 0;
mKernel.setArg(idx++, mGlobalWorkSize[0]);
mKernel.setArg(idx++, mGlobalWorkSize[1]);
mKernel.setArg(idx++, mGlobalWorkSize[2]);
mKernel.setArg(idx++, openCLImage(input));
mKernel.setArg(idx++, openCLImage(input1));
mKernel.setArg(idx++, static_cast<int>(outputHeight));
mKernel.setArg(idx++, openCLImage(output));
mAreadySetArg = true;
}
const std::vector<uint32_t> lws =
localWS3DDefault(mGlobalWorkSize, mMaxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime());
run3DKernelDefault(mKernel, mGlobalWorkSize, lws, mOpenCLBackend->getOpenCLRuntime());
#ifdef LOG_VERBOSE
MNN_PRINT("end SpatialProductExecution onExecute !\n");
#endif
return NO_ERROR;
}
OpenCLCreatorRegister<TypedCreator<SpatialProductExecution>> __spatial_product_op(OpType_SpatialProduct);
} // namespace OpenCL
} // namespace MNN