forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathUpSample.h
501 lines (449 loc) · 18.5 KB
/
UpSample.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
#pragma once
#include <math.h>
#include <ATen/OpMathType.h>
#include <ATen/TensorUtils.h>
#include <ATen/core/Tensor.h>
#include <ATen/cpu/vec/vec.h>
#include <ATen/native/DispatchStub.h>
/**
* Note [compute_scales_value]
* Note [area_pixel_compute_scale]
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
* Interpolate with scale_factor can have different behaviors
* depending on the value of recompute_scale_factor:
*
* - With recompute_scale_factor = True (current default behavior):
* the scale_factor, when provided by the user, are used to calculate
* the output size. The input size and the computed output_size
* are then used to infer new values for the scales which are
* used in the interpolation. Because floating-point math is not exact,
* this may be a different value from the user-supplied scales.
*
* - With recompute_scale_factor = False (which will be the default
* behavior starting 1.5.0):
* the behavior follows opencv logic, and the scales provided by
* the user are the ones used in the interpolation calculations.
*
* If the scales are not provided or if they are provided but
* recompute_scale_factor is set to True (default behavior), the scales
* are computed from the input and the output size;
*
*
* When the scales are inferred from the input and output sizes,
* we view each pixel as an area, idx + 0.5 as its center index.
* Here is an example formula in 1D case.
* if align_corners: center of two corner pixel areas are preserved,
* (0.5, 0.5) -> (0.5, 0.5),
* (input_size - 0.5, 0.5) -> (output_size - 0.5)
* scale = (input_size - 0.5 - 0.5) / (output_size - 0.5 - 0.5)
* src_index + 0.5 - 0.5 = scale * (dst_index + 0.5 - 0.5)
* if not align_corners: the whole range is scaled accordingly
* scale = input_size / output_size
* src_idx + 0.5 = scale * (dst_index + 0.5)
*/
namespace at::native {
namespace upsample {
TORCH_API c10::SmallVector<int64_t, 3> compute_output_size(
c10::IntArrayRef input_size, // Full input tensor size.
at::OptionalIntArrayRef output_size,
c10::optional<c10::ArrayRef<double>> scale_factors);
inline c10::optional<double> get_scale_value(c10::optional<c10::ArrayRef<double>> scales, int idx) {
if (!scales) {
return c10::nullopt;
}
return scales->at(idx);
}
} // namespace upsample
using scale_t = c10::optional<double>;
using upsampling_nearest1d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_w);
using _upsampling_nearest_exact1d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_w);
using upsampling_nearest2d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_h, scale_t scales_w);
using _upsampling_nearest_exact2d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_h, scale_t scales_w);
using upsampling_nearest3d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_d, scale_t scales_h, scale_t scales_w);
using _upsampling_nearest_exact3d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_d, scale_t scales_h, scale_t scales_w);
using upsampling_linear1d = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_w);
using upsampling_bilinear2d = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_h, scale_t scales_w);
using _upsampling_bilinear2d_aa = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_h, scale_t scales_w);
using upsampling_trilinear3d = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_d, scale_t scales_h, scale_t scales_w);
using upsampling_bicubic2d = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_h, scale_t scales_w);
using _upsampling_bicubic2d_aa = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_h, scale_t scales_w);
DECLARE_DISPATCH(upsampling_nearest1d, upsample_nearest1d_kernel);
DECLARE_DISPATCH(_upsampling_nearest_exact1d, _upsample_nearest_exact1d_kernel);
DECLARE_DISPATCH(upsampling_nearest2d, upsample_nearest2d_kernel);
DECLARE_DISPATCH(_upsampling_nearest_exact2d, _upsample_nearest_exact2d_kernel);
DECLARE_DISPATCH(upsampling_nearest3d, upsample_nearest3d_kernel);
DECLARE_DISPATCH(_upsampling_nearest_exact3d, _upsample_nearest_exact3d_kernel);
DECLARE_DISPATCH(upsampling_nearest1d, upsample_nearest1d_backward_kernel);
DECLARE_DISPATCH(_upsampling_nearest_exact1d, _upsample_nearest_exact1d_backward_kernel);
DECLARE_DISPATCH(upsampling_nearest2d, upsample_nearest2d_backward_kernel);
DECLARE_DISPATCH(_upsampling_nearest_exact2d, _upsample_nearest_exact2d_backward_kernel);
DECLARE_DISPATCH(upsampling_nearest3d, upsample_nearest3d_backward_kernel);
DECLARE_DISPATCH(_upsampling_nearest_exact3d, _upsample_nearest_exact3d_backward_kernel);
DECLARE_DISPATCH(upsampling_linear1d, upsample_linear1d_kernel);
DECLARE_DISPATCH(upsampling_bilinear2d, upsample_bilinear2d_kernel);
DECLARE_DISPATCH(_upsampling_bilinear2d_aa, _upsample_bilinear2d_aa_kernel);
DECLARE_DISPATCH(upsampling_trilinear3d, upsample_trilinear3d_kernel);
DECLARE_DISPATCH(upsampling_linear1d, upsample_linear1d_backward_kernel);
DECLARE_DISPATCH(upsampling_bilinear2d, upsample_bilinear2d_backward_kernel);
DECLARE_DISPATCH(_upsampling_bilinear2d_aa, _upsample_bilinear2d_aa_backward_kernel);
DECLARE_DISPATCH(upsampling_trilinear3d, upsample_trilinear3d_backward_kernel);
DECLARE_DISPATCH(upsampling_bicubic2d, upsample_bicubic2d_kernel);
DECLARE_DISPATCH(_upsampling_bicubic2d_aa, _upsample_bicubic2d_aa_kernel);
DECLARE_DISPATCH(_upsampling_bicubic2d_aa, _upsample_bicubic2d_aa_backward_kernel);
static C10_UNUSED std::array<int64_t, 3> upsample_1d_common_check(IntArrayRef input_size, IntArrayRef output_size) {
TORCH_CHECK(
output_size.size() == 1,
"It is expected output_size equals to 1, but got size ",
output_size.size());
TORCH_CHECK(
input_size.size() == 3,
"It is expected input_size equals to 3, but got size ",
input_size.size());
int64_t output_width = output_size[0];
int64_t nbatch = input_size[0];
int64_t channels = input_size[1];
int64_t input_width = input_size[2];
TORCH_CHECK(
input_width > 0 && output_width > 0,
"Input and output sizes should be greater than 0, but got input (W: ",
input_width,
") and output (W: ",
output_width,
")");
return {nbatch, channels, output_width};
}
static C10_UNUSED std::array<int64_t, 4> upsample_2d_common_check(IntArrayRef input_size, IntArrayRef output_size) {
TORCH_CHECK(
output_size.size() == 2,
"It is expected output_size equals to 2, but got size ",
output_size.size());
TORCH_CHECK(
input_size.size() == 4,
"It is expected input_size equals to 4, but got size ",
input_size.size());
int64_t output_height = output_size[0];
int64_t output_width = output_size[1];
int64_t nbatch = input_size[0];
int64_t channels = input_size[1];
int64_t input_height = input_size[2];
int64_t input_width = input_size[3];
TORCH_CHECK(
input_height > 0 && input_width > 0 && output_height > 0 &&
output_width > 0,
"Input and output sizes should be greater than 0,"
" but got input (H: ",
input_height,
", W: ",
input_width,
") output (H: ",
output_height,
", W: ",
output_width,
")");
return {nbatch, channels, output_height, output_width};
}
static C10_UNUSED
std::array<int64_t, 5> upsample_3d_common_check(IntArrayRef input_size, IntArrayRef output_size) {
TORCH_CHECK(
output_size.size() == 3,
"It is expected output_size equals to 3, but got size ",
output_size.size());
TORCH_CHECK(
input_size.size() == 5,
"It is expected input_size equals to 5, but got size ",
input_size.size());
int64_t output_depth = output_size[0];
int64_t output_height = output_size[1];
int64_t output_width = output_size[2];
int64_t nbatch = input_size[0];
int64_t channels = input_size[1];
int64_t input_depth = input_size[2];
int64_t input_height = input_size[3];
int64_t input_width = input_size[4];
TORCH_CHECK(
input_depth > 0 && input_height > 0 && input_width > 0 &&
output_depth > 0 && output_height > 0 && output_width > 0,
"Input and output sizes should be greater than 0, but got input (D: ",
input_depth,
", H: ",
input_height,
", W: ",
input_width,
") output (D: ",
output_depth,
", H: ",
output_height,
", W: ",
output_width,
")");
return {nbatch, channels, output_depth, output_height, output_width};
}
static inline void upsample_2d_shape_check(
const Tensor& input,
const Tensor& grad_output,
int64_t nbatch,
int64_t nchannels,
int64_t input_height,
int64_t input_width,
int64_t output_height,
int64_t output_width) {
TORCH_CHECK(
input_height > 0 && input_width > 0 && output_height > 0 &&
output_width > 0,
"Input and output sizes should be greater than 0,"
" but got input (H: ",
input_height,
", W: ",
input_width,
") output (H: ",
output_height,
", W: ",
output_width,
")");
if (input.defined()) {
// Allow for empty batch size but not other dimensions
TORCH_CHECK(
(input.numel() != 0 ||
(input.size(1) != 0 && input.size(2) != 0 && input.size(3) != 0)
) &&
input.dim() == 4,
"Non-empty 4D data tensor expected but got a tensor with sizes ",
input.sizes());
} else if (grad_output.defined()) {
check_dim_size(grad_output, 4, 0, nbatch);
check_dim_size(grad_output, 4, 1, nchannels);
check_dim_size(grad_output, 4, 2, output_height);
check_dim_size(grad_output, 4, 3, output_width);
}
}
template <typename scalar_t>
static inline scalar_t compute_scales_value(
const c10::optional<double> scale,
int64_t input_size,
int64_t output_size) {
// see Note [compute_scales_value]
// FIXME: remove magic > 0 after we ensure no models were serialized with -1 defaults.
return (scale.has_value() && scale.value() > 0.)
? static_cast<scalar_t>(1.0 / scale.value())
: (static_cast<scalar_t>(input_size) / output_size);
}
template <typename scalar_t>
static inline scalar_t area_pixel_compute_scale(
int64_t input_size,
int64_t output_size,
bool align_corners,
const c10::optional<double> scale) {
// see Note [area_pixel_compute_scale]
if(align_corners) {
if(output_size > 1) {
return static_cast<scalar_t>(input_size - 1) / (output_size - 1);
} else {
return static_cast<scalar_t>(0);
}
} else {
return compute_scales_value<scalar_t>(scale, input_size, output_size);
}
}
template <typename scalar_t>
static inline scalar_t area_pixel_compute_source_index(
scalar_t scale,
int64_t dst_index,
bool align_corners,
bool cubic) {
if (align_corners) {
return scale * dst_index;
} else {
scalar_t src_idx = scale * (dst_index + static_cast<scalar_t>(0.5)) -
static_cast<scalar_t>(0.5);
// [Note] Follow Opencv resize logic:
// We allow negative src_idx here and later will use
// dx = src_idx - floorf(src_idx)
// to compute the "distance"(which affects weights).
// For linear modes, weight distribution doesn't matter
// for negative indices as they use 2 pixels to interpolate.
// For example, [-1, 0], they both use pixel 0 value so it
// doesn't affect if we bound the src_idx to 0 or not.
// TODO: Our current linear mode impls use unbound indices
// where we should and then remove this cubic flag.
// This matters in cubic mode, as we might need [-1, 0, 1, 2]
// to interpolate and the weights can be affected.
return (!cubic && src_idx < static_cast<scalar_t>(0)) ? scalar_t(0)
: src_idx;
}
}
static inline int64_t nearest_neighbor_compute_source_index(
const float scale,
int64_t dst_index,
int64_t input_size) {
// Index computation matching OpenCV INTER_NEAREST
// which is buggy and kept for BC
const int64_t src_index =
std::min(static_cast<int64_t>(floorf(dst_index * scale)), input_size - 1);
return src_index;
}
static inline int64_t nearest_neighbor_exact_compute_source_index(
const float scale,
int64_t dst_index,
int64_t input_size) {
// index_f32 = (output_index + 0.5) * scale - 0.5
// input_index = round(index_f32)
// Same as Pillow and Scikit-Image/Scipy ndi.zoom
const int64_t src_index =
std::min(static_cast<int64_t>(floorf((dst_index + 0.5) * scale)), input_size - 1);
return src_index;
}
static inline int64_t nearest_idx(
int64_t output_index,
int64_t input_size,
int64_t output_size,
c10::optional<double> scales) {
// This method specificly treats cases: output_size == input_size or
// output_size == 2 * input_size, that we would like to get rid of
// We keep this method for BC and consider as deprecated.
// See nearest_exact_idx as replacement
if (output_size == input_size) {
// scale_factor = 1, simply copy
return output_index;
} else if (output_size == 2 * input_size) {
// scale_factor = 2, shift input index
return output_index >> 1;
} else {
float scale = compute_scales_value<float>(scales, input_size, output_size);
return nearest_neighbor_compute_source_index(scale, output_index, input_size);
}
}
static inline int64_t nearest_exact_idx(
int64_t output_index,
int64_t input_size,
int64_t output_size,
c10::optional<double> scales) {
float scale = compute_scales_value<float>(scales, input_size, output_size);
return nearest_neighbor_exact_compute_source_index(scale, output_index, input_size);
}
// Define a typedef to dispatch to nearest_idx or nearest_exact_idx
typedef int64_t (*nearest_idx_fn_t)(int64_t, int64_t, int64_t, c10::optional<double>);
template <typename scalar_t>
static scalar_t upsample_get_value_bounded(
scalar_t* data,
int64_t width,
int64_t height,
int64_t x,
int64_t y) {
int64_t access_x = std::max(std::min(x, width - 1), static_cast<int64_t>(0));
int64_t access_y = std::max(std::min(y, height - 1), static_cast<int64_t>(0));
return data[access_y * width + access_x];
}
template <typename scalar_t>
static void upsample_increment_value_bounded(
scalar_t* data,
int64_t width,
int64_t height,
int64_t x,
int64_t y,
scalar_t value) {
int64_t access_x = std::max(std::min(x, width - 1), static_cast<int64_t>(0));
int64_t access_y = std::max(std::min(y, height - 1), static_cast<int64_t>(0));
data[access_y * width + access_x] += value;
}
// Based on
// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
template <typename scalar_t>
static inline scalar_t cubic_convolution1(scalar_t x, scalar_t A) {
return ((A + 2) * x - (A + 3)) * x * x + 1;
}
template <typename scalar_t>
static inline scalar_t cubic_convolution2(scalar_t x, scalar_t A) {
return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A;
}
template <typename scalar_t>
static inline void get_cubic_upsample_coefficients(
scalar_t coeffs[4],
scalar_t t) {
scalar_t A = -0.75;
scalar_t x1 = t;
coeffs[0] = cubic_convolution2<scalar_t>(x1 + 1.0, A);
coeffs[1] = cubic_convolution1<scalar_t>(x1, A);
// opposite coefficients
scalar_t x2 = 1.0 - t;
coeffs[2] = cubic_convolution1<scalar_t>(x2, A);
coeffs[3] = cubic_convolution2<scalar_t>(x2 + 1.0, A);
}
template <typename scalar_t>
static inline scalar_t cubic_interp1d(
scalar_t x0,
scalar_t x1,
scalar_t x2,
scalar_t x3,
scalar_t t) {
scalar_t coeffs[4];
get_cubic_upsample_coefficients<scalar_t>(coeffs, t);
return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3];
}
// when `real_input_index` becomes larger than the range the floating point
// type can accurately represent, the type casting to `int64_t` might exceed
// `input_size`, causing overflow. So we guard it with `std::min` below.
template<typename scalar_t, typename opmath_t>
static inline void guard_index_and_lambda(const opmath_t& real_input_index, const int64_t& input_size, int64_t& input_index, scalar_t& lambda) {
input_index = std::min(static_cast<int64_t>(floorf(real_input_index)), input_size - 1);
lambda = std::min(
std::max(real_input_index - input_index, static_cast<opmath_t>(0)),
static_cast<opmath_t>(1)
);
}
template<typename scalar_t, typename opmath_t>
static inline void compute_source_index_and_lambda(
int64_t& input_index0,
int64_t& input_index1,
scalar_t& lambda0,
scalar_t& lambda1,
opmath_t ratio,
int64_t output_index,
int64_t input_size,
int64_t output_size,
bool align_corners) {
if (output_size == input_size) {
// scale_factor = 1, simply copy
input_index0 = output_index;
input_index1 = output_index;
lambda0 = static_cast<scalar_t>(1);
lambda1 = static_cast<scalar_t>(0);
} else {
const auto real_input_index =
area_pixel_compute_source_index<opmath_t>(
ratio, output_index, align_corners, /*cubic=*/false);
guard_index_and_lambda(real_input_index, input_size, input_index0, lambda1);
int64_t offset = (input_index0 < input_size - 1) ? 1 : 0;
input_index1 = input_index0 + offset;
lambda0 = static_cast<scalar_t>(1.) - lambda1;
}
}
// It will not be used by data types other than BFloat16.
template <typename scalar_in, typename scalar_out>
void inline apply_grad_input(scalar_in* buffer_ptr, scalar_out* gin, int64_t size) {
TORCH_CHECK((std::is_same<scalar_out, BFloat16>::value),
"Upsample backward only support BFloat16 in the lower percision data types on CPU.")
TORCH_CHECK((std::is_same<scalar_in, float>::value),
"Upsample backward should use float as acc buffer for BFloat16 grad input on CPU.")
return;
}
template <>
void inline apply_grad_input(float* buffer_ptr, BFloat16* gin, int64_t size) {
using bVec = vec::Vectorized<BFloat16>;
using fVec = vec::Vectorized<float>;
int64_t d = 0;
for (; d < size - (size % bVec::size()); d += bVec::size()) {
bVec gin_bvec = bVec::loadu(gin + d);
fVec gin_fvec0, gin_fvec1;
std::tie(gin_fvec0, gin_fvec1) = convert_bfloat16_float(gin_bvec);
gin_fvec0 += fVec::loadu(buffer_ptr + d);
gin_fvec1 += fVec::loadu(buffer_ptr + d + fVec::size());
fVec(0).store(buffer_ptr + d);
fVec(0).store(buffer_ptr + d + fVec::size());
convert_float_bfloat16(gin_fvec0, gin_fvec1).store(gin + d);
}
for (; d < size; d++) {
gin[d] += buffer_ptr[d];
buffer_ptr[d] = 0;
}
}
} // namespace at::native