-
Notifications
You must be signed in to change notification settings - Fork 204
feat(c-api): expose attach_dataset_on_build and add cuvsCagraUpdateDataset #1842
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
736603e
9bc80ec
5f46412
401a42f
d3f1aaf
c4d9936
24320cd
3b5a03f
e4622a0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -129,6 +129,26 @@ void* _build(cuvsResources_t res, cuvsCagraIndexParams params, DLManagedTensor* | |
| return index; | ||
| } | ||
|
|
||
| template <typename T> | ||
| void _update_dataset(cuvsResources_t res, | ||
| cuvsCagraIndex index, | ||
| DLManagedTensor* dataset_tensor) | ||
| { | ||
| auto dataset = dataset_tensor->dl_tensor; | ||
| auto res_ptr = reinterpret_cast<raft::resources*>(res); | ||
| auto index_ptr = reinterpret_cast<cuvs::neighbors::cagra::index<T, uint32_t>*>(index.addr); | ||
|
|
||
| if (cuvs::core::is_dlpack_device_compatible(dataset)) { | ||
| using mdspan_type = raft::device_matrix_view<T const, int64_t, raft::row_major>; | ||
| auto mds = cuvs::core::from_dlpack<mdspan_type>(dataset_tensor); | ||
| index_ptr->update_dataset(*res_ptr, mds); | ||
| } else if (cuvs::core::is_dlpack_host_compatible(dataset)) { | ||
| using mdspan_type = raft::host_matrix_view<T const, int64_t, raft::row_major>; | ||
| auto mds = cuvs::core::from_dlpack<mdspan_type>(dataset_tensor); | ||
| index_ptr->update_dataset(*res_ptr, mds); | ||
| } | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's no else branch here. If The } else {
RAFT_FAIL("Unsupported dataset DLtensor dtype: %d and bits: %d",
dataset.dtype.code,
dataset.dtype.bits);
}we should do the same here
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good catch - thanks! |
||
| } | ||
|
|
||
| template <typename T> | ||
| void* _from_args(cuvsResources_t res, | ||
| cuvsDistanceType _metric, | ||
|
|
@@ -443,6 +463,7 @@ void convert_c_index_params(cuvsCagraIndexParams params, | |
| out->metric = static_cast<cuvs::distance::DistanceType>((int)params.metric); | ||
| out->intermediate_graph_degree = params.intermediate_graph_degree; | ||
| out->graph_degree = params.graph_degree; | ||
| out->attach_dataset_on_build = params.attach_dataset_on_build; | ||
| _set_graph_build_params(out->graph_build_params, params, params.build_algo, n_rows, dim); | ||
|
|
||
| if (auto* cparams = params.compression; cparams != nullptr) { | ||
|
|
@@ -589,6 +610,27 @@ extern "C" cuvsError_t cuvsCagraBuild(cuvsResources_t res, | |
| }); | ||
| } | ||
|
|
||
| extern "C" cuvsError_t cuvsCagraUpdateDataset(cuvsResources_t res, | ||
| DLManagedTensor* dataset_tensor, | ||
| cuvsCagraIndex_t index) | ||
| { | ||
| return cuvs::core::translate_exceptions([=] { | ||
| if (index->dtype.code == kDLFloat && index->dtype.bits == 32) { | ||
| _update_dataset<float>(res, *index, dataset_tensor); | ||
| } else if (index->dtype.code == kDLFloat && index->dtype.bits == 16) { | ||
| _update_dataset<half>(res, *index, dataset_tensor); | ||
| } else if (index->dtype.code == kDLInt && index->dtype.bits == 8) { | ||
| _update_dataset<int8_t>(res, *index, dataset_tensor); | ||
| } else if (index->dtype.code == kDLUInt && index->dtype.bits == 8) { | ||
| _update_dataset<uint8_t>(res, *index, dataset_tensor); | ||
| } else { | ||
| RAFT_FAIL("Unsupported index dtype: %d and bits: %d", | ||
| index->dtype.code, | ||
| index->dtype.bits); | ||
| } | ||
| }); | ||
| } | ||
|
|
||
| extern "C" cuvsError_t cuvsCagraIndexFromArgs(cuvsResources_t res, | ||
| cuvsDistanceType metric, | ||
| DLManagedTensor* graph_tensor, | ||
|
|
@@ -737,7 +779,10 @@ extern "C" cuvsError_t cuvsCagraIndexParamsCreate(cuvsCagraIndexParams_t* params | |
| .intermediate_graph_degree = 128, | ||
| .graph_degree = 64, | ||
| .build_algo = IVF_PQ, | ||
| .nn_descent_niter = 20}; | ||
| .nn_descent_niter = 20, | ||
| .compression = nullptr, | ||
| .graph_build_params = nullptr, | ||
| .attach_dataset_on_build = true}; | ||
| (*params)->graph_build_params = new cuvsIvfPqParams{nullptr, nullptr, 1}; | ||
| }); | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -337,6 +337,110 @@ TEST(CagraC, BuildExtendSearch) | |
| cuvsResourcesDestroy(res); | ||
| } | ||
|
|
||
| TEST(CagraC, BuildNoDatasetThenUpdateAndSearch) | ||
| { | ||
| // Test the attach_dataset_on_build = false workflow: | ||
| // 1. Build index without attaching dataset (saves a full dataset copy) | ||
| // 2. Attach dataset via cuvsCagraUpdateDataset | ||
| // 3. Search and verify correctness | ||
|
|
||
| // create cuvsResources_t | ||
| cuvsResources_t res; | ||
| cuvsResourcesCreate(&res); | ||
| cudaStream_t stream; | ||
| cuvsStreamGet(res, &stream); | ||
|
|
||
| // create dataset DLTensor | ||
| DLManagedTensor dataset_tensor; | ||
| dataset_tensor.dl_tensor.data = dataset; | ||
| dataset_tensor.dl_tensor.device.device_type = kDLCPU; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I'm not mistaken, the motivating scenario (dataset in device, avoiding the redundant copy) is not tested. Adding a variant that creates a
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're right — added |
||
| dataset_tensor.dl_tensor.ndim = 2; | ||
| dataset_tensor.dl_tensor.dtype.code = kDLFloat; | ||
| dataset_tensor.dl_tensor.dtype.bits = 32; | ||
| dataset_tensor.dl_tensor.dtype.lanes = 1; | ||
| int64_t dataset_shape[2] = {4, 2}; | ||
| dataset_tensor.dl_tensor.shape = dataset_shape; | ||
| dataset_tensor.dl_tensor.strides = nullptr; | ||
|
|
||
| // create index | ||
| cuvsCagraIndex_t index; | ||
| cuvsCagraIndexCreate(&index); | ||
|
|
||
| // build index with attach_dataset_on_build = false | ||
| cuvsCagraIndexParams_t build_params; | ||
| cuvsCagraIndexParamsCreate(&build_params); | ||
| build_params->attach_dataset_on_build = false; | ||
| ASSERT_EQ(cuvsCagraBuild(res, build_params, &dataset_tensor, index), CUVS_SUCCESS); | ||
|
|
||
| // now attach the dataset | ||
| ASSERT_EQ(cuvsCagraUpdateDataset(res, &dataset_tensor, index), CUVS_SUCCESS); | ||
|
|
||
| // create queries DLTensor | ||
| rmm::device_uvector<float> queries_d(4 * 2, stream); | ||
| raft::copy(queries_d.data(), (float*)queries, 4 * 2, stream); | ||
|
|
||
| DLManagedTensor queries_tensor; | ||
| queries_tensor.dl_tensor.data = queries_d.data(); | ||
| queries_tensor.dl_tensor.device.device_type = kDLCUDA; | ||
| queries_tensor.dl_tensor.ndim = 2; | ||
| queries_tensor.dl_tensor.dtype.code = kDLFloat; | ||
| queries_tensor.dl_tensor.dtype.bits = 32; | ||
| queries_tensor.dl_tensor.dtype.lanes = 1; | ||
| int64_t queries_shape[2] = {4, 2}; | ||
| queries_tensor.dl_tensor.shape = queries_shape; | ||
| queries_tensor.dl_tensor.strides = nullptr; | ||
|
|
||
| // create neighbors DLTensor | ||
| rmm::device_uvector<uint32_t> neighbors_d(4, stream); | ||
|
|
||
| DLManagedTensor neighbors_tensor; | ||
| neighbors_tensor.dl_tensor.data = neighbors_d.data(); | ||
| neighbors_tensor.dl_tensor.device.device_type = kDLCUDA; | ||
| neighbors_tensor.dl_tensor.ndim = 2; | ||
| neighbors_tensor.dl_tensor.dtype.code = kDLUInt; | ||
| neighbors_tensor.dl_tensor.dtype.bits = 32; | ||
| neighbors_tensor.dl_tensor.dtype.lanes = 1; | ||
| int64_t neighbors_shape[2] = {4, 1}; | ||
| neighbors_tensor.dl_tensor.shape = neighbors_shape; | ||
| neighbors_tensor.dl_tensor.strides = nullptr; | ||
|
|
||
| // create distances DLTensor | ||
| rmm::device_uvector<float> distances_d(4, stream); | ||
|
|
||
| DLManagedTensor distances_tensor; | ||
| distances_tensor.dl_tensor.data = distances_d.data(); | ||
| distances_tensor.dl_tensor.device.device_type = kDLCUDA; | ||
| distances_tensor.dl_tensor.ndim = 2; | ||
| distances_tensor.dl_tensor.dtype.code = kDLFloat; | ||
| distances_tensor.dl_tensor.dtype.bits = 32; | ||
| distances_tensor.dl_tensor.dtype.lanes = 1; | ||
| int64_t distances_shape[2] = {4, 1}; | ||
| distances_tensor.dl_tensor.shape = distances_shape; | ||
| distances_tensor.dl_tensor.strides = nullptr; | ||
|
|
||
| cuvsFilter filter; | ||
| filter.type = NO_FILTER; | ||
| filter.addr = (uintptr_t)NULL; | ||
|
|
||
| // search index | ||
| cuvsCagraSearchParams_t search_params; | ||
| cuvsCagraSearchParamsCreate(&search_params); | ||
| cuvsCagraSearch( | ||
| res, search_params, index, &queries_tensor, &neighbors_tensor, &distances_tensor, filter); | ||
|
|
||
| // verify output — should match the standard BuildSearch test results | ||
| ASSERT_TRUE( | ||
| cuvs::devArrMatchHost(neighbors_exp, neighbors_d.data(), 4, cuvs::Compare<uint32_t>())); | ||
| ASSERT_TRUE(cuvs::devArrMatchHost( | ||
| distances_exp, distances_d.data(), 4, cuvs::CompareApprox<float>(0.001f))); | ||
|
|
||
| // de-allocate index and res | ||
| cuvsCagraSearchParamsDestroy(search_params); | ||
| cuvsCagraIndexParamsDestroy(build_params); | ||
| cuvsCagraIndexDestroy(index); | ||
| cuvsResourcesDestroy(res); | ||
| } | ||
|
|
||
| TEST(CagraC, BuildSearchFiltered) | ||
| { | ||
| // create cuvsResources_t | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Export the new C API symbol.
cuvsCagraUpdateDatasetis declared withoutCUVS_EXPORT, so it may not be visible from shared-library builds. Please export it like other public C APIs.Proposed fix
📝 Committable suggestion
🤖 Prompt for AI Agents