-
Notifications
You must be signed in to change notification settings - Fork 110
Thread loop optimizations RAJA launch #1949
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from 19 commits
cd5065e
484ff1a
18f332b
21f6184
73f224a
8a02fee
1fbe50b
672889e
5908a20
316e019
4d9f800
d9ce271
85aef5a
0469302
4a695f2
f91a498
d21c41f
40a5c1b
e0f4825
96e99d5
a9f0cca
7d4595b
c23f76f
c990a4f
f7939fd
c24331c
0518138
d5da29a
af88dbb
21ad0a8
646a95b
597641b
5403737
e41e970
7c95430
bfe72de
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,31 @@ | ||
| /*! | ||
| ****************************************************************************** | ||
| * | ||
| * \file | ||
| * | ||
| * \brief RAJA header file containing the core components of RAJA::launch | ||
| * | ||
| ****************************************************************************** | ||
| */ | ||
|
|
||
| //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// | ||
| // Copyright (c) 2016-25, Lawrence Livermore National Security, LLC | ||
| // and RAJA project contributors. See the RAJA/LICENSE file for details. | ||
| // | ||
| // SPDX-License-Identifier: (BSD-3-Clause) | ||
| //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// | ||
|
|
||
| #ifndef RAJA_pattern_context_policy_HPP | ||
| #define RAJA_pattern_context_policy_HPP | ||
|
|
||
| namespace RAJA | ||
| { | ||
|
|
||
| class LaunchContextDefaultPolicy; | ||
|
|
||
| #if defined(RAJA_CUDA_ACTIVE) || defined(RAJA_HIP_ACTIVE) | ||
johnbowen42 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| class LaunchContextDim3Policy; | ||
| #endif | ||
|
|
||
| } // namespace RAJA | ||
| #endif | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -28,33 +28,42 @@ | |
| namespace RAJA | ||
| { | ||
|
|
||
| template<typename BODY, typename ReduceParams> | ||
| template<typename BODY, typename LaunchContextPolicy, typename ReduceParams> | ||
| __global__ void launch_new_reduce_global_fcn(const RAJA_CUDA_GRID_CONSTANT BODY | ||
| body_in, | ||
| ReduceParams reduce_params) | ||
| { | ||
| LaunchContext ctx; | ||
|
|
||
| using RAJA::internal::thread_privatize; | ||
| auto privatizer = thread_privatize(body_in); | ||
| auto& body = privatizer.get_priv(); | ||
|
|
||
| // Set pointer to shared memory | ||
| extern __shared__ char raja_shmem_ptr[]; | ||
| ctx.shared_mem_ptr = raja_shmem_ptr; | ||
|
|
||
| RAJA::expt::invoke_body(reduce_params, body, ctx); | ||
| if constexpr (LaunchContextT<LaunchContextPolicy>::hasDim3) | ||
| { | ||
| LaunchContextT<LaunchContextPolicy> ctx(threadIdx, blockDim); | ||
| ctx.shared_mem_ptr = raja_shmem_ptr; | ||
| RAJA::expt::invoke_body(reduce_params, body, ctx); | ||
| } | ||
| else | ||
| { | ||
| LaunchContextT<LaunchContextPolicy> ctx; | ||
| ctx.shared_mem_ptr = raja_shmem_ptr; | ||
| RAJA::expt::invoke_body(reduce_params, body, ctx); | ||
| } | ||
|
|
||
| // Using a flatten global policy as we may use all dimensions | ||
| RAJA::expt::ParamMultiplexer::parampack_combine( | ||
| RAJA::cuda_flatten_global_xyz_direct {}, reduce_params); | ||
| } | ||
|
|
||
| template<bool async> | ||
| template<bool async, typename LaunchContextPolicy> | ||
| struct LaunchExecute< | ||
| RAJA::policy::cuda::cuda_launch_explicit_t<async, | ||
| named_usage::unspecified, | ||
| named_usage::unspecified>> | ||
| named_usage::unspecified, | ||
| LaunchContextPolicy>> | ||
| { | ||
|
|
||
| template<typename BODY_IN, typename ReduceParams> | ||
|
|
@@ -72,7 +81,8 @@ struct LaunchExecute< | |
| EXEC_POL pol {}; | ||
|
|
||
| auto func = reinterpret_cast<const void*>( | ||
| &launch_new_reduce_global_fcn<BODY, camp::decay<ReduceParams>>); | ||
| &launch_new_reduce_global_fcn<BODY, LaunchContextPolicy, | ||
| camp::decay<ReduceParams>>); | ||
|
|
||
| resources::Cuda cuda_res = res.get<RAJA::resources::Cuda>(); | ||
|
|
||
|
|
@@ -137,32 +147,48 @@ struct LaunchExecute< | |
| template<typename BODY, | ||
| int num_threads, | ||
| size_t BLOCKS_PER_SM, | ||
| typename LaunchContextPolicy, | ||
| typename ReduceParams> | ||
| __launch_bounds__(num_threads, BLOCKS_PER_SM) __global__ | ||
| void launch_new_reduce_global_fcn_fixed(const RAJA_CUDA_GRID_CONSTANT BODY | ||
| body_in, | ||
| ReduceParams reduce_params) | ||
| { | ||
| LaunchContext ctx; | ||
|
|
||
| using RAJA::internal::thread_privatize; | ||
| auto privatizer = thread_privatize(body_in); | ||
| auto& body = privatizer.get_priv(); | ||
|
|
||
| // Set pointer to shared memory | ||
| extern __shared__ char raja_shmem_ptr[]; | ||
| ctx.shared_mem_ptr = raja_shmem_ptr; | ||
|
|
||
| RAJA::expt::invoke_body(reduce_params, body, ctx); | ||
| if constexpr (LaunchContextT<LaunchContextPolicy>::hasDim3) | ||
|
||
| { | ||
| LaunchContextT<LaunchContextPolicy> ctx(threadIdx, blockDim); | ||
| ctx.shared_mem_ptr = raja_shmem_ptr; | ||
| RAJA::expt::invoke_body(reduce_params, body, ctx); | ||
| } | ||
| else | ||
| { | ||
| LaunchContextT<LaunchContextPolicy> ctx; | ||
| ctx.shared_mem_ptr = raja_shmem_ptr; | ||
| RAJA::expt::invoke_body(reduce_params, body, ctx); | ||
| } | ||
|
|
||
| // Using a flatten global policy as we may use all dimensions | ||
| RAJA::expt::ParamMultiplexer::parampack_combine( | ||
| RAJA::cuda_flatten_global_xyz_direct {}, reduce_params); | ||
| } | ||
|
|
||
| template<bool async, int nthreads, size_t BLOCKS_PER_SM> | ||
| template<bool async, | ||
| int nthreads, | ||
| size_t BLOCKS_PER_SM, | ||
| typename LaunchContextPolicy> | ||
| struct LaunchExecute< | ||
| RAJA::policy::cuda::cuda_launch_explicit_t<async, nthreads, BLOCKS_PER_SM>> | ||
| RAJA::policy::cuda::cuda_launch_explicit_t<async, | ||
| nthreads, | ||
| BLOCKS_PER_SM, | ||
| LaunchContextPolicy>> | ||
| { | ||
|
|
||
| template<typename BODY_IN, typename ReduceParams> | ||
|
|
@@ -183,6 +209,7 @@ struct LaunchExecute< | |
|
|
||
| auto func = reinterpret_cast<const void*>( | ||
| &launch_new_reduce_global_fcn_fixed<BODY, nthreads, BLOCKS_PER_SM, | ||
| LaunchContextPolicy, | ||
| camp::decay<ReduceParams>>); | ||
|
|
||
| resources::Cuda cuda_res = res.get<RAJA::resources::Cuda>(); | ||
|
|
@@ -245,6 +272,46 @@ struct LaunchExecute< | |
| } | ||
| }; | ||
|
|
||
| /* | ||
| Loop methods which rely on a copy of threaIdx/BlockDim | ||
| for performance. In collaboration with AMD we have have this | ||
| to be more performat. | ||
| */ | ||
|
|
||
| namespace expt | ||
| { | ||
|
|
||
| template<named_dim DIM> | ||
| struct cuda_ctx_thread_loop; | ||
|
|
||
| using cuda_ctx_thread_loop_x = cuda_ctx_thread_loop<named_dim::x>; | ||
| using cuda_ctx_thread_loop_y = cuda_ctx_thread_loop<named_dim::y>; | ||
| using cuda_ctx_thread_loop_z = cuda_ctx_thread_loop<named_dim::z>; | ||
|
|
||
| } // namespace expt | ||
|
|
||
| template<typename SEGMENT, named_dim DIM> | ||
| struct LoopExecute<expt::cuda_ctx_thread_loop<DIM>, SEGMENT> | ||
MrBurmark marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| { | ||
|
|
||
| template<typename BODY> | ||
| static RAJA_INLINE RAJA_DEVICE void exec( | ||
| LaunchContextT<LaunchContextDim3Policy> const& ctx, | ||
| SEGMENT const& segment, | ||
| BODY const& body) | ||
| { | ||
|
|
||
| const int len = segment.end() - segment.begin(); | ||
| constexpr int int_dim = static_cast<int>(DIM); | ||
|
|
||
| for (int i = ::RAJA::internal::CudaDimHelper<DIM>::get(ctx.thread_id); | ||
| i < len; i += ::RAJA::internal::CudaDimHelper<DIM>::get(ctx.block_dim)) | ||
|
||
| { | ||
| body(*(segment.begin() + i)); | ||
| } | ||
| } | ||
| }; | ||
|
|
||
| /* | ||
| CUDA generic loop implementations | ||
| */ | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.