@@ -710,6 +710,9 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
710710 dim3 block (warp_size, indices_per_block);
711711
712712#ifdef USE_ROCM
713+ dim3 new_grid_many_indices (ceil_div (num_indices, (int64_t ) (indices_per_block * warp_size)),
714+ grid.y == 1 ? std::min<int >(at::cuda::getCurrentDeviceProperties ()->maxGridSize [1 ], ceil_div (sliceSize, (int64_t ) (warp_size))) : grid.y ,
715+ grid.z );
713716 dim3 new_grid (ceil_div (num_indices, (int64_t ) (indices_per_block * warp_size)), grid.y , grid.z );
714717 size_t smem_dups_size = indices_per_block * warp_size * sizeof (int64_t );
715718#define KERNEL_GRID new_grid
@@ -788,7 +791,7 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
788791 expandedValue.scalar_type (),
789792 " indexing_backward_many_indices" ,
790793 AT_WRAP ([&] {
791- indexing_backward_kernel_many_indices<scalar_t , UNROLL><<<new_grid , block, smem_dups_size, stream>>> (
794+ indexing_backward_kernel_many_indices<scalar_t , UNROLL><<<new_grid_many_indices , block, smem_dups_size, stream>>> (
792795 sorted_indices.const_data_ptr <int64_t >(),
793796 orig_indices.const_data_ptr <int64_t >(),
794797 expandedValue.const_data_ptr <scalar_t >(),
0 commit comments