|
16 | 16 | #include <raft/core/operators.hpp> |
17 | 17 | #include <raft/core/resource/cuda_stream.hpp> |
18 | 18 | #include <raft/core/resource/device_memory_resource.hpp> |
| 19 | +#include <raft/core/resource/device_properties.hpp> |
19 | 20 | #include <raft/core/resource/thrust_policy.hpp> |
20 | 21 | #include <raft/linalg/add.cuh> |
21 | 22 | #include <raft/linalg/gemm.cuh> |
@@ -171,22 +172,28 @@ inline std::enable_if_t<std::is_floating_point_v<MathT>> predict_core( |
171 | 172 | * @return A suggested minibatch size and the expected memory cost per-row (in bytes) |
172 | 173 | */ |
173 | 174 | template <typename MathT, typename IdxT> |
174 | | -constexpr auto calc_minibatch_size(IdxT n_clusters, |
175 | | - IdxT n_rows, |
176 | | - IdxT dim, |
177 | | - cuvs::distance::DistanceType metric, |
178 | | - bool needs_conversion) -> std::tuple<IdxT, size_t> |
| 175 | +auto calc_minibatch_size(const raft::resources& handle, |
| 176 | + IdxT n_clusters, |
| 177 | + IdxT n_rows, |
| 178 | + IdxT dim, |
| 179 | + cuvs::distance::DistanceType metric, |
| 180 | + bool needs_conversion) -> std::tuple<IdxT, size_t> |
179 | 181 | { |
180 | 182 | n_clusters = std::max<IdxT>(1, n_clusters); |
181 | 183 |
|
182 | 184 | // Estimate memory needs per row (i.e element of the batch). |
183 | 185 | size_t mem_per_row = 0; |
184 | 186 | switch (metric) { |
185 | | - // fusedL2NN needs a mutex and a key-value pair for each row. |
186 | 187 | case distance::DistanceType::L2Expanded: |
187 | 188 | case distance::DistanceType::L2SqrtExpanded: { |
188 | | - mem_per_row += sizeof(int); |
189 | | - mem_per_row += sizeof(raft::KeyValuePair<IdxT, MathT>); |
| 189 | + if (use_fused<MathT, IdxT, IdxT>(handle, n_rows, n_clusters, dim)) { |
| 190 | + // fusedL2NN needs a mutex and a key-value pair for each row. |
| 191 | + mem_per_row += sizeof(int); |
| 192 | + mem_per_row += sizeof(raft::KeyValuePair<IdxT, MathT>); |
| 193 | + } else { |
| 194 | + // unfused path needs a full GEMM output (distance matrix row). |
| 195 | + mem_per_row += sizeof(MathT) * n_clusters; |
| 196 | + } |
190 | 197 | } break; |
191 | 198 | // Other metrics require storing a distance matrix. |
192 | 199 | default: { |
@@ -377,8 +384,8 @@ void predict(const raft::resources& handle, |
377 | 384 | raft::common::nvtx::range<cuvs::common::nvtx::domain::cuvs> fun_scope( |
378 | 385 | "predict(%zu, %u)", static_cast<size_t>(n_rows), n_clusters); |
379 | 386 | auto mem_res = mr.value_or(raft::resource::get_workspace_resource_ref(handle)); |
380 | | - auto [max_minibatch_size, _mem_per_row] = |
381 | | - calc_minibatch_size<MathT>(n_clusters, n_rows, dim, params.metric, std::is_same_v<T, MathT>); |
| 387 | + auto [max_minibatch_size, _mem_per_row] = calc_minibatch_size<MathT>( |
| 388 | + handle, n_clusters, n_rows, dim, params.metric, std::is_same_v<T, MathT>); |
382 | 389 | rmm::device_uvector<MathT> cur_dataset( |
383 | 390 | std::is_same_v<T, MathT> ? 0 : max_minibatch_size * dim, stream, mem_res); |
384 | 391 | bool need_compute_norm = |
@@ -989,8 +996,8 @@ void build_hierarchical(const raft::resources& handle, |
989 | 996 | // TODO: Remove the explicit managed memory- we shouldn't be creating this on the user's behalf. |
990 | 997 | rmm::mr::managed_memory_resource managed_memory; |
991 | 998 | rmm::device_async_resource_ref device_memory = raft::resource::get_workspace_resource_ref(handle); |
992 | | - auto [max_minibatch_size, mem_per_row] = |
993 | | - calc_minibatch_size<MathT>(n_clusters, n_rows, dim, params.metric, std::is_same_v<T, MathT>); |
| 999 | + auto [max_minibatch_size, mem_per_row] = calc_minibatch_size<MathT>( |
| 1000 | + handle, n_clusters, n_rows, dim, params.metric, std::is_same_v<T, MathT>); |
994 | 1001 |
|
995 | 1002 | // Precompute the L2 norm of the dataset if relevant and not yet computed. |
996 | 1003 | rmm::device_uvector<MathT> dataset_norm_buf(0, stream, device_memory); |
|
0 commit comments