Skip to content

Commit 3987612

Browse files
r-barnesfacebook-github-bot
authored andcommitted
Fix CUDA kernel index data type in vision/fair/pytorch3d/pytorch3d/csrc/compositing/alpha_composite.cu +10
Summary: CUDA kernel variables matching the type `(thread|block|grid).(Idx|Dim).(x|y|z)` [have the data type `uint`](https://docs.nvidia.com/cuda/cuda-c-programming-guide/#built-in-variables). Many programmers mistakenly use implicit casts to turn these data types into `int`. In fact, the [CUDA Programming Guide](https://docs.nvidia.com/cuda/cuda-c-programming-guide/) it self is inconsistent and incorrect in its use of data types in programming examples. The result of these implicit casts is that our kernels may give unexpected results when exposed to large datasets, i.e., those exceeding >~2B items. While we now have linters in place to prevent simple mistakes (D71236150), our codebase has many problematic instances. This diff fixes some of them. Reviewed By: dtolnay Differential Revision: D71355356 fbshipit-source-id: cea44891416d9efd2f466d6c45df4e36008fa036
1 parent 06a76ef commit 3987612

File tree

10 files changed

+54
-54
lines changed

10 files changed

+54
-54
lines changed

pytorch3d/csrc/compositing/alpha_composite.cu

+6-6
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,11 @@ __global__ void alphaCompositeCudaForwardKernel(
3333
const int64_t W = points_idx.size(3);
3434

3535
// Get the batch and index
36-
const int batch = blockIdx.x;
36+
const auto batch = blockIdx.x;
3737

3838
const int num_pixels = C * H * W;
39-
const int num_threads = gridDim.y * blockDim.x;
40-
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
39+
const auto num_threads = gridDim.y * blockDim.x;
40+
const auto tid = blockIdx.y * blockDim.x + threadIdx.x;
4141

4242
// Iterate over each feature in each pixel
4343
for (int pid = tid; pid < num_pixels; pid += num_threads) {
@@ -83,11 +83,11 @@ __global__ void alphaCompositeCudaBackwardKernel(
8383
const int64_t W = points_idx.size(3);
8484

8585
// Get the batch and index
86-
const int batch = blockIdx.x;
86+
const auto batch = blockIdx.x;
8787

8888
const int num_pixels = C * H * W;
89-
const int num_threads = gridDim.y * blockDim.x;
90-
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
89+
const auto num_threads = gridDim.y * blockDim.x;
90+
const auto tid = blockIdx.y * blockDim.x + threadIdx.x;
9191

9292
// Parallelize over each feature in each pixel in images of size H * W,
9393
// for each image in the batch of size batch_size

pytorch3d/csrc/compositing/norm_weighted_sum.cu

+6-6
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,11 @@ __global__ void weightedSumNormCudaForwardKernel(
3333
const int64_t W = points_idx.size(3);
3434

3535
// Get the batch and index
36-
const int batch = blockIdx.x;
36+
const auto batch = blockIdx.x;
3737

3838
const int num_pixels = C * H * W;
39-
const int num_threads = gridDim.y * blockDim.x;
40-
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
39+
const auto num_threads = gridDim.y * blockDim.x;
40+
const auto tid = blockIdx.y * blockDim.x + threadIdx.x;
4141

4242
// Parallelize over each feature in each pixel in images of size H * W,
4343
// for each image in the batch of size batch_size
@@ -96,11 +96,11 @@ __global__ void weightedSumNormCudaBackwardKernel(
9696
const int64_t W = points_idx.size(3);
9797

9898
// Get the batch and index
99-
const int batch = blockIdx.x;
99+
const auto batch = blockIdx.x;
100100

101101
const int num_pixels = C * W * H;
102-
const int num_threads = gridDim.y * blockDim.x;
103-
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
102+
const auto num_threads = gridDim.y * blockDim.x;
103+
const auto tid = blockIdx.y * blockDim.x + threadIdx.x;
104104

105105
// Parallelize over each feature in each pixel in images of size H * W,
106106
// for each image in the batch of size batch_size

pytorch3d/csrc/compositing/weighted_sum.cu

+6-6
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,11 @@ __global__ void weightedSumCudaForwardKernel(
3131
const int64_t W = points_idx.size(3);
3232

3333
// Get the batch and index
34-
const int batch = blockIdx.x;
34+
const auto batch = blockIdx.x;
3535

3636
const int num_pixels = C * H * W;
37-
const int num_threads = gridDim.y * blockDim.x;
38-
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
37+
const auto num_threads = gridDim.y * blockDim.x;
38+
const auto tid = blockIdx.y * blockDim.x + threadIdx.x;
3939

4040
// Parallelize over each feature in each pixel in images of size H * W,
4141
// for each image in the batch of size batch_size
@@ -78,11 +78,11 @@ __global__ void weightedSumCudaBackwardKernel(
7878
const int64_t W = points_idx.size(3);
7979

8080
// Get the batch and index
81-
const int batch = blockIdx.x;
81+
const auto batch = blockIdx.x;
8282

8383
const int num_pixels = C * H * W;
84-
const int num_threads = gridDim.y * blockDim.x;
85-
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
84+
const auto num_threads = gridDim.y * blockDim.x;
85+
const auto tid = blockIdx.y * blockDim.x + threadIdx.x;
8686

8787
// Iterate over each pixel to compute the contribution to the
8888
// gradient for the features and weights

pytorch3d/csrc/gather_scatter/gather_scatter.cu

+3-3
Original file line numberDiff line numberDiff line change
@@ -20,22 +20,22 @@ __global__ void GatherScatterCudaKernel(
2020
const size_t V,
2121
const size_t D,
2222
const size_t E) {
23-
const int tid = threadIdx.x;
23+
const auto tid = threadIdx.x;
2424

2525
// Reverse the vertex order if backward.
2626
const int v0_idx = backward ? 1 : 0;
2727
const int v1_idx = backward ? 0 : 1;
2828

2929
// Edges are split evenly across the blocks.
30-
for (int e = blockIdx.x; e < E; e += gridDim.x) {
30+
for (auto e = blockIdx.x; e < E; e += gridDim.x) {
3131
// Get indices of vertices which form the edge.
3232
const int64_t v0 = edges[2 * e + v0_idx];
3333
const int64_t v1 = edges[2 * e + v1_idx];
3434

3535
// Split vertex features evenly across threads.
3636
// This implementation will be quite wasteful when D<128 since there will be
3737
// a lot of threads doing nothing.
38-
for (int d = tid; d < D; d += blockDim.x) {
38+
for (auto d = tid; d < D; d += blockDim.x) {
3939
const float val = input[v1 * D + d];
4040
float* address = output + v0 * D + d;
4141
atomicAdd(address, val);

pytorch3d/csrc/interp_face_attrs/interp_face_attrs.cu

+4-4
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ __global__ void InterpFaceAttrsForwardKernel(
2020
const size_t P,
2121
const size_t F,
2222
const size_t D) {
23-
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
24-
const int num_threads = blockDim.x * gridDim.x;
23+
const auto tid = threadIdx.x + blockIdx.x * blockDim.x;
24+
const auto num_threads = blockDim.x * gridDim.x;
2525
for (int pd = tid; pd < P * D; pd += num_threads) {
2626
const int p = pd / D;
2727
const int d = pd % D;
@@ -93,8 +93,8 @@ __global__ void InterpFaceAttrsBackwardKernel(
9393
const size_t P,
9494
const size_t F,
9595
const size_t D) {
96-
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
97-
const int num_threads = blockDim.x * gridDim.x;
96+
const auto tid = threadIdx.x + blockIdx.x * blockDim.x;
97+
const auto num_threads = blockDim.x * gridDim.x;
9898
for (int pd = tid; pd < P * D; pd += num_threads) {
9999
const int p = pd / D;
100100
const int d = pd % D;

pytorch3d/csrc/point_mesh/point_mesh_cuda.cu

+9-9
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ __global__ void DistanceForwardKernel(
110110
__syncthreads();
111111

112112
// Perform reduction in shared memory.
113-
for (int s = blockDim.x / 2; s > 32; s >>= 1) {
113+
for (auto s = blockDim.x / 2; s > 32; s >>= 1) {
114114
if (tid < s) {
115115
if (min_dists[tid] > min_dists[tid + s]) {
116116
min_dists[tid] = min_dists[tid + s];
@@ -502,8 +502,8 @@ __global__ void PointFaceArrayForwardKernel(
502502
const float3* tris_f3 = (float3*)tris;
503503

504504
// Parallelize over P * S computations
505-
const int num_threads = gridDim.x * blockDim.x;
506-
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
505+
const auto num_threads = gridDim.x * blockDim.x;
506+
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
507507

508508
for (int t_i = tid; t_i < P * T; t_i += num_threads) {
509509
const int t = t_i / P; // segment index.
@@ -576,8 +576,8 @@ __global__ void PointFaceArrayBackwardKernel(
576576
const float3* tris_f3 = (float3*)tris;
577577

578578
// Parallelize over P * S computations
579-
const int num_threads = gridDim.x * blockDim.x;
580-
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
579+
const auto num_threads = gridDim.x * blockDim.x;
580+
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
581581

582582
for (int t_i = tid; t_i < P * T; t_i += num_threads) {
583583
const int t = t_i / P; // triangle index.
@@ -683,8 +683,8 @@ __global__ void PointEdgeArrayForwardKernel(
683683
float3* segms_f3 = (float3*)segms;
684684

685685
// Parallelize over P * S computations
686-
const int num_threads = gridDim.x * blockDim.x;
687-
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
686+
const auto num_threads = gridDim.x * blockDim.x;
687+
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
688688

689689
for (int t_i = tid; t_i < P * S; t_i += num_threads) {
690690
const int s = t_i / P; // segment index.
@@ -752,8 +752,8 @@ __global__ void PointEdgeArrayBackwardKernel(
752752
float3* segms_f3 = (float3*)segms;
753753

754754
// Parallelize over P * S computations
755-
const int num_threads = gridDim.x * blockDim.x;
756-
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
755+
const auto num_threads = gridDim.x * blockDim.x;
756+
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
757757

758758
for (int t_i = tid; t_i < P * S; t_i += num_threads) {
759759
const int s = t_i / P; // segment index.

pytorch3d/csrc/rasterize_coarse/bitmask.cuh

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class BitMask {
2525

2626
// Use all threads in the current block to clear all bits of this BitMask
2727
__device__ void block_clear() {
28-
for (int i = threadIdx.x; i < H * W * D; i += blockDim.x) {
28+
for (auto i = threadIdx.x; i < H * W * D; i += blockDim.x) {
2929
data[i] = 0;
3030
}
3131
__syncthreads();

pytorch3d/csrc/rasterize_coarse/rasterize_coarse.cu

+7-7
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ __global__ void TriangleBoundingBoxKernel(
2323
const float blur_radius,
2424
float* bboxes, // (4, F)
2525
bool* skip_face) { // (F,)
26-
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
27-
const int num_threads = blockDim.x * gridDim.x;
26+
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
27+
const auto num_threads = blockDim.x * gridDim.x;
2828
const float sqrt_radius = sqrt(blur_radius);
2929
for (int f = tid; f < F; f += num_threads) {
3030
const float v0x = face_verts[f * 9 + 0 * 3 + 0];
@@ -56,8 +56,8 @@ __global__ void PointBoundingBoxKernel(
5656
const int P,
5757
float* bboxes, // (4, P)
5858
bool* skip_points) {
59-
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
60-
const int num_threads = blockDim.x * gridDim.x;
59+
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
60+
const auto num_threads = blockDim.x * gridDim.x;
6161
for (int p = tid; p < P; p += num_threads) {
6262
const float x = points[p * 3 + 0];
6363
const float y = points[p * 3 + 1];
@@ -113,7 +113,7 @@ __global__ void RasterizeCoarseCudaKernel(
113113
const int chunks_per_batch = 1 + (E - 1) / chunk_size;
114114
const int num_chunks = N * chunks_per_batch;
115115

116-
for (int chunk = blockIdx.x; chunk < num_chunks; chunk += gridDim.x) {
116+
for (auto chunk = blockIdx.x; chunk < num_chunks; chunk += gridDim.x) {
117117
const int batch_idx = chunk / chunks_per_batch; // batch index
118118
const int chunk_idx = chunk % chunks_per_batch;
119119
const int elem_chunk_start_idx = chunk_idx * chunk_size;
@@ -123,7 +123,7 @@ __global__ void RasterizeCoarseCudaKernel(
123123
const int64_t elem_stop_idx = elem_start_idx + elems_per_batch[batch_idx];
124124

125125
// Have each thread handle a different face within the chunk
126-
for (int e = threadIdx.x; e < chunk_size; e += blockDim.x) {
126+
for (auto e = threadIdx.x; e < chunk_size; e += blockDim.x) {
127127
const int e_idx = elem_chunk_start_idx + e;
128128

129129
// Check that we are still within the same element of the batch
@@ -170,7 +170,7 @@ __global__ void RasterizeCoarseCudaKernel(
170170
// Now we have processed every elem in the current chunk. We need to
171171
// count the number of elems in each bin so we can write the indices
172172
// out to global memory. We have each thread handle a different bin.
173-
for (int byx = threadIdx.x; byx < num_bins_y * num_bins_x;
173+
for (auto byx = threadIdx.x; byx < num_bins_y * num_bins_x;
174174
byx += blockDim.x) {
175175
const int by = byx / num_bins_x;
176176
const int bx = byx % num_bins_x;

pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu

+6-6
Original file line numberDiff line numberDiff line change
@@ -260,8 +260,8 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
260260
float* pix_dists,
261261
float* bary) {
262262
// Simple version: One thread per output pixel
263-
int num_threads = gridDim.x * blockDim.x;
264-
int tid = blockDim.x * blockIdx.x + threadIdx.x;
263+
auto num_threads = gridDim.x * blockDim.x;
264+
auto tid = blockDim.x * blockIdx.x + threadIdx.x;
265265

266266
for (int i = tid; i < N * H * W; i += num_threads) {
267267
// Convert linear index to 3D index
@@ -446,8 +446,8 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
446446

447447
// Parallelize over each pixel in images of
448448
// size H * W, for each image in the batch of size N.
449-
const int num_threads = gridDim.x * blockDim.x;
450-
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
449+
const auto num_threads = gridDim.x * blockDim.x;
450+
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
451451

452452
for (int t_i = tid; t_i < N * H * W; t_i += num_threads) {
453453
// Convert linear index to 3D index
@@ -650,8 +650,8 @@ __global__ void RasterizeMeshesFineCudaKernel(
650650
) {
651651
// This can be more than H * W if H or W are not divisible by bin_size.
652652
int num_pixels = N * BH * BW * bin_size * bin_size;
653-
int num_threads = gridDim.x * blockDim.x;
654-
int tid = blockIdx.x * blockDim.x + threadIdx.x;
653+
auto num_threads = gridDim.x * blockDim.x;
654+
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
655655

656656
for (int pid = tid; pid < num_pixels; pid += num_threads) {
657657
// Convert linear index into bin and pixel indices. We make the within

pytorch3d/csrc/rasterize_points/rasterize_points.cu

+6-6
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,8 @@ __global__ void RasterizePointsNaiveCudaKernel(
9797
float* zbuf, // (N, H, W, K)
9898
float* pix_dists) { // (N, H, W, K)
9999
// Simple version: One thread per output pixel
100-
const int num_threads = gridDim.x * blockDim.x;
101-
const int tid = blockDim.x * blockIdx.x + threadIdx.x;
100+
const auto num_threads = gridDim.x * blockDim.x;
101+
const auto tid = blockDim.x * blockIdx.x + threadIdx.x;
102102
for (int i = tid; i < N * H * W; i += num_threads) {
103103
// Convert linear index to 3D index
104104
const int n = i / (H * W); // Batch index
@@ -237,8 +237,8 @@ __global__ void RasterizePointsFineCudaKernel(
237237
float* pix_dists) { // (N, H, W, K)
238238
// This can be more than H * W if H or W are not divisible by bin_size.
239239
const int num_pixels = N * BH * BW * bin_size * bin_size;
240-
const int num_threads = gridDim.x * blockDim.x;
241-
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
240+
const auto num_threads = gridDim.x * blockDim.x;
241+
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
242242

243243
for (int pid = tid; pid < num_pixels; pid += num_threads) {
244244
// Convert linear index into bin and pixel indices. We make the within
@@ -376,8 +376,8 @@ __global__ void RasterizePointsBackwardCudaKernel(
376376
float* grad_points) { // (P, 3)
377377
// Parallelized over each of K points per pixel, for each pixel in images of
378378
// size H * W, for each image in the batch of size N.
379-
int num_threads = gridDim.x * blockDim.x;
380-
int tid = blockIdx.x * blockDim.x + threadIdx.x;
379+
auto num_threads = gridDim.x * blockDim.x;
380+
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
381381
for (int i = tid; i < N * H * W * K; i += num_threads) {
382382
// const int n = i / (H * W * K); // batch index (not needed).
383383
const int yxk = i % (H * W * K);

0 commit comments

Comments
 (0)