diff --git a/cpp/src/neighbors/detail/nn_descent.cuh b/cpp/src/neighbors/detail/nn_descent.cuh index 61b8f80f10..f7b80fbf7f 100644 --- a/cpp/src/neighbors/detail/nn_descent.cuh +++ b/cpp/src/neighbors/detail/nn_descent.cuh @@ -141,6 +141,16 @@ struct dtype_traits<__half> { static __device__ __forceinline__ float to_float(__half v) { return __half2float(v); } }; +template +concept Byte = std::is_same_v or std::is_same_v; +template +struct dtype_traits { + static constexpr int APAD = 4; + static constexpr int BPAD = 4; + static constexpr int TILE_COL_WIDTH = 128; + static __device__ __forceinline__ float to_float(T v) { return static_cast(v); } +}; + template __device__ __forceinline__ ResultItem xor_swap(ResultItem x, int mask, int dir) { @@ -244,69 +254,64 @@ __device__ __forceinline__ void load_vec(Data_t* vec_buffer, } } -// TODO: Replace with RAFT utilities https://github.com/rapidsai/raft/issues/1827 -/** Calculate L2 norm, and cast data to Output_t */ -template -RAFT_KERNEL preprocess_data_kernel( - const Data_t* input_data, - Output_t* output_data, - int dim, - DistData_t* l2_norms, - size_t list_offset = 0, - cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Expanded) +/** Converting load: loads Data_t from global memory into __half shared memory buffer. */ +template + requires(!std::is_same_v) +__device__ __forceinline__ void load_vec(__half* vec_buffer, + const Data_t* d_vec, + const int load_dims, + const int padding_dims, + const int lane_id) +{ + constexpr int num_load_elems_per_warp = raft::warp_size(); + const __half half_0 = __float2half(0.0f); + for (int step = 0; step < raft::ceildiv(padding_dims, num_load_elems_per_warp); step++) { + int idx = step * num_load_elems_per_warp + lane_id; + if (idx < load_dims) { + vec_buffer[idx] = d_vec[idx]; + } else if (idx < padding_dims) { + vec_buffer[idx] = half_0; + } + } +} + +/** One warp per block. Computes squared L2 norm for each row. */ +template +RAFT_KERNEL compute_l2_norms_kernel(const Data_t* data, int dim, DistData_t* l2_norms) { extern __shared__ char buffer[]; __shared__ float l2_norm; Data_t* s_vec = (Data_t*)buffer; - size_t list_id = list_offset + blockIdx.x; + size_t list_id = blockIdx.x; + int lane_id = threadIdx.x % raft::warp_size(); - load_vec(s_vec, - input_data + static_cast(blockIdx.x) * dim, - dim, - dim, - threadIdx.x % raft::warp_size()); + load_vec(s_vec, data + static_cast(blockIdx.x) * dim, dim, dim, lane_id); if (threadIdx.x == 0) { l2_norm = 0; } __syncthreads(); - if (metric == cuvs::distance::DistanceType::L2Expanded || - metric == cuvs::distance::DistanceType::L2SqrtExpanded || - metric == cuvs::distance::DistanceType::CosineExpanded) { - int lane_id = threadIdx.x % raft::warp_size(); - for (int step = 0; step < raft::ceildiv(dim, raft::warp_size()); step++) { - int idx = step * raft::warp_size() + lane_id; - float part_dist = 0; - if (idx < dim) { - part_dist = s_vec[idx]; - part_dist = part_dist * part_dist; - } - __syncwarp(); - for (int offset = raft::warp_size() >> 1; offset >= 1; offset >>= 1) { - part_dist += __shfl_down_sync(raft::warp_full_mask(), part_dist, offset); - } - if (lane_id == 0) { l2_norm += part_dist; } - __syncwarp(); - } - } - for (int step = 0; step < raft::ceildiv(dim, raft::warp_size()); step++) { - int idx = step * raft::warp_size() + threadIdx.x; + int idx = step * raft::warp_size() + lane_id; + float part_dist = 0; if (idx < dim) { - if (metric == cuvs::distance::DistanceType::InnerProduct || - metric == cuvs::distance::DistanceType::L1) { - output_data[list_id * dim + idx] = input_data[(size_t)blockIdx.x * dim + idx]; - } else if (metric == cuvs::distance::DistanceType::CosineExpanded) { - output_data[list_id * dim + idx] = - (float)input_data[(size_t)blockIdx.x * dim + idx] / sqrt(l2_norm); - } else if (metric == cuvs::distance::DistanceType::BitwiseHamming) { - int idx_for_byte = list_id * dim + idx; // uint8 or int8 data - uint8_t* output_bytes = reinterpret_cast(output_data); - output_bytes[idx_for_byte] = input_data[(size_t)blockIdx.x * dim + idx]; - } else { // L2Expanded or L2SqrtExpanded - output_data[list_id * dim + idx] = input_data[(size_t)blockIdx.x * dim + idx]; - if (idx == 0) { l2_norms[list_id] = l2_norm; } - } + part_dist = static_cast(s_vec[idx]); + part_dist = part_dist * part_dist; + } + __syncwarp(); + for (int offset = raft::warp_size() >> 1; offset >= 1; offset >>= 1) { + part_dist += __shfl_down_sync(raft::warp_full_mask(), part_dist, offset); } + if (lane_id == 0) { l2_norm += part_dist; } + __syncwarp(); } + + if (lane_id == 0) { l2_norms[list_id] = l2_norm; } +} + +template +RAFT_KERNEL convert_copy_kernel(const Src_t* src, Dst_t* dst, size_t n) +{ + size_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (idx < n) { dst[idx] = static_cast(src[idx]); } } template @@ -527,7 +532,9 @@ __device__ __forceinline__ void calculate_metric(float* s_distances, if (metric == cuvs::distance::DistanceType::InnerProduct && can_postprocess_dist) { s_distances[i] = -s_distances[i]; } else if (metric == cuvs::distance::DistanceType::CosineExpanded) { - s_distances[i] = 1.0 - s_distances[i]; + float norm_product = l2_norms[row_neighbors[row_id]] * l2_norms[col_neighbors[col_id]]; + s_distances[i] = + (norm_product > 0.0f) ? (1.0f - s_distances[i] / sqrtf(norm_product)) : 0.0f; } else if (metric == cuvs::distance::DistanceType::BitwiseHamming) { s_distances[i] = 0.0; int n1 = row_neighbors[row_id]; @@ -573,12 +580,12 @@ struct DistAccumulator { // For architectures 750 and 860 (890), the values for MAX_RESIDENT_THREAD_PER_SM // is 1024 and 1536 respectively, which means the bounds don't work anymore // SIMT kernel: scalar element-wise distance computation. -// Used for fp32 data (all metrics) and fp16 data with L1 distance (which cannot use tensor cores). +// Used for fp32 data (all metrics) and L1 distance computation for all dtypes (which cannot use +// tensor cores). template , typename DistEpilogue_t> - requires(std::is_same_v || std::is_same_v) RAFT_KERNEL #ifdef __CUDA_ARCH__ // Use minBlocksPerMultiprocessor = 4 on specific arches @@ -689,6 +696,7 @@ __launch_bounds__(BLOCK_SIZE) if (idx < list_new_size) { size_t neighbor_id = new_neighbors[idx]; size_t idx_in_data = neighbor_id * data_dim; + // loaded to shared memory while keeping the original dtype load_vec(s_nv[idx], data + idx_in_data + step * TILE_COL_WIDTH, num_load_elems, @@ -706,6 +714,7 @@ __launch_bounds__(BLOCK_SIZE) if (tmp_row < list_new_size && tmp_col < list_new_size) { float acc = 0.0f; for (int d = 0; d < num_load_elems; d++) { + // converted to float for distance computation float a = dtype_traits::to_float(s_nv[tmp_row][d]); float b = dtype_traits::to_float(s_nv[tmp_col][d]); acc += dist_acc(a, b); @@ -844,7 +853,11 @@ __launch_bounds__(BLOCK_SIZE) // MAX_RESIDENT_THREAD_PER_SM = BLOCK_SIZE * BLOCKS_PER_SM = 2048 // For architectures 750 and 860 (890), the values for MAX_RESIDENT_THREAD_PER_SM // is 1024 and 1536 respectively, which means the bounds don't work anymore -template , typename DistEpilogue_t> +// Used for fp32 data downcast to fp16, and all types using non-L1 distance metric. +template , + typename DistEpilogue_t> RAFT_KERNEL #ifdef __CUDA_ARCH__ // Use minBlocksPerMultiprocessor = 4 on specific arches @@ -862,7 +875,7 @@ __launch_bounds__(BLOCK_SIZE) const Index_t* rev_graph_old, const int2* sizes_old, const int width, - const __half* data, + const Data_t* data, const int data_dim, ID_t* graph, DistData_t* dists, @@ -958,6 +971,7 @@ __launch_bounds__(BLOCK_SIZE) if (idx < list_new_size) { size_t neighbor_id = new_neighbors[idx]; size_t idx_in_data = neighbor_id * data_dim; + // converted to fp16 on-the-fly while loading load_vec(s_nv[idx], data + idx_in_data + step * TILE_COL_WIDTH, num_load_elems, @@ -1352,23 +1366,6 @@ GNND::GNND(raft::resources const& res, const BuildConfig& build { static_assert(NUM_SAMPLES <= 32); - using input_t = typename std::remove_const::type; - if (std::is_same_v && - (build_config.dist_comp_dtype == cuvs::neighbors::nn_descent::DIST_COMP_DTYPE::FP32 || - (build_config.dist_comp_dtype == cuvs::neighbors::nn_descent::DIST_COMP_DTYPE::AUTO && - build_config.dataset_dim <= 16))) { - // use fp32 distance computation for better precision with smaller dimension - d_data_float_.emplace( - raft::make_device_matrix(res, nrow_, ndim_)); - } else { - d_data_half_.emplace(raft::make_device_matrix( - res, - nrow_, - build_config.metric == cuvs::distance::DistanceType::BitwiseHamming - ? (build_config.dataset_dim + 1) / 2 - : build_config.dataset_dim)); - } - raft::matrix::fill(res, dists_buffer_.view(), std::numeric_limits::max()); auto graph_buffer_view = raft::make_device_matrix_view( reinterpret_cast(graph_buffer_.data_handle()), nrow_, DEGREE_ON_DEVICE); @@ -1376,7 +1373,11 @@ GNND::GNND(raft::resources const& res, const BuildConfig& build raft::matrix::fill(res, d_locks_.view(), 0); if (build_config.metric == cuvs::distance::DistanceType::L2Expanded || - build_config.metric == cuvs::distance::DistanceType::L2SqrtExpanded) { + build_config.metric == cuvs::distance::DistanceType::L2SqrtExpanded || + build_config.metric == cuvs::distance::DistanceType::CosineExpanded) { + // for device memory efficiency, we do not allocate a separate array for the data + // to normalize the data when using CosineExpanded metric. Instead, we use the l2_norms_ vector + // and compute inside the calculate_metric kernel. l2_norms_ = raft::make_device_vector(res, nrow_); } }; @@ -1415,61 +1416,68 @@ void GNND::local_join(cudaStream_t stream, DistEpilogue_t dist_ { raft::matrix::fill(res, dists_buffer_.view(), std::numeric_limits::max()); - // Kernel dispatch logic: - // fp32 data -> SIMT (metric resolved at runtime inside the kernel) - // fp16 data + L1 distance -> SIMT (L1 needs element-wise ops, cannot use tensor cores) - // fp16 data + other metrics -> WMMA (tensor-core accelerated dot product) - if (d_data_float_.has_value()) { - local_join_kernel_simt<<>>(graph_.h_graph_new.data_handle(), - h_rev_graph_new_.data_handle(), - d_list_sizes_new_.data_handle(), - h_graph_old_.data_handle(), - h_rev_graph_old_.data_handle(), - d_list_sizes_old_.data_handle(), - NUM_SAMPLES, - d_data_float_->data_handle(), - ndim_, - graph_buffer_.data_handle(), - dists_buffer_.data_handle(), - DEGREE_ON_DEVICE, - d_locks_.data_handle(), - l2_norms_.data_handle(), - build_config_.metric, - dist_epilogue); - } else if (build_config_.metric == cuvs::distance::DistanceType::L1) { - local_join_kernel_simt<<>>(graph_.h_graph_new.data_handle(), - h_rev_graph_new_.data_handle(), - d_list_sizes_new_.data_handle(), - h_graph_old_.data_handle(), - h_rev_graph_old_.data_handle(), - d_list_sizes_old_.data_handle(), - NUM_SAMPLES, - d_data_half_.value().data_handle(), - ndim_, - graph_buffer_.data_handle(), - dists_buffer_.data_handle(), - DEGREE_ON_DEVICE, - d_locks_.data_handle(), - l2_norms_.data_handle(), - build_config_.metric, - dist_epilogue); + // Kernel dispatch logic, based on the effective distance-computation dtype (which depends on + // the input dtype and dist_comp_dtype): + // fp32 dist (only fp32 input, dist_comp_dtype == FP32 or AUTO with dim <= 16) -> SIMT: scalar + // element-wise distance computation in fp32. + // fp16 dist (everything else: fp16/int8/uint8 input, or fp32 input with dist_comp_dtype == + // FP16 or AUTO with dim > 16) -> WMMA (tensor-core accelerated dot product). Non-fp16 + // dtypes are converted to fp16 on-the-fly while loading into shared memory; for fp32 host + // input this conversion happens earlier at copy-in time (see d_data_half_). + // L1 distance for any input -> SIMT (L1 needs element-wise ops, can't use tensor cores). + using DCT = cuvs::neighbors::nn_descent::DIST_COMP_DTYPE; + bool use_fp16_dist = + std::is_same_v && (build_config_.dist_comp_dtype == DCT::FP16 || + (build_config_.dist_comp_dtype == DCT::AUTO && ndim_ > 16)); + bool use_simt = (std::is_same_v && !use_fp16_dist) || + build_config_.metric == cuvs::distance::DistanceType::L1; + + auto launch_kernel = [&](auto* typed_ptr) { + if (use_simt) { + local_join_kernel_simt<<>>(graph_.h_graph_new.data_handle(), + h_rev_graph_new_.data_handle(), + d_list_sizes_new_.data_handle(), + h_graph_old_.data_handle(), + h_rev_graph_old_.data_handle(), + d_list_sizes_old_.data_handle(), + NUM_SAMPLES, + typed_ptr, + ndim_, + graph_buffer_.data_handle(), + dists_buffer_.data_handle(), + DEGREE_ON_DEVICE, + d_locks_.data_handle(), + l2_norms_.data_handle(), + build_config_.metric, + dist_epilogue); + } else { + local_join_kernel_wmma<<>>(graph_.h_graph_new.data_handle(), + h_rev_graph_new_.data_handle(), + d_list_sizes_new_.data_handle(), + h_graph_old_.data_handle(), + h_rev_graph_old_.data_handle(), + d_list_sizes_old_.data_handle(), + NUM_SAMPLES, + typed_ptr, + ndim_, + graph_buffer_.data_handle(), + dists_buffer_.data_handle(), + DEGREE_ON_DEVICE, + d_locks_.data_handle(), + l2_norms_.data_handle(), + build_config_.metric, + dist_epilogue); + } + RAFT_CUDA_TRY(cudaPeekAtLastError()); + }; + + if (d_data_half_.has_value()) { + // Host fp32 input was downcast to a device-side fp16 buffer because distances are computed in + // fp16 (dist_comp_dtype == FP16, or AUTO with dim > 16). + launch_kernel(static_cast(d_data_ptr_)); } else { - local_join_kernel_wmma<<>>(graph_.h_graph_new.data_handle(), - h_rev_graph_new_.data_handle(), - d_list_sizes_new_.data_handle(), - h_graph_old_.data_handle(), - h_rev_graph_old_.data_handle(), - d_list_sizes_old_.data_handle(), - NUM_SAMPLES, - d_data_half_.value().data_handle(), - ndim_, - graph_buffer_.data_handle(), - dists_buffer_.data_handle(), - DEGREE_ON_DEVICE, - d_locks_.data_handle(), - l2_norms_.data_handle(), - build_config_.metric, - dist_epilogue); + // Data stored as input_t: device data used directly, or host data copied as-is. + launch_kernel(static_cast(d_data_ptr_)); } } @@ -1497,51 +1505,94 @@ void GNND::build(Data_t* data, update_counter_ = 0; graph_.h_graph = (InternalID_t*)output_graph; - if (d_data_float_.has_value()) { - raft::matrix::fill(res, d_data_float_.value().view(), static_cast(0)); - } else { - raft::matrix::fill(res, d_data_half_.value().view(), static_cast(0)); - } + d_data_ptr_ = nullptr; cudaPointerAttributes data_ptr_attr; RAFT_CUDA_TRY(cudaPointerGetAttributes(&data_ptr_attr, data)); - size_t batch_size = (data_ptr_attr.devicePointer == nullptr) ? 100000 : nrow_; - - auto vec_batches = cuvs::spatial::knn::detail::utils::make_batch_load_iterator( - res, - data, - static_cast(nrow_), - static_cast(build_config_.dataset_dim), - batch_size, - stream); - for (auto const& batch : vec_batches) { - if (d_data_float_.has_value()) { - preprocess_data_kernel<<(raft::warp_size())) * - raft::warp_size(), - stream>>>(batch.data(), - d_data_float_.value().data_handle(), - build_config_.dataset_dim, - l2_norms_.data_handle(), - batch.offset(), - build_config_.metric); - } else { - preprocess_data_kernel<<(raft::warp_size())) * - raft::warp_size(), - stream>>>(batch.data(), - d_data_half_.value().data_handle(), - build_config_.dataset_dim, - l2_norms_.data_handle(), - batch.offset(), - build_config_.metric); + bool data_on_device = (data_ptr_attr.type == cudaMemoryTypeDevice); + + bool needs_l2_norms = build_config_.metric == cuvs::distance::DistanceType::L2Expanded || + build_config_.metric == cuvs::distance::DistanceType::L2SqrtExpanded || + build_config_.metric == cuvs::distance::DistanceType::CosineExpanded; + + // For fp32 host input, downcast to a device-side fp16 buffer when distance computation will be + // done in fp16 anyway: dispatch matches the SIMT/WMMA decision in local_join() (FP16 explicit, or + // AUTO with dim > 16). + using DCT = cuvs::neighbors::nn_descent::DIST_COMP_DTYPE; + bool fp32_input_uses_fp16_dist = + std::is_same_v && + (build_config_.dist_comp_dtype == DCT::FP16 || + (build_config_.dist_comp_dtype == DCT::AUTO && build_config_.dataset_dim > 16)); + bool downcast_host_data = !data_on_device && fp32_input_uses_fp16_dist; + + if (data_on_device) { + // When user-given data is on device, we use it directly. This can be any type (fp32, fp16, + // int8, uint8) + d_data_ptr_ = data; + } else if (downcast_host_data) { + // When user-given data is fp32 host data and distances will be computed in fp16, we allocate + // an fp16 device buffer and downcast at copy-in time. Storing the dataset on device in fp16 + // (instead of fp32) for this path halves both the device memory footprint and the per- + // iteration read bandwidth of the WMMA kernel. + if (!d_data_half_.has_value()) { + d_data_half_.emplace(raft::make_device_matrix( + res, build_config_.max_dataset_size, build_config_.dataset_dim)); + } + size_t batch_size = 100000; + auto vec_batches = cuvs::spatial::knn::detail::utils::make_batch_load_iterator( + res, + data, + static_cast(nrow_), + static_cast(build_config_.dataset_dim), + batch_size, + stream); + constexpr int TPB = 256; + for (auto const& batch : vec_batches) { + size_t n_elems = batch.size() * build_config_.dataset_dim; + int num_blocks = raft::ceildiv(n_elems, static_cast(TPB)); + size_t dst_offset = batch.offset() * build_config_.dataset_dim; + if (needs_l2_norms) { + // Compute l2 norms on the fp32 batches before they're downcast to fp16. + compute_l2_norms_kernel<<(raft::warp_size())) * + raft::warp_size(), + stream>>>( + batch.data(), build_config_.dataset_dim, l2_norms_.data_handle() + batch.offset()); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } + convert_copy_kernel<<>>( + batch.data(), d_data_half_.value().data_handle() + dst_offset, n_elems); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } + d_data_ptr_ = d_data_half_.value().data_handle(); + } else { + // Other cases: user-given data is not device-accessible, but we don't need a precision + // conversion. Allocate a device buffer in input_t and copy as-is. + if (!d_data_direct_.has_value()) { + d_data_direct_.emplace(raft::make_device_matrix( + res, build_config_.max_dataset_size, build_config_.dataset_dim)); } + raft::copy(d_data_direct_.value().data_handle(), + data, + static_cast(nrow_) * build_config_.dataset_dim, + stream); + d_data_ptr_ = d_data_direct_.value().data_handle(); + } + + if (needs_l2_norms && !downcast_host_data) { + compute_l2_norms_kernel<<(raft::warp_size())) * + raft::warp_size(), + stream>>>( + static_cast(d_data_ptr_), build_config_.dataset_dim, l2_norms_.data_handle()); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + raft::resource::sync_stream(res); } graph_.clear(); @@ -1592,7 +1643,7 @@ void GNND::build(Data_t* data, // __CUDA_ARCH__ >= 700. Since RAFT supports compilation for ARCH 600, // we need to ensure that `local_join_kernel` (which uses tensor) operations // is not only not compiled, but also a runtime error is presented to the user - auto kernel = preprocess_data_kernel; + auto kernel = compute_l2_norms_kernel; void* kernel_ptr = reinterpret_cast(kernel); auto runtime_arch = raft::util::arch::kernel_virtual_arch(kernel_ptr); auto wmma_range = diff --git a/cpp/src/neighbors/detail/nn_descent_gnnd.hpp b/cpp/src/neighbors/detail/nn_descent_gnnd.hpp index a2639e4f43..7a242afe44 100644 --- a/cpp/src/neighbors/detail/nn_descent_gnnd.hpp +++ b/cpp/src/neighbors/detail/nn_descent_gnnd.hpp @@ -228,8 +228,18 @@ class GNND { size_t nrow_; size_t ndim_; - std::optional> d_data_float_; + using input_t = std::remove_const_t; + + // d_data_half_ is used for a special case when input data is fp32 on host and distances will be + // computed in fp16 (dist_comp_dtype == FP16, or AUTO with dim > 16): we store the dataset on + // device as fp16 (instead of fp32) to halve the device memory footprint and WMMA kernel read + // bandwidth. std::optional> d_data_half_; + // d_data_direct_ is used when input data is on host, and we need to copy it to device + std::optional> d_data_direct_; + + // d_data_ptr_ is used to store the general pointer to the input data + const void* d_data_ptr_{nullptr}; raft::device_vector l2_norms_; raft::device_matrix graph_buffer_;