Skip to content

Commit b2b161e

Browse files
authored
[ROCm] Adjust grid size for non-unit stride backwards indexing (#2710)
cherry-pick of pytorch@01a2812
1 parent fe1f5d7 commit b2b161e

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

aten/src/ATen/native/cuda/Indexing.cu

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -710,6 +710,9 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
710710
dim3 block(warp_size, indices_per_block);
711711

712712
#ifdef USE_ROCM
713+
dim3 new_grid_many_indices(ceil_div(num_indices, (int64_t) (indices_per_block * warp_size)),
714+
grid.y == 1 ? std::min<int>(at::cuda::getCurrentDeviceProperties()->maxGridSize[1], ceil_div(sliceSize, (int64_t) (warp_size))) : grid.y,
715+
grid.z);
713716
dim3 new_grid(ceil_div(num_indices, (int64_t) (indices_per_block * warp_size)), grid.y, grid.z);
714717
size_t smem_dups_size = indices_per_block * warp_size * sizeof(int64_t);
715718
#define KERNEL_GRID new_grid
@@ -788,7 +791,7 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
788791
expandedValue.scalar_type(),
789792
"indexing_backward_many_indices",
790793
AT_WRAP([&] {
791-
indexing_backward_kernel_many_indices<scalar_t, UNROLL><<<new_grid, block, smem_dups_size, stream>>>(
794+
indexing_backward_kernel_many_indices<scalar_t, UNROLL><<<new_grid_many_indices, block, smem_dups_size, stream>>>(
792795
sorted_indices.const_data_ptr<int64_t>(),
793796
orig_indices.const_data_ptr<int64_t>(),
794797
expandedValue.const_data_ptr<scalar_t>(),

0 commit comments

Comments
 (0)