Skip to content

Commit d94fe86

Browse files
committed
Refactor for explicitness of supported distance types
1 parent 3ddd98a commit d94fe86

1 file changed

Lines changed: 10 additions & 6 deletions

File tree

cpp/src/neighbors/ivf_pq/ivf_pq_fp16_overflow.cuh

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,15 +100,19 @@ bool estimate_fp16_overflow(
100100
{
101101
if (dataset.extent(0) == 0) { return false; }
102102

103-
// Cosine similarity scores does normalization itself, so overflow won't happen
104-
if (metric == cuvs::distance::DistanceType::CosineExpanded) { return false; }
103+
float dist_factor = 1.0f;
104+
switch (metric) {
105+
case cuvs::distance::DistanceType::L2Expanded: dist_factor = 4.0f; break;
106+
case cuvs::distance::DistanceType::CosineExpanded:
107+
// Cosine similarity scores does normalization itself, so overflow won't happen
108+
return false;
109+
case cuvs::distance::DistanceType::InnerProduct: dist_factor = 1.0f; break;
110+
default: RAFT_FAIL("Unsupported distance type for IVF-PQ search %d.", int(metric));
111+
}
105112

106113
const float max_vector_sq_norm =
107114
cuvs::neighbors::ivf_pq::detail::estimate_max_squared_norm(handle, dataset);
108-
109-
const float max_distance_sq_norm = metric == cuvs::distance::DistanceType::L2Expanded
110-
? 4.0f * max_vector_sq_norm
111-
: max_vector_sq_norm;
115+
const float max_distance_sq_norm = dist_factor * max_vector_sq_norm;
112116

113117
constexpr float kFp16Max = 65504.0f;
114118
return max_distance_sq_norm > kFp16Max;

0 commit comments

Comments
 (0)