Skip to content

Commit 7f796b1

Browse files
authored
Decouple C++ library from C library (#1488)
Currently, the C++ library needs C headers to build and be installed at runtime (as it was leaking those headers). This is unnecessary circular dependency and as the only C header needed by the C++ library was that of `cuvs/distance/distance.h`, it was easily resolved. As precedent has been set already, we declare an `enum` first in C++ and then create a duplicate of it in C, and cast between the two when calling the C++ API from the C API. Authors: - Divye Gala (https://github.com/divyegala) Approvers: - Robert Maynard (https://github.com/robertmaynard) - Tarang Jain (https://github.com/tarang-jain) URL: #1488
1 parent 1f6ff6f commit 7f796b1

20 files changed

Lines changed: 203 additions & 102 deletions

File tree

c/src/cluster/kmeans.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ namespace {
1818
cuvs::cluster::kmeans::params convert_params(const cuvsKMeansParams& params)
1919
{
2020
auto kmeans_params = cuvs::cluster::kmeans::params();
21-
kmeans_params.metric = params.metric;
21+
kmeans_params.metric = static_cast<cuvs::distance::DistanceType>(params.metric);
2222
kmeans_params.init = static_cast<cuvs::cluster::kmeans::params::InitMethod>(params.init);
2323
kmeans_params.n_clusters = params.n_clusters;
2424
kmeans_params.max_iter = params.max_iter;
@@ -33,7 +33,7 @@ cuvs::cluster::kmeans::params convert_params(const cuvsKMeansParams& params)
3333
cuvs::cluster::kmeans::balanced_params convert_balanced_params(const cuvsKMeansParams& params)
3434
{
3535
auto kmeans_params = cuvs::cluster::kmeans::balanced_params();
36-
kmeans_params.metric = params.metric;
36+
kmeans_params.metric = static_cast<cuvs::distance::DistanceType>(params.metric);
3737
kmeans_params.n_iters = params.hierarchical_n_iters;
3838
return kmeans_params;
3939
}
@@ -185,7 +185,7 @@ extern "C" cuvsError_t cuvsKMeansParamsCreate(cuvsKMeansParams_t* params)
185185
cuvs::cluster::kmeans::params cpp_params;
186186
cuvs::cluster::kmeans::balanced_params cpp_balanced_params;
187187
*params =
188-
new cuvsKMeansParams{.metric = cpp_params.metric,
188+
new cuvsKMeansParams{.metric = static_cast<cuvsDistanceType>(cpp_params.metric),
189189
.n_clusters = cpp_params.n_clusters,
190190
.init = static_cast<cuvsKMeansInitMethod>(cpp_params.init),
191191
.max_iter = cpp_params.max_iter,

c/src/distance/pairwise_distance.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <raft/core/resources.hpp>
1313

1414
#include <cuvs/core/c_api.h>
15+
#include <cuvs/distance/distance.h>
1516
#include <cuvs/distance/distance.hpp>
1617

1718
#include "../core/exceptions.hpp"
@@ -35,8 +36,9 @@ void _pairwise_distance(cuvsResources_t res,
3536
auto x_mds = cuvs::core::from_dlpack<mdspan_type>(x_tensor);
3637
auto y_mds = cuvs::core::from_dlpack<mdspan_type>(y_tensor);
3738
auto distances_mds = cuvs::core::from_dlpack<distances_mdspan_type>(distances_tensor);
39+
auto metric_type = static_cast<cuvs::distance::DistanceType>(metric);
3840

39-
cuvs::distance::pairwise_distance(*res_ptr, x_mds, y_mds, distances_mds, metric, metric_arg);
41+
cuvs::distance::pairwise_distance(*res_ptr, x_mds, y_mds, distances_mds, metric_type, metric_arg);
4042
}
4143
} // namespace
4244

c/src/neighbors/all_neighbors.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <raft/util/cudart_utils.hpp>
1212

1313
#include <cuvs/core/c_api.h>
14+
#include <cuvs/distance/distance.hpp>
1415
#include <cuvs/neighbors/all_neighbors.h>
1516
#include <cuvs/neighbors/ivf_pq.h>
1617
#include <cuvs/neighbors/nn_descent.h>

c/src/neighbors/brute_force.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
#include <raft/core/serialize.hpp>
1515

1616
#include <cuvs/core/c_api.h>
17+
#include <cuvs/distance/distance.h>
18+
#include <cuvs/distance/distance.hpp>
1719
#include <cuvs/neighbors/brute_force.h>
1820
#include <cuvs/neighbors/common.h>
1921
#include <cuvs/neighbors/brute_force.hpp>
@@ -35,7 +37,7 @@ void* _build(cuvsResources_t res,
3537
auto mds = cuvs::core::from_dlpack<mdspan_type>(dataset_tensor);
3638

3739
cuvs::neighbors::brute_force::index_params params;
38-
params.metric = metric;
40+
params.metric = static_cast<cuvs::distance::DistanceType>((int)metric);
3941
params.metric_arg = metric_arg;
4042

4143
auto index_on_stack = cuvs::neighbors::brute_force::build(*res_ptr, params, mds);

c/src/neighbors/cagra.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
#include <raft/core/serialize.hpp>
1313

1414
#include <cuvs/core/c_api.h>
15+
#include <cuvs/distance/distance.h>
16+
#include <cuvs/distance/distance.hpp>
1517
#include <cuvs/neighbors/cagra.h>
1618
#include <cuvs/neighbors/common.h>
1719
#include <cuvs/neighbors/cagra.hpp>

c/src/neighbors/hnsw.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include <cuvs/core/c_api.h>
1616
#include <cuvs/distance/distance.h>
17+
#include <cuvs/distance/distance.hpp>
1718
#include <cuvs/neighbors/hnsw.h>
1819
#include <cuvs/neighbors/hnsw.hpp>
1920

@@ -110,8 +111,9 @@ void* _deserialize(cuvsResources_t res,
110111
cuvs::neighbors::hnsw::index<T>* index = nullptr;
111112
auto cpp_params = cuvs::neighbors::hnsw::index_params();
112113
cpp_params.hierarchy = static_cast<cuvs::neighbors::hnsw::HnswHierarchy>(params->hierarchy);
114+
auto metric_type = static_cast<cuvs::distance::DistanceType>(metric);
113115
cuvs::neighbors::hnsw::deserialize(
114-
*res_ptr, cpp_params, std::string(filename), dim, metric, &index);
116+
*res_ptr, cpp_params, std::string(filename), dim, metric_type, &index);
115117
return index;
116118
}
117119
} // namespace

c/src/neighbors/nn_descent.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <raft/util/cudart_utils.hpp>
1616

1717
#include <cuvs/core/c_api.h>
18+
#include <cuvs/distance/distance.h>
1819
#include <cuvs/neighbors/nn_descent.h>
1920
#include <cuvs/neighbors/nn_descent.hpp>
2021

@@ -170,7 +171,7 @@ extern "C" cuvsError_t cuvsNNDescentIndexParamsCreate(cuvsNNDescentIndexParams_t
170171
cuvs::neighbors::nn_descent::index_params cpp_params;
171172

172173
*params = new cuvsNNDescentIndexParams{
173-
.metric = cpp_params.metric,
174+
.metric = static_cast<cuvsDistanceType>((int)cpp_params.metric),
174175
.metric_arg = cpp_params.metric_arg,
175176
.graph_degree = cpp_params.graph_degree,
176177
.intermediate_graph_degree = cpp_params.intermediate_graph_degree,

c/src/neighbors/refine.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
#include <raft/core/resources.hpp>
1212

1313
#include <cuvs/core/c_api.h>
14+
#include <cuvs/distance/distance.h>
15+
#include <cuvs/distance/distance.hpp>
1416
#include <cuvs/neighbors/refine.h>
1517
#include <cuvs/neighbors/refine.hpp>
1618

@@ -41,7 +43,7 @@ void _refine(bool on_device,
4143
auto candidates = cuvs::core::from_dlpack<candidates_type>(candidates_tensor);
4244
auto indices = cuvs::core::from_dlpack<indices_type>(indices_tensor);
4345
auto distances = cuvs::core::from_dlpack<distances_type>(distances_tensor);
44-
cuvs::neighbors::refine(*res_ptr, dataset, queries, candidates, indices, distances, metric);
46+
cuvs::neighbors::refine(*res_ptr, dataset, queries, candidates, indices, distances, static_cast<cuvs::distance::DistanceType>((int)metric));
4547
} else {
4648
using queries_type = raft::host_matrix_view<const T, int64_t, raft::row_major>;
4749
using candidates_type = raft::host_matrix_view<const int64_t, int64_t, raft::row_major>;
@@ -52,7 +54,7 @@ void _refine(bool on_device,
5254
auto candidates = cuvs::core::from_dlpack<candidates_type>(candidates_tensor);
5355
auto indices = cuvs::core::from_dlpack<indices_type>(indices_tensor);
5456
auto distances = cuvs::core::from_dlpack<distances_type>(distances_tensor);
55-
cuvs::neighbors::refine(*res_ptr, dataset, queries, candidates, indices, distances, metric);
57+
cuvs::neighbors::refine(*res_ptr, dataset, queries, candidates, indices, distances, static_cast<cuvs::distance::DistanceType>((int)metric));
5658
}
5759
}
5860
} // namespace

c/src/neighbors/tiered_index.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
#include <raft/core/serialize.hpp>
1313

1414
#include <cuvs/core/c_api.h>
15+
#include <cuvs/distance/distance.h>
16+
#include <cuvs/distance/distance.hpp>
1517
#include <cuvs/neighbors/tiered_index.h>
1618
#include <cuvs/neighbors/tiered_index.hpp>
1719

@@ -35,7 +37,7 @@ void convert_c_index_params(cuvsTieredIndexParams params,
3537
{
3638
out->min_ann_rows = params.min_ann_rows;
3739
out->create_ann_index_on_extend = params.create_ann_index_on_extend;
38-
out->metric = params.metric;
40+
out->metric = static_cast<cuvs::distance::DistanceType>((int)params.metric);
3941

4042
if constexpr (std::is_same_v<T, cagra::index_params>) {
4143
if (params.cagra_params != NULL) {
@@ -314,7 +316,7 @@ extern "C" cuvsError_t cuvsTieredIndexParamsCreate(cuvsTieredIndexParams_t* para
314316
return cuvs::core::translate_exceptions([=] {
315317
cuvs::neighbors::tiered_index::index_params<cagra::index_params> cpp_params;
316318
*params = new cuvsTieredIndexParams{
317-
.metric = cpp_params.metric,
319+
.metric = static_cast<cuvsDistanceType>((int)cpp_params.metric),
318320
.algo = CUVS_TIERED_INDEX_ALGO_CAGRA,
319321
.min_ann_rows = cpp_params.min_ann_rows,
320322
.create_ann_index_on_extend = cpp_params.create_ann_index_on_extend};

c/tests/neighbors/brute_force_c.cu

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -161,11 +161,11 @@ void cpu_sddmm(value_t* A,
161161
norms_B += B[b_index] * B[b_index];
162162
}
163163
vals[j] = alpha * sum + beta * vals[j];
164-
if (metric == cuvs::distance::DistanceType::L2Expanded) {
164+
if (metric == L2Expanded) {
165165
vals[j] = value_t(-2.0) * vals[j] + norms_A + norms_B;
166-
} else if (metric == cuvs::distance::DistanceType::L2SqrtExpanded) {
166+
} else if (metric == L2SqrtExpanded) {
167167
vals[j] = std::sqrt(value_t(-2.0) * vals[j] + norms_A + norms_B);
168-
} else if (metric == cuvs::distance::DistanceType::CosineExpanded) {
168+
} else if (metric == CosineExpanded) {
169169
vals[j] = value_t(1.0) - vals[j] / std::sqrt(norms_A * norms_B);
170170
}
171171
}
@@ -306,7 +306,7 @@ void recall_eval(T* query_data,
306306
n_rows,
307307
n_dim,
308308
n_neighbors,
309-
static_cast<cuvs::distance::DistanceType>((uint16_t)metric));
309+
static_cast<cuvs::distance::DistanceType>((int)metric));
310310

311311
size_t size = n_queries * n_neighbors;
312312
std::vector<IdxT> neighbors_h(size);
@@ -363,7 +363,7 @@ void recall_eval_with_filter(T* query_data,
363363
raft::copy(queries_h.data(), query_data, n_queries * n_dim, stream);
364364
raft::copy(indices_h.data(), index_data, n_rows * n_dim, stream);
365365

366-
bool select_min = cuvs::distance::is_min_close(metric);
366+
bool select_min = cuvs::distance::is_min_close(static_cast<cuvs::distance::DistanceType>((int)metric));
367367

368368
cpu_brute_force_with_filter(queries_h.data(),
369369
indices_h.data(),
@@ -377,7 +377,7 @@ void recall_eval_with_filter(T* query_data,
377377
n_neighbors,
378378
nnz,
379379
select_min,
380-
static_cast<cuvs::distance::DistanceType>((uint16_t)metric));
380+
metric);
381381

382382
// verify output
383383
double min_recall = 0.95;
@@ -453,7 +453,7 @@ void run_test_with_filter(int64_t n_samples,
453453
int64_t nnz = create_sparse_matrix(n_rows_filter, n_samples, sparsity, filter_h);
454454

455455
cuvsDistanceType metric = L2Expanded;
456-
bool select_min = cuvs::distance::is_min_close(metric);
456+
bool select_min = cuvs::distance::is_min_close(static_cast<cuvs::distance::DistanceType>((int)metric));
457457

458458
std::vector<float> distances_ref_h(
459459
n_queries * n_neighbors,

0 commit comments

Comments
 (0)