forked from alibaba/MNN
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathReshapeTf.cpp
More file actions
58 lines (49 loc) · 1.71 KB
/
Copy pathReshapeTf.cpp
File metadata and controls
58 lines (49 loc) · 1.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
//
// ReshapeTf.cpp
// MNNConverter
//
// Created by MNN on 2019/01/31.
// Copyright © 2018, Alibaba Group Holding Limited
//
#include <string.h>
#include "TfUtils.hpp"
#include "tfOpConverter.hpp"
#include "graph.pb.h"
DECLARE_OP_CONVERTER(ReshapeTf);
MNN::OpType ReshapeTf::opType() {
return MNN::OpType_Reshape;
}
MNN::OpParameter ReshapeTf::type() {
return MNN::OpParameter_Reshape;
}
void ReshapeTf::run(MNN::OpT *dstOp, TmpNode *srcNode, TmpGraph *tempGraph) {
auto reshape = new MNN::ReshapeT;
TmpNode *shapeNode = tempGraph->_getTmpNode(srcNode->inEdges[1]);
if (shapeNode->opType != "Const") {
dstOp->main.value = reshape;
return;
}
// Const Shape
tensorflow::AttrValue value;
if (find_attr_value(shapeNode->tfNode, "value", value)) {
MNN::DataType dataType = (MNN::DataType)value.tensor().dtype();
CHECK(dataType == MNN::DataType_DT_INT32) << "Shape Dtype ERROR" << srcNode->opName;
reshape->dimType = MNN::MNN_DATA_FORMAT_NHWC;
if (!value.tensor().tensor_content().empty()) // int32
{
const int *data = reinterpret_cast<const int *>(value.tensor().tensor_content().c_str());
int size = value.tensor().tensor_content().size() / sizeof(int);
CHECK(size > 1) << "Shape Data ERROR!!! ===> " << srcNode->opName;
reshape->dims.resize(size);
for (int i = 0; i < size; ++i) {
reshape->dims[i] = data[i];
}
} else {
// only one int value
reshape->dims.resize(1);
reshape->dims[0] = value.tensor().int_val(0);
}
}
dstOp->main.value = reshape;
}
REGISTER_CONVERTER(ReshapeTf, Reshape);