@@ -113,31 +113,46 @@ __global__ void upsample_trilinear3d_out_frame(
113
113
template <typename scalar_t , typename accscalar_t >
114
114
C10_LAUNCH_BOUNDS_1 (1024 )
115
115
__global__ void upsample_trilinear3d_backward_out_frame (
116
- const size_t nc_,
117
- const int depth1,
118
- const int height1,
119
- const int width1,
120
- const int depth2,
121
- const int height2,
122
- const int width2,
116
+ const int num_kernels,
123
117
const accscalar_t rdepth,
124
118
const accscalar_t rheight,
125
119
const accscalar_t rwidth,
126
120
const bool align_corners,
127
- scalar_t * __restrict__ idata,
128
- const scalar_t * __restrict__ odata) {
129
- const size_t i_numel = nc_ * depth1 * height1 * width1;
130
- const size_t o_numel = nc_ * depth2 * height2 * width2;
131
-
132
- for (size_t index = blockDim .x * blockIdx .x + threadIdx .x ; index < o_numel; index += blockDim .x * gridDim .x ) {
133
- size_t index_temp = index ;
134
- const int w2 = index_temp % width2; // 0:width2-1
135
- index_temp /= width2;
136
- const int h2 = index_temp % height2; // 0:height2-1
137
- index_temp /= height2;
138
- const int t2 = index_temp % depth2; // 0:depth2-1
139
- const int nc = index_temp / depth2;
121
+ PackedTensorAccessor64<scalar_t , 5 > idata,
122
+ const PackedTensorAccessor64<scalar_t , 5 > odata,
123
+ scalar_t * idata_ptr) {
124
+ int index = threadIdx .x + blockIdx .x * blockDim .x ;
125
+
126
+ const int batchsize = idata.size (0 );
127
+ const int channels = idata.size (1 );
128
+ const int depth1 = idata.size (2 );
129
+ const int height1 = idata.size (3 );
130
+ const int width1 = idata.size (4 );
131
+ const int depth2 = odata.size (2 );
132
+ const int height2 = odata.size (3 );
133
+ const int width2 = odata.size (4 );
134
+
135
+ const size_t i_numel = batchsize * channels * depth1 * height1 * width1;
136
+
137
+ if (index < num_kernels) {
138
+ const int w2 = (index % (height2 * width2)) % width2; // 0:width2-1
139
+ const int h2 = (index % (height2 * width2)) / width2; // 0:height2-1
140
+ const int t2 = index / (height2 * width2); // 0:depth2-1
141
+ // special case: just copy
142
+ if (depth1 == depth2 && height1 == height2 && width1 == width2) {
143
+ const int t1 = t2;
144
+ const int h1 = h2;
145
+ const int w1 = w2;
140
146
147
+ for (int n = 0 ; n < batchsize; n++) {
148
+ for (int c = 0 ; c < channels; ++c) {
149
+ const scalar_t val = odata[n][c][t1][h1][w1];
150
+ idata[n][c][t2][h2][w2] = val;
151
+ }
152
+ }
153
+ return ;
154
+ }
155
+ //
141
156
const accscalar_t t1r = area_pixel_compute_source_index<accscalar_t >(
142
157
rdepth, t2, align_corners, /* cubic=*/ false );
143
158
const int t1 = t1r;
@@ -159,55 +174,60 @@ __global__ void upsample_trilinear3d_backward_out_frame(
159
174
const accscalar_t w1lambda = w1r - w1;
160
175
const accscalar_t w0lambda = static_cast <accscalar_t >(1 ) - w1lambda;
161
176
//
162
- const scalar_t d2val = odata[index ];
163
- fastAtomicAdd (
164
- idata,
165
- idx_3d (nc, depth1, height1, width1, t1, h1, w1),
166
- i_numel,
167
- static_cast <scalar_t >(t0lambda * h0lambda * w0lambda * d2val),
168
- true );
169
- fastAtomicAdd (
170
- idata,
171
- idx_3d (nc, depth1, height1, width1, t1, h1, w1 + w1p),
172
- i_numel,
173
- static_cast <scalar_t >(t0lambda * h0lambda * w1lambda * d2val),
174
- true );
175
- fastAtomicAdd (
176
- idata,
177
- idx_3d (nc, depth1, height1, width1, t1, h1 + h1p, w1),
178
- i_numel,
179
- static_cast <scalar_t >(t0lambda * h1lambda * w0lambda * d2val),
180
- true );
181
- fastAtomicAdd (
182
- idata,
183
- idx_3d (nc, depth1, height1, width1, t1, h1 + h1p, w1 + w1p),
184
- i_numel,
185
- static_cast <scalar_t >(t0lambda * h1lambda * w1lambda * d2val),
186
- true );
187
- fastAtomicAdd (
188
- idata,
189
- idx_3d (nc, depth1, height1, width1, t1 + t1p, h1, w1),
190
- i_numel,
191
- static_cast <scalar_t >(t1lambda * h0lambda * w0lambda * d2val),
192
- true );
193
- fastAtomicAdd (
194
- idata,
195
- idx_3d (nc, depth1, height1, width1, t1 + t1p, h1, w1 + w1p),
196
- i_numel,
197
- static_cast <scalar_t >(t1lambda * h0lambda * w1lambda * d2val),
198
- true );
199
- fastAtomicAdd (
200
- idata,
201
- idx_3d (nc, depth1, height1, width1, t1 + t1p, h1 + h1p, w1),
202
- i_numel,
203
- static_cast <scalar_t >(t1lambda * h1lambda * w0lambda * d2val),
204
- true );
205
- fastAtomicAdd (
206
- idata,
207
- idx_3d (nc, depth1, height1, width1, t1 + t1p, h1 + h1p, w1 + w1p),
208
- i_numel,
209
- static_cast <scalar_t >(t1lambda * h1lambda * w1lambda * d2val),
210
- true );
177
+ for (int n = 0 ; n < batchsize; n++) {
178
+ for (int c = 0 ; c < channels; ++c) {
179
+ const scalar_t d2val = odata[n][c][t2][h2][w2];
180
+ const size_t nc = n * channels + c;
181
+ fastAtomicAdd (
182
+ idata_ptr,
183
+ idx_3d (nc, depth1, height1, width1, t1, h1, w1),
184
+ i_numel,
185
+ static_cast <scalar_t >(t0lambda * h0lambda * w0lambda * d2val),
186
+ true );
187
+ fastAtomicAdd (
188
+ idata_ptr,
189
+ idx_3d (nc, depth1, height1, width1, t1, h1, w1 + w1p),
190
+ i_numel,
191
+ static_cast <scalar_t >(t0lambda * h0lambda * w1lambda * d2val),
192
+ true );
193
+ fastAtomicAdd (
194
+ idata_ptr,
195
+ idx_3d (nc, depth1, height1, width1, t1, h1 + h1p, w1),
196
+ i_numel,
197
+ static_cast <scalar_t >(t0lambda * h1lambda * w0lambda * d2val),
198
+ true );
199
+ fastAtomicAdd (
200
+ idata_ptr,
201
+ idx_3d (nc, depth1, height1, width1, t1, h1 + h1p, w1 + w1p),
202
+ i_numel,
203
+ static_cast <scalar_t >(t0lambda * h1lambda * w1lambda * d2val),
204
+ true );
205
+ fastAtomicAdd (
206
+ idata_ptr,
207
+ idx_3d (nc, depth1, height1, width1, t1 + t1p, h1, w1),
208
+ i_numel,
209
+ static_cast <scalar_t >(t1lambda * h0lambda * w0lambda * d2val),
210
+ true );
211
+ fastAtomicAdd (
212
+ idata_ptr,
213
+ idx_3d (nc, depth1, height1, width1, t1 + t1p, h1, w1 + w1p),
214
+ i_numel,
215
+ static_cast <scalar_t >(t1lambda * h0lambda * w1lambda * d2val),
216
+ true );
217
+ fastAtomicAdd (
218
+ idata_ptr,
219
+ idx_3d (nc, depth1, height1, width1, t1 + t1p, h1 + h1p, w1),
220
+ i_numel,
221
+ static_cast <scalar_t >(t1lambda * h1lambda * w0lambda * d2val),
222
+ true );
223
+ fastAtomicAdd (
224
+ idata_ptr,
225
+ idx_3d (nc, depth1, height1, width1, t1 + t1p, h1 + h1p, w1 + w1p),
226
+ i_numel,
227
+ static_cast <scalar_t >(t1lambda * h1lambda * w1lambda * d2val),
228
+ true );
229
+ }
230
+ }
211
231
}
212
232
}
213
233
@@ -350,21 +370,20 @@ static void upsample_trilinear3d_backward_out_cuda_template(
350
370
// so it has to be initialized to zero.
351
371
grad_input.zero_ ();
352
372
353
- // const size_t num_kernels = nbatch * channels * output_depth * output_height * output_width;
354
- const size_t num_kernels = grad_output.numel ();
373
+ const int num_kernels = output_depth * output_height * output_width;
355
374
const int num_threads = std::min (
356
375
at::cuda::getCurrentDeviceProperties ()->maxThreadsPerBlock , 1024 );
357
376
cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
358
377
359
- if (num_kernels > 0 ) {
360
378
AT_DISPATCH_FLOATING_TYPES_AND_HALF (
361
379
grad_output.scalar_type (),
362
380
" upsample_trilinear3d_backward_out_frame" ,
363
381
[&] {
364
382
using accscalar_t = at::acc_type<scalar_t , true >;
365
383
366
- auto idata = grad_input.data_ptr <scalar_t >();
367
- auto odata = grad_output.data_ptr <scalar_t >();
384
+ auto idata = grad_input.packed_accessor64 <scalar_t , 5 >();
385
+ auto odata = grad_output.packed_accessor64 <scalar_t , 5 >();
386
+ scalar_t * idata_ptr = grad_input.data_ptr <scalar_t >();
368
387
369
388
const accscalar_t rdepth = area_pixel_compute_scale<accscalar_t >(
370
389
input_depth, output_depth, align_corners, scales_d);
@@ -374,26 +393,20 @@ static void upsample_trilinear3d_backward_out_cuda_template(
374
393
input_width, output_width, align_corners, scales_w);
375
394
376
395
upsample_trilinear3d_backward_out_frame<scalar_t , accscalar_t >
377
- <<<cuda::ATenCeilDiv(num_kernels, static_cast < size_t >( num_threads) ),
396
+ <<<cuda::ATenCeilDiv(num_kernels, num_threads),
378
397
num_threads,
379
398
0 ,
380
399
stream>>> (
381
- nbatch * channels,
382
- input_depth,
383
- input_height,
384
- input_width,
385
- output_depth,
386
- output_height,
387
- output_width,
400
+ num_kernels,
388
401
rdepth,
389
402
rheight,
390
403
rwidth,
391
404
align_corners,
392
405
idata,
393
- odata);
406
+ odata,
407
+ idata_ptr);
394
408
C10_CUDA_KERNEL_LAUNCH_CHECK ();
395
409
});
396
- }
397
410
}
398
411
399
412
} // namespace
0 commit comments