forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathPixelShuffle.cpp
More file actions
161 lines (128 loc) · 6.02 KB
/
PixelShuffle.cpp
File metadata and controls
161 lines (128 loc) · 6.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/native/TensorTransformations.h>
#include <ATen/native/cpu/PixelShuffleKernel.h>
#include <ATen/native/PixelShuffle.h>
#include <c10/util/Exception.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/empty.h>
#include <ATen/ops/pixel_shuffle_native.h>
#include <ATen/ops/pixel_unshuffle_native.h>
#endif
#include <algorithm>
#include <numeric>
#include <vector>
namespace at {
namespace native {
Tensor pixel_shuffle_cpu(const Tensor& self, int64_t upscale_factor) {
check_pixel_shuffle_shapes(self, upscale_factor);
// Format: (B1, ..., Bn), C, H, W
std::vector<int64_t> output_sizes(self.sizes().begin(), self.sizes().end() - 3);
output_sizes.insert(output_sizes.end(),
{self.size(-3) / upscale_factor / upscale_factor,
self.size(-2) * upscale_factor,
self.size(-1) * upscale_factor});
auto output = at::empty({0}, self.options());
auto memory_format = self.suggest_memory_format();
output.resize_(output_sizes, memory_format);
if (output.numel() == 0) {
return output;
}
auto input = self.contiguous(memory_format);
pixel_shuffle_kernel(kCPU, output, input, upscale_factor);
return output;
}
Tensor pixel_unshuffle_cpu(const Tensor& self, int64_t downscale_factor) {
check_pixel_unshuffle_shapes(self, downscale_factor);
if (self.numel() == 0) {
return self.clone();
}
// Format: (B1, ..., Bn), C, H, W
std::vector<int64_t> output_sizes(self.sizes().begin(), self.sizes().end() - 3);
output_sizes.insert(output_sizes.end(),
{self.size(-3) * downscale_factor * downscale_factor,
self.size(-2) / downscale_factor,
self.size(-1) / downscale_factor});
auto output = at::empty({0}, self.options());
auto memory_format = self.suggest_memory_format();
output.resize_(output_sizes, memory_format);
if (output.numel() == 0) {
return output;
}
auto input = self.contiguous(memory_format);
pixel_unshuffle_kernel(kCPU, output, input, downscale_factor);
return output;
}
Tensor math_pixel_shuffle(const Tensor& self, int64_t upscale_factor) {
check_pixel_shuffle_shapes(self, upscale_factor);
// Format: (B1, ..., Bn), C, H, W
int64_t c = self.size(-3);
int64_t h = self.size(-2);
int64_t w = self.size(-1);
const auto NUM_NON_BATCH_DIMS = 3;
const auto self_sizes_batch_end = self.sizes().end() - NUM_NON_BATCH_DIMS;
int64_t upscale_factor_squared = upscale_factor * upscale_factor;
int64_t oc = c / upscale_factor_squared;
int64_t oh = h * upscale_factor;
int64_t ow = w * upscale_factor;
// First, reshape to split the channels dim from c into 3 separate dims: (oc,
// upscale_factor, upscale_factor). This allows shuffling to be done next by
// permuting dims.
std::vector<int64_t> added_dims_shape(
self.sizes().begin(), self_sizes_batch_end);
added_dims_shape.insert(
added_dims_shape.end(), {oc, upscale_factor, upscale_factor, h, w});
const auto input_reshaped = self.reshape(added_dims_shape);
// Next, shuffle by permuting the new upscale_factor dims alongside the height and width dims.
std::vector<int64_t> permutation(self.sizes().begin(), self_sizes_batch_end);
// std::iota is used to maintain the batch dims within the permutation.
std::iota(permutation.begin(), permutation.end(), 0);
permutation.insert(permutation.end(), {-5 /* oc */, -2 /* h */, -4 /* 1st upscale_factor */, -1 /* w */,
-3 /* 2nd upscale_factor */});
const auto input_permuted = input_reshaped.permute(permutation);
// Finally, upscale by collapsing (h, upscale_factor) -> a single dim (oh)
// and (w, upscale_factor) -> a single dim (ow).
std::vector<int64_t> final_shape(self.sizes().begin(), self_sizes_batch_end);
final_shape.insert(final_shape.end(), {oc, oh, ow});
// pixel_shuffle expects to *never* return an alias of the input.
return input_permuted.clone(at::MemoryFormat::Contiguous).view(final_shape);
}
Tensor math_pixel_unshuffle(const Tensor& self, int64_t downscale_factor) {
check_pixel_unshuffle_shapes(self, downscale_factor);
// Format: (B1, ..., Bn), C, H, W
int64_t c = self.size(-3);
int64_t h = self.size(-2);
int64_t w = self.size(-1);
constexpr auto NUM_NON_BATCH_DIMS = 3;
const auto self_sizes_batch_end = self.sizes().end() - NUM_NON_BATCH_DIMS;
int64_t downscale_factor_squared = downscale_factor * downscale_factor;
int64_t oc = c * downscale_factor_squared;
int64_t oh = h / downscale_factor;
int64_t ow = w / downscale_factor;
// First, reshape to split height dim into (oh, downscale_factor) dims and
// width dim into (ow, downscale_factor) dims. This allows unshuffling to be
// done next by permuting dims.
std::vector<int64_t> added_dims_shape(
self.sizes().begin(), self_sizes_batch_end);
added_dims_shape.insert(
added_dims_shape.end(), {c, oh, downscale_factor, ow, downscale_factor});
const auto input_reshaped = self.reshape(added_dims_shape);
// Next, unshuffle by permuting the downscale_factor dims alongside the channel dim.
std::vector<int64_t> permutation(self.sizes().begin(), self_sizes_batch_end);
// std::iota is used to maintain the batch dims within the permutation.
std::iota(permutation.begin(), permutation.end(), 0);
permutation.insert(permutation.end(), {-5 /* c */, -3 /* 1st downscale_factor */, -1 /*2nd downscale_factor */,
-4 /* oh */, -2 /* ow */});
const auto input_permuted = input_reshaped.permute(permutation);
// Finally, downscale by collapsing (c, downscale_factor, downscale_factor) -> a single dim (oc),
// resulting in height=oh and width=ow.
std::vector<int64_t> final_shape(self.sizes().begin(), self_sizes_batch_end);
final_shape.insert(final_shape.end(), {oc, oh, ow});
// pixel_unshuffle expects to *never* return an alias of the input.
return input_permuted.clone(at::MemoryFormat::Contiguous).view(final_shape);
}
DEFINE_DISPATCH(pixel_shuffle_kernel);
DEFINE_DISPATCH(pixel_unshuffle_kernel);
}} // namespace at::native