From 5289d50822e557c2b378cc1160110906e9849658 Mon Sep 17 00:00:00 2001 From: "Cai, Justin" Date: Mon, 27 Mar 2023 11:54:49 -0700 Subject: [PATCH 1/2] [SYCL] Make joint_reduce work with sub_group Note: the unqualified name lookup of joint_reduce in the overload of joint_reduce without an init param was not finding the overload of joint_reduce with an init param (because that declaration was located after it), so it searched for joint_reduce via ADL. With sycl::group, ADL can find both overloads of joint_reduce, but sycl::sub_group = sycl::ext::oneapi::sub_group, ADL finds no joint_reduce in sycl::ext::oneapi. Signed-off-by: Cai, Justin --- sycl/include/sycl/group_algorithm.hpp | 46 +++++++++++++-------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/sycl/include/sycl/group_algorithm.hpp b/sycl/include/sycl/group_algorithm.hpp index 1fa39d5ba3b5c..0bedd60122548 100644 --- a/sycl/include/sycl/group_algorithm.hpp +++ b/sycl/include/sycl/group_algorithm.hpp @@ -315,29 +315,6 @@ reduce_over_group(Group g, V x, T init, BinaryOperation binary_op) { } // ---- joint_reduce -template -detail::enable_if_t< - (is_group_v> && detail::is_pointer::value && - detail::is_arithmetic_or_complex< - typename detail::remove_pointer::type>::value && - detail::is_plus_or_multiplies_if_complex< - typename detail::remove_pointer::type, BinaryOperation>::value), - typename detail::remove_pointer::type> -joint_reduce(Group g, Ptr first, Ptr last, BinaryOperation binary_op) { -#ifdef __SYCL_DEVICE_ONLY__ - using T = typename detail::remove_pointer::type; - T init = detail::identity_for_ga_op(); - return joint_reduce(g, first, last, init, binary_op); -#else - (void)g; - (void)first; - (void)last; - (void)binary_op; - throw runtime_error("Group algorithms are not supported on host.", - PI_ERROR_INVALID_DEVICE); -#endif -} - template detail::enable_if_t< (is_group_v> && detail::is_pointer::value && @@ -373,6 +350,29 @@ joint_reduce(Group g, Ptr first, Ptr last, T init, BinaryOperation binary_op) { #endif } +template +detail::enable_if_t< + (is_group_v> && detail::is_pointer::value && + detail::is_arithmetic_or_complex< + typename detail::remove_pointer::type>::value && + detail::is_plus_or_multiplies_if_complex< + typename detail::remove_pointer::type, BinaryOperation>::value), + typename detail::remove_pointer::type> +joint_reduce(Group g, Ptr first, Ptr last, BinaryOperation binary_op) { +#ifdef __SYCL_DEVICE_ONLY__ + using T = typename detail::remove_pointer::type; + T init = detail::identity_for_ga_op(); + return joint_reduce(g, first, last, init, binary_op); +#else + (void)g; + (void)first; + (void)last; + (void)binary_op; + throw runtime_error("Group algorithms are not supported on host.", + PI_ERROR_INVALID_DEVICE); +#endif +} + // ---- any_of_group template detail::enable_if_t>, bool> From d7167621bf9298bccd0854ce2ebec39ad17e089a Mon Sep 17 00:00:00 2001 From: "Cai, Justin" Date: Tue, 28 Mar 2023 09:17:02 -0700 Subject: [PATCH 2/2] [SYCL] Add joint_reduce E2E test with sub_group Signed-off-by: Cai, Justin --- sycl/test-e2e/GroupAlgorithm/reduce_sycl2020.cpp | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/sycl/test-e2e/GroupAlgorithm/reduce_sycl2020.cpp b/sycl/test-e2e/GroupAlgorithm/reduce_sycl2020.cpp index 95e0d6119a34d..efa7062749b46 100644 --- a/sycl/test-e2e/GroupAlgorithm/reduce_sycl2020.cpp +++ b/sycl/test-e2e/GroupAlgorithm/reduce_sycl2020.cpp @@ -34,6 +34,7 @@ void test(queue q, InputContainer input, OutputContainer output, cgh.parallel_for( nd_range<1>(G, G), [=](nd_item<1> it) { group<1> g = it.get_group(); + auto sg = it.get_sub_group(); int lid = it.get_local_id(0); out[0] = reduce_over_group(g, in[lid], binary_op); out[1] = reduce_over_group(g, in[lid], init, binary_op); @@ -41,6 +42,10 @@ void test(queue q, InputContainer input, OutputContainer output, binary_op); out[3] = joint_reduce(g, in.get_pointer(), in.get_pointer() + N, init, binary_op); + out[4] = joint_reduce(sg, in.get_pointer(), in.get_pointer() + N, + binary_op); + out[5] = joint_reduce(sg, in.get_pointer(), in.get_pointer() + N, + init, binary_op); }); }); } @@ -54,6 +59,10 @@ void test(queue q, InputContainer input, OutputContainer output, std::accumulate(input.begin(), input.end(), identity, binary_op)); assert(output[3] == std::accumulate(input.begin(), input.end(), init, binary_op)); + assert(output[4] == + std::accumulate(input.begin(), input.end(), identity, binary_op)); + assert(output[5] == + std::accumulate(input.begin(), input.end(), init, binary_op)); } int main() { @@ -65,7 +74,7 @@ int main() { constexpr int N = 128; std::array input; - std::array output; + std::array output; std::iota(input.begin(), input.end(), 0); std::fill(output.begin(), output.end(), 0); @@ -93,7 +102,7 @@ int main() { // sycl::plus binary operation. #ifdef SYCL_EXT_ONEAPI_COMPLEX_ALGORITHMS std::array, N> input_cf; - std::array, 4> output_cf; + std::array, 6> output_cf; std::iota(input_cf.begin(), input_cf.end(), 0); std::fill(output_cf.begin(), output_cf.end(), 0); test(q, input_cf, output_cf,