diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorScan.h b/unsupported/Eigen/CXX11/src/Tensor/TensorScan.h index 1f545ef1a..0ec62490b 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorScan.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorScan.h @@ -39,8 +39,9 @@ struct nested, 1, { typedef TensorScanOp type; }; -} // end namespace internal + +} // end namespace internal /** \class TensorScan * \ingroup CXX11_Tensor_Module * @@ -242,14 +243,15 @@ struct ScanLauncher { } }; -#if defined(EIGEN_USE_GPU) && defined(EIGEN_CUDACC) // GPU implementation of scan // TODO(ibab) This placeholder implementation performs multiple scans in // parallel, but it would be better to use a parallel scan algorithm and // optimize memory access. +#if defined(EIGEN_CUDACC) || defined(__HIPCC__) template -__global__ void ScanKernel(Self self, Index total_size, typename Self::CoeffReturnType* data) { +__global__ void +ScanKernel(Self self, Index total_size, typename Self::CoeffReturnType* data) { // Compute offset as in the CPU version Index val = threadIdx.x + blockIdx.x * blockDim.x; Index offset = (val / self.stride()) * self.stride() * self.size() + val % self.stride(); @@ -271,7 +273,9 @@ __global__ void ScanKernel(Self self, Index total_size, typename Self::CoeffRetu __syncthreads(); } +#endif +#if defined(EIGEN_USE_GPU) && defined(EIGEN_CUDACC) template struct ScanLauncher { void operator()(const Self& self, typename Self::CoeffReturnType* data) { @@ -283,6 +287,20 @@ struct ScanLauncher { }; #endif // EIGEN_USE_GPU && EIGEN_CUDACC + +#if defined(EIGEN_USE_GPU) && defined(__HIPCC__) +template +struct ScanLauncher { + void operator()(const Self& self, typename Self::CoeffReturnType* data) { + Index total_size = internal::array_prod(self.dimensions()); + Index num_blocks = (total_size / self.size() + 63) / 64; + Index block_size = 64; + hipLaunchKernelGGL(HIP_KERNEL_NAME(ScanKernel), dim3(num_blocks), dim3(block_size), 0, self.device().stream(), self, total_size, data); + } +}; +#endif + + } // end namespace Eigen #endif // EIGEN_CXX11_TENSOR_TENSOR_SCAN_H