diff --git a/sycl/include/sycl/khr/work_item_queries.hpp b/sycl/include/sycl/khr/work_item_queries.hpp new file mode 100644 index 0000000000000..ce89e9523cd38 --- /dev/null +++ b/sycl/include/sycl/khr/work_item_queries.hpp @@ -0,0 +1,36 @@ +//===-- work_item_queries.hpp --- KHR work item queries extension ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#pragma once + +#ifdef __DPCPP_ENABLE_UNFINISHED_KHR_EXTENSIONS + +#include + +#define SYCL_KHR_WORK_ITEM_QUERIES 1 + +namespace sycl { +inline namespace _V1 { +namespace khr { + +template nd_item this_nd_item() { + return ext::oneapi::experimental::this_nd_item(); +} + +template group this_group() { + return ext::oneapi::this_work_item::get_work_group(); +} + +inline sub_group this_sub_group() { + return ext::oneapi::this_work_item::get_sub_group(); +} + +} // namespace khr +} // namespace _V1 +} // namespace sycl + +#endif diff --git a/sycl/include/sycl/sycl.hpp b/sycl/include/sycl/sycl.hpp index 4ac7bbef129c1..12e83e89f5b32 100644 --- a/sycl/include/sycl/sycl.hpp +++ b/sycl/include/sycl/sycl.hpp @@ -124,3 +124,4 @@ #include #include #include +#include diff --git a/sycl/test-e2e/Basic/work_item_queries/work_item_queries.cpp b/sycl/test-e2e/Basic/work_item_queries/work_item_queries.cpp new file mode 100644 index 0000000000000..0ea97b0059b40 --- /dev/null +++ b/sycl/test-e2e/Basic/work_item_queries/work_item_queries.cpp @@ -0,0 +1,99 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +//===- work_item_queries.cpp - KHR work item queries test -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#define __DPCPP_ENABLE_UNFINISHED_KHR_EXTENSIONS + +#include +#include +#include + +template static void check_this_nd_item_api() { + // Define the kernel ranges. + constexpr int Dimensions = sizeof...(Dims); + const sycl::range local_range{Dims...}; + const sycl::range global_range = local_range; + const sycl::nd_range nd_range{global_range, local_range}; + // Launch an ND-range kernel. + sycl::queue q; + sycl::buffer results{global_range}; + q.submit([&](sycl::handler &cgh) { + sycl::accessor acc{results, cgh, sycl::write_only}; + cgh.parallel_for(nd_range, [=](sycl::nd_item it) { + // Compare it to this_nd_item(). + acc[it.get_global_id()] = (it == sycl::khr::this_nd_item()); + }); + }); + // Check the test results. + sycl::host_accessor acc{results}; + for (const auto &result : acc) + assert(result); +} + +template static void check_this_group_api() { + // Define the kernel ranges. + constexpr int Dimensions = sizeof...(Dims); + const sycl::range local_range{Dims...}; + const sycl::range global_range = local_range; + const sycl::nd_range nd_range{global_range, local_range}; + // Launch an ND-range kernel. + sycl::queue q; + sycl::buffer results{global_range}; + q.submit([&](sycl::handler &cgh) { + sycl::accessor acc{results, cgh, sycl::write_only}; + cgh.parallel_for(nd_range, [=](sycl::nd_item it) { + // Compare it.get_group() to this_group(). + acc[it.get_global_id()] = + (it.get_group() == sycl::khr::this_group()); + }); + }); + // Check the test results. + sycl::host_accessor acc{results}; + for (const auto &result : acc) + assert(result); +} + +template static void check_this_sub_group_api() { + // Define the kernel ranges. + constexpr int Dimensions = sizeof...(Dims); + const sycl::range local_range{Dims...}; + const sycl::range global_range = local_range; + const sycl::nd_range nd_range{global_range, local_range}; + // Launch an ND-range kernel. + sycl::queue q; + sycl::buffer results{global_range}; + q.submit([&](sycl::handler &cgh) { + sycl::accessor acc{results, cgh, sycl::write_only}; + cgh.parallel_for(nd_range, [=](sycl::nd_item it) { + // Compare it.get_sub_group() to this_sub_group(). + acc[it.get_global_id()] = + (it.get_sub_group() == sycl::khr::this_sub_group()); + }); + }); + // Check the test results. + sycl::host_accessor acc{results}; + for (const auto &result : acc) + assert(result); +} + +int main() { + // nd_item + check_this_nd_item_api<2>(); + check_this_nd_item_api<2, 3>(); + check_this_nd_item_api<2, 3, 4>(); + // group + check_this_group_api<2>(); + check_this_group_api<2, 3>(); + check_this_group_api<2, 3, 4>(); + // sub_group + check_this_sub_group_api<2>(); + check_this_sub_group_api<2, 3>(); + check_this_sub_group_api<2, 3, 4>(); +}