Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/cuda_native/main.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ void compute_clusters(TQueue& queue, std::vector<int>& cluster_indexes) {
int main() {
cudaStream_t stream;
cudaStreamCreate(&stream);
clue::Queue queue(stream);
auto queue = clue::get_queue(stream);

std::vector<int> cluster_indexes;
compute_clusters(queue, cluster_indexes);
Expand Down
2 changes: 1 addition & 1 deletion examples/hip_native/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ void compute_clusters(TQueue& queue, std::vector<int>& cluster_indexes) {
int main() {
hipStream_t stream;
hipStreamCreate(&stream);
clue::Queue queue(stream);
auto queue = clue::get_queue(stream);

std::vector<int> cluster_indexes;
compute_clusters(queue, cluster_indexes);
Expand Down
24 changes: 20 additions & 4 deletions include/CLUEstering/utils/get_queue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ namespace clue {
/// @param device_id The index of the device
/// @return An alpaka queue created from the device corresponding to the given index
template <std::integral TIdx>
inline clue::Queue get_queue(TIdx device_id = TIdx{}) {
inline auto get_queue(TIdx device_id = TIdx{}) {
auto device = alpaka::getDevByIdx(clue::Platform{}, device_id);
return clue::Queue{device};
return clue::Queue(device);
}

/// @brief Get an alpaka queue created from a given device
Expand All @@ -28,8 +28,24 @@ namespace clue {
/// @param device The device to create the queue from
/// @return An alpaka queue created from the given device
template <concepts::device TDevice>
inline clue::Queue get_queue(const TDevice& device) {
return clue::Queue{device};
inline auto get_queue(const TDevice& device) {
return clue::Queue(device);
}

#ifdef ALPAKA_ACC_GPU_CUDA_ENABLED
/// @brief Get an alpaka queue wrapping a CUDA stream
///
/// @param stream The CUDA stream to wrap inside the alpaka queue
/// @return An alpaka queue wrapping the given CUDA stream
inline auto get_queue(cudaStream_t& stream) { return clue::Queue(stream); }
#endif

#ifdef ALPAKA_ACC_GPU_HIP_ENABLED
/// @brief Get an alpaka queue wrapping a HIP stream
///
/// @param stream The HIP stream to wrap inside the alpaka queue
/// @return An alpaka queue wrapping the given HIP stream
inline auto get_queue(hipStream_t& stream) { return clue::Queue(stream); }
#endif

} // namespace clue
40 changes: 40 additions & 0 deletions tests/test_utilities.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,46 @@ TEST_CASE("Test clue::get_queue utility") {
auto d_points2 = clue::PointsDevice<2>(queue2, points2.size());
CHECK(1);
}

#ifdef ALPAKA_ACC_GPU_CUDA_ENABLED
SUBCASE("Create queue using CUDA stream") {
cudaStream_t stream;
cudaStreamCreate(&stream);

{
auto queue = clue::get_queue(stream);
static_assert(std::is_same_v<decltype(queue), clue::Queue>, "Expected type clue::Queue");
CHECK(alpaka::getDev(queue) == alpaka::getDevByIdx(clue::Platform{}, 0u));

// check if data allocation works
clue::PointsHost<2> points1(queue, 1000);
auto d_points1 = clue::PointsDevice<2>(queue, points1.size());
CHECK(1);
}

cudaStreamDestroy(stream);
}
#endif

#ifdef ALPAKA_ACC_GPU_HIP_ENABLED
SUBCASE("Create queue using HIP stream") {
hipStream_t stream;
hipStreamCreate(&stream);

{
auto queue = clue::get_queue(stream);
static_assert(std::is_same_v<decltype(queue), clue::Queue>, "Expected type clue::Queue");
CHECK(alpaka::getDev(queue) == alpaka::getDevByIdx(clue::Platform{}, 0u));

// check if data allocation works
clue::PointsHost<2> points1(queue, 1000);
auto d_points1 = clue::PointsDevice<2>(queue, points1.size());
CHECK(1);
}

hipStreamDestroy(stream);
}
#endif
}

TEST_CASE("Test get_clusters host function") {
Expand Down
Loading