forked from alibaba/MNN
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathBatchToSpaceNDTf.cpp
More file actions
79 lines (62 loc) · 2.91 KB
/
Copy pathBatchToSpaceNDTf.cpp
File metadata and controls
79 lines (62 loc) · 2.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
//
// BatchToSpaceNDTf.cpp
// MNNConverter
//
// Created by MNN on 2019/01/31.
// Copyright © 2018, Alibaba Group Holding Limited
//
#include "TfUtils.hpp"
#include "tfOpConverter.hpp"
#include "graph.pb.h"
DECLARE_OP_CONVERTER(BatchToSpaceNDTf);
MNN::OpType BatchToSpaceNDTf::opType() {
return MNN::OpType_BatchToSpaceND;
}
MNN::OpParameter BatchToSpaceNDTf::type() {
return MNN::OpParameter_SpaceBatch;
}
void BatchToSpaceNDTf::run(MNN::OpT *dstOp, TmpNode *srcNode, TmpGraph *tempGraph) {
DCHECK(srcNode->inEdges.size() == 3) << "BatchToSpaceND Input Node's Num ERROR";
auto spacebatch = new MNN::SpaceBatchT;
auto block_shape = new MNN::BlobT;
block_shape->dataType = MNN::DataType_DT_INT32;
auto paddings = new MNN::BlobT;
paddings->dataType = MNN::DataType_DT_INT32;
tensorflow::AttrValue weightsValue;
if (find_attr_value(srcNode->tfNode, "Tblock_shape", weightsValue)) {
block_shape->dataType = static_cast<MNN::DataType>(weightsValue.type());
}
if (find_attr_value(srcNode->tfNode, "Tpaddings", weightsValue)) {
paddings->dataType = static_cast<MNN::DataType>(weightsValue.type());
}
DCHECK(block_shape->dataType == MNN::DataType_DT_INT32) << "BlockShape Data Type ERROR!";
DCHECK(paddings->dataType == MNN::DataType_DT_INT32) << "BlockShape Data Type ERROR!";
auto blockShapeTensor = tempGraph->_getTmpNode(srcNode->inEdges[1]);
find_attr_value(blockShapeTensor->tfNode, "value", weightsValue);
const auto dimSize = weightsValue.tensor().tensor_shape().dim_size();
block_shape->dims.resize(dimSize);
int dataSize = 1;
for (int i = 0; i < dimSize; ++i) {
dataSize *= weightsValue.tensor().tensor_shape().dim(i).size();
block_shape->dims[i] = weightsValue.tensor().tensor_shape().dim(i).size();
}
auto tensor_content = reinterpret_cast<const int *>(weightsValue.tensor().tensor_content().data());
block_shape->int32s.resize(dataSize);
::memcpy(block_shape->int32s.data(), tensor_content, sizeof(int) * dataSize);
auto paddingTensor = tempGraph->_getTmpNode(srcNode->inEdges[2]);
find_attr_value(paddingTensor->tfNode, "value", weightsValue);
const auto dim = weightsValue.tensor().tensor_shape().dim_size();
paddings->dims.resize(dim);
dataSize = 1;
for (int i = 0; i < dim; ++i) {
dataSize *= weightsValue.tensor().tensor_shape().dim(i).size();
paddings->dims[i] = weightsValue.tensor().tensor_shape().dim(i).size();
}
auto paddingData = reinterpret_cast<const int *>(weightsValue.tensor().tensor_content().data());
paddings->int32s.resize(dataSize);
::memcpy(paddings->int32s.data(), paddingData, sizeof(int) * dataSize);
spacebatch->blockShape = std::unique_ptr<MNN::BlobT>(block_shape);
spacebatch->padding = std::unique_ptr<MNN::BlobT>(paddings);
dstOp->main.value = spacebatch;
}
REGISTER_CONVERTER(BatchToSpaceNDTf, BatchToSpaceND);