-
Notifications
You must be signed in to change notification settings - Fork 528
/
Copy pathtensor_accessor.h
218 lines (194 loc) · 6.34 KB
/
tensor_accessor.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/core/result.h>
namespace executorch {
namespace extension {
namespace internal {
/**
* Base class template storing the underlying data with size and stride helpers.
* Inherited by TensorAccessor<> which requires specialization on rank.
*/
template <typename T, ssize_t N>
class TensorAccessorBase {
public:
/// Returns the size of the underlying tensor at the given dimension.
executorch::aten::SizesType size(ssize_t i) const {
ET_CHECK_MSG(
i < dim_ && i >= 0,
"Dimension outside of [0, %zd], got %zd",
dim_ - 1,
i);
return sizes_[i];
}
/// Returns the stride of the underlying tensor at the given dimension.
executorch::aten::StridesType stride(ssize_t i) const {
ET_CHECK_MSG(
i < dim_ && i >= 0,
"Dimension outside of [0, %zd], got %zd",
dim_ - 1,
i);
return strides_[i];
}
protected:
TensorAccessorBase(
T* data,
const executorch::aten::SizesType* sizes,
const executorch::aten::StridesType* strides,
ssize_t dim)
: data_(data), sizes_(sizes), strides_(strides), dim_(dim) {}
T* data_;
const executorch::aten::SizesType* sizes_;
const executorch::aten::StridesType* strides_;
ssize_t dim_;
};
} // namespace internal
/**
* TensorAccessor template with data type and rank as template parameters. No
* public constructors, can only be created using make_tensor_accessor from a
* given executorch::aten::Tensor. Use operator[] to index and obtain a lower
* rank accessor or the underlying scalar value.
*/
template <typename T, ssize_t N>
class TensorAccessor : public internal::TensorAccessorBase<T, N> {
public:
/**
* Index into the the outer most dimension.
*
* @param i Index.
* @return If N > 1, a TensorAccessor with N-1 dimensions. If N == 1, a
* reference to the underlying scalar. Refer to the TensorAccessor<T, 1>
* specialization.
*/
TensorAccessor<T, N - 1> operator[](ssize_t i) {
return TensorAccessor<T, N - 1>(
this->data_ + this->strides_[0] * i,
this->sizes_ + 1,
this->strides_ + 1,
N - 1);
}
/**
* Index into the the outer most dimension.
*
* @param i Index.
* @return If N > 1, a constant TensorAccessor with N-1 dimensions. If N == 1,
* a constant reference to the underlying scalar. Refer to the
* TensorAccessor<T, 1> specialization.
*/
const TensorAccessor<T, N - 1> operator[](ssize_t i) const {
return TensorAccessor<T, N - 1>(
this->data_ + this->strides_[0] * i,
this->sizes_ + 1,
this->strides_ + 1,
N - 1);
}
private:
TensorAccessor(
T* data,
const executorch::aten::SizesType* sizes,
const executorch::aten::StridesType* strides,
ssize_t dim)
: internal::TensorAccessorBase<T, N>(data, sizes, strides, dim) {}
template <typename T2, ssize_t N2>
friend class TensorAccessor;
template <typename T2, ssize_t N2>
friend executorch::runtime::Result<TensorAccessor<T2, N2>>
make_tensor_accessor(const executorch::aten::Tensor& t);
};
/**
* TensorAccessor specialization for N == 1, where operator[] returns a
* reference to the underlying scalar.
*/
template <typename T>
class TensorAccessor<T, 1> : public internal::TensorAccessorBase<T, 1> {
public:
/**
* Index into the the outer most dimension.
*
* @param i Index.
* @return Reference to the underlying scalar.
*/
T& operator[](ssize_t i) {
return this->data_[this->strides_[0] * i];
}
/**
* Index into the the outer most dimension.
*
* @param i Index.
* @return Constant reference to the underlying scalar.
*/
const T& operator[](ssize_t i) const {
return this->data_[this->strides_[0] * i];
}
private:
TensorAccessor(
T* data,
const executorch::aten::SizesType* sizes,
const executorch::aten::StridesType* strides,
ssize_t dim)
: internal::TensorAccessorBase<T, 1>(data, sizes, strides, dim) {}
template <typename T2, ssize_t N2>
friend class TensorAccessor;
template <typename T2, ssize_t N2>
friend executorch::runtime::Result<TensorAccessor<T2, N2>>
make_tensor_accessor(const executorch::aten::Tensor& t);
};
/**
* Creates a TensorAccessor<T, N> from the given tensor. The number of dimension
* N and the data type T's size must match those of the input tensor. For
* Executorch tensors, non-trivial dimension order is not supported.
*
* @param tensor Origin tensor. The TensorImpl inside must outlive the returned
* TensorAccessor.
* @return TensorAccessor of the input tensor.
* @retval Error::InvalidArgument Mismatch on data type or number of dimensions.
* @retval Error::NotSupported Input tensor has non-trivial dimension onrder.
*/
template <typename T, ssize_t N>
executorch::runtime::Result<TensorAccessor<T, N>> make_tensor_accessor(
const executorch::aten::Tensor& tensor) {
static_assert(
N > 0,
"TensorAccessor is used for indexing tensors, for scalar use *_data_ptr<T>()");
if (N != tensor.dim()) {
ET_LOG(
Error,
"Expecting %zd dimensions but tensor has %zd.",
static_cast<ssize_t>(N),
static_cast<ssize_t>(tensor.dim()));
return executorch::runtime::Error::InvalidArgument;
}
if (sizeof(T) != tensor.element_size()) {
ET_LOG(
Error,
"Size of data type template argument (%zd) not equal to tensor element size (%zd)",
static_cast<ssize_t>(sizeof(T)),
static_cast<ssize_t>(tensor.element_size()));
return executorch::runtime::Error::InvalidArgument;
}
#ifndef USE_ATEN_LIB
auto dim_order = tensor.dim_order();
for (ssize_t i = 0; i < dim_order.size(); i++) {
if (dim_order[i] != i) {
ET_LOG(Error, "Non-trival dim_order not supported.");
return executorch::runtime::Error::NotSupported;
}
}
#endif
T* ptr = nullptr;
if constexpr (std::is_const_v<T>) {
ptr = tensor.const_data_ptr<T>();
} else {
ptr = tensor.mutable_data_ptr<T>();
}
return TensorAccessor<T, N>(
ptr, tensor.sizes().data(), tensor.strides().data(), N);
}
} // namespace extension
} // namespace executorch