From 8a0e86aedd91a7e6a4a465cc9e6815cd589028f7 Mon Sep 17 00:00:00 2001 From: Akira Naruse Date: Mon, 8 Jun 2026 21:04:22 +0900 Subject: [PATCH 1/5] Improve balanced k-means balancing behavior and example - Add balance tolerance and centroid offset parameters - Rework center adjustment to split oversized partitions more effectively - Document tolerance limits for heuristic rebalancing - Add a balanced k-means example with regular k-means comparison --- cpp/include/cuvs/cluster/kmeans.hpp | 19 + cpp/src/cluster/detail/kmeans_balanced.cuh | 168 ++++--- cpp/src/cluster/kmeans_balanced.cuh | 5 +- examples/README.md | 18 + examples/cpp/CMakeLists.txt | 2 + examples/cpp/src/balanced_kmeans_example.cu | 474 ++++++++++++++++++++ fern/pages/cluster/kmeans.md | 2 + 7 files changed, 624 insertions(+), 64 deletions(-) create mode 100644 examples/cpp/src/balanced_kmeans_example.cu diff --git a/cpp/include/cuvs/cluster/kmeans.hpp b/cpp/include/cuvs/cluster/kmeans.hpp index e2b4ea4a36..055c01be1e 100644 --- a/cpp/include/cuvs/cluster/kmeans.hpp +++ b/cpp/include/cuvs/cluster/kmeans.hpp @@ -161,6 +161,25 @@ struct balanced_params : base_params { * Number of training iterations */ uint32_t n_iters = 20; + + /** + * Balance tolerance used during hierarchical training. Clusters no larger than + * `average_cluster_size * balance_tolerance` are underfull. Clusters larger than + * `average_cluster_size / balance_tolerance` are overfull donors. The default value of `0.33` + * targets clusters outside roughly one third to three times the average size. Very strict values + * around `0.7` or higher can be difficult for this heuristic rebalancing method to satisfy. + * + * Valid range: (0, 1). + */ + float balance_tolerance = 0.33f; + + /** + * Offset used when reinitializing an underfull cluster near an overfull cluster. The new center + * is placed at `donor_center + centroid_offset * (donor_point - donor_center)`. + * + * Valid range: (0, 1]. + */ + float centroid_offset = 0.01f; }; /** diff --git a/cpp/src/cluster/detail/kmeans_balanced.cuh b/cpp/src/cluster/detail/kmeans_balanced.cuh index 7fac255810..52c5b09c39 100644 --- a/cpp/src/cluster/detail/kmeans_balanced.cuh +++ b/cpp/src/cluster/detail/kmeans_balanced.cuh @@ -42,15 +42,16 @@ #include #include +#include #include #include #include #include +#include +#include namespace cuvs::cluster::kmeans::detail { -constexpr static inline float kAdjustCentersWeight = 7.0f; - /** * @brief Predict labels for the dataset; floating-point types only. * @@ -459,61 +460,57 @@ template __launch_bounds__((raft::WarpSize * BlockDimY)) RAFT_KERNEL adjust_centers_kernel(MathT* centers, // [n_clusters, dim] - IdxT n_clusters, + IdxT n_pairs, IdxT dim, const T* dataset, // [n_rows, dim] IdxT n_rows, - const LabelT* labels, // [n_rows] - const CounterT* cluster_sizes, // [n_clusters] - MathT threshold, - IdxT average, + const LabelT* labels, // [n_rows] + const IdxT* receiver_clusters, + const IdxT* donor_clusters, + MathT centroid_offset, IdxT seed, - IdxT* count, + IdxT* update_count, MappingOpT mapping_op) { - IdxT l = threadIdx.y + BlockDimY * static_cast(blockIdx.x); - if (l >= n_clusters) return; - auto csize = static_cast(cluster_sizes[l]); - // skip big clusters - if (csize > static_cast(average * threshold)) return; - - // choose a "random" i that belongs to a rather large cluster - IdxT i; - IdxT j = raft::laneId(); - if (j == 0) { - do { - auto old = atomicAdd(count, IdxT{1}); - i = (seed * (old + 1)) % n_rows; - } while (static_cast(cluster_sizes[labels[i]]) < average); + IdxT pair_id = threadIdx.y + BlockDimY * static_cast(blockIdx.x); + if (pair_id >= n_pairs) return; + + auto receiver_cluster = receiver_clusters[pair_id]; + auto donor_cluster = donor_clusters[pair_id]; + IdxT i = n_rows; + IdxT j = raft::laneId(); + for (IdxT attempt = 0; attempt < n_rows; attempt += raft::WarpSize) { + auto candidate = (seed * (attempt + j + 1) + pair_id) % n_rows; + auto found = static_cast(labels[candidate]) == donor_cluster; + auto mask = __ballot_sync(raft::warp_full_mask(), found); + if (mask != 0) { + auto source_lane = __ffs(mask) - 1; + i = raft::shfl(found ? candidate : n_rows, source_lane); + if (j == source_lane) { atomicAdd(update_count, IdxT{1}); } + break; + } } - i = raft::shfl(i, 0); - - // Adjust the center of the selected smaller cluster to gravitate towards - // a sample from the selected larger cluster. - const IdxT li = static_cast(labels[i]); - // Weight of the current center for the weighted average. - // We dump it for anomalously small clusters, but keep constant otherwise. - const MathT wc = min(static_cast(csize), static_cast(kAdjustCentersWeight)); - // Weight for the datapoint used to shift the center. - const MathT wd = 1.0; + if (i >= n_rows) return; + + // Reinitialize the small cluster close to the large cluster centroid, with a small offset towards + // a random donor point so it can split the large partition in the next prediction step. for (; j < dim; j += raft::WarpSize) { - MathT val = 0; - val += wc * centers[j + dim * li]; - val += wd * mapping_op(dataset[j + dim * i]); - val /= wc + wd; - centers[j + dim * l] = val; + auto donor_center = centers[j + dim * donor_cluster]; + auto donor_point = mapping_op(dataset[j + dim * i]); + auto val = donor_center + centroid_offset * (donor_point - donor_center); + centers[j + dim * receiver_cluster] = val; } } /** * @brief Adjust centers for clusters that have small number of entries. * - * For each cluster, where the cluster size is not bigger than a threshold, the center is moved - * towards a data point that belongs to a large cluster. + * Cluster sizes are sorted, then the smallest clusters are paired with the largest clusters. For + * each pair where the small cluster is underfull or the large cluster is overfull, the small + * cluster center is moved towards a data point from the large cluster. * * NB: if this function returns `true`, you should update the labels. * @@ -533,9 +530,11 @@ __launch_bounds__((raft::WarpSize * BlockDimY)) RAFT_KERNEL * @param[in] n_rows number of rows in `dataset` * @param[in] labels a host pointer to the cluster indices [n_rows] * @param[in] cluster_sizes number of rows in each cluster [n_clusters] - * @param[in] threshold defines a criterion for adjusting a cluster - * (cluster_sizes <= average_size * threshold) - * 0 <= threshold < 1 + * @param[in] balance_tolerance defines criteria for adjusting clusters: + * min_cluster_size <= average_size * balance_tolerance, or + * max_cluster_size > average_size / balance_tolerance + * 0 < balance_tolerance < 1 + * @param[in] centroid_offset offset from the donor cluster centroid towards a donor point * @param[in] mapping_op Mapping operation from T to MathT * @param[in] stream CUDA stream * @param[inout] device_memory memory resource to use for temporary allocations @@ -549,13 +548,15 @@ template auto adjust_centers(MathT* centers, + const raft::resources& handle, IdxT n_clusters, IdxT dim, const T* dataset, IdxT n_rows, const LabelT* labels, const CounterT* cluster_sizes, - MathT threshold, + MathT balance_tolerance, + MathT centroid_offset, MappingOpT mapping_op, rmm::cuda_stream_view stream, rmm::device_async_resource_ref device_memory) -> bool @@ -567,36 +568,70 @@ auto adjust_centers(MathT* centers, 601, 659, 733, 809, 863, 941, 1013, 1069, 1151, 1223, 1291, 1373, 1451, 1511, 1583, 1657, 1733, 1811, 1889, 1987, 2053, 2129, 2213, 2287, 2357, 2423, 2531, 2617, 2687, 2741}; - static IdxT i = 0; static IdxT i_primes = 0; - bool adjusted = false; - IdxT average = n_rows / n_clusters; + IdxT average = n_rows / n_clusters; + auto lower_threshold = static_cast(average) * balance_tolerance; + auto upper_threshold = static_cast(average) / balance_tolerance; + std::vector host_cluster_sizes(n_clusters); + raft::update_host(host_cluster_sizes.data(), cluster_sizes, n_clusters, stream); + raft::resource::sync_stream(handle, stream); + + std::vector> sorted_clusters; + sorted_clusters.reserve(n_clusters); + for (IdxT cluster = 0; cluster < n_clusters; ++cluster) { + sorted_clusters.emplace_back(host_cluster_sizes[cluster], cluster); + } + std::sort(sorted_clusters.begin(), sorted_clusters.end()); + + std::vector host_receiver_clusters; + std::vector host_donor_clusters; + host_receiver_clusters.reserve(n_clusters / 2); + host_donor_clusters.reserve(n_clusters / 2); + for (IdxT pair_id = 0; pair_id < n_clusters / 2; ++pair_id) { + auto const& [small_size, small_cluster] = sorted_clusters[pair_id]; + auto const& [large_size, large_cluster] = sorted_clusters[n_clusters - 1 - pair_id]; + if (small_cluster == large_cluster) { break; } + if (static_cast(small_size) > lower_threshold && + static_cast(large_size) <= upper_threshold) { + break; + } + host_receiver_clusters.push_back(small_cluster); + host_donor_clusters.push_back(large_cluster); + } + auto n_pairs = static_cast(host_receiver_clusters.size()); + if (n_pairs == 0) { return false; } + IdxT ofst; do { i_primes = (i_primes + 1) % kPrimes.size(); ofst = kPrimes[i_primes]; } while (n_rows % ofst == 0); + rmm::device_uvector receiver_clusters(n_pairs, stream, device_memory); + rmm::device_uvector donor_clusters(n_pairs, stream, device_memory); + raft::update_device(receiver_clusters.data(), host_receiver_clusters.data(), n_pairs, stream); + raft::update_device(donor_clusters.data(), host_donor_clusters.data(), n_pairs, stream); + constexpr uint32_t kBlockDimY = 4; const dim3 block_dim(raft::WarpSize, kBlockDimY, 1); - const dim3 grid_dim(raft::ceildiv(n_clusters, static_cast(kBlockDimY)), 1, 1); + const dim3 grid_dim(raft::ceildiv(n_pairs, static_cast(kBlockDimY)), 1, 1); rmm::device_scalar update_count(0, stream, device_memory); adjust_centers_kernel<<>>(centers, - n_clusters, + n_pairs, dim, dataset, n_rows, labels, - cluster_sizes, - threshold, - average, + receiver_clusters.data(), + donor_clusters.data(), + centroid_offset, ofst, update_count.data(), mapping_op); - adjusted = update_count.value(stream) > 0; // NB: rmm scalar performs the sync - - return adjusted; + auto n_updates = update_count.value(stream); // NB: rmm scalar performs the sync + RAFT_EXPECTS(n_updates == n_pairs, "Balanced k-means failed to update all adjusted centers"); + return n_updates > 0; } /** @@ -629,9 +664,11 @@ auto adjust_centers(MathT* centers, * one extra iteration is performed (this could happen several times) (default should be `2`). * In other words, the first and then every `ballancing_pullback`-th rebalancing operation adds * one more iteration to the main cycle. - * @param[in] balancing_threshold - * the rebalancing takes place if any cluster is smaller than `avg_size * balancing_threshold` - * on a given iteration (default should be `~ 0.25`). + * @param[in] balance_tolerance + * the lower balance tolerance and the upper donor tolerance. Small clusters are rebalanced when + * their paired small cluster is smaller than `avg_size * balance_tolerance`; if their paired + * large cluster is larger than `avg_size / balance_tolerance`, the small cluster is rebalanced + * towards it. * @param[in] mapping_op Mapping operation from T to MathT * @param[inout] device_memory * A memory resource for device allocations (makes sense to provide a memory pool here) @@ -654,23 +691,30 @@ void balancing_em_iters(const raft::resources& handle, LabelT* cluster_labels, CounterT* cluster_sizes, uint32_t balancing_pullback, - MathT balancing_threshold, + MathT balance_tolerance, MappingOpT mapping_op, rmm::device_async_resource_ref device_memory) { + RAFT_EXPECTS(balance_tolerance > MathT{0} && balance_tolerance < MathT{1}, + "Balanced k-means balance tolerance must be in the range (0, 1)"); + RAFT_EXPECTS(params.centroid_offset > 0.0f && params.centroid_offset <= 1.0f, + "Balanced k-means centroid offset must be in the range (0, 1]"); + auto stream = raft::resource::get_cuda_stream(handle); uint32_t balancing_counter = balancing_pullback; for (uint32_t iter = 0; iter < n_iters; iter++) { // Balancing step - move the centers around to equalize cluster sizes // (but not on the first iteration) if (iter > 0 && adjust_centers(cluster_centers, + handle, n_clusters, dim, dataset, n_rows, cluster_labels, cluster_sizes, - balancing_threshold, + balance_tolerance, + static_cast(params.centroid_offset), mapping_op, stream, device_memory)) { @@ -777,7 +821,7 @@ void build_clusters(const raft::resources& handle, cluster_labels, cluster_sizes, 2, - MathT{0.25}, + static_cast(params.balance_tolerance), mapping_op, device_memory); } @@ -1126,7 +1170,7 @@ void build_hierarchical(const raft::resources& handle, labels.data(), cluster_sizes.data(), 5, - MathT{0.2}, + static_cast(params.balance_tolerance), mapping_op, device_memory); diff --git a/cpp/src/cluster/kmeans_balanced.cuh b/cpp/src/cluster/kmeans_balanced.cuh index f3f52c2d8f..d3fdd21a12 100644 --- a/cpp/src/cluster/kmeans_balanced.cuh +++ b/cpp/src/cluster/kmeans_balanced.cuh @@ -23,8 +23,9 @@ namespace cuvs::cluster::kmeans_balanced { * iterations over the whole dataset and with all the centroids to obtain the final clusters. * * Each k-means iteration applies expectation-maximization-balancing: - * - Balancing: adjust centers for clusters that have a small number of entries. If the size of a - * cluster is below a threshold, the center is moved towards a bigger cluster. + * - Balancing: adjust centers to reduce underfull and overfull clusters. Small clusters are moved + * towards larger clusters; when overfull clusters exist, below-average clusters are moved + * towards those overfull clusters. * - Expectation: predict the labels (i.e find closest cluster centroid to each point) * - Maximization: calculate optimal centroids (i.e find the center of gravity of each cluster) * diff --git a/examples/README.md b/examples/README.md index f5a606ee35..4c1cc9ac4b 100644 --- a/examples/README.md +++ b/examples/README.md @@ -15,3 +15,21 @@ Make sure to link against the appropriate CMake targets. Use `cuvs::c_api` and ` ```cmake target_link_libraries(your_app_target PRIVATE cuvs::cuvs) ``` + +## Balanced k-means example + +`BALANCED_KMEANS_EXAMPLE` partitions a vector database with cuVS balanced k-means. Specify the +dataset path with `-d`, its data type with `-t`, and the desired number of partitions with `-P`: + +```bash +./cpp/build/BALANCED_KMEANS_EXAMPLE -d vectors.bin -t float -P 256 -I 20 -B 0.33 -O 0.01 +``` + +The supported data types are `float`, `half`, `int8`, and `uint8`. The dataset can use the BIGANN +format (`uint32` vector count, `uint32` dimension count, then row-major vectors) or the xvec format. +Use `-I` to set the number of k-means iterations; it defaults to 20. Use `-B` to set the balance +tolerance and `-O` to set the centroid offset used when splitting large partitions; they default to +0.33 and 0.01. The default tolerance targets partitions outside roughly one third to three times the +average partition size. Very strict tolerance values around 0.7 or higher can be difficult for this +heuristic rebalancing method to satisfy. The example prints partition sizes, centroid prefixes, and +the partition assigned to each of the first vector IDs. diff --git a/examples/cpp/CMakeLists.txt b/examples/cpp/CMakeLists.txt index 034b0b3d96..b72e6616ac 100644 --- a/examples/cpp/CMakeLists.txt +++ b/examples/cpp/CMakeLists.txt @@ -30,6 +30,7 @@ include(../cmake/thirdparty/get_cuvs.cmake) # -------------- compile tasks ----------------- # add_executable(BRUTE_FORCE_EXAMPLE src/brute_force_bitmap.cu) +add_executable(BALANCED_KMEANS_EXAMPLE src/balanced_kmeans_example.cu) add_executable(CAGRA_EXAMPLE src/cagra_example.cu) add_executable(CAGRA_HNSW_ACE_EXAMPLE src/cagra_hnsw_ace_example.cu) add_executable(CAGRA_PERSISTENT_EXAMPLE src/cagra_persistent_example.cu) @@ -43,6 +44,7 @@ add_executable(SCANN_EXAMPLE src/scann_example.cu) # `$` is a generator expression that ensures that targets are # installed in a conda environment, if one exists target_link_libraries(BRUTE_FORCE_EXAMPLE PRIVATE cuvs::cuvs $) +target_link_libraries(BALANCED_KMEANS_EXAMPLE PRIVATE cuvs::cuvs $) target_link_libraries(CAGRA_EXAMPLE PRIVATE cuvs::cuvs $) target_link_libraries(CAGRA_HNSW_ACE_EXAMPLE PRIVATE cuvs::cuvs $) target_link_libraries( diff --git a/examples/cpp/src/balanced_kmeans_example.cu b/examples/cpp/src/balanced_kmeans_example.cu new file mode 100644 index 0000000000..8d0b61c30a --- /dev/null +++ b/examples/cpp/src/balanced_kmeans_example.cu @@ -0,0 +1,474 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +const char* argp_docs = "balanced_kmeans_example 0.1"; + +static struct argp_option options[] = { + {"dataset", 'd', "PATH", 0, "Path to dataset file"}, + {"dtype", 't', "TYPE", 0, "Data type [float/half/int8/uint8]"}, + {"partitions", 'P', "INT", 0, "Number of balanced partitions"}, + {"iterations", 'I', "INT", 0, "Number of k-means iterations (default: 20)"}, + {"balance-tolerance", 'B', "FLOAT", 0, "Balance tolerance (default: 0.33)"}, + {"centroid-offset", 'O', "FLOAT", 0, "Centroid offset when splitting partitions (default: 0.01)"}, + {0}}; + +struct arguments { + std::string dataset_path; + std::string dtype; + std::uint32_t n_partitions; + std::uint32_t n_iters; + float balance_tolerance; + float centroid_offset; +}; + +static error_t parse_opt(int key, char* arg, struct argp_state* state) +{ + struct arguments* arguments = reinterpret_cast(state->input); + + switch (key) { + case 'd': arguments->dataset_path = arg; break; + case 't': arguments->dtype = arg; break; + case 'P': arguments->n_partitions = std::stoul(arg); break; + case 'I': arguments->n_iters = std::stoul(arg); break; + case 'B': arguments->balance_tolerance = std::stof(arg); break; + case 'O': arguments->centroid_offset = std::stof(arg); break; + case ARGP_KEY_ARG: break; + case ARGP_KEY_END: break; + default: return ARGP_ERR_UNKNOWN; + } + return 0; +} + +static struct argp argp = {options, parse_opt, nullptr, argp_docs}; + +namespace { + +struct partition_size_stats { + std::vector sorted_sizes; + int64_t min_size; + int64_t max_size; + int64_t underflow_count; + int64_t overflow_count; + double median_size; + double mean_size; + double stddev_size; + double lower_threshold; + double upper_threshold; +}; + +enum dataset_file_format_t { XVECS, BIGANN, AUTO_DETECT }; + +template +struct dataset_descriptor_t { + std::size_t dim; + std::size_t size; + + std::unique_ptr data; + dataset_file_format_t file_format; +}; + +template +void get_dataset_info(dataset_descriptor_t& desc, + std::string const& file_path, + dataset_file_format_t file_format = AUTO_DETECT) +{ + std::ifstream ifs(file_path, std::ios::binary); + if (!ifs) { + throw std::runtime_error("File not exist : " + file_path + " (`" + __func__ + "` in " + + __FILE__ + ")"); + } + + ifs.seekg(0, std::ios::end); + auto const file_size_in_byte = static_cast(ifs.tellg()); + ifs.seekg(0, std::ios::beg); + + std::uint32_t tmp_val[2]; + ifs.read(reinterpret_cast(tmp_val), sizeof(std::uint32_t) * 2); + + desc.file_format = file_format; + if (desc.file_format == AUTO_DETECT) { + if (sizeof(std::uint32_t) * 2 + sizeof(DataT) * tmp_val[0] * tmp_val[1] == file_size_in_byte) { + desc.file_format = BIGANN; + } else { + desc.file_format = XVECS; + } + } + + if (desc.file_format == BIGANN) { + std::fprintf(stderr, "# BIGANN type file (%s)\n", file_path.c_str()); + desc.size = tmp_val[0]; + desc.dim = tmp_val[1]; + } else { + std::fprintf(stderr, "# Xvec type file (%s)\n", file_path.c_str()); + desc.dim = tmp_val[0]; + desc.size = (file_size_in_byte - sizeof(std::uint32_t)) / desc.dim / sizeof(DataT) - 1; + } +} + +template +void load_dataset(dataset_descriptor_t& desc, + std::string const& file_path, + dataset_file_format_t file_format = AUTO_DETECT) +{ + get_dataset_info(desc, file_path, file_format); + std::ifstream ifs(file_path, std::ios::binary); + if (!ifs) { + throw std::runtime_error("File not exist : " + file_path + " (`" + __func__ + "` in " + + __FILE__ + ")"); + } + + auto const array_size = sizeof(DataT) * desc.dim * desc.size; + desc.data = std::make_unique(desc.dim * desc.size); + + if (desc.file_format == BIGANN) { + ifs.seekg(sizeof(std::uint32_t) * 2, std::ios::beg); + ifs.read(reinterpret_cast(desc.data.get()), array_size); + } else { + ifs.seekg(sizeof(std::uint32_t), std::ios::beg); + for (std::size_t i = 0; i < desc.size; i++) { + ifs.seekg(sizeof(std::uint32_t), std::ios::cur); + ifs.read(reinterpret_cast(desc.data.get() + i * desc.dim), sizeof(DataT) * desc.dim); + } + } +} + +template +partition_size_stats compute_partition_size_stats( + raft::device_resources const& resources, + int64_t n_partitions, + raft::device_vector_view labels, + float balance_tolerance) +{ + auto host_labels = raft::make_host_vector(labels.extent(0)); + auto stream = raft::resource::get_cuda_stream(resources); + + raft::copy(host_labels.data_handle(), labels.data_handle(), labels.size(), stream); + raft::resource::sync_stream(resources, stream); + + std::vector partition_sizes(n_partitions, 0); + for (int64_t row = 0; row < labels.extent(0); ++row) { + ++partition_sizes.at(static_cast(host_labels(row))); + } + + std::sort(partition_sizes.begin(), partition_sizes.end()); + + auto minimum = partition_sizes.front(); + auto maximum = partition_sizes.back(); + auto median = + n_partitions % 2 == 0 + ? (partition_sizes[n_partitions / 2 - 1] + partition_sizes[n_partitions / 2]) / 2.0 + : static_cast(partition_sizes[n_partitions / 2]); + auto mean = static_cast(labels.extent(0)) / n_partitions; + auto lower_threshold = mean * balance_tolerance; + auto upper_threshold = mean / balance_tolerance; + auto underflow_count = static_cast( + std::count_if(partition_sizes.begin(), partition_sizes.end(), [lower_threshold](int64_t size) { + return size < lower_threshold; + })); + auto overflow_count = static_cast( + std::count_if(partition_sizes.begin(), partition_sizes.end(), [upper_threshold](int64_t size) { + return size > upper_threshold; + })); + auto variance = + std::accumulate(partition_sizes.begin(), + partition_sizes.end(), + 0.0, + [mean](double sum, int64_t size) { return sum + std::pow(size - mean, 2); }) / + n_partitions; + + return {std::move(partition_sizes), + minimum, + maximum, + underflow_count, + overflow_count, + median, + mean, + std::sqrt(variance), + lower_threshold, + upper_threshold}; +} + +void print_partition_size_stats(std::string const& label, partition_size_stats const& stats) +{ + std::cout << label << " partition size statistics: min=" << stats.min_size + << ", max=" << stats.max_size << ", median=" << stats.median_size + << ", mean=" << stats.mean_size << ", standard deviation=" << stats.stddev_size + << ", min/mean=" << stats.min_size / stats.mean_size + << ", max/mean=" << stats.max_size / stats.mean_size + << ", underflow=" << stats.underflow_count << " (< " << stats.lower_threshold << ")" + << ", overflow=" << stats.overflow_count << " (> " << stats.upper_threshold << ")" + << '\n'; +} + +void print_partition_size_histogram(std::string const& label, + partition_size_stats const& stats, + int64_t histogram_min, + int64_t histogram_max, + int64_t n_bins = 20) +{ + if (stats.sorted_sizes.empty()) { return; } + + std::vector bins(n_bins, 0); + auto const range = static_cast(histogram_max - histogram_min); + if (range == 0.0) { + bins.front() = static_cast(stats.sorted_sizes.size()); + } else { + for (auto size : stats.sorted_sizes) { + auto bin = static_cast((size - histogram_min) / range * n_bins); + bins[std::min(bin, n_bins - 1)]++; + } + } + + auto const max_bin_count = *std::max_element(bins.begin(), bins.end()); + auto const bar_width = int64_t{40}; + auto const bin_width = range / n_bins; + + std::cout << label << " partition size histogram:\n"; + for (int64_t bin = 0; bin < n_bins; ++bin) { + auto const lower = + range == 0.0 ? static_cast(histogram_min) : histogram_min + bin_width * bin; + auto const upper = + range == 0.0 ? static_cast(histogram_max) : histogram_min + bin_width * (bin + 1); + auto const count = bins[bin]; + auto const hashes = + max_bin_count == 0 ? int64_t{0} : std::max(1, count * bar_width / max_bin_count); + + std::cout << " [" << std::setw(8) << static_cast(std::floor(lower)) << ", " + << std::setw(8) << static_cast(std::ceil(upper)) << "] " << std::setw(4) + << count << " | "; + for (int64_t i = 0; i < hashes && count != 0; ++i) { + std::cout << '#'; + } + std::cout << '\n'; + } +} + +void print_balance_improvement(partition_size_stats const& regular_stats, + partition_size_stats const& balanced_stats) +{ + auto const regular_max_ratio = regular_stats.max_size / regular_stats.mean_size; + auto const balanced_max_ratio = balanced_stats.max_size / balanced_stats.mean_size; + auto const regular_stddev = regular_stats.stddev_size; + auto const balanced_stddev = balanced_stats.stddev_size; + + std::cout << "Balance improvement: max/mean " << regular_max_ratio << " -> " << balanced_max_ratio + << ", standard deviation " << regular_stddev << " -> " << balanced_stddev << '\n'; +} + +template +std::optional run_regular_kmeans_comparison( + raft::device_resources const& resources, + raft::device_matrix_view dataset, + int64_t n_partitions, + std::uint32_t n_iters, + float balance_tolerance) +{ + if constexpr (std::is_same_v) { + cuvs::cluster::kmeans::params params; + params.metric = cuvs::distance::DistanceType::L2Expanded; + params.n_clusters = static_cast(n_partitions); + params.max_iter = static_cast(n_iters); + + auto centroids = raft::make_device_matrix( + resources, n_partitions, static_cast(dataset.extent(1))); + auto labels = raft::make_device_vector(resources, dataset.extent(0)); + + float inertia = 0.0f; + int64_t n_iter = 0; + cuvs::cluster::kmeans::fit(resources, + params, + dataset, + std::nullopt, + centroids.view(), + raft::make_host_scalar_view(&inertia), + raft::make_host_scalar_view(&n_iter)); + cuvs::cluster::kmeans::predict(resources, + params, + dataset, + std::nullopt, + raft::make_const_mdspan(centroids.view()), + labels.view(), + false, + raft::make_host_scalar_view(&inertia)); + + return compute_partition_size_stats( + resources, n_partitions, raft::make_const_mdspan(labels.view()), balance_tolerance); + } else { + return std::nullopt; + } +} + +template +void partition_dataset(std::string const& dataset_path, + std::uint32_t n_partitions, + std::uint32_t n_iters, + float balance_tolerance, + float centroid_offset) +{ + raft::device_resources resources; + + dataset_descriptor_t dataset_desc; + load_dataset(dataset_desc, dataset_path); + + auto n_samples = static_cast(dataset_desc.size); + auto n_features = static_cast(dataset_desc.dim); + if (n_partitions > dataset_desc.size) { + throw std::invalid_argument("Number of partitions cannot exceed the number of vectors"); + } + + auto dataset = raft::make_device_matrix(resources, n_samples, n_features); + auto stream = raft::resource::get_cuda_stream(resources); + raft::copy(dataset.data_handle(), dataset_desc.data.get(), dataset.size(), stream); + raft::resource::sync_stream(resources, stream); + dataset_desc.data.reset(); + + std::cout << "Partitioning " << n_samples << " vectors with " << n_features << " dimensions into " + << n_partitions << " balanced partitions\n"; + + cuvs::cluster::kmeans::balanced_params params; + params.metric = cuvs::distance::DistanceType::L2Expanded; + params.n_iters = n_iters; + params.balance_tolerance = balance_tolerance; + params.centroid_offset = centroid_offset; + + auto centroids = raft::make_device_matrix(resources, n_partitions, n_features); + auto labels = raft::make_device_vector(resources, n_samples); + auto dataset_view = raft::make_const_mdspan(dataset.view()); + + auto regular_stats = run_regular_kmeans_comparison( + resources, dataset_view, n_partitions, n_iters, balance_tolerance); + + cuvs::cluster::kmeans::fit(resources, params, dataset_view, centroids.view()); + cuvs::cluster::kmeans::predict( + resources, params, dataset_view, raft::make_const_mdspan(centroids.view()), labels.view()); + + auto balanced_stats = compute_partition_size_stats( + resources, n_partitions, raft::make_const_mdspan(labels.view()), balance_tolerance); + + if (regular_stats.has_value()) { + auto const histogram_min = std::min(regular_stats->min_size, balanced_stats.min_size); + auto const histogram_max = std::max(regular_stats->max_size, balanced_stats.max_size); + print_partition_size_stats("Regular k-means", regular_stats.value()); + print_partition_size_histogram( + "Regular k-means", regular_stats.value(), histogram_min, histogram_max); + print_partition_size_stats("Balanced k-means", balanced_stats); + print_partition_size_histogram( + "Balanced k-means", balanced_stats, histogram_min, histogram_max); + } else { + std::cout << "Regular k-means comparison is only shown for float input in this example.\n"; + print_partition_size_stats("Balanced k-means", balanced_stats); + print_partition_size_histogram( + "Balanced k-means", balanced_stats, balanced_stats.min_size, balanced_stats.max_size); + } + if (regular_stats.has_value()) { + print_balance_improvement(regular_stats.value(), balanced_stats); + } +} + +} // namespace + +int main(int argc, char** argv) +{ + try { + struct arguments args = { + "", /* dataset_path */ + "", /* dtype */ + 0, /* n_partitions */ + 20, /* n_iters */ + 0.33f, /* balance_tolerance */ + 0.01f, /* centroid_offset */ + }; + + argp_parse(&argp, argc, argv, 0, 0, &args); + + std::string error_message; + if (args.dataset_path.empty()) { + error_message += "- Path to dataset file has not been provided (-d)\n"; + } + if (args.dtype.empty()) { error_message += "- Data type has not been provided (-t)\n"; } + if (args.n_partitions == 0) { + error_message += "- Number of partitions must be larger than 0 (-P)\n"; + } + if (args.n_iters == 0) { + error_message += "- Number of k-means iterations must be larger than 0 (-I)\n"; + } + if (!std::isfinite(args.balance_tolerance) || args.balance_tolerance <= 0.0f || + args.balance_tolerance >= 1.0f) { + error_message += "- Balance tolerance must be in the range (0, 1) (-B)\n"; + } + if (!std::isfinite(args.centroid_offset) || args.centroid_offset <= 0.0f || + args.centroid_offset > 1.0f) { + error_message += "- Centroid offset must be in the range (0, 1] (-O)\n"; + } + if (!error_message.empty()) { throw std::invalid_argument(error_message); } + + std::cout << "# dataset_path: " << args.dataset_path << '\n' + << "# dtype: " << args.dtype << '\n' + << "# partitions: " << args.n_partitions << '\n' + << "# iterations: " << args.n_iters << '\n' + << "# balance_tolerance: " << args.balance_tolerance << '\n' + << "# centroid_offset: " << args.centroid_offset << '\n'; + + if (args.dtype == "float") { + partition_dataset(args.dataset_path, + args.n_partitions, + args.n_iters, + args.balance_tolerance, + args.centroid_offset); + } else if (args.dtype == "half") { + partition_dataset(args.dataset_path, + args.n_partitions, + args.n_iters, + args.balance_tolerance, + args.centroid_offset); + } else if (args.dtype == "int8") { + partition_dataset(args.dataset_path, + args.n_partitions, + args.n_iters, + args.balance_tolerance, + args.centroid_offset); + } else if (args.dtype == "uint8") { + partition_dataset(args.dataset_path, + args.n_partitions, + args.n_iters, + args.balance_tolerance, + args.centroid_offset); + } else { + throw std::invalid_argument("Unknown data type: " + args.dtype); + } + } catch (std::exception const& error) { + std::cerr << "Error: " << error.what() << '\n'; + return 1; + } + + return 0; +} diff --git a/fern/pages/cluster/kmeans.md b/fern/pages/cluster/kmeans.md index 58c438f64a..64757f0c47 100644 --- a/fern/pages/cluster/kmeans.md +++ b/fern/pages/cluster/kmeans.md @@ -369,6 +369,8 @@ Balanced K-Means encourages more even cluster sizes. It is useful when clusters | `streaming_batch_size` | `0` | Number of host rows streamed to the GPU per batch. `0` processes all host rows at once. | | `hierarchical` | `false` | Enables hierarchical, balanced K-Means in C and Python. | | `hierarchical_n_iters` | implementation default | Number of training iterations for hierarchical K-Means. | +| `balance_tolerance` | `0.33` | C++ balanced K-Means tolerance for rebalancing clusters during hierarchical training and final global fine-tuning iterations. Small clusters are adjusted when their size is no larger than `average_cluster_size * balance_tolerance`; if overfull clusters exist, below-average clusters are adjusted towards donors larger than `average_cluster_size / balance_tolerance`. The default targets clusters outside roughly one third to three times the average size. Very strict values around `0.7` or higher can be difficult for this heuristic rebalancing method to satisfy. | +| `centroid_offset` | `0.01` | C++ balanced K-Means offset used when reinitializing a small cluster near a large cluster. The new center is placed at `donor_center + centroid_offset * (donor_point - donor_center)`. | ## Tuning From d11ff2644b0afcb99c6a0886962e36cb10052d7f Mon Sep 17 00:00:00 2001 From: Akira Naruse Date: Tue, 9 Jun 2026 10:51:56 +0900 Subject: [PATCH 2/5] Improve balanced k-means tolerance tuning and example reporting Split balanced k-means tolerance into lower and upper bounds so users can control underflow and overflow thresholds independently. Update the balanced k-means example to evaluate multiple tolerance combinations in one run and report clearer partition size statistics, including shared-range histograms. Also update the documentation for the new parameters and their defaults. --- cpp/include/cuvs/cluster/kmeans.hpp | 20 +- cpp/src/cluster/detail/kmeans_balanced.cuh | 48 ++-- examples/README.md | 16 +- examples/cpp/src/balanced_kmeans_example.cu | 245 ++++++++++++++------ fern/pages/cluster/kmeans.md | 3 +- 5 files changed, 226 insertions(+), 106 deletions(-) diff --git a/cpp/include/cuvs/cluster/kmeans.hpp b/cpp/include/cuvs/cluster/kmeans.hpp index 055c01be1e..f466967477 100644 --- a/cpp/include/cuvs/cluster/kmeans.hpp +++ b/cpp/include/cuvs/cluster/kmeans.hpp @@ -163,15 +163,23 @@ struct balanced_params : base_params { uint32_t n_iters = 20; /** - * Balance tolerance used during hierarchical training. Clusters no larger than - * `average_cluster_size * balance_tolerance` are underfull. Clusters larger than - * `average_cluster_size / balance_tolerance` are overfull donors. The default value of `0.33` - * targets clusters outside roughly one third to three times the average size. Very strict values - * around `0.7` or higher can be difficult for this heuristic rebalancing method to satisfy. + * Lower balance tolerance used during hierarchical training. Clusters smaller than + * `average_cluster_size * balance_lower_tolerance` are underfull. The default value of `0.333` + * targets clusters smaller than roughly one third of the average size. * * Valid range: (0, 1). */ - float balance_tolerance = 0.33f; + float balance_lower_tolerance = 0.333f; + + /** + * Upper balance tolerance used during hierarchical training. Clusters larger than + * `average_cluster_size * balance_upper_tolerance` are overfull donors. The default value of + * `3.0` targets clusters larger than roughly three times the average size. Very strict upper + * values around `1.4` or lower can be difficult for this heuristic rebalancing method to satisfy. + * + * Valid range: (1, infinity). + */ + float balance_upper_tolerance = 3.0f; /** * Offset used when reinitializing an underfull cluster near an overfull cluster. The new center diff --git a/cpp/src/cluster/detail/kmeans_balanced.cuh b/cpp/src/cluster/detail/kmeans_balanced.cuh index 52c5b09c39..9d771f1503 100644 --- a/cpp/src/cluster/detail/kmeans_balanced.cuh +++ b/cpp/src/cluster/detail/kmeans_balanced.cuh @@ -530,10 +530,12 @@ __launch_bounds__((raft::WarpSize * BlockDimY)) RAFT_KERNEL * @param[in] n_rows number of rows in `dataset` * @param[in] labels a host pointer to the cluster indices [n_rows] * @param[in] cluster_sizes number of rows in each cluster [n_clusters] - * @param[in] balance_tolerance defines criteria for adjusting clusters: - * min_cluster_size <= average_size * balance_tolerance, or - * max_cluster_size > average_size / balance_tolerance - * 0 < balance_tolerance < 1 + * @param[in] balance_lower_tolerance defines the underfull cluster criterion: + * min_cluster_size < average_size * balance_lower_tolerance + * 0 < balance_lower_tolerance < 1 + * @param[in] balance_upper_tolerance defines the overfull donor cluster criterion: + * max_cluster_size > average_size * balance_upper_tolerance + * balance_upper_tolerance > 1 * @param[in] centroid_offset offset from the donor cluster centroid towards a donor point * @param[in] mapping_op Mapping operation from T to MathT * @param[in] stream CUDA stream @@ -555,7 +557,8 @@ auto adjust_centers(MathT* centers, IdxT n_rows, const LabelT* labels, const CounterT* cluster_sizes, - MathT balance_tolerance, + MathT balance_lower_tolerance, + MathT balance_upper_tolerance, MathT centroid_offset, MappingOpT mapping_op, rmm::cuda_stream_view stream, @@ -571,8 +574,8 @@ auto adjust_centers(MathT* centers, static IdxT i_primes = 0; IdxT average = n_rows / n_clusters; - auto lower_threshold = static_cast(average) * balance_tolerance; - auto upper_threshold = static_cast(average) / balance_tolerance; + auto lower_threshold = static_cast(average) * balance_lower_tolerance; + auto upper_threshold = static_cast(average) * balance_upper_tolerance; std::vector host_cluster_sizes(n_clusters); raft::update_host(host_cluster_sizes.data(), cluster_sizes, n_clusters, stream); raft::resource::sync_stream(handle, stream); @@ -592,7 +595,7 @@ auto adjust_centers(MathT* centers, auto const& [small_size, small_cluster] = sorted_clusters[pair_id]; auto const& [large_size, large_cluster] = sorted_clusters[n_clusters - 1 - pair_id]; if (small_cluster == large_cluster) { break; } - if (static_cast(small_size) > lower_threshold && + if (static_cast(small_size) >= lower_threshold && static_cast(large_size) <= upper_threshold) { break; } @@ -664,11 +667,12 @@ auto adjust_centers(MathT* centers, * one extra iteration is performed (this could happen several times) (default should be `2`). * In other words, the first and then every `ballancing_pullback`-th rebalancing operation adds * one more iteration to the main cycle. - * @param[in] balance_tolerance - * the lower balance tolerance and the upper donor tolerance. Small clusters are rebalanced when - * their paired small cluster is smaller than `avg_size * balance_tolerance`; if their paired - * large cluster is larger than `avg_size / balance_tolerance`, the small cluster is rebalanced - * towards it. + * @param[in] balance_lower_tolerance + * Small clusters are rebalanced when their paired small cluster is smaller than + * `avg_size * balance_lower_tolerance`. + * @param[in] balance_upper_tolerance + * If the paired large cluster is larger than `avg_size * balance_upper_tolerance`, the small + * cluster is rebalanced towards it. * @param[in] mapping_op Mapping operation from T to MathT * @param[inout] device_memory * A memory resource for device allocations (makes sense to provide a memory pool here) @@ -691,12 +695,15 @@ void balancing_em_iters(const raft::resources& handle, LabelT* cluster_labels, CounterT* cluster_sizes, uint32_t balancing_pullback, - MathT balance_tolerance, + MathT balance_lower_tolerance, + MathT balance_upper_tolerance, MappingOpT mapping_op, rmm::device_async_resource_ref device_memory) { - RAFT_EXPECTS(balance_tolerance > MathT{0} && balance_tolerance < MathT{1}, - "Balanced k-means balance tolerance must be in the range (0, 1)"); + RAFT_EXPECTS(balance_lower_tolerance > MathT{0} && balance_lower_tolerance < MathT{1}, + "Balanced k-means lower balance tolerance must be in the range (0, 1)"); + RAFT_EXPECTS(balance_upper_tolerance > MathT{1}, + "Balanced k-means upper balance tolerance must be greater than 1"); RAFT_EXPECTS(params.centroid_offset > 0.0f && params.centroid_offset <= 1.0f, "Balanced k-means centroid offset must be in the range (0, 1]"); @@ -713,7 +720,8 @@ void balancing_em_iters(const raft::resources& handle, n_rows, cluster_labels, cluster_sizes, - balance_tolerance, + balance_lower_tolerance, + balance_upper_tolerance, static_cast(params.centroid_offset), mapping_op, stream, @@ -821,7 +829,8 @@ void build_clusters(const raft::resources& handle, cluster_labels, cluster_sizes, 2, - static_cast(params.balance_tolerance), + static_cast(params.balance_lower_tolerance), + static_cast(params.balance_upper_tolerance), mapping_op, device_memory); } @@ -1170,7 +1179,8 @@ void build_hierarchical(const raft::resources& handle, labels.data(), cluster_sizes.data(), 5, - static_cast(params.balance_tolerance), + static_cast(params.balance_lower_tolerance), + static_cast(params.balance_upper_tolerance), mapping_op, device_memory); diff --git a/examples/README.md b/examples/README.md index 4c1cc9ac4b..3d23209e55 100644 --- a/examples/README.md +++ b/examples/README.md @@ -22,14 +22,16 @@ target_link_libraries(your_app_target PRIVATE cuvs::cuvs) dataset path with `-d`, its data type with `-t`, and the desired number of partitions with `-P`: ```bash -./cpp/build/BALANCED_KMEANS_EXAMPLE -d vectors.bin -t float -P 256 -I 20 -B 0.33 -O 0.01 +./cpp/build/BALANCED_KMEANS_EXAMPLE -d vectors.bin -t float -P 256 -I 20 -L 0.333,0.5 -U 2.0,3.0 -O 0.01 ``` The supported data types are `float`, `half`, `int8`, and `uint8`. The dataset can use the BIGANN format (`uint32` vector count, `uint32` dimension count, then row-major vectors) or the xvec format. -Use `-I` to set the number of k-means iterations; it defaults to 20. Use `-B` to set the balance -tolerance and `-O` to set the centroid offset used when splitting large partitions; they default to -0.33 and 0.01. The default tolerance targets partitions outside roughly one third to three times the -average partition size. Very strict tolerance values around 0.7 or higher can be difficult for this -heuristic rebalancing method to satisfy. The example prints partition sizes, centroid prefixes, and -the partition assigned to each of the first vector IDs. +Use `-I` to set the number of k-means iterations; it defaults to 20. Use `-L` to set one or more +lower balance tolerances, `-U` to set one or more upper balance tolerances, and `-O` to set the +centroid offset used when splitting large partitions; they default to 0.333, 3.0, and 0.01. The +example runs balanced k-means for every `-L` and `-U` combination. The defaults target partitions +outside roughly one third to three times the average partition size. Very strict upper tolerance +values around 1.4 or lower can be difficult for this heuristic rebalancing method to satisfy. The +example prints partition size statistics, underflow/overflow counts, and histograms comparing +regular k-means and balanced k-means for `float` input. diff --git a/examples/cpp/src/balanced_kmeans_example.cu b/examples/cpp/src/balanced_kmeans_example.cu index 8d0b61c30a..4ad395482a 100644 --- a/examples/cpp/src/balanced_kmeans_example.cu +++ b/examples/cpp/src/balanced_kmeans_example.cu @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -37,7 +38,16 @@ static struct argp_option options[] = { {"dtype", 't', "TYPE", 0, "Data type [float/half/int8/uint8]"}, {"partitions", 'P', "INT", 0, "Number of balanced partitions"}, {"iterations", 'I', "INT", 0, "Number of k-means iterations (default: 20)"}, - {"balance-tolerance", 'B', "FLOAT", 0, "Balance tolerance (default: 0.33)"}, + {"balance-lower-tolerance", + 'L', + "FLOATS", + 0, + "Comma-separated lower balance tolerances (default: 0.333)"}, + {"balance-upper-tolerance", + 'U', + "FLOATS", + 0, + "Comma-separated upper balance tolerances (default: 3.0)"}, {"centroid-offset", 'O', "FLOAT", 0, "Centroid offset when splitting partitions (default: 0.01)"}, {0}}; @@ -46,10 +56,26 @@ struct arguments { std::string dtype; std::uint32_t n_partitions; std::uint32_t n_iters; - float balance_tolerance; + std::vector balance_lower_tolerances; + std::vector balance_upper_tolerances; float centroid_offset; }; +std::vector parse_float_list(std::string const& arg) +{ + std::vector values; + std::stringstream ss(arg); + std::string token; + while (std::getline(ss, token, ',')) { + if (token.empty()) { + throw std::invalid_argument("Empty value in comma-separated list: " + arg); + } + values.push_back(std::stof(token)); + } + if (values.empty()) { throw std::invalid_argument("Empty comma-separated list"); } + return values; +} + static error_t parse_opt(int key, char* arg, struct argp_state* state) { struct arguments* arguments = reinterpret_cast(state->input); @@ -59,7 +85,8 @@ static error_t parse_opt(int key, char* arg, struct argp_state* state) case 't': arguments->dtype = arg; break; case 'P': arguments->n_partitions = std::stoul(arg); break; case 'I': arguments->n_iters = std::stoul(arg); break; - case 'B': arguments->balance_tolerance = std::stof(arg); break; + case 'L': arguments->balance_lower_tolerances = parse_float_list(arg); break; + case 'U': arguments->balance_upper_tolerances = parse_float_list(arg); break; case 'O': arguments->centroid_offset = std::stof(arg); break; case ARGP_KEY_ARG: break; case ARGP_KEY_END: break; @@ -166,7 +193,8 @@ partition_size_stats compute_partition_size_stats( raft::device_resources const& resources, int64_t n_partitions, raft::device_vector_view labels, - float balance_tolerance) + float balance_lower_tolerance, + float balance_upper_tolerance) { auto host_labels = raft::make_host_vector(labels.extent(0)); auto stream = raft::resource::get_cuda_stream(resources); @@ -188,8 +216,8 @@ partition_size_stats compute_partition_size_stats( ? (partition_sizes[n_partitions / 2 - 1] + partition_sizes[n_partitions / 2]) / 2.0 : static_cast(partition_sizes[n_partitions / 2]); auto mean = static_cast(labels.extent(0)) / n_partitions; - auto lower_threshold = mean * balance_tolerance; - auto upper_threshold = mean / balance_tolerance; + auto lower_threshold = mean * balance_lower_tolerance; + auto upper_threshold = mean * balance_upper_tolerance; auto underflow_count = static_cast( std::count_if(partition_sizes.begin(), partition_sizes.end(), [lower_threshold](int64_t size) { return size < lower_threshold; @@ -229,22 +257,35 @@ void print_partition_size_stats(std::string const& label, partition_size_stats c << '\n'; } +void print_partition_size_summary(std::string const& label, partition_size_stats const& stats) +{ + std::cout << label << " partition size statistics: min=" << stats.min_size + << ", max=" << stats.max_size << ", median=" << stats.median_size + << ", mean=" << stats.mean_size << ", standard deviation=" << stats.stddev_size + << ", min/mean=" << stats.min_size / stats.mean_size + << ", max/mean=" << stats.max_size / stats.mean_size << '\n'; +} + void print_partition_size_histogram(std::string const& label, partition_size_stats const& stats, int64_t histogram_min, - int64_t histogram_max, + double histogram_upper, int64_t n_bins = 20) { if (stats.sorted_sizes.empty()) { return; } - std::vector bins(n_bins, 0); - auto const range = static_cast(histogram_max - histogram_min); + std::vector bins(n_bins + 1, 0); + auto const range = histogram_upper - histogram_min; if (range == 0.0) { bins.front() = static_cast(stats.sorted_sizes.size()); } else { for (auto size : stats.sorted_sizes) { - auto bin = static_cast((size - histogram_min) / range * n_bins); - bins[std::min(bin, n_bins - 1)]++; + if (static_cast(size) > histogram_upper) { + bins.back()++; + } else { + auto bin = static_cast((size - histogram_min) / range * n_bins); + bins[std::min(bin, n_bins - 1)]++; + } } } @@ -256,8 +297,7 @@ void print_partition_size_histogram(std::string const& label, for (int64_t bin = 0; bin < n_bins; ++bin) { auto const lower = range == 0.0 ? static_cast(histogram_min) : histogram_min + bin_width * bin; - auto const upper = - range == 0.0 ? static_cast(histogram_max) : histogram_min + bin_width * (bin + 1); + auto const upper = range == 0.0 ? histogram_upper : histogram_min + bin_width * (bin + 1); auto const count = bins[bin]; auto const hashes = max_bin_count == 0 ? int64_t{0} : std::max(1, count * bar_width / max_bin_count); @@ -270,6 +310,18 @@ void print_partition_size_histogram(std::string const& label, } std::cout << '\n'; } + + auto const overflow_count = bins.back(); + auto const overflow_hashes = max_bin_count == 0 + ? int64_t{0} + : std::max(1, overflow_count * bar_width / max_bin_count); + std::cout << " (" << std::setw(8) << static_cast(std::ceil(histogram_upper)) << ", " + << std::setw(8) << "inf" + << "] " << std::setw(4) << overflow_count << " | "; + for (int64_t i = 0; i < overflow_hashes && overflow_count != 0; ++i) { + std::cout << '#'; + } + std::cout << '\n'; } void print_balance_improvement(partition_size_stats const& regular_stats, @@ -285,12 +337,11 @@ void print_balance_improvement(partition_size_stats const& regular_stats, } template -std::optional run_regular_kmeans_comparison( - raft::device_resources const& resources, - raft::device_matrix_view dataset, - int64_t n_partitions, - std::uint32_t n_iters, - float balance_tolerance) +bool run_regular_kmeans(raft::device_resources const& resources, + raft::device_matrix_view dataset, + int64_t n_partitions, + std::uint32_t n_iters, + raft::device_vector_view labels) { if constexpr (std::is_same_v) { cuvs::cluster::kmeans::params params; @@ -300,7 +351,6 @@ std::optional run_regular_kmeans_comparison( auto centroids = raft::make_device_matrix( resources, n_partitions, static_cast(dataset.extent(1))); - auto labels = raft::make_device_vector(resources, dataset.extent(0)); float inertia = 0.0f; int64_t n_iter = 0; @@ -316,14 +366,13 @@ std::optional run_regular_kmeans_comparison( dataset, std::nullopt, raft::make_const_mdspan(centroids.view()), - labels.view(), + labels, false, raft::make_host_scalar_view(&inertia)); - return compute_partition_size_stats( - resources, n_partitions, raft::make_const_mdspan(labels.view()), balance_tolerance); + return true; } else { - return std::nullopt; + return false; } } @@ -331,7 +380,8 @@ template void partition_dataset(std::string const& dataset_path, std::uint32_t n_partitions, std::uint32_t n_iters, - float balance_tolerance, + std::vector const& balance_lower_tolerances, + std::vector const& balance_upper_tolerances, float centroid_offset) { raft::device_resources resources; @@ -354,43 +404,71 @@ void partition_dataset(std::string const& dataset_path, std::cout << "Partitioning " << n_samples << " vectors with " << n_features << " dimensions into " << n_partitions << " balanced partitions\n"; - cuvs::cluster::kmeans::balanced_params params; - params.metric = cuvs::distance::DistanceType::L2Expanded; - params.n_iters = n_iters; - params.balance_tolerance = balance_tolerance; - params.centroid_offset = centroid_offset; - - auto centroids = raft::make_device_matrix(resources, n_partitions, n_features); - auto labels = raft::make_device_vector(resources, n_samples); - auto dataset_view = raft::make_const_mdspan(dataset.view()); - - auto regular_stats = run_regular_kmeans_comparison( - resources, dataset_view, n_partitions, n_iters, balance_tolerance); - - cuvs::cluster::kmeans::fit(resources, params, dataset_view, centroids.view()); - cuvs::cluster::kmeans::predict( - resources, params, dataset_view, raft::make_const_mdspan(centroids.view()), labels.view()); - - auto balanced_stats = compute_partition_size_stats( - resources, n_partitions, raft::make_const_mdspan(labels.view()), balance_tolerance); - - if (regular_stats.has_value()) { - auto const histogram_min = std::min(regular_stats->min_size, balanced_stats.min_size); - auto const histogram_max = std::max(regular_stats->max_size, balanced_stats.max_size); - print_partition_size_stats("Regular k-means", regular_stats.value()); + auto centroids = raft::make_device_matrix(resources, n_partitions, n_features); + auto labels = raft::make_device_vector(resources, n_samples); + auto regular_labels = raft::make_device_vector(resources, n_samples); + auto dataset_view = raft::make_const_mdspan(dataset.view()); + + auto const has_regular_stats = run_regular_kmeans( + resources, dataset_view, n_partitions, n_iters, regular_labels.view()); + std::optional regular_reference_stats; + if (has_regular_stats) { + regular_reference_stats = + compute_partition_size_stats(resources, + n_partitions, + raft::make_const_mdspan(regular_labels.view()), + balance_lower_tolerances.front(), + balance_upper_tolerances.front()); + print_partition_size_summary("Regular k-means", regular_reference_stats.value()); print_partition_size_histogram( - "Regular k-means", regular_stats.value(), histogram_min, histogram_max); - print_partition_size_stats("Balanced k-means", balanced_stats); - print_partition_size_histogram( - "Balanced k-means", balanced_stats, histogram_min, histogram_max); + "Regular k-means", + regular_reference_stats.value(), + regular_reference_stats->min_size, + regular_reference_stats->mean_size + 2.0 * regular_reference_stats->stddev_size); } else { std::cout << "Regular k-means comparison is only shown for float input in this example.\n"; - print_partition_size_stats("Balanced k-means", balanced_stats); - print_partition_size_histogram( - "Balanced k-means", balanced_stats, balanced_stats.min_size, balanced_stats.max_size); } - if (regular_stats.has_value()) { - print_balance_improvement(regular_stats.value(), balanced_stats); + + for (auto balance_lower_tolerance : balance_lower_tolerances) { + for (auto balance_upper_tolerance : balance_upper_tolerances) { + std::cout << "\n# balance_lower_tolerance: " << balance_lower_tolerance << '\n' + << "# balance_upper_tolerance: " << balance_upper_tolerance << '\n'; + + cuvs::cluster::kmeans::balanced_params params; + params.metric = cuvs::distance::DistanceType::L2Expanded; + params.n_iters = n_iters; + params.balance_lower_tolerance = balance_lower_tolerance; + params.balance_upper_tolerance = balance_upper_tolerance; + params.centroid_offset = centroid_offset; + + cuvs::cluster::kmeans::fit(resources, params, dataset_view, centroids.view()); + cuvs::cluster::kmeans::predict( + resources, params, dataset_view, raft::make_const_mdspan(centroids.view()), labels.view()); + + auto balanced_stats = compute_partition_size_stats(resources, + n_partitions, + raft::make_const_mdspan(labels.view()), + balance_lower_tolerance, + balance_upper_tolerance); + + if (has_regular_stats) { + auto const& regular_stats = regular_reference_stats.value(); + auto const histogram_min = std::min(regular_stats.min_size, balanced_stats.min_size); + auto const histogram_upper = + std::max(regular_stats.mean_size + 2.0 * regular_stats.stddev_size, + balanced_stats.mean_size + 2.0 * balanced_stats.stddev_size); + print_partition_size_stats("Balanced k-means", balanced_stats); + print_partition_size_histogram( + "Balanced k-means", balanced_stats, histogram_min, histogram_upper); + print_balance_improvement(regular_stats, balanced_stats); + } else { + print_partition_size_stats("Balanced k-means", balanced_stats); + print_partition_size_histogram("Balanced k-means", + balanced_stats, + balanced_stats.min_size, + balanced_stats.mean_size + 2.0 * balanced_stats.stddev_size); + } + } } } @@ -400,12 +478,13 @@ int main(int argc, char** argv) { try { struct arguments args = { - "", /* dataset_path */ - "", /* dtype */ - 0, /* n_partitions */ - 20, /* n_iters */ - 0.33f, /* balance_tolerance */ - 0.01f, /* centroid_offset */ + "", /* dataset_path */ + "", /* dtype */ + 0, /* n_partitions */ + 20, /* n_iters */ + {0.333f}, /* balance_lower_tolerances */ + {3.0f}, /* balance_upper_tolerances */ + 0.01f, /* centroid_offset */ }; argp_parse(&argp, argc, argv, 0, 0, &args); @@ -421,9 +500,18 @@ int main(int argc, char** argv) if (args.n_iters == 0) { error_message += "- Number of k-means iterations must be larger than 0 (-I)\n"; } - if (!std::isfinite(args.balance_tolerance) || args.balance_tolerance <= 0.0f || - args.balance_tolerance >= 1.0f) { - error_message += "- Balance tolerance must be in the range (0, 1) (-B)\n"; + for (auto balance_lower_tolerance : args.balance_lower_tolerances) { + if (!std::isfinite(balance_lower_tolerance) || balance_lower_tolerance <= 0.0f || + balance_lower_tolerance >= 1.0f) { + error_message += "- Lower balance tolerances must be in the range (0, 1) (-L)\n"; + break; + } + } + for (auto balance_upper_tolerance : args.balance_upper_tolerances) { + if (!std::isfinite(balance_upper_tolerance) || balance_upper_tolerance <= 1.0f) { + error_message += "- Upper balance tolerances must be greater than 1 (-U)\n"; + break; + } } if (!std::isfinite(args.centroid_offset) || args.centroid_offset <= 0.0f || args.centroid_offset > 1.0f) { @@ -435,32 +523,43 @@ int main(int argc, char** argv) << "# dtype: " << args.dtype << '\n' << "# partitions: " << args.n_partitions << '\n' << "# iterations: " << args.n_iters << '\n' - << "# balance_tolerance: " << args.balance_tolerance << '\n' - << "# centroid_offset: " << args.centroid_offset << '\n'; + << "# balance_lower_tolerances:"; + for (auto value : args.balance_lower_tolerances) { + std::cout << ' ' << value; + } + std::cout << '\n' << "# balance_upper_tolerances:"; + for (auto value : args.balance_upper_tolerances) { + std::cout << ' ' << value; + } + std::cout << '\n' << "# centroid_offset: " << args.centroid_offset << '\n'; if (args.dtype == "float") { partition_dataset(args.dataset_path, args.n_partitions, args.n_iters, - args.balance_tolerance, + args.balance_lower_tolerances, + args.balance_upper_tolerances, args.centroid_offset); } else if (args.dtype == "half") { partition_dataset(args.dataset_path, args.n_partitions, args.n_iters, - args.balance_tolerance, + args.balance_lower_tolerances, + args.balance_upper_tolerances, args.centroid_offset); } else if (args.dtype == "int8") { partition_dataset(args.dataset_path, args.n_partitions, args.n_iters, - args.balance_tolerance, + args.balance_lower_tolerances, + args.balance_upper_tolerances, args.centroid_offset); } else if (args.dtype == "uint8") { partition_dataset(args.dataset_path, args.n_partitions, args.n_iters, - args.balance_tolerance, + args.balance_lower_tolerances, + args.balance_upper_tolerances, args.centroid_offset); } else { throw std::invalid_argument("Unknown data type: " + args.dtype); diff --git a/fern/pages/cluster/kmeans.md b/fern/pages/cluster/kmeans.md index 64757f0c47..245305f826 100644 --- a/fern/pages/cluster/kmeans.md +++ b/fern/pages/cluster/kmeans.md @@ -369,7 +369,8 @@ Balanced K-Means encourages more even cluster sizes. It is useful when clusters | `streaming_batch_size` | `0` | Number of host rows streamed to the GPU per batch. `0` processes all host rows at once. | | `hierarchical` | `false` | Enables hierarchical, balanced K-Means in C and Python. | | `hierarchical_n_iters` | implementation default | Number of training iterations for hierarchical K-Means. | -| `balance_tolerance` | `0.33` | C++ balanced K-Means tolerance for rebalancing clusters during hierarchical training and final global fine-tuning iterations. Small clusters are adjusted when their size is no larger than `average_cluster_size * balance_tolerance`; if overfull clusters exist, below-average clusters are adjusted towards donors larger than `average_cluster_size / balance_tolerance`. The default targets clusters outside roughly one third to three times the average size. Very strict values around `0.7` or higher can be difficult for this heuristic rebalancing method to satisfy. | +| `balance_lower_tolerance` | `0.333` | C++ balanced K-Means lower tolerance for rebalancing clusters during hierarchical training and final global fine-tuning iterations. Small clusters are adjusted when their size is smaller than `average_cluster_size * balance_lower_tolerance`. The default targets clusters smaller than roughly one third of the average size. | +| `balance_upper_tolerance` | `3.0` | C++ balanced K-Means upper tolerance for selecting overfull donor clusters during hierarchical training and final global fine-tuning iterations. Donor clusters are selected when their size is larger than `average_cluster_size * balance_upper_tolerance`. The default targets clusters larger than roughly three times the average size. Very strict upper tolerance values around `1.4` or lower can be difficult for this heuristic rebalancing method to satisfy. | | `centroid_offset` | `0.01` | C++ balanced K-Means offset used when reinitializing a small cluster near a large cluster. The new center is placed at `donor_center + centroid_offset * (donor_point - donor_center)`. | ## Tuning From 30f45025690ac388284a75470a5cfd0ee0229ef2 Mon Sep 17 00:00:00 2001 From: Akira Naruse Date: Wed, 10 Jun 2026 12:55:08 +0900 Subject: [PATCH 3/5] Fix balanced k-means example wording --- examples/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/README.md b/examples/README.md index 3d23209e55..60c1dde883 100644 --- a/examples/README.md +++ b/examples/README.md @@ -31,7 +31,7 @@ Use `-I` to set the number of k-means iterations; it defaults to 20. Use `-L` to lower balance tolerances, `-U` to set one or more upper balance tolerances, and `-O` to set the centroid offset used when splitting large partitions; they default to 0.333, 3.0, and 0.01. The example runs balanced k-means for every `-L` and `-U` combination. The defaults target partitions -outside roughly one third to three times the average partition size. Very strict upper tolerance +outside roughly one-third to three times the average partition size. Very strict upper tolerance values around 1.4 or lower can be difficult for this heuristic rebalancing method to satisfy. The example prints partition size statistics, underflow/overflow counts, and histograms comparing regular k-means and balanced k-means for `float` input. From 519316dd50b75bb91e4684b51719f7651e27a92b Mon Sep 17 00:00:00 2001 From: Akira Naruse Date: Mon, 15 Jun 2026 16:51:16 +0900 Subject: [PATCH 4/5] Address balanced k-means review feedback Compute balance thresholds from a floating-point average, avoid pairing against empty donor clusters, and perform candidate index arithmetic with int64_t intermediates. Fix xvec dataset handling in the balanced k-means example by accounting for per-row dimension headers and reading each row from the beginning. --- cpp/src/cluster/detail/kmeans_balanced.cuh | 16 ++++++++++------ examples/cpp/src/balanced_kmeans_example.cu | 17 +++++++++++++---- 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans_balanced.cuh b/cpp/src/cluster/detail/kmeans_balanced.cuh index 9d771f1503..c8b790c4ea 100644 --- a/cpp/src/cluster/detail/kmeans_balanced.cuh +++ b/cpp/src/cluster/detail/kmeans_balanced.cuh @@ -483,9 +483,12 @@ __launch_bounds__((raft::WarpSize * BlockDimY)) RAFT_KERNEL IdxT i = n_rows; IdxT j = raft::laneId(); for (IdxT attempt = 0; attempt < n_rows; attempt += raft::WarpSize) { - auto candidate = (seed * (attempt + j + 1) + pair_id) % n_rows; - auto found = static_cast(labels[candidate]) == donor_cluster; - auto mask = __ballot_sync(raft::warp_full_mask(), found); + auto candidate = + static_cast((static_cast(seed) * static_cast(attempt + j + 1) + + static_cast(pair_id)) % + static_cast(n_rows)); + auto found = static_cast(labels[candidate]) == donor_cluster; + auto mask = __ballot_sync(raft::warp_full_mask(), found); if (mask != 0) { auto source_lane = __ffs(mask) - 1; i = raft::shfl(found ? candidate : n_rows, source_lane); @@ -573,9 +576,9 @@ auto adjust_centers(MathT* centers, 2053, 2129, 2213, 2287, 2357, 2423, 2531, 2617, 2687, 2741}; static IdxT i_primes = 0; - IdxT average = n_rows / n_clusters; - auto lower_threshold = static_cast(average) * balance_lower_tolerance; - auto upper_threshold = static_cast(average) * balance_upper_tolerance; + auto average = static_cast(n_rows) / static_cast(n_clusters); + auto lower_threshold = average * balance_lower_tolerance; + auto upper_threshold = average * balance_upper_tolerance; std::vector host_cluster_sizes(n_clusters); raft::update_host(host_cluster_sizes.data(), cluster_sizes, n_clusters, stream); raft::resource::sync_stream(handle, stream); @@ -595,6 +598,7 @@ auto adjust_centers(MathT* centers, auto const& [small_size, small_cluster] = sorted_clusters[pair_id]; auto const& [large_size, large_cluster] = sorted_clusters[n_clusters - 1 - pair_id]; if (small_cluster == large_cluster) { break; } + if (large_size == 0) { break; } if (static_cast(small_size) >= lower_threshold && static_cast(large_size) <= upper_threshold) { break; diff --git a/examples/cpp/src/balanced_kmeans_example.cu b/examples/cpp/src/balanced_kmeans_example.cu index 4ad395482a..3c55994f48 100644 --- a/examples/cpp/src/balanced_kmeans_example.cu +++ b/examples/cpp/src/balanced_kmeans_example.cu @@ -156,8 +156,13 @@ void get_dataset_info(dataset_descriptor_t& desc, desc.dim = tmp_val[1]; } else { std::fprintf(stderr, "# Xvec type file (%s)\n", file_path.c_str()); - desc.dim = tmp_val[0]; - desc.size = (file_size_in_byte - sizeof(std::uint32_t)) / desc.dim / sizeof(DataT) - 1; + desc.dim = tmp_val[0]; + auto const row_size = + sizeof(std::uint32_t) + sizeof(DataT) * static_cast(desc.dim); + if (row_size == 0 || file_size_in_byte % row_size != 0) { + throw std::runtime_error("Invalid Xvec file size : " + file_path); + } + desc.size = file_size_in_byte / row_size; } } @@ -180,12 +185,16 @@ void load_dataset(dataset_descriptor_t& desc, ifs.seekg(sizeof(std::uint32_t) * 2, std::ios::beg); ifs.read(reinterpret_cast(desc.data.get()), array_size); } else { - ifs.seekg(sizeof(std::uint32_t), std::ios::beg); for (std::size_t i = 0; i < desc.size; i++) { - ifs.seekg(sizeof(std::uint32_t), std::ios::cur); + std::uint32_t row_dim = 0; + ifs.read(reinterpret_cast(&row_dim), sizeof(row_dim)); + if (row_dim != desc.dim) { + throw std::runtime_error("Inconsistent Xvec dimension in : " + file_path); + } ifs.read(reinterpret_cast(desc.data.get() + i * desc.dim), sizeof(DataT) * desc.dim); } } + if (!ifs) { throw std::runtime_error("Failed to read dataset : " + file_path); } } template From 859d7259dfcc57b73ed34146dbc744d5f9903b90 Mon Sep 17 00:00:00 2001 From: Akira Naruse Date: Tue, 30 Jun 2026 19:33:59 +0900 Subject: [PATCH 5/5] Address review comment on adjust_centers args --- cpp/src/cluster/detail/kmeans_balanced.cuh | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans_balanced.cuh b/cpp/src/cluster/detail/kmeans_balanced.cuh index 7516d29af3..4ecc80bd4e 100644 --- a/cpp/src/cluster/detail/kmeans_balanced.cuh +++ b/cpp/src/cluster/detail/kmeans_balanced.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 */ @@ -526,6 +526,7 @@ __launch_bounds__((raft::WarpSize * BlockDimY)) RAFT_KERNEL * @tparam CounterT counter type supported by CUDA's native atomicAdd * @tparam MappingOpT type of the mapping operation * + * @param[in] handle The raft handle * @param[inout] centers cluster centers [n_clusters, dim] * @param[in] n_clusters number of rows in `centers` * @param[in] dim number of columns in `centers` and `dataset` @@ -541,7 +542,6 @@ __launch_bounds__((raft::WarpSize * BlockDimY)) RAFT_KERNEL * balance_upper_tolerance > 1 * @param[in] centroid_offset offset from the donor cluster centroid towards a donor point * @param[in] mapping_op Mapping operation from T to MathT - * @param[in] stream CUDA stream * @param[inout] device_memory memory resource to use for temporary allocations * * @return whether any of the centers has been updated (and thus, `labels` need to be recalculated). @@ -552,8 +552,8 @@ template -auto adjust_centers(MathT* centers, - const raft::resources& handle, +auto adjust_centers(const raft::resources& handle, + MathT* centers, IdxT n_clusters, IdxT dim, const T* dataset, @@ -564,12 +564,12 @@ auto adjust_centers(MathT* centers, MathT balance_upper_tolerance, MathT centroid_offset, MappingOpT mapping_op, - rmm::cuda_stream_view stream, rmm::device_async_resource_ref device_memory) -> bool { raft::common::nvtx::range fun_scope( "adjust_centers(%zu, %u)", static_cast(n_rows), n_clusters); if (n_clusters == 0) { return false; } + auto stream = raft::resource::get_cuda_stream(handle); constexpr static std::array kPrimes{29, 71, 113, 173, 229, 281, 349, 409, 463, 541, 601, 659, 733, 809, 863, 941, 1013, 1069, 1151, 1223, 1291, 1373, 1451, 1511, 1583, 1657, 1733, 1811, 1889, 1987, @@ -711,13 +711,12 @@ void balancing_em_iters(const raft::resources& handle, RAFT_EXPECTS(params.centroid_offset > 0.0f && params.centroid_offset <= 1.0f, "Balanced k-means centroid offset must be in the range (0, 1]"); - auto stream = raft::resource::get_cuda_stream(handle); uint32_t balancing_counter = balancing_pullback; for (uint32_t iter = 0; iter < n_iters; iter++) { // Balancing step - move the centers around to equalize cluster sizes // (but not on the first iteration) - if (iter > 0 && adjust_centers(cluster_centers, - handle, + if (iter > 0 && adjust_centers(handle, + cluster_centers, n_clusters, dim, dataset, @@ -728,7 +727,6 @@ void balancing_em_iters(const raft::resources& handle, balance_upper_tolerance, static_cast(params.centroid_offset), mapping_op, - stream, device_memory)) { if (balancing_counter++ >= balancing_pullback) { balancing_counter -= balancing_pullback;