Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions cmake/modules/RocmSetup.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,24 @@ if(FBGEMM_BUILD_VARIANT STREQUAL BUILD_VARIANT_ROCM)
-Wno-ignored-attributes
-Wno-unused-result)

# is this hipify v2?
execute_process(
COMMAND "${Python_EXECUTABLE}" -c
"from torch.utils.hipify import __version__; print(__version__)"
WORKING_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}"
OUTPUT_VARIABLE _tempvar
RESULT_VARIABLE _resvar
ERROR_VARIABLE _errvar)
if(NOT "${_resvar}" EQUAL "0")
message(WARNING "Failed to execute Python (${Python_EXECUTABLE})\n"
"Result: ${_resvar}\n"
"Error: ${_errvar}\n")
endif()
string(FIND "${_tempvar}" "2" found_pos)
if(found_pos GREATER_EQUAL 0)
list(APPEND HIP_HCC_FLAGS -DHIPIFY_V2)
endif()

BLOCK_PRINT(
"HIP found: ${HIP_FOUND}"
"HIPCC compiler flags:"
Expand Down
3 changes: 2 additions & 1 deletion fbgemm_gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,8 @@ if(FBGEMM_BUILD_VARIANT STREQUAL BUILD_VARIANT_ROCM)
${CMAKE_CURRENT_SOURCE_DIR}/src
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/experimental/example
${CMAKE_CURRENT_SOURCE_DIR}/experimental/gen_ai)
${CMAKE_CURRENT_SOURCE_DIR}/experimental/gen_ai
${CMAKE_CURRENT_SOURCE_DIR}/experimental/gen_ai/src/quantize/common/include/fbgemm_gpu/quantize)

# HIPify all .CU and .CUH sources under the current directory (`/fbgemm_gpu`)
#
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp"
#include "kernels/bf16_grouped_kernel_manifest.h"

#ifdef HIPIFY_V2
#define getCurrentHIPStream getCurrentCUDAStream
#endif

namespace fbgemm_gpu {

// Define useful types that are needed for various kernels.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
*/

#include <ATen/ATen.h>
#ifdef USE_ROCM
#include <c10/hip/HIPStream.h>
#else
#include <c10/cuda/CUDAStream.h>

#ifdef HIPIFY_V2
#define getCurrentHIPStream getCurrentCUDAStream
#endif

#include "ck/ck.hpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
#include <ATen/ATen.h>
#include <c10/hip/HIPStream.h>

#ifdef HIPIFY_V2
#define getCurrentHIPStream getCurrentCUDAStream
#endif

#if defined(USE_ROCM)

#include "ck/ck.hpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
#include <ATen/ATen.h>
#include <c10/hip/HIPStream.h>

#ifdef HIPIFY_V2
#define getCurrentHIPStream getCurrentCUDAStream
#endif

#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
#include <iostream>

#include <ATen/ATen.h>
#ifdef USE_ROCM
#include <c10/hip/HIPStream.h>
#else
#include <c10/cuda/CUDAStream.h>

#ifdef HIPIFY_V2
#define getCurrentHIPStream getCurrentCUDAStream
#endif

#include "ck/ck.hpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
*/

#include <ATen/ATen.h>
#ifdef USE_ROCM
#include <c10/hip/HIPStream.h>
#else
#include <c10/cuda/CUDAStream.h>

#ifdef HIPIFY_V2
#define getCurrentHIPStream getCurrentCUDAStream
#endif

#include "ck/ck.hpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
#include <ATen/core/Tensor.h>
#include <c10/hip/HIPStream.h>

#ifdef HIPIFY_V2
#define getCurrentHIPStream getCurrentCUDAStream
#endif

#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp"
#include "kernels/fp8_rowwise_grouped_kernel_manifest.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
*/
#undef __HIP_NO_HALF_CONVERSIONS__
#include <ATen/ATen.h>
#ifdef USE_ROCM
#include <c10/hip/HIPStream.h>
#else
#include <c10/cuda/CUDAStream.h>

#ifdef HIPIFY_V2
#define getCurrentHIPStream getCurrentCUDAStream
#endif

#include "ck/ck.hpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
#include <iostream>

#include <ATen/ATen.h>
#ifdef USE_ROCM
#include <c10/hip/HIPStream.h>
#else
#include <c10/cuda/CUDAStream.h>

#ifdef HIPIFY_V2
#define getCurrentHIPStream getCurrentCUDAStream
#endif

#include "ck/ck.hpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
#include <ATen/ATen.h>
#include <c10/hip/HIPStream.h>

#ifdef HIPIFY_V2
#define getCurrentHIPStream getCurrentCUDAStream
#endif

#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@

#include <c10/hip/HIPStream.h>

#ifdef HIPIFY_V2
#define getCurrentHIPStream getCurrentCUDAStream
#endif
#
#include <atomic>
#include <cassert>
#include <cmath>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,9 @@

#include <ATen/ATen.h>

// In OSS hipification of the include is not working, so we hipify it manually.
#ifdef USE_ROCM
#include <ATen/hip/HIPEvent.h> // @manual
#include <ATen/hip/HIPGraph.h> // @manual
#include <hip/hip_runtime.h>
#define GPUStream at::hip::HIPStreamMasqueradingAsCUDA
#define GPUStreamGuard at::hip::HIPStreamGuardMasqueradingAsCUDA
#define getStreamFromPool at::hip::getStreamFromPoolMasqueradingAsCUDA
#define gpuStreamCaptureModeRelaxed hipStreamCaptureModeRelaxed
#define gpuEventDefault hipEventDefault
#else
#include <ATen/cuda/CUDAEvent.h>
#include <ATen/cuda/CUDAGraph.h>
#include <cuda_runtime.h>
#define GPUStream at::cuda::CUDAStream
#define GPUStreamGuard at::cuda::CUDAStreamGuard
#define getStreamFromPool at::cuda::getStreamFromPool
#define gpuStreamCaptureModeRelaxed cudaStreamCaptureModeRelaxed
#define gpuEventDefault cudaEventDefault
#endif

#include <ostream>

Expand Down Expand Up @@ -232,8 +215,8 @@ class TuningCache final {
at::cuda::CUDAGraph graph;
{
// CUDAGraph capture must happen on non-default stream
GPUStream stream = getStreamFromPool(true);
GPUStreamGuard streamGuard(stream);
at::cuda::CUDAStream stream = at::cuda::getStreamFromPool(true);
at::cuda::CUDAStreamGuard streamGuard(stream);

// For flexibility, we use cudaStreamCaptureModeRelaxed.
// - cudaStreamCaptureModeGlobal prevents other threads from calling
Expand All @@ -242,7 +225,7 @@ class TuningCache final {
// - cudaStreamCaptureModeThreadLocal prevents CCA from freeing memory.
// Since CUDA graph is preferred for offline benchmark this should be
// fine.
graph.capture_begin({0, 0}, gpuStreamCaptureModeRelaxed);
graph.capture_begin({0, 0}, cudaStreamCaptureModeRelaxed);
for (int i = 0; i < num_iters; ++i) {
kernel(std::forward<Args>(args)...);
}
Expand Down Expand Up @@ -296,8 +279,8 @@ class TuningCache final {

constexpr static std::string_view FBGEMM_CACHE_DIR = ".fbgemm";

at::cuda::CUDAEvent start_ = at::cuda::CUDAEvent(gpuEventDefault);
at::cuda::CUDAEvent stop_ = at::cuda::CUDAEvent(gpuEventDefault);
at::cuda::CUDAEvent start_ = at::cuda::CUDAEvent(cudaEventDefault);
at::cuda::CUDAEvent stop_ = at::cuda::CUDAEvent(cudaEventDefault);

// If FBGEMM_AUTOTUNE_USE_CUDA_GRAPH is set, use CUDA graph for benchmarking.
// CUDA graphs use a separate memory pool to do allocation in PyTorch
Expand Down
Loading