Skip to content

Commit db6b385

Browse files
committed
undo kvp, add constrains and alignments to tile export
1 parent 6742853 commit db6b385

25 files changed

Lines changed: 950 additions & 751 deletions

cpp/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -980,9 +980,9 @@ if(NOT BUILD_CPU_ONLY)
980980
MATRIX_JSON_FILE
981981
"${fused_1nn_cutile_dir}/fused_1nn_cutile_matrix.json"
982982
FRAGMENT_TAG_FORMAT_CUBIN
983-
"cuvs::distance::detail::fragment_tag_fused_1nn_cubin<cuvs::neighbors::detail::tag_@data_abbrev@, cuvs::distance::detail::metric_tag_@metric_abbrev@, cuvs::distance::detail::cutile_tile_config<@tile_m@, @tile_n@, @tile_k@>, cuvs::detail::jit_lto::@arch_tag@>"
983+
"cuvs::distance::detail::fragment_tag_fused_1nn_cubin<cuvs::neighbors::detail::tag_@data_abbrev@, cuvs::distance::detail::metric_tag_@metric_abbrev@, cuvs::neighbors::detail::tag_index_@index_abbrev@, cuvs::distance::detail::cutile_tile_config<@tile_m@, @tile_n@, @tile_k@>, cuvs::detail::jit_lto::@arch_tag@>"
984984
FRAGMENT_TAG_FORMAT_TILEIR
985-
"cuvs::distance::detail::fragment_tag_fused_1nn_tileir<cuvs::neighbors::detail::tag_@data_abbrev@, cuvs::distance::detail::metric_tag_@metric_abbrev@, cuvs::distance::detail::cutile_tile_config<@tile_m@, @tile_n@, @tile_k@>>"
985+
"cuvs::distance::detail::fragment_tag_fused_1nn_tileir<cuvs::neighbors::detail::tag_@data_abbrev@, cuvs::distance::detail::metric_tag_@metric_abbrev@, cuvs::neighbors::detail::tag_index_@index_abbrev@, cuvs::distance::detail::cutile_tile_config<@tile_m@, @tile_n@, @tile_k@>>"
986986
FRAGMENT_TAG_HEADER_FILES
987987
"<cuvs/detail/jit_lto/fused_distance_nn/fused_1nn_fragments.hpp>"
988988
"<cuvs/detail/jit_lto/cutile_arch_tags.hpp>"

cpp/cmake/modules/generate_cutile_kernels.cmake

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,22 @@ function(process_cutile_matrix_entry source_list_var)
150150
set(embedded_header_file "${_artifact_stem}_${register}.h")
151151

152152
set(_python_args
153-
--format "${output_format}" --data-type "${data_type}" --metric "${metric}" --tile-m
154-
"${tile_m}" --tile-n "${tile_n}" --tile-k "${tile_k}" --gpu-code "${gpu_code}"
153+
--format
154+
"${output_format}"
155+
--data-type
156+
"${data_type}"
157+
--metric
158+
"${metric}"
159+
--index-type
160+
"${index_type}"
161+
--tile-m
162+
"${tile_m}"
163+
--tile-n
164+
"${tile_n}"
165+
--tile-k
166+
"${tile_k}"
167+
--gpu-code
168+
"${gpu_code}"
155169
)
156170
if(DEFINED bytecode_version AND NOT "${bytecode_version}" STREQUAL "")
157171
list(APPEND _python_args --bytecode-version "${bytecode_version}")

cpp/include/cuvs/detail/jit_lto/fused_distance_nn/fused_1nn_fragments.hpp

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
#pragma once
77

8+
#include <cstdint>
9+
810
#include <cuvs/detail/jit_lto/common_fragments.hpp>
911
#include <cuvs/distance/distance.hpp>
1012

@@ -75,13 +77,33 @@ struct fused_1nn_data_tag<half> {
7577
template <typename DataT>
7678
using fused_1nn_data_tag_t = typename fused_1nn_data_tag<DataT>::type;
7779

78-
template <typename DataTag, typename MetricTag, typename TileTag, typename ArchTag>
80+
template <typename IdxT>
81+
struct fused_1nn_index_tag;
82+
83+
template <>
84+
struct fused_1nn_index_tag<int32_t> {
85+
using type = cuvs::neighbors::detail::tag_index_i32;
86+
};
87+
88+
template <>
89+
struct fused_1nn_index_tag<int64_t> {
90+
using type = cuvs::neighbors::detail::tag_index_i64;
91+
};
92+
93+
template <typename IdxT>
94+
using fused_1nn_index_tag_t = typename fused_1nn_index_tag<IdxT>::type;
95+
96+
template <typename DataTag,
97+
typename MetricTag,
98+
typename IndexTag,
99+
typename TileTag,
100+
typename ArchTag>
79101
struct fragment_tag_fused_1nn_cubin {
80102
static constexpr int cc_major = ArchTag::cc_major;
81103
static constexpr int cc_minor = ArchTag::cc_minor;
82104
};
83105

84-
template <typename DataTag, typename MetricTag, typename TileTag>
106+
template <typename DataTag, typename MetricTag, typename IndexTag, typename TileTag>
85107
struct fragment_tag_fused_1nn_tileir {};
86108

87109
} // namespace cuvs::distance::detail

cpp/src/cluster/detail/kmeans.cuh

Lines changed: 27 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -682,8 +682,8 @@ void kmeans_fit(
682682
DataT* cur_centroids_ptr = cur_centroids_buf.data();
683683
DataT* new_centroids_ptr = new_centroids_buf.data();
684684

685-
auto minClusterAndDistance = raft::make_device_vector<raft::KeyValuePair<IndexT, DataT>, IndexT>(
686-
handle, streaming_batch_size);
685+
auto nearest_idx = raft::make_device_vector<IndexT, IndexT>(handle, streaming_batch_size);
686+
auto nearest_dist = raft::make_device_vector<DataT, IndexT>(handle, streaming_batch_size);
687687
auto L2NormBatch = raft::make_device_vector<DataT, IndexT>(handle, streaming_batch_size);
688688
auto batch_weights_buf = raft::make_device_vector<DataT, IndexT>(handle, streaming_batch_size);
689689
rmm::device_uvector<DataT> L2NormBuf_OR_DistBuf(0, stream);
@@ -853,8 +853,10 @@ void kmeans_fit(
853853
auto batch_weights_view =
854854
cur_batch_weights(static_cast<IndexT>(data_batch.offset()), wt_data, cur_batch_size);
855855

856-
auto minCAD_view = raft::make_device_vector_view<raft::KeyValuePair<IndexT, DataT>, IndexT>(
857-
minClusterAndDistance.data_handle(), cur_batch_size);
856+
auto nearest_idx_view =
857+
raft::make_device_vector_view<IndexT, IndexT>(nearest_idx.data_handle(), cur_batch_size);
858+
auto nearest_dist_view =
859+
raft::make_device_vector_view<DataT, IndexT>(nearest_dist.data_handle(), cur_batch_size);
858860

859861
if constexpr (!data_on_device) {
860862
if (need_compute_norms) {
@@ -883,7 +885,8 @@ void kmeans_fit(
883885
metric,
884886
iter_params.batch_samples,
885887
iter_params.batch_centroids,
886-
minCAD_view,
888+
nearest_idx_view,
889+
nearest_dist_view,
887890
l2_const_view,
888891
L2NormBuf_OR_DistBuf,
889892
ws,
@@ -1071,8 +1074,7 @@ void kmeans_predict(raft::resources const& handle,
10711074
raft::make_const_mdspan(weight.view()));
10721075
}
10731076

1074-
auto minClusterAndDistance =
1075-
raft::make_device_vector<raft::KeyValuePair<IndexT, DataT>, IndexT>(handle, n_samples);
1077+
auto nearest_dist = raft::make_device_vector<DataT, IndexT>(handle, n_samples);
10761078
rmm::device_uvector<DataT> L2NormBuf_OR_DistBuf(0, stream);
10771079

10781080
// L2 norm of X: ||x||^2
@@ -1082,50 +1084,35 @@ void kmeans_predict(raft::resources const& handle,
10821084
raft::linalg::norm<raft::linalg::L2Norm, raft::Apply::ALONG_ROWS>(handle, X, L2NormX.view());
10831085
}
10841086

1085-
// computes minClusterAndDistance[0:n_samples) where minClusterAndDistance[i]
1086-
// is a <key, value> pair where
1087-
// 'key' is index to a sample in 'centroids' (index of the nearest
1088-
// centroid) and 'value' is the distance between the sample 'X[i]' and the
1089-
// 'centroid[key]'
10901087
auto l2normx_view =
10911088
raft::make_device_vector_view<const DataT, IndexT>(L2NormX.data_handle(), n_samples);
1092-
cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute<DataT, IndexT>(
1093-
handle,
1094-
X,
1095-
centroids,
1096-
minClusterAndDistance.view(),
1097-
l2normx_view,
1098-
L2NormBuf_OR_DistBuf,
1099-
pams.metric,
1100-
pams.batch_samples,
1101-
pams.batch_centroids,
1102-
workspace);
1089+
cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute<DataT, IndexT>(handle,
1090+
X,
1091+
centroids,
1092+
labels,
1093+
nearest_dist.view(),
1094+
l2normx_view,
1095+
L2NormBuf_OR_DistBuf,
1096+
pams.metric,
1097+
pams.batch_samples,
1098+
pams.batch_centroids,
1099+
workspace);
11031100

1104-
// calculate cluster cost phi_x(C)
11051101
rmm::device_scalar<DataT> clusterCostD(stream);
1106-
raft::linalg::map(
1107-
handle,
1108-
minClusterAndDistance.view(),
1109-
[=] __device__(const raft::KeyValuePair<IndexT, DataT> kvp, DataT wt) {
1110-
raft::KeyValuePair<IndexT, DataT> res;
1111-
res.value = kvp.value * wt;
1112-
res.key = kvp.key;
1113-
return res;
1114-
},
1115-
raft::make_const_mdspan(minClusterAndDistance.view()),
1116-
raft::make_const_mdspan(weight.view()));
1102+
raft::linalg::map(handle,
1103+
nearest_dist.view(),
1104+
raft::mul_op{},
1105+
raft::make_const_mdspan(nearest_dist.view()),
1106+
raft::make_const_mdspan(weight.view()));
11171107

11181108
cuvs::cluster::kmeans::detail::computeClusterCost(
11191109
handle,
1120-
minClusterAndDistance.view(),
1110+
nearest_dist.view(),
11211111
workspace,
11221112
raft::make_device_scalar_view(clusterCostD.data()),
1123-
raft::value_op{},
1113+
raft::identity_op{},
11241114
raft::add_op{});
11251115

1126-
raft::linalg::map(
1127-
handle, labels, raft::key_op{}, raft::make_const_mdspan(minClusterAndDistance.view()));
1128-
11291116
inertia[0] = clusterCostD.value(stream);
11301117
}
11311118

cpp/src/cluster/detail/kmeans_balanced.cuh

Lines changed: 84 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -98,58 +98,90 @@ inline std::enable_if_t<std::is_floating_point_v<MathT>> predict_core(
9898
raft::make_device_matrix_view<const MathT, IdxT>(centers, n_clusters, dim);
9999
auto X_norm_view = raft::make_device_vector_view<const MathT, IdxT>(dataset_norm, n_rows);
100100

101-
auto minClusterAndDistance = raft::make_device_mdarray<raft::KeyValuePair<IdxT, MathT>, IdxT>(
102-
handle, mr, raft::make_extents<IdxT>(n_rows));
103-
104-
cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute<MathT, IdxT>(
105-
handle,
106-
X_view,
107-
centroids_view,
108-
minClusterAndDistance.view(),
109-
X_norm_view,
110-
L2NormBuf_OR_DistBuf,
111-
params.metric,
112-
0, // batch_samples (unused for fused reduction)
113-
0, // batch_centroids (unused for fused reduction)
114-
workspace);
115-
116-
// Copy keys to output labels
117-
raft::linalg::map(handle,
118-
raft::make_const_mdspan(minClusterAndDistance.view()),
119-
raft::make_device_vector_view<LabelT, IdxT>(labels, n_rows),
120-
raft::compose_op<raft::cast_op<LabelT>, raft::key_op>());
121-
break;
122-
}
123-
case cuvs::distance::DistanceType::InnerProduct: {
124-
if (use_cutile_fused_nn<MathT, IdxT, IdxT>(handle, n_rows, n_clusters, dim)) {
125-
rmm::device_uvector<MathT> L2NormBuf_OR_DistBuf(0, stream, mr);
126-
rmm::device_uvector<char> workspace(0, stream, mr);
127-
128-
auto X_view = raft::make_device_matrix_view<const MathT, IdxT>(dataset, n_rows, dim);
129-
auto centroids_view =
130-
raft::make_device_matrix_view<const MathT, IdxT>(centers, n_clusters, dim);
131-
auto X_norm_view = raft::make_device_vector_view<const MathT, IdxT>(dataset_norm, n_rows);
132-
133-
auto minClusterAndDistance =
134-
raft::make_device_mdarray<raft::KeyValuePair<IdxT, MathT>, IdxT>(
135-
handle, mr, raft::make_extents<IdxT>(n_rows));
101+
auto nearest_dist =
102+
raft::make_device_mdarray<MathT, IdxT>(handle, mr, raft::make_extents<IdxT>(n_rows));
136103

104+
if constexpr (std::is_same_v<LabelT, IdxT>) {
105+
auto labels_view = raft::make_device_vector_view<IdxT, IdxT>(labels, n_rows);
137106
cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute<MathT, IdxT>(
138107
handle,
139108
X_view,
140109
centroids_view,
141-
minClusterAndDistance.view(),
110+
labels_view,
111+
nearest_dist.view(),
112+
X_norm_view,
113+
L2NormBuf_OR_DistBuf,
114+
params.metric,
115+
0, // batch_samples (unused for fused reduction)
116+
0, // batch_centroids (unused for fused reduction)
117+
workspace);
118+
} else {
119+
auto nearest_idx =
120+
raft::make_device_mdarray<IdxT, IdxT>(handle, mr, raft::make_extents<IdxT>(n_rows));
121+
cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute<MathT, IdxT>(
122+
handle,
123+
X_view,
124+
centroids_view,
125+
nearest_idx.view(),
126+
nearest_dist.view(),
142127
X_norm_view,
143128
L2NormBuf_OR_DistBuf,
144129
params.metric,
145130
0,
146131
0,
147132
workspace);
133+
raft::copy(
134+
handle, raft::make_device_vector_view<LabelT, IdxT>(labels, n_rows), nearest_idx.view());
135+
}
136+
break;
137+
}
138+
case cuvs::distance::DistanceType::InnerProduct: {
139+
if (uses_fused_distance_nn(
140+
use_fused<MathT, IdxT, IdxT>(handle, n_rows, n_clusters, dim, params.metric))) {
141+
rmm::device_uvector<MathT> L2NormBuf_OR_DistBuf(0, stream, mr);
142+
rmm::device_uvector<char> workspace(0, stream, mr);
148143

149-
raft::linalg::map(handle,
150-
raft::make_const_mdspan(minClusterAndDistance.view()),
151-
raft::make_device_vector_view<LabelT, IdxT>(labels, n_rows),
152-
raft::compose_op<raft::cast_op<LabelT>, raft::key_op>());
144+
auto X_view = raft::make_device_matrix_view<const MathT, IdxT>(dataset, n_rows, dim);
145+
auto centroids_view =
146+
raft::make_device_matrix_view<const MathT, IdxT>(centers, n_clusters, dim);
147+
auto X_norm_view = raft::make_device_vector_view<const MathT, IdxT>(dataset_norm, n_rows);
148+
149+
auto nearest_dist =
150+
raft::make_device_mdarray<MathT, IdxT>(handle, mr, raft::make_extents<IdxT>(n_rows));
151+
152+
if constexpr (std::is_same_v<LabelT, IdxT>) {
153+
auto labels_view = raft::make_device_vector_view<IdxT, IdxT>(labels, n_rows);
154+
cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute<MathT, IdxT>(
155+
handle,
156+
X_view,
157+
centroids_view,
158+
labels_view,
159+
nearest_dist.view(),
160+
X_norm_view,
161+
L2NormBuf_OR_DistBuf,
162+
params.metric,
163+
0,
164+
0,
165+
workspace);
166+
} else {
167+
auto nearest_idx =
168+
raft::make_device_mdarray<IdxT, IdxT>(handle, mr, raft::make_extents<IdxT>(n_rows));
169+
cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute<MathT, IdxT>(
170+
handle,
171+
X_view,
172+
centroids_view,
173+
nearest_idx.view(),
174+
nearest_dist.view(),
175+
X_norm_view,
176+
L2NormBuf_OR_DistBuf,
177+
params.metric,
178+
0,
179+
0,
180+
workspace);
181+
raft::copy(handle,
182+
raft::make_device_vector_view<LabelT, IdxT>(labels, n_rows),
183+
nearest_idx.view());
184+
}
153185
} else {
154186
rmm::device_uvector<MathT> distances(n_rows * n_clusters, stream, mr);
155187

@@ -216,22 +248,19 @@ auto calc_minibatch_size(const raft::resources& handle,
216248
size_t mem_per_row = 0;
217249
switch (metric) {
218250
case distance::DistanceType::L2Expanded:
219-
case distance::DistanceType::L2SqrtExpanded: {
220-
if (use_fused<MathT, IdxT, IdxT>(handle, n_rows, n_clusters, dim)) {
221-
// fusedL2NN needs a mutex and a key-value pair for each row.
222-
mem_per_row += sizeof(int);
223-
mem_per_row += sizeof(raft::KeyValuePair<IdxT, MathT>);
224-
} else {
225-
// unfused path needs a full GEMM output (distance matrix row).
226-
mem_per_row += sizeof(MathT) * n_clusters;
227-
}
228-
} break;
251+
case distance::DistanceType::L2SqrtExpanded:
229252
case distance::DistanceType::InnerProduct: {
230-
if (use_cutile_fused_nn<MathT, IdxT, IdxT>(handle, n_rows, n_clusters, dim)) {
231-
mem_per_row += sizeof(int);
232-
mem_per_row += sizeof(raft::KeyValuePair<IdxT, MathT>);
233-
} else {
234-
mem_per_row += sizeof(MathT) * n_clusters;
253+
switch (use_fused<MathT, IdxT, IdxT>(handle, n_rows, n_clusters, dim, metric)) {
254+
case FusedDistancePath::FusedCutile: break;
255+
case FusedDistancePath::FusedCutlass:
256+
// fusedDistanceNNMinReduce CUTLASS fallback: mutex workspace + scratch KVP per row.
257+
mem_per_row += sizeof(int);
258+
mem_per_row += sizeof(raft::KeyValuePair<IdxT, MathT>);
259+
break;
260+
case FusedDistancePath::Unfused:
261+
// unfused / GEMM+argmin path needs a full distance matrix row.
262+
mem_per_row += sizeof(MathT) * n_clusters;
263+
break;
235264
}
236265
} break;
237266
// Other metrics require storing a distance matrix.

0 commit comments

Comments
 (0)