diff --git a/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate_byte.cu b/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate_byte.cu index 2f2b1e9eae..286c26aa34 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate_byte.cu +++ b/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate_byte.cu @@ -94,7 +94,7 @@ __launch_bounds__(kCacheMaxThreads) void lfu_cache_insert_byte_kernel( const uint32_t sorted_slot = slots[0]; const int64_t sorted_lfu_cost = costs[0]; - for (int32_t l = 0; l < min(SL, kWarpSize); ++l) { + for (auto l = 0; l < min(SL, kWarpSize); ++l) { const int32_t insert_slot = shfl_sync(sorted_slot, l); const int64_t insert_current_lfu_cost = shfl_sync(sorted_lfu_cost, l); const index_t insert_idx = cache_set_sorted_indices[n + l]; diff --git a/fbgemm_gpu/src/split_embeddings_cache/lru_cache_populate_byte.cu b/fbgemm_gpu/src/split_embeddings_cache/lru_cache_populate_byte.cu index e52af82bba..000c105d11 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/lru_cache_populate_byte.cu +++ b/fbgemm_gpu/src/split_embeddings_cache/lru_cache_populate_byte.cu @@ -115,28 +115,24 @@ Tensor direct_mapped_lru_cache_find_uncached_cuda( linear_cache_indices.scalar_type(), "direct_mapped_lru_cache_find_uncached_cuda", [&] { -#ifdef FBGEMM_GPU_MEMCHECK - const char* func_name = "direct_mapped_lru_cache_find_uncached_kernel"; -#endif // Find uncached indices - direct_mapped_lru_cache_find_uncached_kernel<<< + FBGEMM_LAUNCH_KERNEL( + (direct_mapped_lru_cache_find_uncached_kernel), std::min( div_round_up(N, kMaxThreads), get_max_thread_blocks_for_cache_kernels_()), kMaxThreads, 0, - at::cuda::getCurrentCUDAStream()>>>( - MAKE_PTA_WITH_NAME(func_name, linear_cache_indices, index_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, cache_sets, int32_t, 1, 32), + at::cuda::getCurrentCUDAStream(), + PTA_B(linear_cache_indices, index_t, 1, 32), + PTA_B(cache_sets, int32_t, 1, 32), max_indices, - MAKE_PTA_WITH_NAME(func_name, lxu_cache_state, int64_t, 2, 32), + PTA_B(lxu_cache_state, int64_t, 2, 32), time_stamp, - MAKE_PTA_WITH_NAME(func_name, lru_state, int64_t, 2, 32), + PTA_B(lru_state, int64_t, 2, 32), gather_cache_stats, - MAKE_PTA_WITH_NAME(func_name, uvm_cache_stats, int32_t, 1, 32), - MAKE_PTA_WITH_NAME( - func_name, lxu_cache_miss_timestamp, int64_t, 2, 32)); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + PTA_B(uvm_cache_stats, int32_t, 1, 32), + PTA_B(lxu_cache_miss_timestamp, int64_t, 2, 32)); }); return cache_sets; @@ -172,7 +168,7 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_insert_byte_kernel( const int64_t row_alignment) { const int32_t C = lxu_cache_state.size(0); int64_t n_conflict_misses = 0; - for (int32_t n = blockIdx.x * blockDim.y + threadIdx.y; n < *N_unique; + for (auto n = blockIdx.x * blockDim.y + threadIdx.y; n < *N_unique; n += gridDim.x * blockDim.y) { // check if this warp is responsible for this whole segment. const bool segment_start = @@ -197,16 +193,16 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_insert_byte_kernel( // now, we need to insert the (unique!) values in indices[n:n + SL] into // our slots. - const int32_t slot = threadIdx.x; + const auto slot = threadIdx.x; const int64_t slot_time = lru_state[cache_set][slot]; int64_t costs[1] = {slot_time}; - int32_t slots[1] = {slot}; + uint32_t slots[1] = {slot}; - BitonicSort>::sort(costs, slots); - const int32_t sorted_slot = slots[0]; - const int64_t sorted_lru_cost = costs[0]; + BitonicSort>::sort(costs, slots); + const auto sorted_slot = slots[0]; + const auto sorted_lru_cost = costs[0]; - for (int32_t l = 0; l < min(SL, kWarpSize); ++l) { + for (auto l = 0; l < min(SL, kWarpSize); ++l) { const int32_t insert_slot = shfl_sync(sorted_slot, l); const int64_t insert_current_lru_cost = shfl_sync(sorted_lru_cost, l); if (insert_current_lru_cost == time_stamp) { @@ -232,7 +228,7 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_insert_byte_kernel( &weights[weights_offset_insert + idx_insert * D_insert_bytes + 0]); auto cache_row = reinterpret_cast( &lxu_cache_weights[cache_set * kWarpSize + insert_slot][0]); - for (int32_t d = threadIdx.x; d * sizeof(uint4) < D_insert_bytes; + for (auto d = threadIdx.x; d * sizeof(uint4) < D_insert_bytes; d += blockDim.x) { cache_row[d] = row[d]; } @@ -285,7 +281,7 @@ __launch_bounds__(kMaxThreads) void direct_mapped_lru_cache_insert_byte_kernel( // one warp for each set (multiple times) // (no divergence for each control branch) - for (int32_t pos = blockIdx.x * blockDim.y + threadIdx.y; pos < N; + for (auto pos = blockIdx.x * blockDim.y + threadIdx.y; pos < N; pos += gridDim.x * blockDim.y) { auto cache_set = cache_sets[pos]; @@ -346,7 +342,7 @@ __launch_bounds__(kMaxThreads) void direct_mapped_lru_cache_insert_byte_kernel( auto row = reinterpret_cast( &weights[weights_offset_insert + idx_insert * D_insert_bytes + 0]); auto cache_row = reinterpret_cast(&lxu_cache_weights[cache_set][0]); - for (int32_t d = threadIdx.x; d * sizeof(uint4) < D_insert_bytes; + for (auto d = threadIdx.x; d * sizeof(uint4) < D_insert_bytes; d += blockDim.x) { cache_row[d] = row[d]; } @@ -398,36 +394,30 @@ void lru_cache_insert_byte_cuda( cache_set_sorted_unique_indices.scalar_type(), "lru_cache_insert_byte_cuda", [&] { -#ifdef FBGEMM_GPU_MEMCHECK - const char* func_name = "lru_cache_insert_byte_kernel"; -#endif - lru_cache_insert_byte_kernel<<< + FBGEMM_LAUNCH_KERNEL( + (lru_cache_insert_byte_kernel), std::min( div_round_up(N, kMaxThreads / kWarpSize), get_max_thread_blocks_for_cache_kernels_()), dim3(kWarpSize, kMaxThreads / kWarpSize), 0, - at::cuda::getCurrentCUDAStream()>>>( - MAKE_PTA_WITH_NAME(func_name, weights, uint8_t, 1, 64), - MAKE_PTA_WITH_NAME( - func_name, cache_hash_size_cumsum, int64_t, 1, 32), - MAKE_PTA_WITH_NAME( - func_name, cache_index_table_map, int32_t, 1, 64), - MAKE_PTA_WITH_NAME(func_name, weights_offsets, int64_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, weights_tys, uint8_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, D_offsets, int32_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, sorted_cache_sets, int32_t, 1, 32), - MAKE_PTA_WITH_NAME( - func_name, cache_set_sorted_unique_indices, index_t, 1, 32), + at::cuda::getCurrentCUDAStream(), + PTA_B(weights, uint8_t, 1, 64), + PTA_B(cache_hash_size_cumsum, int64_t, 1, 32), + PTA_B(cache_index_table_map, int32_t, 1, 64), + PTA_B(weights_offsets, int64_t, 1, 32), + PTA_B(weights_tys, uint8_t, 1, 32), + PTA_B(D_offsets, int32_t, 1, 32), + PTA_B(sorted_cache_sets, int32_t, 1, 32), + PTA_B(cache_set_sorted_unique_indices, index_t, 1, 32), unique_indices_length.data_ptr(), - MAKE_PTA_WITH_NAME(func_name, lxu_cache_state, int64_t, 2, 32), - MAKE_PTA_WITH_NAME(func_name, lxu_cache_weights, uint8_t, 2, 64), + PTA_B(lxu_cache_state, int64_t, 2, 32), + PTA_B(lxu_cache_weights, uint8_t, 2, 64), time_stamp, - MAKE_PTA_WITH_NAME(func_name, lru_state, int64_t, 2, 32), + PTA_B(lru_state, int64_t, 2, 32), gather_cache_stats, - MAKE_PTA_WITH_NAME(func_name, uvm_cache_stats, int32_t, 1, 32), + PTA_B(uvm_cache_stats, int32_t, 1, 32), row_alignment); - C10_CUDA_KERNEL_LAUNCH_CHECK(); }); } @@ -469,36 +459,30 @@ void direct_mapped_lru_cache_insert_byte_cuda( linear_cache_indices.scalar_type(), "direct_mapped_lru_cache_insert_byte_cuda", [&] { -#ifdef FBGEMM_GPU_MEMCHECK - const char* func_name = "direct_mapped_lru_cache_insert_byte_kernel"; -#endif - direct_mapped_lru_cache_insert_byte_kernel<<< + FBGEMM_LAUNCH_KERNEL( + (direct_mapped_lru_cache_insert_byte_kernel), std::min( div_round_up(N, kMaxThreads / kWarpSize), get_max_thread_blocks_for_cache_kernels_()), dim3(kWarpSize, kMaxThreads / kWarpSize), 0, - at::cuda::getCurrentCUDAStream()>>>( - MAKE_PTA_WITH_NAME(func_name, weights, uint8_t, 1, 64), - MAKE_PTA_WITH_NAME( - func_name, cache_hash_size_cumsum, int64_t, 1, 32), - MAKE_PTA_WITH_NAME( - func_name, cache_index_table_map, int32_t, 1, 64), - MAKE_PTA_WITH_NAME(func_name, weights_offsets, int64_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, weights_tys, uint8_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, D_offsets, int32_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, lxu_cache_state, int64_t, 2, 32), - MAKE_PTA_WITH_NAME(func_name, lxu_cache_weights, uint8_t, 2, 64), + at::cuda::getCurrentCUDAStream(), + PTA_B(weights, uint8_t, 1, 64), + PTA_B(cache_hash_size_cumsum, int64_t, 1, 32), + PTA_B(cache_index_table_map, int32_t, 1, 64), + PTA_B(weights_offsets, int64_t, 1, 32), + PTA_B(weights_tys, uint8_t, 1, 32), + PTA_B(D_offsets, int32_t, 1, 32), + PTA_B(lxu_cache_state, int64_t, 2, 32), + PTA_B(lxu_cache_weights, uint8_t, 2, 64), time_stamp, - MAKE_PTA_WITH_NAME(func_name, lru_state, int64_t, 2, 32), - MAKE_PTA_WITH_NAME(func_name, linear_cache_indices, index_t, 1, 32), - MAKE_PTA_WITH_NAME( - func_name, lxu_cache_miss_timestamp, int64_t, 2, 32), - MAKE_PTA_WITH_NAME(func_name, cache_sets, int32_t, 1, 32), + PTA_B(lru_state, int64_t, 2, 32), + PTA_B(linear_cache_indices, index_t, 1, 32), + PTA_B(lxu_cache_miss_timestamp, int64_t, 2, 32), + PTA_B(cache_sets, int32_t, 1, 32), gather_cache_stats, - MAKE_PTA_WITH_NAME(func_name, uvm_cache_stats, int32_t, 1, 32), + PTA_B(uvm_cache_stats, int32_t, 1, 32), row_alignment); - C10_CUDA_KERNEL_LAUNCH_CHECK(); }); }