diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index 2c21018deb..0e2ad6d769 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -1706,7 +1706,7 @@ template , raft::memory_type::host>> void optimize( - raft::resources const& res, + raft::resources const& res_const, raft::mdspan, raft::row_major, AccessorKnnGraph> knn_graph, raft::mdspan, raft::row_major, AccessorOutputGraph> new_graph, const bool guarantee_connectivity = true, @@ -1715,6 +1715,12 @@ void optimize( RAFT_LOG_DEBUG( "# Pruning kNN graph (size=%lu, degree=%lu)\n", knn_graph.extent(0), knn_graph.extent(1)); + // TODO(achirkin): come up with a reasonable API to initialize a non-empty stream pool. + // raft::resource::set_cuda_stream_pool below modifies the resource, so it cannot be const. + // The optimize() is a heavy function, so copying the resource and creating a private stream pool + // is not a big overhead. + raft::resources res{res_const}; + // large temporary memory for large arrays, e.g. everything >= O(graph_size) auto large_tmp_mr = raft::resource::get_large_workspace_resource_ref(res); // temporary memory for small arrays, e.g. everything <= O(batchsize * graph_degree)