Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions cpp/include/cuvs/cluster/kmeans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,33 @@ struct balanced_params : base_params {
* Number of training iterations
*/
uint32_t n_iters = 20;

/**
* Lower balance tolerance used during hierarchical training. Clusters smaller than
* `average_cluster_size * balance_lower_tolerance` are underfull. The default value of `0.333`
* targets clusters smaller than roughly one third of the average size.
*
* Valid range: (0, 1).
*/
float balance_lower_tolerance = 0.333f;

/**
* Upper balance tolerance used during hierarchical training. Clusters larger than
* `average_cluster_size * balance_upper_tolerance` are overfull donors. The default value of
* `3.0` targets clusters larger than roughly three times the average size. Very strict upper
* values around `1.4` or lower can be difficult for this heuristic rebalancing method to satisfy.
*
* Valid range: (1, infinity).
*/
float balance_upper_tolerance = 3.0f;

/**
* Offset used when reinitializing an underfull cluster near an overfull cluster. The new center
* is placed at `donor_center + centroid_offset * (donor_point - donor_center)`.
*
* Valid range: (0, 1].
*/
float centroid_offset = 0.01f;
};

/**
Expand Down
194 changes: 125 additions & 69 deletions cpp/src/cluster/detail/kmeans_balanced.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/

Expand Down Expand Up @@ -42,15 +42,16 @@
#include <thrust/iterator/transform_iterator.h>
#include <thrust/transform.h>

#include <algorithm>
#include <limits>
#include <optional>
#include <tuple>
#include <type_traits>
#include <utility>
#include <vector>

namespace cuvs::cluster::kmeans::detail {

constexpr static inline float kAdjustCentersWeight = 7.0f;

/**
* @brief Predict labels for the dataset; floating-point types only.
*
Expand Down Expand Up @@ -459,61 +460,60 @@ template <uint32_t BlockDimY,
typename MathT,
typename IdxT,
typename LabelT,
typename CounterT,
typename MappingOpT>
__launch_bounds__((raft::WarpSize * BlockDimY)) RAFT_KERNEL
adjust_centers_kernel(MathT* centers, // [n_clusters, dim]
IdxT n_clusters,
IdxT n_pairs,
IdxT dim,
const T* dataset, // [n_rows, dim]
IdxT n_rows,
const LabelT* labels, // [n_rows]
const CounterT* cluster_sizes, // [n_clusters]
MathT threshold,
IdxT average,
const LabelT* labels, // [n_rows]
const IdxT* receiver_clusters,
const IdxT* donor_clusters,
MathT centroid_offset,
IdxT seed,
IdxT* count,
IdxT* update_count,
MappingOpT mapping_op)
{
IdxT l = threadIdx.y + BlockDimY * static_cast<IdxT>(blockIdx.x);
if (l >= n_clusters) return;
auto csize = static_cast<IdxT>(cluster_sizes[l]);
// skip big clusters
if (csize > static_cast<IdxT>(average * threshold)) return;

// choose a "random" i that belongs to a rather large cluster
IdxT i;
IdxT j = raft::laneId();
if (j == 0) {
do {
auto old = atomicAdd(count, IdxT{1});
i = (seed * (old + 1)) % n_rows;
} while (static_cast<IdxT>(cluster_sizes[labels[i]]) < average);
IdxT pair_id = threadIdx.y + BlockDimY * static_cast<IdxT>(blockIdx.x);
if (pair_id >= n_pairs) return;

auto receiver_cluster = receiver_clusters[pair_id];
auto donor_cluster = donor_clusters[pair_id];
IdxT i = n_rows;
IdxT j = raft::laneId();
for (IdxT attempt = 0; attempt < n_rows; attempt += raft::WarpSize) {
auto candidate =
static_cast<IdxT>((static_cast<int64_t>(seed) * static_cast<int64_t>(attempt + j + 1) +
static_cast<int64_t>(pair_id)) %
static_cast<int64_t>(n_rows));
auto found = static_cast<IdxT>(labels[candidate]) == donor_cluster;
auto mask = __ballot_sync(raft::warp_full_mask(), found);
if (mask != 0) {
auto source_lane = __ffs(mask) - 1;
i = raft::shfl(found ? candidate : n_rows, source_lane);
if (j == source_lane) { atomicAdd(update_count, IdxT{1}); }
break;
}
}
i = raft::shfl(i, 0);

// Adjust the center of the selected smaller cluster to gravitate towards
// a sample from the selected larger cluster.
const IdxT li = static_cast<IdxT>(labels[i]);
// Weight of the current center for the weighted average.
// We dump it for anomalously small clusters, but keep constant otherwise.
const MathT wc = min(static_cast<MathT>(csize), static_cast<MathT>(kAdjustCentersWeight));
// Weight for the datapoint used to shift the center.
const MathT wd = 1.0;
if (i >= n_rows) return;

// Reinitialize the small cluster close to the large cluster centroid, with a small offset towards
// a random donor point so it can split the large partition in the next prediction step.
for (; j < dim; j += raft::WarpSize) {
MathT val = 0;
val += wc * centers[j + dim * li];
val += wd * mapping_op(dataset[j + dim * i]);
val /= wc + wd;
centers[j + dim * l] = val;
auto donor_center = centers[j + dim * donor_cluster];
auto donor_point = mapping_op(dataset[j + dim * i]);
auto val = donor_center + centroid_offset * (donor_point - donor_center);
centers[j + dim * receiver_cluster] = val;
}
}

/**
* @brief Adjust centers for clusters that have small number of entries.
*
* For each cluster, where the cluster size is not bigger than a threshold, the center is moved
* towards a data point that belongs to a large cluster.
* Cluster sizes are sorted, then the smallest clusters are paired with the largest clusters. For
* each pair where the small cluster is underfull or the large cluster is overfull, the small
* cluster center is moved towards a data point from the large cluster.
*
* NB: if this function returns `true`, you should update the labels.
*
Expand All @@ -526,18 +526,22 @@ __launch_bounds__((raft::WarpSize * BlockDimY)) RAFT_KERNEL
* @tparam CounterT counter type supported by CUDA's native atomicAdd
* @tparam MappingOpT type of the mapping operation
*
* @param[in] handle The raft handle
* @param[inout] centers cluster centers [n_clusters, dim]
* @param[in] n_clusters number of rows in `centers`
* @param[in] dim number of columns in `centers` and `dataset`
* @param[in] dataset a host pointer to the row-major data matrix [n_rows, dim]
* @param[in] n_rows number of rows in `dataset`
* @param[in] labels a host pointer to the cluster indices [n_rows]
* @param[in] cluster_sizes number of rows in each cluster [n_clusters]
* @param[in] threshold defines a criterion for adjusting a cluster
* (cluster_sizes <= average_size * threshold)
* 0 <= threshold < 1
* @param[in] balance_lower_tolerance defines the underfull cluster criterion:
* min_cluster_size < average_size * balance_lower_tolerance
* 0 < balance_lower_tolerance < 1
* @param[in] balance_upper_tolerance defines the overfull donor cluster criterion:
* max_cluster_size > average_size * balance_upper_tolerance
* balance_upper_tolerance > 1
* @param[in] centroid_offset offset from the donor cluster centroid towards a donor point
* @param[in] mapping_op Mapping operation from T to MathT
* @param[in] stream CUDA stream
* @param[inout] device_memory memory resource to use for temporary allocations
*
* @return whether any of the centers has been updated (and thus, `labels` need to be recalculated).
Expand All @@ -548,55 +552,93 @@ template <typename T,
typename LabelT,
typename CounterT,
typename MappingOpT>
auto adjust_centers(MathT* centers,
auto adjust_centers(const raft::resources& handle,
MathT* centers,
IdxT n_clusters,
IdxT dim,
const T* dataset,
IdxT n_rows,
const LabelT* labels,
const CounterT* cluster_sizes,
MathT threshold,
MathT balance_lower_tolerance,
MathT balance_upper_tolerance,
MathT centroid_offset,
MappingOpT mapping_op,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref device_memory) -> bool
{
raft::common::nvtx::range<cuvs::common::nvtx::domain::cuvs> fun_scope(
"adjust_centers(%zu, %u)", static_cast<size_t>(n_rows), n_clusters);
if (n_clusters == 0) { return false; }
auto stream = raft::resource::get_cuda_stream(handle);
constexpr static std::array kPrimes{29, 71, 113, 173, 229, 281, 349, 409, 463, 541,
601, 659, 733, 809, 863, 941, 1013, 1069, 1151, 1223,
1291, 1373, 1451, 1511, 1583, 1657, 1733, 1811, 1889, 1987,
2053, 2129, 2213, 2287, 2357, 2423, 2531, 2617, 2687, 2741};
static IdxT i = 0;
static IdxT i_primes = 0;

bool adjusted = false;
IdxT average = n_rows / n_clusters;
auto average = static_cast<MathT>(n_rows) / static_cast<MathT>(n_clusters);
auto lower_threshold = average * balance_lower_tolerance;
auto upper_threshold = average * balance_upper_tolerance;
std::vector<CounterT> host_cluster_sizes(n_clusters);
raft::update_host(host_cluster_sizes.data(), cluster_sizes, n_clusters, stream);
raft::resource::sync_stream(handle, stream);
Comment thread
anaruse marked this conversation as resolved.

std::vector<std::pair<CounterT, IdxT>> sorted_clusters;
sorted_clusters.reserve(n_clusters);
for (IdxT cluster = 0; cluster < n_clusters; ++cluster) {
sorted_clusters.emplace_back(host_cluster_sizes[cluster], cluster);
}
std::sort(sorted_clusters.begin(), sorted_clusters.end());

std::vector<IdxT> host_receiver_clusters;
std::vector<IdxT> host_donor_clusters;
host_receiver_clusters.reserve(n_clusters / 2);
host_donor_clusters.reserve(n_clusters / 2);
for (IdxT pair_id = 0; pair_id < n_clusters / 2; ++pair_id) {
auto const& [small_size, small_cluster] = sorted_clusters[pair_id];
auto const& [large_size, large_cluster] = sorted_clusters[n_clusters - 1 - pair_id];
if (small_cluster == large_cluster) { break; }
if (large_size == 0) { break; }
if (static_cast<MathT>(small_size) >= lower_threshold &&
static_cast<MathT>(large_size) <= upper_threshold) {
break;
}
host_receiver_clusters.push_back(small_cluster);
host_donor_clusters.push_back(large_cluster);
}
auto n_pairs = static_cast<IdxT>(host_receiver_clusters.size());
if (n_pairs == 0) { return false; }

IdxT ofst;
do {
i_primes = (i_primes + 1) % kPrimes.size();
ofst = kPrimes[i_primes];
} while (n_rows % ofst == 0);

rmm::device_uvector<IdxT> receiver_clusters(n_pairs, stream, device_memory);
rmm::device_uvector<IdxT> donor_clusters(n_pairs, stream, device_memory);
raft::update_device(receiver_clusters.data(), host_receiver_clusters.data(), n_pairs, stream);
raft::update_device(donor_clusters.data(), host_donor_clusters.data(), n_pairs, stream);

constexpr uint32_t kBlockDimY = 4;
const dim3 block_dim(raft::WarpSize, kBlockDimY, 1);
const dim3 grid_dim(raft::ceildiv(n_clusters, static_cast<IdxT>(kBlockDimY)), 1, 1);
const dim3 grid_dim(raft::ceildiv(n_pairs, static_cast<IdxT>(kBlockDimY)), 1, 1);
rmm::device_scalar<IdxT> update_count(0, stream, device_memory);
adjust_centers_kernel<kBlockDimY><<<grid_dim, block_dim, 0, stream>>>(centers,
n_clusters,
n_pairs,
dim,
dataset,
n_rows,
labels,
cluster_sizes,
threshold,
average,
receiver_clusters.data(),
donor_clusters.data(),
centroid_offset,
ofst,
update_count.data(),
mapping_op);
adjusted = update_count.value(stream) > 0; // NB: rmm scalar performs the sync

return adjusted;
auto n_updates = update_count.value(stream); // NB: rmm scalar performs the sync
RAFT_EXPECTS(n_updates == n_pairs, "Balanced k-means failed to update all adjusted centers");
Comment thread
anaruse marked this conversation as resolved.
return n_updates > 0;
}

/**
Expand Down Expand Up @@ -629,9 +671,12 @@ auto adjust_centers(MathT* centers,
* one extra iteration is performed (this could happen several times) (default should be `2`).
* In other words, the first and then every `ballancing_pullback`-th rebalancing operation adds
* one more iteration to the main cycle.
* @param[in] balancing_threshold
* the rebalancing takes place if any cluster is smaller than `avg_size * balancing_threshold`
* on a given iteration (default should be `~ 0.25`).
* @param[in] balance_lower_tolerance
* Small clusters are rebalanced when their paired small cluster is smaller than
* `avg_size * balance_lower_tolerance`.
* @param[in] balance_upper_tolerance
* If the paired large cluster is larger than `avg_size * balance_upper_tolerance`, the small
* cluster is rebalanced towards it.
* @param[in] mapping_op Mapping operation from T to MathT
* @param[inout] device_memory
* A memory resource for device allocations (makes sense to provide a memory pool here)
Expand All @@ -654,25 +699,34 @@ void balancing_em_iters(const raft::resources& handle,
LabelT* cluster_labels,
CounterT* cluster_sizes,
uint32_t balancing_pullback,
MathT balancing_threshold,
MathT balance_lower_tolerance,
MathT balance_upper_tolerance,
MappingOpT mapping_op,
rmm::device_async_resource_ref device_memory)
{
auto stream = raft::resource::get_cuda_stream(handle);
RAFT_EXPECTS(balance_lower_tolerance > MathT{0} && balance_lower_tolerance < MathT{1},
"Balanced k-means lower balance tolerance must be in the range (0, 1)");
RAFT_EXPECTS(balance_upper_tolerance > MathT{1},
"Balanced k-means upper balance tolerance must be greater than 1");
RAFT_EXPECTS(params.centroid_offset > 0.0f && params.centroid_offset <= 1.0f,
"Balanced k-means centroid offset must be in the range (0, 1]");

uint32_t balancing_counter = balancing_pullback;
for (uint32_t iter = 0; iter < n_iters; iter++) {
// Balancing step - move the centers around to equalize cluster sizes
// (but not on the first iteration)
if (iter > 0 && adjust_centers(cluster_centers,
if (iter > 0 && adjust_centers(handle,
cluster_centers,
n_clusters,
dim,
dataset,
n_rows,
cluster_labels,
cluster_sizes,
balancing_threshold,
balance_lower_tolerance,
balance_upper_tolerance,
static_cast<MathT>(params.centroid_offset),
mapping_op,
stream,
device_memory)) {
if (balancing_counter++ >= balancing_pullback) {
balancing_counter -= balancing_pullback;
Expand Down Expand Up @@ -776,7 +830,8 @@ void build_clusters(const raft::resources& handle,
cluster_labels,
cluster_sizes,
2,
MathT{0.25},
static_cast<MathT>(params.balance_lower_tolerance),
static_cast<MathT>(params.balance_upper_tolerance),
mapping_op,
device_memory);
}
Expand Down Expand Up @@ -1128,7 +1183,8 @@ void build_hierarchical(const raft::resources& handle,
labels.data(),
cluster_sizes.data(),
5,
MathT{0.2},
static_cast<MathT>(params.balance_lower_tolerance),
static_cast<MathT>(params.balance_upper_tolerance),
mapping_op,
device_memory);

Expand Down
5 changes: 3 additions & 2 deletions cpp/src/cluster/kmeans_balanced.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ namespace cuvs::cluster::kmeans_balanced {
* iterations over the whole dataset and with all the centroids to obtain the final clusters.
*
* Each k-means iteration applies expectation-maximization-balancing:
* - Balancing: adjust centers for clusters that have a small number of entries. If the size of a
* cluster is below a threshold, the center is moved towards a bigger cluster.
* - Balancing: adjust centers to reduce underfull and overfull clusters. Small clusters are moved
* towards larger clusters; when overfull clusters exist, below-average clusters are moved
* towards those overfull clusters.
* - Expectation: predict the labels (i.e find closest cluster centroid to each point)
* - Maximization: calculate optimal centroids (i.e find the center of gravity of each cluster)
*
Expand Down
Loading
Loading