Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion cpp/src/neighbors/detail/cagra/graph_core.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1706,7 +1706,7 @@ template <typename IdxT = uint32_t,
typename AccessorOutputGraph =
raft::host_device_accessor<cuda::std::default_accessor<IdxT>, raft::memory_type::host>>
void optimize(
raft::resources const& res,
raft::resources const& res_const,
raft::mdspan<IdxT, raft::matrix_extent<int64_t>, raft::row_major, AccessorKnnGraph> knn_graph,
raft::mdspan<IdxT, raft::matrix_extent<int64_t>, raft::row_major, AccessorOutputGraph> new_graph,
const bool guarantee_connectivity = true,
Expand All @@ -1715,6 +1715,12 @@ void optimize(
RAFT_LOG_DEBUG(
"# Pruning kNN graph (size=%lu, degree=%lu)\n", knn_graph.extent(0), knn_graph.extent(1));

// TODO(achirkin): come up with a reasonable API to initialize a non-empty stream pool.
// raft::resource::set_cuda_stream_pool below modifies the resource, so it cannot be const.
// The optimize() is a heavy function, so copying the resource and creating a private stream pool
// is not a big overhead.
raft::resources res{res_const};

// large temporary memory for large arrays, e.g. everything >= O(graph_size)
auto large_tmp_mr = raft::resource::get_large_workspace_resource_ref(res);
// temporary memory for small arrays, e.g. everything <= O(batchsize * graph_degree)
Expand Down
Loading