Skip to content

Commit d6943ea

Browse files
authored
apply diff 52351 (pytorch#52649)
1 parent 02b61b4 commit d6943ea

File tree

1 file changed

+97
-84
lines changed

1 file changed

+97
-84
lines changed

aten/src/ATen/native/cuda/UpSampleTrilinear3d.cu

+97-84
Original file line numberDiff line numberDiff line change
@@ -113,31 +113,46 @@ __global__ void upsample_trilinear3d_out_frame(
113113
template <typename scalar_t, typename accscalar_t>
114114
C10_LAUNCH_BOUNDS_1(1024)
115115
__global__ void upsample_trilinear3d_backward_out_frame(
116-
const size_t nc_,
117-
const int depth1,
118-
const int height1,
119-
const int width1,
120-
const int depth2,
121-
const int height2,
122-
const int width2,
116+
const int num_kernels,
123117
const accscalar_t rdepth,
124118
const accscalar_t rheight,
125119
const accscalar_t rwidth,
126120
const bool align_corners,
127-
scalar_t* __restrict__ idata,
128-
const scalar_t* __restrict__ odata) {
129-
const size_t i_numel = nc_ * depth1 * height1 * width1;
130-
const size_t o_numel = nc_ * depth2 * height2 * width2;
131-
132-
for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < o_numel; index += blockDim.x * gridDim.x) {
133-
size_t index_temp = index;
134-
const int w2 = index_temp % width2; // 0:width2-1
135-
index_temp /= width2;
136-
const int h2 = index_temp % height2; // 0:height2-1
137-
index_temp /= height2;
138-
const int t2 = index_temp % depth2; // 0:depth2-1
139-
const int nc = index_temp / depth2;
121+
PackedTensorAccessor64<scalar_t, 5> idata,
122+
const PackedTensorAccessor64<scalar_t, 5> odata,
123+
scalar_t* idata_ptr) {
124+
int index = threadIdx.x + blockIdx.x * blockDim.x;
125+
126+
const int batchsize = idata.size(0);
127+
const int channels = idata.size(1);
128+
const int depth1 = idata.size(2);
129+
const int height1 = idata.size(3);
130+
const int width1 = idata.size(4);
131+
const int depth2 = odata.size(2);
132+
const int height2 = odata.size(3);
133+
const int width2 = odata.size(4);
134+
135+
const size_t i_numel = batchsize * channels * depth1 * height1 * width1;
136+
137+
if (index < num_kernels) {
138+
const int w2 = (index % (height2 * width2)) % width2; // 0:width2-1
139+
const int h2 = (index % (height2 * width2)) / width2; // 0:height2-1
140+
const int t2 = index / (height2 * width2); // 0:depth2-1
141+
// special case: just copy
142+
if (depth1 == depth2 && height1 == height2 && width1 == width2) {
143+
const int t1 = t2;
144+
const int h1 = h2;
145+
const int w1 = w2;
140146

147+
for (int n = 0; n < batchsize; n++) {
148+
for (int c = 0; c < channels; ++c) {
149+
const scalar_t val = odata[n][c][t1][h1][w1];
150+
idata[n][c][t2][h2][w2] = val;
151+
}
152+
}
153+
return;
154+
}
155+
//
141156
const accscalar_t t1r = area_pixel_compute_source_index<accscalar_t>(
142157
rdepth, t2, align_corners, /*cubic=*/false);
143158
const int t1 = t1r;
@@ -159,55 +174,60 @@ __global__ void upsample_trilinear3d_backward_out_frame(
159174
const accscalar_t w1lambda = w1r - w1;
160175
const accscalar_t w0lambda = static_cast<accscalar_t>(1) - w1lambda;
161176
//
162-
const scalar_t d2val = odata[index];
163-
fastAtomicAdd(
164-
idata,
165-
idx_3d(nc, depth1, height1, width1, t1, h1, w1),
166-
i_numel,
167-
static_cast<scalar_t>(t0lambda * h0lambda * w0lambda * d2val),
168-
true);
169-
fastAtomicAdd(
170-
idata,
171-
idx_3d(nc, depth1, height1, width1, t1, h1, w1 + w1p),
172-
i_numel,
173-
static_cast<scalar_t>(t0lambda * h0lambda * w1lambda * d2val),
174-
true);
175-
fastAtomicAdd(
176-
idata,
177-
idx_3d(nc, depth1, height1, width1, t1, h1 + h1p, w1),
178-
i_numel,
179-
static_cast<scalar_t>(t0lambda * h1lambda * w0lambda * d2val),
180-
true);
181-
fastAtomicAdd(
182-
idata,
183-
idx_3d(nc, depth1, height1, width1, t1, h1 + h1p, w1 + w1p),
184-
i_numel,
185-
static_cast<scalar_t>(t0lambda * h1lambda * w1lambda * d2val),
186-
true);
187-
fastAtomicAdd(
188-
idata,
189-
idx_3d(nc, depth1, height1, width1, t1 + t1p, h1, w1),
190-
i_numel,
191-
static_cast<scalar_t>(t1lambda * h0lambda * w0lambda * d2val),
192-
true);
193-
fastAtomicAdd(
194-
idata,
195-
idx_3d(nc, depth1, height1, width1, t1 + t1p, h1, w1 + w1p),
196-
i_numel,
197-
static_cast<scalar_t>(t1lambda * h0lambda * w1lambda * d2val),
198-
true);
199-
fastAtomicAdd(
200-
idata,
201-
idx_3d(nc, depth1, height1, width1, t1 + t1p, h1 + h1p, w1),
202-
i_numel,
203-
static_cast<scalar_t>(t1lambda * h1lambda * w0lambda * d2val),
204-
true);
205-
fastAtomicAdd(
206-
idata,
207-
idx_3d(nc, depth1, height1, width1, t1 + t1p, h1 + h1p, w1 + w1p),
208-
i_numel,
209-
static_cast<scalar_t>(t1lambda * h1lambda * w1lambda * d2val),
210-
true);
177+
for (int n = 0; n < batchsize; n++) {
178+
for (int c = 0; c < channels; ++c) {
179+
const scalar_t d2val = odata[n][c][t2][h2][w2];
180+
const size_t nc = n * channels + c;
181+
fastAtomicAdd(
182+
idata_ptr,
183+
idx_3d(nc, depth1, height1, width1, t1, h1, w1),
184+
i_numel,
185+
static_cast<scalar_t>(t0lambda * h0lambda * w0lambda * d2val),
186+
true);
187+
fastAtomicAdd(
188+
idata_ptr,
189+
idx_3d(nc, depth1, height1, width1, t1, h1, w1 + w1p),
190+
i_numel,
191+
static_cast<scalar_t>(t0lambda * h0lambda * w1lambda * d2val),
192+
true);
193+
fastAtomicAdd(
194+
idata_ptr,
195+
idx_3d(nc, depth1, height1, width1, t1, h1 + h1p, w1),
196+
i_numel,
197+
static_cast<scalar_t>(t0lambda * h1lambda * w0lambda * d2val),
198+
true);
199+
fastAtomicAdd(
200+
idata_ptr,
201+
idx_3d(nc, depth1, height1, width1, t1, h1 + h1p, w1 + w1p),
202+
i_numel,
203+
static_cast<scalar_t>(t0lambda * h1lambda * w1lambda * d2val),
204+
true);
205+
fastAtomicAdd(
206+
idata_ptr,
207+
idx_3d(nc, depth1, height1, width1, t1 + t1p, h1, w1),
208+
i_numel,
209+
static_cast<scalar_t>(t1lambda * h0lambda * w0lambda * d2val),
210+
true);
211+
fastAtomicAdd(
212+
idata_ptr,
213+
idx_3d(nc, depth1, height1, width1, t1 + t1p, h1, w1 + w1p),
214+
i_numel,
215+
static_cast<scalar_t>(t1lambda * h0lambda * w1lambda * d2val),
216+
true);
217+
fastAtomicAdd(
218+
idata_ptr,
219+
idx_3d(nc, depth1, height1, width1, t1 + t1p, h1 + h1p, w1),
220+
i_numel,
221+
static_cast<scalar_t>(t1lambda * h1lambda * w0lambda * d2val),
222+
true);
223+
fastAtomicAdd(
224+
idata_ptr,
225+
idx_3d(nc, depth1, height1, width1, t1 + t1p, h1 + h1p, w1 + w1p),
226+
i_numel,
227+
static_cast<scalar_t>(t1lambda * h1lambda * w1lambda * d2val),
228+
true);
229+
}
230+
}
211231
}
212232
}
213233

@@ -350,21 +370,20 @@ static void upsample_trilinear3d_backward_out_cuda_template(
350370
// so it has to be initialized to zero.
351371
grad_input.zero_();
352372

353-
// const size_t num_kernels = nbatch * channels * output_depth * output_height * output_width;
354-
const size_t num_kernels = grad_output.numel();
373+
const int num_kernels = output_depth * output_height * output_width;
355374
const int num_threads = std::min(
356375
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
357376
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
358377

359-
if (num_kernels > 0) {
360378
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
361379
grad_output.scalar_type(),
362380
"upsample_trilinear3d_backward_out_frame",
363381
[&] {
364382
using accscalar_t = at::acc_type<scalar_t, true>;
365383

366-
auto idata = grad_input.data_ptr<scalar_t>();
367-
auto odata = grad_output.data_ptr<scalar_t>();
384+
auto idata = grad_input.packed_accessor64<scalar_t, 5>();
385+
auto odata = grad_output.packed_accessor64<scalar_t, 5>();
386+
scalar_t* idata_ptr = grad_input.data_ptr<scalar_t>();
368387

369388
const accscalar_t rdepth = area_pixel_compute_scale<accscalar_t>(
370389
input_depth, output_depth, align_corners, scales_d);
@@ -374,26 +393,20 @@ static void upsample_trilinear3d_backward_out_cuda_template(
374393
input_width, output_width, align_corners, scales_w);
375394

376395
upsample_trilinear3d_backward_out_frame<scalar_t, accscalar_t>
377-
<<<cuda::ATenCeilDiv(num_kernels, static_cast<size_t>(num_threads)),
396+
<<<cuda::ATenCeilDiv(num_kernels, num_threads),
378397
num_threads,
379398
0,
380399
stream>>>(
381-
nbatch * channels,
382-
input_depth,
383-
input_height,
384-
input_width,
385-
output_depth,
386-
output_height,
387-
output_width,
400+
num_kernels,
388401
rdepth,
389402
rheight,
390403
rwidth,
391404
align_corners,
392405
idata,
393-
odata);
406+
odata,
407+
idata_ptr);
394408
C10_CUDA_KERNEL_LAUNCH_CHECK();
395409
});
396-
}
397410
}
398411

399412
} // namespace

0 commit comments

Comments
 (0)