Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
ddf0071
HIP compilation
ryanstocks00 Jul 18, 2024
6070afd
Add hip version of mgga kernels
ryanstocks00 Jul 18, 2024
bb4559f
Removed commented print
ryanstocks00 Jul 18, 2024
6d2d5f0
Copy paste cleanup
ryanstocks00 Jul 18, 2024
6b46eb6
More missing HIP functions
ryanstocks00 Jul 18, 2024
c81132f
Missing HIP kernel
ryanstocks00 Jul 18, 2024
c2e3cc2
Removed register from hip
ryanstocks00 Jul 18, 2024
e9bf3a2
Reduced shared mem req
ryanstocks00 Jul 18, 2024
e9c616d
HIP discovery fixes
ajaypanyala Jul 20, 2024
b781db3
update readme [skip ci]
ajaypanyala Jul 20, 2024
08ae311
update ExchEXX hash
ajaypanyala Jul 28, 2024
6c80aff
Merge remote-tracking branch 'upstream/master'
ryanstocks00 Sep 21, 2024
5b2273e
Fixed HIP compilation
ryanstocks00 Sep 21, 2024
9d9145d
hipblas.h -> hipblas/hipblas.h
ryanstocks00 Sep 21, 2024
1ad6fd4
Renamed SM_BLOCK_Y for cuda compilation
ryanstocks00 Sep 21, 2024
af72d05
Move a bunch of cuda -> hip
ryanstocks00 Sep 22, 2024
7c26939
Allow passing additional flags to obara saika host compilation
ryanstocks00 Sep 26, 2024
2bb4783
Moved obara saika compile flags override
ryanstocks00 Sep 26, 2024
ecf6eac
Compiling HIP on NVIDIA
ryanstocks00 Sep 30, 2024
23c78e9
Pseudofunctional HIP on NVIDIA
ryanstocks00 Oct 8, 2024
031fb0a
Fixed mem access violation
ryanstocks00 Oct 8, 2024
eeff105
Copy zmat from cuda
ryanstocks00 Oct 8, 2024
2089af6
Small refactor of cuda vvar kernel to support any grid/block dims
ryanstocks00 Oct 9, 2024
d4675df
Revert SM block size changes
ryanstocks00 Oct 9, 2024
bfd8803
More forceful double instead of double2
ryanstocks00 Oct 9, 2024
f0b1a51
AMD compilation
ryanstocks00 Oct 14, 2024
63bb6cd
Merge branch 'wavefunction91:master' into master
ryanstocks00 May 26, 2025
5882e88
Remove shmem to fix bug in HIP vvar_grad kernel
ryanstocks00 Jun 30, 2025
39498dd
Fix MGGA on AMD
ryanstocks00 Jul 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,12 @@ endif()
if( GAUXC_ENABLE_HIP )
enable_language( HIP )
set( GAUXC_HAS_HIP TRUE CACHE BOOL "GauXC has HIP and will build HIP bindings" FORCE )
if(NOT DEFINED ROCM_PATH)
message(FATAL_ERROR "ROCM_PATH must be set")
endif()
if( NOT DEFINED CMAKE_HIP_ARCHITECTURES )
message( FATAL_ERROR "CMAKE_HIP_ARCHITECTURES must be set" )
endif()
endif()

# Decided if we're compiling device bindings
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,10 @@ target_link_libraries( my_target PUBLIC gauxc::gauxc )
| `GAUXC_ENABLE_MPI` | Enable MPI Bindings | `ON` |
| `GAUXC_ENABLE_OPENMP` | Enable OpenMP Bindings | `ON` |
| `CMAKE_CUDA_ARCHITECTURES` | CUDA architechtures (e.g. 70 for Volta, 80 for Ampere) | -- |
| `CMAKE_HIP_ARCHITECTURES` | HIP architechtures (e.g. gfx90a for MI250X) | -- |
| `BLAS_LIBRARIES` | Full BLAS linker. | -- |
| `MAGMA_ROOT_DIR` | Install prefix for MAGMA. | -- |
| `ROCM_PATH` | Install prefix for ROCM. | -- |



Expand Down
14 changes: 14 additions & 0 deletions cmake/gauxc-config.cmake.in
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,20 @@ if( GAUXC_HAS_CUDA )
endif()
endif()

if( GAUXC_HAS_HIP )
enable_language( HIP )
set (CMAKE_HIP_ARCHITECTURES @CMAKE_HIP_ARCHITECTURES@)
set (ROCM_PATH @ROCM_PATH@)

list (PREPEND CMAKE_PREFIX_PATH ${ROCM_PATH} ${ROCM_PATH}/hip ${ROCM_PATH}/hipblas)
set(GPU_TARGETS "${CMAKE_HIP_ARCHITECTURES}" CACHE STRING "AMD GPU targets to compile for")

find_package( hip REQUIRED )
find_package( hipblas REQUIRED )

list(REMOVE_AT CMAKE_PREFIX_PATH 0 1 2)
endif

if( GAUXC_HAS_MPI )
find_dependency( MPI )
endif()
Expand Down
4 changes: 2 additions & 2 deletions cmake/gauxc-dep-versions.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ set( GAUXC_CUB_REVISION 1.10.0 )
set( GAUXC_CUTLASS_REPOSITORY https://github.com/NVIDIA/cutlass.git )
set( GAUXC_CUTLASS_REVISION v2.10.0 )

set( GAUXC_EXCHCXX_REPOSITORY https://github.com/wavefunction91/ExchCXX.git )
set( GAUXC_EXCHCXX_REVISION 21a4700a826ec0beae1311a1d59677393bcb168f )
set( GAUXC_EXCHCXX_REPOSITORY https://github.com/ryanstocks00/ExchCXX.git )
set( GAUXC_EXCHCXX_REVISION 8a0004609afc710bdad4367867026a9daa0a758b)

set( GAUXC_GAU2GRID_REPOSITORY https://github.com/dgasmith/gau2grid.git )
set( GAUXC_GAU2GRID_REVISION v2.0.6 )
Expand Down
16 changes: 16 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -224,3 +224,19 @@ install( FILES
# Install Custom Find Modules
include( ${linalg-cmake-modules_SOURCE_DIR}/LinAlgModulesMacros.cmake )
install_linalg_modules( INSTALL_CONFIGDIR )

# This allows specifying a lower compiler optimization level for NVHPC which fails to compile with the -O3 flag whilst leaving the remaining flags unchanged
if (DEFINED GAUXC_OBARA_SAIKA_COMPILE_OPTIMIZATION_OPTIONS)
get_target_property(default_compile_options gauxc COMPILE_OPTIONS)
get_target_property(gauxc_sources gauxc SOURCES)
set_target_properties(gauxc PROPERTIES COMPILE_OPTIONS "")
set_source_files_properties(${gauxc_sources} PROPERTIES COMPILE_OPTIONS "${default_compile_options}")

file(GLOB OB_HOST_SRC_FILES ${CMAKE_CURRENT_LIST_DIR}/xc_integrator/local_work_driver/host/obara_saika/src/*.cxx)
set(adjusted_compile_options ${default_compile_options})
foreach (flag "[\\/\\-]O3" "[\\/\\-]Ofast" "[\\/\\-]fast")
string(REGEX REPLACE ${flag} ${GAUXC_OBARA_SAIKA_COMPILE_OPTIMIZATION_OPTIONS} adjusted_compile_options "${adjusted_compile_options}")
endforeach()
message("-- Setting Obara-Saika COMPILE_OPTIONS to: ${adjusted_compile_options}")
set_source_files_properties(${OB_HOST_SRC_FILES} PROPERTIES COMPILE_OPTIONS "${adjusted_compile_options}")
endif()
2 changes: 1 addition & 1 deletion src/exceptions/hipblas_exception.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

#ifdef GAUXC_HAS_HIP
#include "hip/hip_runtime.h"
#include <hipblas.h>
#include <hipblas/hipblas.h>

namespace GauXC {

Expand Down
5 changes: 5 additions & 0 deletions src/runtime_environment/device/hip/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,13 @@
# See LICENSE.txt for details
#

list (PREPEND CMAKE_PREFIX_PATH ${ROCM_PATH} ${ROCM_PATH}/hip ${ROCM_PATH}/hipblas)
set(GPU_TARGETS "${CMAKE_HIP_ARCHITECTURES}" CACHE STRING "AMD GPU targets to compile for")

find_package( hip REQUIRED )
find_package( hipblas REQUIRED )

list(REMOVE_AT CMAKE_PREFIX_PATH 0 1 2)

target_sources( gauxc PRIVATE hip_backend.cxx )
target_link_libraries( gauxc PUBLIC hip::host roc::hipblas )
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@
namespace GauXC {
namespace hip {

#ifdef __HIP_PLATFORM_NVIDIA__
static constexpr uint32_t warp_size = 32;
#else
static constexpr uint32_t warp_size = 64;
static constexpr uint32_t max_threads_per_thread_block = 1024;
#endif
static constexpr uint32_t max_threads_per_thread_block = 512;
static constexpr uint32_t max_warps_per_thread_block =
max_threads_per_thread_block / warp_size;

Expand Down
26 changes: 16 additions & 10 deletions src/xc_integrator/local_work_driver/device/cuda/kernels/uvvars.cu
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ void eval_uvars_mgga( size_t ntasks, size_t npts_total, int32_t nbf_max,
{
dim3 threads( cuda::warp_size, cuda::max_warps_per_thread_block / 2, 1 );
dim3 blocks( std::min(uint64_t(4), util::div_ceil( nbf_max, 4 )),
std::min(uint64_t(16), util::div_ceil( nbf_max, 16 )),
std::min(uint64_t(MGGA_KERNEL_SM_BLOCK), util::div_ceil( npts_max, MGGA_KERNEL_SM_BLOCK )),
ntasks );
if(do_lapl)
eval_uvars_mgga_rks_kernel<true><<< blocks, threads, 0, stream >>>( ntasks, device_tasks );
Expand Down Expand Up @@ -614,17 +614,23 @@ __global__ void eval_vvar_kern( size_t ntasks,

const auto* den_basis_prod_device = task.zmat;

const int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
const int tid_y = blockIdx.y * blockDim.y + threadIdx.y;

register double den_reg = 0.;

if( tid_x < nbf and tid_y < npts ) {
int start_y = blockIdx.y * blockDim.y + threadIdx.y;

for (int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
tid_x < nbf;
tid_x += blockDim.x * gridDim.x ) {

for (int tid_y = start_y;
tid_y < npts;
tid_y += blockDim.y * gridDim.y ) {

const double* bf_col = basis_eval_device + tid_x*npts;
const double* db_col = den_basis_prod_device + tid_x*npts;
const double* bf_col = basis_eval_device + tid_x*npts;
const double* db_col = den_basis_prod_device + tid_x*npts;

den_reg = bf_col[ tid_y ] * db_col[ tid_y ];
den_reg += bf_col[ tid_y ] * db_col[ tid_y ];
}

}

Expand All @@ -634,8 +640,8 @@ __global__ void eval_vvar_kern( size_t ntasks,
den_reg = cuda::warp_reduce_sum<warp_size>( den_reg );


if( threadIdx.x == 0 and tid_y < npts ) {
atomicAdd( den_eval_device + tid_y, den_reg );
if( threadIdx.x == 0 and start_y < npts ) {
atomicAdd( den_eval_device + start_y, den_reg );
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,76 @@ GPGAUEVAL_INLINE __device__ void collocation_spherical_unnorm_angular_3_deriv1(

}

template <typename T>
GPGAUEVAL_INLINE __device__ void collocation_spherical_unnorm_angular_4(
int32_t npts,
const T bf,
const T x,
const T y,
const T z,
T* __restrict__ eval
) {

eval[npts * 0] = sqrt_35*bf*x*y*(x*x - y*y)/2;
eval[npts * 1] = sqrt_70*bf*y*z*(3*x*x - y*y)/4;
eval[npts * 2] = sqrt_5*bf*x*y*(-x*x - y*y + 6*z*z)/2;
eval[npts * 3] = sqrt_10*bf*y*z*(-3*x*x - 3*y*y + 4*z*z)/4;
eval[npts * 4] = bf*(3*x*x*x*x + 6*x*x*y*y - 24*x*x*z*z + 3*y*y*y*y - 24*y*y*z*z + 8*z*z*z*z)/8;
eval[npts * 5] = sqrt_10*bf*x*z*(-3*x*x - 3*y*y + 4*z*z)/4;
eval[npts * 6] = sqrt_5*bf*(-x*x*x*x + 6*x*x*z*z + y*y*y*y - 6*y*y*z*z)/4;
eval[npts * 7] = sqrt_70*bf*x*z*(x*x - 3*y*y)/4;
eval[npts * 8] = sqrt_35*bf*(x*x*x*x - 6*x*x*y*y + y*y*y*y)/8;

}

template <typename T>
GPGAUEVAL_INLINE __device__ void collocation_spherical_unnorm_angular_4_deriv1(
const int32_t npts,
const T bf,
const T bf_x,
const T bf_y,
const T bf_z,
const T x,
const T y,
const T z,
T* __restrict__ eval_x,
T* __restrict__ eval_y,
T* __restrict__ eval_z
) {

eval_x[npts * 0] = sqrt_35*y*(bf*(3*x*x - y*y) + bf_x*x*(x*x - y*y))/2;
eval_x[npts * 1] = sqrt_70*y*z*(6*bf*x + bf_x*(3*x*x - y*y))/4;
eval_x[npts * 2] = sqrt_5*y*(-bf*(3*x*x + y*y - 6*z*z) - bf_x*x*(x*x + y*y - 6*z*z))/2;
eval_x[npts * 3] = sqrt_10*y*z*(-6*bf*x - bf_x*(3*x*x + 3*y*y - 4*z*z))/4;
eval_x[npts * 4] = 3*bf*x*(x*x + y*y - 4*z*z)/2 + bf_x*(3*x*x*x*x + 6*x*x*y*y - 24*x*x*z*z + 3*y*y*y*y - 24*y*y*z*z + 8*z*z*z*z)/8;
eval_x[npts * 5] = sqrt_10*z*(-bf*(9*x*x + 3*y*y - 4*z*z) - bf_x*x*(3*x*x + 3*y*y - 4*z*z))/4;
eval_x[npts * 6] = sqrt_5*(-bf*x*(x*x - 3*z*z) - bf_x*(x*x*x*x - 6*x*x*z*z - y*y*y*y + 6*y*y*z*z)/4);
eval_x[npts * 7] = sqrt_70*z*(3*bf*(x*x - y*y) + bf_x*x*(x*x - 3*y*y))/4;
eval_x[npts * 8] = sqrt_35*(4*bf*x*(x*x - 3*y*y) + bf_x*(x*x*x*x - 6*x*x*y*y + y*y*y*y))/8;

eval_y[npts * 0] = sqrt_35*x*(-bf*(-x*x + 3*y*y) + bf_y*y*(x*x - y*y))/2;
eval_y[npts * 1] = sqrt_70*z*(-3*bf*(-x*x + y*y) + bf_y*y*(3*x*x - y*y))/4;
eval_y[npts * 2] = sqrt_5*x*(-bf*(x*x + 3*y*y - 6*z*z) - bf_y*y*(x*x + y*y - 6*z*z))/2;
eval_y[npts * 3] = sqrt_10*z*(-bf*(3*x*x + 9*y*y - 4*z*z) - bf_y*y*(3*x*x + 3*y*y - 4*z*z))/4;
eval_y[npts * 4] = 3*bf*y*(x*x + y*y - 4*z*z)/2 + bf_y*(3*x*x*x*x + 6*x*x*y*y - 24*x*x*z*z + 3*y*y*y*y - 24*y*y*z*z + 8*z*z*z*z)/8;
eval_y[npts * 5] = sqrt_10*x*z*(-6*bf*y - bf_y*(3*x*x + 3*y*y - 4*z*z))/4;
eval_y[npts * 6] = sqrt_5*(bf*y*(y*y - 3*z*z) - bf_y*(x*x*x*x - 6*x*x*z*z - y*y*y*y + 6*y*y*z*z)/4);
eval_y[npts * 7] = sqrt_70*x*z*(-6*bf*y + bf_y*(x*x - 3*y*y))/4;
eval_y[npts * 8] = sqrt_35*(-4*bf*y*(3*x*x - y*y) + bf_y*(x*x*x*x - 6*x*x*y*y + y*y*y*y))/8;

eval_z[npts * 0] = sqrt_35*bf_z*x*y*(x*x - y*y)/2;
eval_z[npts * 1] = sqrt_70*y*(bf + bf_z*z)*(3*x*x - y*y)/4;
eval_z[npts * 2] = sqrt_5*x*y*(12*bf*z - bf_z*(x*x + y*y - 6*z*z))/2;
eval_z[npts * 3] = sqrt_10*y*(3*bf*(-x*x - y*y + 4*z*z) - bf_z*z*(3*x*x + 3*y*y - 4*z*z))/4;
eval_z[npts * 4] = -2*bf*z*(3*x*x + 3*y*y - 2*z*z) + bf_z*(3*x*x*x*x + 6*x*x*y*y - 24*x*x*z*z + 3*y*y*y*y - 24*y*y*z*z + 8*z*z*z*z)/8;
eval_z[npts * 5] = sqrt_10*x*(3*bf*(-x*x - y*y + 4*z*z) - bf_z*z*(3*x*x + 3*y*y - 4*z*z))/4;
eval_z[npts * 6] = sqrt_5*(12*bf*z*(x*x - y*y) - bf_z*(x*x*x*x - 6*x*x*z*z - y*y*y*y + 6*y*y*z*z))/4;
eval_z[npts * 7] = sqrt_70*x*(bf + bf_z*z)*(x*x - 3*y*y)/4;
eval_z[npts * 8] = sqrt_35*bf_z*(x*x*x*x - 6*x*x*y*y + y*y*y*y)/8;

}



template <typename T>
GPGAUEVAL_INLINE __device__ void collocation_spherical_unnorm_angular(
Expand Down Expand Up @@ -239,8 +309,14 @@ GPGAUEVAL_INLINE __device__ void collocation_spherical_unnorm_angular(

collocation_spherical_unnorm_angular_3( npts, bf, x, y, z, eval );

} else if( l == 4 ) {

collocation_spherical_unnorm_angular_4( npts, bf, x, y, z, eval );

} else {

assert( false && "L < L_MAX" );

}

} // collocation_spherical_unnorm_angular
Expand Down Expand Up @@ -284,6 +360,11 @@ GPGAUEVAL_INLINE __device__ void collocation_spherical_unnorm_angular_deriv1(
collocation_spherical_unnorm_angular_3( npts, bf, x, y, z, eval );
collocation_spherical_unnorm_angular_3_deriv1( npts, bf, bf_x, bf_y, bf_z, x, y, z, eval_x, eval_y, eval_z );

} else if( l == 4 ) {

collocation_spherical_unnorm_angular_4( npts, bf, x, y, z, eval );
collocation_spherical_unnorm_angular_4_deriv1( npts, bf, bf_x, bf_y, bf_z, x, y, z, eval_x, eval_y, eval_z );

} else {
assert( false && "L < L_MAX" );
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@

namespace GauXC {

constexpr double sqrt_15 = 3.872983346207417;
constexpr double sqrt_3 = 1.7320508075688772;
constexpr double sqrt_6 = 2.449489742783178;
constexpr double sqrt_5 = 2.23606797749979;
constexpr double sqrt_15 = 3.872983346207417;
constexpr double sqrt_10 = 3.1622776601683795;
constexpr double sqrt_6 = 2.449489742783178;
constexpr double sqrt_35 = 5.916079783099616;
constexpr double sqrt_70 = 8.366600265340756;

} // namespace GauXC
Original file line number Diff line number Diff line change
Expand Up @@ -16,31 +16,45 @@ namespace hip {
template <size_t warp_sz, typename T>
__device__ T warp_reduce_sum( T val ) {

#ifdef __HIP_PLATFORM_NVIDIA__
for(int i=(warp_sz/2); i>=1; i/=2)
val += __shfl_xor_sync(0xffffffff, val, i, warp_sz);

return val;
#else
using warp_reducer = hipcub::WarpReduce<double>;
static __shared__ typename warp_reducer::TempStorage
static __shared__ typename warp_reducer::TempStorage
temp_storage[hip::max_warps_per_thread_block];
int tid =
int tid =
threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y;

int warp_lane = tid / warp_size;

return warp_reducer( temp_storage[warp_lane] ).Sum( val );
#endif

}

template <size_t warp_sz, typename T>
__device__ T warp_reduce_prod( T val ) {

#ifdef __HIP_PLATFORM_NVIDIA__
for(int i=(warp_sz/2); i>=1; i/=2)
val *= __shfl_xor_sync(0xffffffff, val, i, warp_sz);

return val;
#else
using warp_reducer = hipcub::WarpReduce<double>;
static __shared__ typename warp_reducer::TempStorage
static __shared__ typename warp_reducer::TempStorage
temp_storage[hip::max_warps_per_thread_block];
int tid =
int tid =
threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y;

int warp_lane = tid / warp_size;

return warp_reducer( temp_storage[warp_lane] ).Reduce( val,
[](const T& a, const T& b){ return a * b; } );
#endif

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,12 @@ void modify_weights_ssf_kernel_2d( int32_t npts, int32_t natoms,
int cont = (iCenter < natoms);

// We will continue iterating until all of the threads have cont set to 0

#ifdef __HIP_PLATFORM_NVIDIA__
while (__any_sync(__activemask(), cont)) {
#else
while (__any(cont)) {
#endif
if (cont) {
double2 rj[weight_unroll/2];
double2 rab_val[weight_unroll/2];
Expand All @@ -131,8 +136,16 @@ void modify_weights_ssf_kernel_2d( int32_t npts, int32_t natoms,

#pragma unroll
for (int k = 0; k < weight_unroll/2; k++) {
rj[k] = *((double2*)(local_dist_scratch + jCenter) + k);
rab_val[k] = *((double2*)(local_rab + jCenter) + k);
double* addr = (double*)((double2*)(local_dist_scratch + jCenter) + k);
rj[k].x = addr[0];
rj[k].y = addr[1];
double* addr2 = (double*)((double2*)(local_rab + jCenter) + k);
rab_val[k].x = addr2[0];
rab_val[k].y = addr2[1];
// These caused a memory access violation when lddist is not a
// multiple of 2 as then there can be an unaligned access
// rj[k] = *((double2*)(local_dist_scratch + jCenter) + k);
// rab_val[k] = *((double2*)(local_rab + jCenter) + k);
}

#pragma unroll
Expand Down Expand Up @@ -177,7 +190,6 @@ void modify_weights_ssf_kernel_2d( int32_t npts, int32_t natoms,
sum += parent_weight;
weights[ipt] *= parent_weight / sum;
}

#endif
}

Expand Down
Loading