@@ -71,6 +71,57 @@ void* _build(cuvsResources_t res,
7171 RAFT_FAIL (" dataset must be accessible on host or device memory" );
7272 }
7373}
74+
75+ template <typename output_mdspan_type>
76+ void _get_graph (cuvsResources_t res, cuvsNNDescentIndex_t index, DLManagedTensor* graph)
77+ {
78+ auto dtype = index->dtype ;
79+ if ((dtype.code == kDLUInt ) && (dtype.bits == 32 )) {
80+ auto index_ptr = reinterpret_cast <cuvs::neighbors::nn_descent::index<uint32_t >*>(index->addr );
81+ auto dst = cuvs::core::from_dlpack<output_mdspan_type>(graph);
82+ auto src = index_ptr->graph ();
83+ auto res_ptr = reinterpret_cast <raft::resources*>(res);
84+
85+ RAFT_EXPECTS (src.extent (0 ) == dst.extent (0 ), " Output graph has incorrect number of rows" );
86+ RAFT_EXPECTS (src.extent (1 ) == dst.extent (1 ), " Output graph has incorrect number of cols" );
87+
88+ cudaMemcpyAsync (dst.data_handle (),
89+ src.data_handle (),
90+ dst.extent (0 ) * dst.extent (1 ) * sizeof (uint32_t ),
91+ cudaMemcpyDefault,
92+ raft::resource::get_cuda_stream (*res_ptr));
93+ } else {
94+ RAFT_FAIL (" Unsupported nn-descent index dtype: %d and bits: %d" , dtype.code , dtype.bits );
95+ }
96+ }
97+
98+ template <typename output_mdspan_type>
99+ void _get_distances (cuvsResources_t res, cuvsNNDescentIndex_t index, DLManagedTensor* distances)
100+ {
101+ auto dtype = index->dtype ;
102+ if ((dtype.code == kDLUInt ) && (dtype.bits == 32 )) {
103+ auto index_ptr = reinterpret_cast <cuvs::neighbors::nn_descent::index<uint32_t >*>(index->addr );
104+ auto src = index_ptr->distances ();
105+ if (!src.has_value ()) {
106+ RAFT_FAIL (" nn-descent index doesn't contain distances - set return_distances when building" );
107+ }
108+
109+ auto res_ptr = reinterpret_cast <raft::resources*>(res);
110+ auto dst = cuvs::core::from_dlpack<output_mdspan_type>(distances);
111+
112+ RAFT_EXPECTS (src->extent (0 ) == dst.extent (0 ), " Output distances has incorrect number of rows" );
113+ RAFT_EXPECTS (src->extent (1 ) == dst.extent (1 ), " Output distances has incorrect number of cols" );
114+
115+ cudaMemcpyAsync (dst.data_handle (),
116+ src->data_handle (),
117+ dst.extent (0 ) * dst.extent (1 ) * sizeof (float ),
118+ cudaMemcpyDefault,
119+ raft::resource::get_cuda_stream (*res_ptr));
120+
121+ } else {
122+ RAFT_FAIL (" Unsupported nn-descent index dtype: %d and bits: %d" , dtype.code , dtype.bits );
123+ }
124+ }
74125} // namespace
75126
76127extern " C" cuvsError_t cuvsNNDescentIndexCreate (cuvsNNDescentIndex_t* index)
@@ -146,22 +197,32 @@ extern "C" cuvsError_t cuvsNNDescentIndexParamsDestroy(cuvsNNDescentIndexParams_
146197 return cuvs::core::translate_exceptions ([=] { delete params; });
147198}
148199
149- extern " C" cuvsError_t cuvsNNDescentIndexGetGraph (cuvsNNDescentIndex_t index,
200+ extern " C" cuvsError_t cuvsNNDescentIndexGetGraph (cuvsResources_t res,
201+ cuvsNNDescentIndex_t index,
150202 DLManagedTensor* graph)
151203{
152204 return cuvs::core::translate_exceptions ([=] {
153- auto dtype = index->dtype ;
154- if ((dtype.code == kDLUInt ) && (dtype.bits == 32 )) {
155- auto index_ptr = reinterpret_cast <cuvs::neighbors::nn_descent::index<uint32_t >*>(index->addr );
205+ if (cuvs::core::is_dlpack_device_compatible (graph->dl_tensor )) {
206+ using output_mdspan_type = raft::device_matrix_view<uint32_t , int64_t , raft::row_major>;
207+ _get_graph<output_mdspan_type>(res, index, graph);
208+ } else {
156209 using output_mdspan_type = raft::host_matrix_view<uint32_t , int64_t , raft::row_major>;
157- auto dst = cuvs::core::from_dlpack<output_mdspan_type>(graph);
158- auto src = index_ptr->graph ();
210+ _get_graph<output_mdspan_type>(res, index, graph);
211+ }
212+ });
213+ }
159214
160- RAFT_EXPECTS (src.extent (0 ) == dst.extent (0 ), " Output graph has incorrect number of rows" );
161- RAFT_EXPECTS (src.extent (1 ) == dst.extent (1 ), " Output graph has incorrect number of cols" );
162- std::copy (src.data_handle (), src.data_handle () + dst.size (), dst.data_handle ());
215+ extern " C" cuvsError_t cuvsNNDescentIndexGetDistances (cuvsResources_t res,
216+ cuvsNNDescentIndex_t index,
217+ DLManagedTensor* distances)
218+ {
219+ return cuvs::core::translate_exceptions ([=] {
220+ if (cuvs::core::is_dlpack_device_compatible (distances->dl_tensor )) {
221+ using output_mdspan_type = raft::device_matrix_view<float , int64_t , raft::row_major>;
222+ _get_distances<output_mdspan_type>(res, index, distances);
163223 } else {
164- RAFT_FAIL (" Unsupported nn-descent index dtype: %d and bits: %d" , dtype.code , dtype.bits );
224+ using output_mdspan_type = raft::host_matrix_view<float , int64_t , raft::row_major>;
225+ _get_distances<output_mdspan_type>(res, index, distances);
165226 }
166227 });
167228}
0 commit comments