@@ -859,7 +859,7 @@ __launch_bounds__(BLOCK_SIZE)
859859// MAX_RESIDENT_THREAD_PER_SM = BLOCK_SIZE * BLOCKS_PER_SM = 2048
860860// For architectures 750 and 860 (890), the values for MAX_RESIDENT_THREAD_PER_SM
861861// is 1024 and 1536 respectively, which means the bounds don't work anymore
862- // Used for fp32 data compressed to fp16, and all types using non-L1 distance metric.
862+ // Used for fp32 data downcast to fp16, and all types using non-L1 distance metric.
863863template <typename Data_t,
864864 typename Index_t,
865865 typename ID_t = InternalID_t<Index_t>,
@@ -1373,11 +1373,11 @@ GNND<Data_t, Index_t>::GNND(raft::resources const& res, const BuildConfig& build
13731373 static_assert (NUM_SAMPLES <= 32 );
13741374
13751375 using input_t = typename std::remove_const<Data_t>::type;
1376- if (build_config.use_fp16_dist_comp && build_config.dataset_dim <= 16 &&
1376+ if (build_config.internal_distance_dtype == CUDA_R_16F && build_config.dataset_dim <= 16 &&
13771377 std::is_same_v<input_t , float >) {
13781378 RAFT_LOG_WARN (
13791379 " Using fp16 for distance computation for data in fp32 with small dimensions (%zu) <= 16 may "
1380- " result in low quality results. Consider setting use_fp16_dist_comp = false ." ,
1380+ " result in low quality results. Consider setting internal_distance_dtype = CUDA_R_32F ." ,
13811381 build_config.dataset_dim );
13821382 }
13831383
@@ -1431,14 +1431,17 @@ void GNND<Data_t, Index_t>::local_join(cudaStream_t stream, DistEpilogue_t dist_
14311431{
14321432 raft::matrix::fill (res, dists_buffer_.view (), std::numeric_limits<float >::max ());
14331433 // Kernel dispatch logic:
1434- // fp32 data can have an effective type of fp32 OR fp16 (when use_fp16_dist_comp flag = True for
1435- // wmma usage) Based on EFFECTIVE dtype:
1434+ // fp32 data can have an effective type of fp32 OR fp16 (when internal_distance_dtype is
1435+ // CUDA_R_16F, fp32 host data is downcast into a device-side fp16 buffer at copy-in time so the
1436+ // WMMA kernel reads it in fp16). Based on EFFECTIVE dtype:
14361437 // fp32 data || L1 distance -> SIMT: internally converted to fp32 for distance computation
1437- // on-the-fly dypte <= fp16 && non-L1 metrics -> WMMA (tensor-core accelerated dot product):
1438- // internally converted to fp16 for distance computation on-the-fly
1438+ // on-the-fly
1439+ // dtype <= fp16 && non-L1 metrics -> WMMA (tensor-core accelerated dot product):
1440+ // internally converted to fp16 for distance computation on-the-fly
14391441
1440- bool use_simt = (std::is_same_v<input_t , float > && !build_config_.use_fp16_dist_comp ) ||
1441- build_config_.metric == cuvs::distance::DistanceType::L1 ;
1442+ bool use_simt =
1443+ (std::is_same_v<input_t , float > && build_config_.internal_distance_dtype != CUDA_R_16F ) ||
1444+ build_config_.metric == cuvs::distance::DistanceType::L1 ;
14421445
14431446 auto launch_kernel = [&](auto * typed_ptr) {
14441447 if (use_simt) {
@@ -1479,7 +1482,8 @@ void GNND<Data_t, Index_t>::local_join(cudaStream_t stream, DistEpilogue_t dist_
14791482 };
14801483
14811484 if (d_data_half_.has_value ()) {
1482- // Host fp32 input compressed to fp16 via use_fp16_dist_comp.
1485+ // Host fp32 input was downcast to a device-side fp16 buffer via internal_distance_dtype =
1486+ // CUDA_R_16F.
14831487 launch_kernel (static_cast <const half*>(d_data_ptr_));
14841488 } else {
14851489 // Data stored as input_t: device data used directly, or host data copied as-is.
@@ -1521,17 +1525,18 @@ void GNND<Data_t, Index_t>::build(Data_t* data,
15211525 build_config_.metric == cuvs::distance::DistanceType::L2SqrtExpanded ||
15221526 build_config_.metric == cuvs::distance::DistanceType::CosineExpanded;
15231527
1524- bool compress_host_data =
1525- !data_on_device && std::is_same_v< input_t , float > && build_config_.use_fp16_dist_comp ;
1528+ bool downcast_host_data = !data_on_device && std::is_same_v< input_t , float > &&
1529+ build_config_.internal_distance_dtype == CUDA_R_16F ;
15261530
15271531 if (data_on_device) {
15281532 // When user-given data is on device, we use it directly. This can be any type (fp32, fp16,
15291533 // int8, uint8)
15301534 d_data_ptr_ = data;
1531- } else if (compress_host_data) {
1532- // When user-given data is fp32 host data, and use_fp16_dist_comp is true, we allocate fp16
1533- // buffer to copy the data. This allows the wmma kernel to be used for distance computation
1534- // instead of simt kernel.
1535+ } else if (downcast_host_data) {
1536+ // When user-given data is fp32 host data, and internal_distance_dtype is CUDA_R_16F, we
1537+ // allocate an fp16 device buffer and downcast at copy-in time. Storing the dataset on device
1538+ // in fp16 (instead of fp32) for this path halves both the device memory footprint and the
1539+ // per-iteration read bandwidth of the WMMA kernel.
15351540 if (!d_data_half_.has_value ()) {
15361541 d_data_half_.emplace (raft::make_device_matrix<half, size_t , raft::row_major>(
15371542 res, build_config_.max_dataset_size , build_config_.dataset_dim ));
@@ -1545,7 +1550,7 @@ void GNND<Data_t, Index_t>::build(Data_t* data,
15451550 int num_blocks = raft::ceildiv (n_elems, static_cast <size_t >(TPB ));
15461551 size_t dst_offset = batch.offset () * build_config_.dataset_dim ;
15471552 if (needs_l2_norms) {
1548- // we compute l2 norms on the fp32 data directly .
1553+ // Compute l2 norms on the fp32 batches before they're downcast to fp16 .
15491554 compute_l2_norms_kernel<<<batch.size(),
15501555 raft::warp_size (),
15511556 sizeof(float ) *
@@ -1560,8 +1565,8 @@ void GNND<Data_t, Index_t>::build(Data_t* data,
15601565 }
15611566 d_data_ptr_ = d_data_half_.value().data_handle();
15621567 } else {
1563- // In other cases where user-given data is not device-accessible, we allocate a device buffer to
1564- // copy the data. The input type is kept as-is (fp32, fp16, int8, uint8) .
1568+ // Other cases: user-given data is not device-accessible, but we don't need a precision
1569+ // conversion. Allocate a device buffer in input_t and copy as-is.
15651570 if (!d_data_direct_.has_value ()) {
15661571 d_data_direct_.emplace (raft::make_device_matrix<input_t , size_t , raft::row_major>(
15671572 res, build_config_.max_dataset_size , build_config_.dataset_dim ));
@@ -1573,7 +1578,7 @@ void GNND<Data_t, Index_t>::build(Data_t* data,
15731578 d_data_ptr_ = d_data_direct_.value ().data_handle ();
15741579 }
15751580
1576- if (needs_l2_norms && !compress_host_data ) {
1581+ if (needs_l2_norms && !downcast_host_data ) {
15771582 compute_l2_norms_kernel<<<
15781583 nrow_,
15791584 raft::warp_size (),
0 commit comments