Skip to content

Commit

Permalink
support optional weights
Browse files Browse the repository at this point in the history
  • Loading branch information
jnke2016 committed Dec 28, 2024
1 parent 4da1c7e commit 6984645
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions python/pylibcugraph/pylibcugraph/node2vec_random_walks.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ from pylibcugraph.utils cimport (
)


def node2vec(ResourceHandle resource_handle,
def node2vec_random_walks(ResourceHandle resource_handle,
_GPUGraph graph,
seed_array,
size_t max_depth,
Expand Down Expand Up @@ -172,11 +172,14 @@ def node2vec(ResourceHandle resource_handle,
# arrays for returning.
cdef cugraph_type_erased_device_array_view_t* paths_ptr = \
cugraph_random_walk_result_get_paths(result_ptr)
cdef cugraph_type_erased_device_array_view_t* weights_ptr = \
cugraph_random_walk_result_get_weights(result_ptr)

if graph.weights_view_ptr is NULL and graph.weights_view_ptr_ptr is NULL:
cupy_weights = None
else:
weights_ptr = cugraph_random_walk_result_get_weights(result_ptr)
cupy_weights = copy_to_cupy_array(c_resource_handle_ptr, weights_ptr)

cupy_paths = copy_to_cupy_array(c_resource_handle_ptr, paths_ptr)
cupy_weights = copy_to_cupy_array(c_resource_handle_ptr, weights_ptr)

cugraph_random_walk_result_free(result_ptr)
cugraph_type_erased_device_array_view_free(seed_view_ptr)
Expand Down

0 comments on commit 6984645

Please sign in to comment.