@@ -98,58 +98,90 @@ inline std::enable_if_t<std::is_floating_point_v<MathT>> predict_core(
9898 raft::make_device_matrix_view<const MathT, IdxT>(centers, n_clusters, dim);
9999 auto X_norm_view = raft::make_device_vector_view<const MathT, IdxT>(dataset_norm, n_rows);
100100
101- auto minClusterAndDistance = raft::make_device_mdarray<raft::KeyValuePair<IdxT, MathT>, IdxT>(
102- handle, mr, raft::make_extents<IdxT>(n_rows));
103-
104- cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute<MathT, IdxT>(
105- handle,
106- X_view,
107- centroids_view,
108- minClusterAndDistance.view (),
109- X_norm_view,
110- L2NormBuf_OR_DistBuf,
111- params.metric ,
112- 0 , // batch_samples (unused for fused reduction)
113- 0 , // batch_centroids (unused for fused reduction)
114- workspace);
115-
116- // Copy keys to output labels
117- raft::linalg::map (handle,
118- raft::make_const_mdspan (minClusterAndDistance.view ()),
119- raft::make_device_vector_view<LabelT, IdxT>(labels, n_rows),
120- raft::compose_op<raft::cast_op<LabelT>, raft::key_op>());
121- break ;
122- }
123- case cuvs::distance::DistanceType::InnerProduct: {
124- if (use_cutile_fused_nn<MathT, IdxT, IdxT>(handle, n_rows, n_clusters, dim)) {
125- rmm::device_uvector<MathT> L2NormBuf_OR_DistBuf (0 , stream, mr);
126- rmm::device_uvector<char > workspace (0 , stream, mr);
127-
128- auto X_view = raft::make_device_matrix_view<const MathT, IdxT>(dataset, n_rows, dim);
129- auto centroids_view =
130- raft::make_device_matrix_view<const MathT, IdxT>(centers, n_clusters, dim);
131- auto X_norm_view = raft::make_device_vector_view<const MathT, IdxT>(dataset_norm, n_rows);
132-
133- auto minClusterAndDistance =
134- raft::make_device_mdarray<raft::KeyValuePair<IdxT, MathT>, IdxT>(
135- handle, mr, raft::make_extents<IdxT>(n_rows));
101+ auto nearest_dist =
102+ raft::make_device_mdarray<MathT, IdxT>(handle, mr, raft::make_extents<IdxT>(n_rows));
136103
104+ if constexpr (std::is_same_v<LabelT, IdxT>) {
105+ auto labels_view = raft::make_device_vector_view<IdxT, IdxT>(labels, n_rows);
137106 cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute<MathT, IdxT>(
138107 handle,
139108 X_view,
140109 centroids_view,
141- minClusterAndDistance.view (),
110+ labels_view,
111+ nearest_dist.view (),
112+ X_norm_view,
113+ L2NormBuf_OR_DistBuf,
114+ params.metric ,
115+ 0 , // batch_samples (unused for fused reduction)
116+ 0 , // batch_centroids (unused for fused reduction)
117+ workspace);
118+ } else {
119+ auto nearest_idx =
120+ raft::make_device_mdarray<IdxT, IdxT>(handle, mr, raft::make_extents<IdxT>(n_rows));
121+ cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute<MathT, IdxT>(
122+ handle,
123+ X_view,
124+ centroids_view,
125+ nearest_idx.view (),
126+ nearest_dist.view (),
142127 X_norm_view,
143128 L2NormBuf_OR_DistBuf,
144129 params.metric ,
145130 0 ,
146131 0 ,
147132 workspace);
133+ raft::copy (
134+ handle, raft::make_device_vector_view<LabelT, IdxT>(labels, n_rows), nearest_idx.view ());
135+ }
136+ break ;
137+ }
138+ case cuvs::distance::DistanceType::InnerProduct: {
139+ if (uses_fused_distance_nn (
140+ use_fused<MathT, IdxT, IdxT>(handle, n_rows, n_clusters, dim, params.metric ))) {
141+ rmm::device_uvector<MathT> L2NormBuf_OR_DistBuf (0 , stream, mr);
142+ rmm::device_uvector<char > workspace (0 , stream, mr);
148143
149- raft::linalg::map (handle,
150- raft::make_const_mdspan (minClusterAndDistance.view ()),
151- raft::make_device_vector_view<LabelT, IdxT>(labels, n_rows),
152- raft::compose_op<raft::cast_op<LabelT>, raft::key_op>());
144+ auto X_view = raft::make_device_matrix_view<const MathT, IdxT>(dataset, n_rows, dim);
145+ auto centroids_view =
146+ raft::make_device_matrix_view<const MathT, IdxT>(centers, n_clusters, dim);
147+ auto X_norm_view = raft::make_device_vector_view<const MathT, IdxT>(dataset_norm, n_rows);
148+
149+ auto nearest_dist =
150+ raft::make_device_mdarray<MathT, IdxT>(handle, mr, raft::make_extents<IdxT>(n_rows));
151+
152+ if constexpr (std::is_same_v<LabelT, IdxT>) {
153+ auto labels_view = raft::make_device_vector_view<IdxT, IdxT>(labels, n_rows);
154+ cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute<MathT, IdxT>(
155+ handle,
156+ X_view,
157+ centroids_view,
158+ labels_view,
159+ nearest_dist.view (),
160+ X_norm_view,
161+ L2NormBuf_OR_DistBuf,
162+ params.metric ,
163+ 0 ,
164+ 0 ,
165+ workspace);
166+ } else {
167+ auto nearest_idx =
168+ raft::make_device_mdarray<IdxT, IdxT>(handle, mr, raft::make_extents<IdxT>(n_rows));
169+ cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute<MathT, IdxT>(
170+ handle,
171+ X_view,
172+ centroids_view,
173+ nearest_idx.view (),
174+ nearest_dist.view (),
175+ X_norm_view,
176+ L2NormBuf_OR_DistBuf,
177+ params.metric ,
178+ 0 ,
179+ 0 ,
180+ workspace);
181+ raft::copy (handle,
182+ raft::make_device_vector_view<LabelT, IdxT>(labels, n_rows),
183+ nearest_idx.view ());
184+ }
153185 } else {
154186 rmm::device_uvector<MathT> distances (n_rows * n_clusters, stream, mr);
155187
@@ -216,22 +248,19 @@ auto calc_minibatch_size(const raft::resources& handle,
216248 size_t mem_per_row = 0 ;
217249 switch (metric) {
218250 case distance::DistanceType::L2Expanded:
219- case distance::DistanceType::L2SqrtExpanded: {
220- if (use_fused<MathT, IdxT, IdxT>(handle, n_rows, n_clusters, dim)) {
221- // fusedL2NN needs a mutex and a key-value pair for each row.
222- mem_per_row += sizeof (int );
223- mem_per_row += sizeof (raft::KeyValuePair<IdxT, MathT>);
224- } else {
225- // unfused path needs a full GEMM output (distance matrix row).
226- mem_per_row += sizeof (MathT) * n_clusters;
227- }
228- } break ;
251+ case distance::DistanceType::L2SqrtExpanded:
229252 case distance::DistanceType::InnerProduct: {
230- if (use_cutile_fused_nn<MathT, IdxT, IdxT>(handle, n_rows, n_clusters, dim)) {
231- mem_per_row += sizeof (int );
232- mem_per_row += sizeof (raft::KeyValuePair<IdxT, MathT>);
233- } else {
234- mem_per_row += sizeof (MathT) * n_clusters;
253+ switch (use_fused<MathT, IdxT, IdxT>(handle, n_rows, n_clusters, dim, metric)) {
254+ case FusedDistancePath::FusedCutile: break ;
255+ case FusedDistancePath::FusedCutlass:
256+ // fusedDistanceNNMinReduce CUTLASS fallback: mutex workspace + scratch KVP per row.
257+ mem_per_row += sizeof (int );
258+ mem_per_row += sizeof (raft::KeyValuePair<IdxT, MathT>);
259+ break ;
260+ case FusedDistancePath::Unfused:
261+ // unfused / GEMM+argmin path needs a full distance matrix row.
262+ mem_per_row += sizeof (MathT) * n_clusters;
263+ break ;
235264 }
236265 } break ;
237266 // Other metrics require storing a distance matrix.
0 commit comments