Skip to content
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
c131985
[AMD][ROCm] Improve support of AMD
k-artem Jul 15, 2025
4490ea5
[AMD][ROCm] Fixes review comments
k-artem Jul 25, 2025
77a7e06
[AMD][ROCm] Fixes review comments
k-artem Aug 3, 2025
110d6dd
Merge branch 'master' into improve_support_of_amd_hardware
sfc-gh-truwase Aug 16, 2025
0946828
[AMD][ROCm] Enable BF16 and fixes review's comment
k-artem Aug 18, 2025
c75a4b4
Merge branch 'master' into improve_support_of_amd_hardware
sfc-gh-truwase Aug 19, 2025
f9934bb
Merge branch 'master' into improve_support_of_amd_hardware
sfc-gh-truwase Aug 20, 2025
2d16fb1
Merge branch 'master' into improve_support_of_amd_hardware
loadams Aug 20, 2025
47cb5cc
Merge branch 'master' into improve_support_of_amd_hardware
loadams Aug 20, 2025
a23815a
[AMD][ROCm] Fix format
k-artem Aug 21, 2025
234920e
Merge branch 'master' into improve_support_of_amd_hardware
loadams Aug 28, 2025
4eade1e
Merge branch 'master' into improve_support_of_amd_hardware
loadams Sep 2, 2025
4904d94
Fix BF16 support for AMD
k-artem Oct 13, 2025
4a1d7b7
Remove unnecessary changes
k-artem Oct 13, 2025
7389a8f
Merge branch 'master' into improve_support_of_amd_hardware
k-artem Oct 13, 2025
2b14460
Merge branch 'master' into improve_support_of_amd_hardware
k-artem Oct 13, 2025
2428cb7
Merge branch 'master' into improve_support_of_amd_hardware
k-artem Oct 22, 2025
ab1af24
Merge branch 'master' into improve_support_of_amd_hardware
k-artem Oct 22, 2025
427071c
Merge branch 'master' into improve_support_of_amd_hardware
k-artem Oct 23, 2025
f08fe18
Merge branch 'master' into improve_support_of_amd_hardware
k-artem Oct 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion csrc/fp_quantizer/fp_quantize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// DeepSpeed Team

#include <stdexcept>
#include "context.h"
#include "fp_context.h"
#include "fp_quantize.h"
#include "memory_access_utils.h"
#include "reduction_utils.h"
Expand All @@ -14,6 +14,9 @@

#include <cuda_fp16.h>
#include <curand_kernel.h>
#if defined(__HIP_PLATFORM_AMD__)
#include <hip/hip_fp16.h>
#endif

#ifdef BF16_AVAILABLE
#include <cuda_bf16.h>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@
#include <torch/extension.h>
#include <vector>

#if defined(__HIP_PLATFORM_AMD__)
#include <hip/hip_fp16.h>
#if BF16_AVAILABLE
#include <hip/hip_bf16.h>
#endif
#endif

#define DISPATCH_QUANTIZE(T_TYPE, C_TYPE, mantisa, exponent) \
if (val.options().dtype() == torch::T_TYPE) { \
launch_quantization<C_TYPE, mantisa, exponent>((C_TYPE*)val.data_ptr(), \
Expand Down
69 changes: 69 additions & 0 deletions csrc/includes/conversion_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ DS_D_INLINE __half to(__half val)
{
return val;
}

#ifdef BF16_AVAILABLE
template <>
DS_D_INLINE __nv_bfloat16 to(__nv_bfloat16 val)
Expand Down Expand Up @@ -363,42 +364,74 @@ DS_D_INLINE __nv_bfloat16 to(float val)
template <>
DS_D_INLINE __nv_bfloat16 to(int64_t val)
{
#ifdef __HIP_PLATFORM_AMD__
return __double2bfloat16(__ll2double_rn(val));
#else
return __ll2bfloat16_rn(val);
#endif
}
template <>
DS_D_INLINE __nv_bfloat16 to(int32_t val)
{
#ifdef __HIP_PLATFORM_AMD__
return __float2bfloat16(__int2float_rn(val));
#else
return __int2bfloat16_rn(val);
#endif
}
template <>
DS_D_INLINE __nv_bfloat16 to(int16_t val)
{
#ifdef __HIP_PLATFORM_AMD__
return __float2bfloat16(__int2float_rn(val));
#else
return __short2bfloat16_rn(val);
#endif
}
template <>
DS_D_INLINE __nv_bfloat16 to(int8_t val)
{
#ifdef __HIP_PLATFORM_AMD__
return __float2bfloat16(__int2float_rn(val));
#else
return __int2bfloat16_rn(val);
#endif
}
template <>
DS_D_INLINE __nv_bfloat16 to(uint64_t val)
{
#ifdef __HIP_PLATFORM_AMD__
return __double2bfloat16(__ull2double_rn(val));
#else
return __ull2bfloat16_rn(val);
#endif
}
template <>
DS_D_INLINE __nv_bfloat16 to(uint32_t val)
{
#ifdef __HIP_PLATFORM_AMD__
return __float2bfloat16(__uint2float_rn(val));
#else
return __uint2bfloat16_rn(val);
#endif
}
template <>
DS_D_INLINE __nv_bfloat16 to(uint16_t val)
{
#ifdef __HIP_PLATFORM_AMD__
return __float2bfloat16(__uint2float_rn(val));
#else
return __ushort2bfloat16_rn(val);
#endif
}
template <>
DS_D_INLINE __nv_bfloat16 to(uint8_t val)
{
#ifdef __HIP_PLATFORM_AMD__
return __float2bfloat16(__uint2float_rn(val));
#else
return __uint2bfloat16_rn(val);
#endif
}
#endif

Expand All @@ -412,7 +445,11 @@ DS_D_INLINE __nv_bfloat162 to(float2 val)
template <>
DS_D_INLINE __nv_bfloat162 to(float val)
{
#ifdef __HIP_PLATFORM_AMD__
return __bfloat162bfloat162(__float2bfloat16(val));
#else
return __float2bfloat162_rn(val);
#endif
}
template <>
DS_D_INLINE __nv_bfloat162 to(__half2 val)
Expand Down Expand Up @@ -444,7 +481,11 @@ DS_D_INLINE int64_t to(__half val)
template <>
DS_D_INLINE int64_t to(__nv_bfloat16 val)
{
#ifdef __HIP_PLATFORM_AMD__
return __float2ll_rn(__bfloat162float(val));
#else
return __bfloat162ll_rn(val);
#endif
}
#endif

Expand All @@ -471,7 +512,11 @@ DS_D_INLINE int32_t to(__half val)
template <>
DS_D_INLINE int32_t to(__nv_bfloat16 val)
{
#ifdef __HIP_PLATFORM_AMD__
return __float2int_rn(__bfloat162float(val));
#else
return __bfloat162int_rn(val);
#endif
}
#endif

Expand All @@ -498,7 +543,11 @@ DS_D_INLINE int16_t to(__half val)
template <>
DS_D_INLINE int16_t to(__nv_bfloat16 val)
{
#ifdef __HIP_PLATFORM_AMD__
return __float2int_rn(__bfloat162float(val));
#else
return __bfloat162int_rn(val);
#endif
}
#endif

Expand All @@ -525,7 +574,11 @@ DS_D_INLINE int8_t to(__half val)
template <>
DS_D_INLINE int8_t to(__nv_bfloat16 val)
{
#ifdef __HIP_PLATFORM_AMD__
return __float2int_rn(__bfloat162float(val));
#else
return __bfloat162int_rn(val);
#endif
}
#endif

Expand All @@ -552,7 +605,11 @@ DS_D_INLINE uint64_t to(__half val)
template <>
DS_D_INLINE uint64_t to(__nv_bfloat16 val)
{
#ifdef __HIP_PLATFORM_AMD__
return __float2ull_rn(__bfloat162float(val));
#else
return __bfloat162ull_rn(val);
#endif
}
#endif

Expand All @@ -579,7 +636,11 @@ DS_D_INLINE uint32_t to(__half val)
template <>
DS_D_INLINE uint32_t to(__nv_bfloat16 val)
{
#ifdef __HIP_PLATFORM_AMD__
return __float2uint_rn(__bfloat162float(val));
#else
return __bfloat162uint_rn(val);
#endif
}
#endif

Expand All @@ -606,7 +667,11 @@ DS_D_INLINE uint16_t to(__half val)
template <>
DS_D_INLINE uint16_t to(__nv_bfloat16 val)
{
#ifdef __HIP_PLATFORM_AMD__
return __float2uint_rn(__bfloat162float(val));
#else
return __bfloat162uint_rn(val);
#endif
}
#endif

Expand All @@ -633,7 +698,11 @@ DS_D_INLINE uint8_t to(__half val)
template <>
DS_D_INLINE uint8_t to(__nv_bfloat16 val)
{
#ifdef __HIP_PLATFORM_AMD__
return __float2uint_rn(__bfloat162float(val));
#else
return __bfloat162uint_rn(val);
#endif
}
#endif

Expand Down
35 changes: 34 additions & 1 deletion csrc/includes/reduction_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
#include "ds_kernel_utils.h"
#include "memory_access_utils.h"

#if defined(BF16_AVAILABLE) && defined(__HIP_PLATFORM_AMD__)
#include <hip/hip_bfloat16.h>
#endif

namespace cg = cooperative_groups;

namespace reduce {
Expand Down Expand Up @@ -374,7 +378,11 @@ DS_D_INLINE __half init<ROpType::Max>()
template <>
DS_D_INLINE __nv_bfloat16 init<ROpType::Max>()
{
#ifdef __HIP_PLATFORM_AMD__
constexpr __hip_bfloat16_raw neg_inf = {0xFF80};
#else
constexpr __nv_bfloat16_raw neg_inf = {0xFF80};
#endif
return __nv_bfloat16(neg_inf);
}
#endif
Expand Down Expand Up @@ -573,6 +581,24 @@ DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, T* data)
}
}

#if defined(__HIP_PLATFORM_AMD__)
template <int reduce_width, typename T, ROpType... Ops>
DS_D_INLINE void _warp_with_type_conversion(cg::thread_block_tile<hw_warp_size>& warp_arg, T* data)
{
constexpr int elems = sizeof...(Ops);
if constexpr (!(std::is_integral<T>::value || std::is_floating_point<T>::value)) {
float temp_data[elems];
#pragma unroll
for (int i = 0; i < elems; i++) { temp_data[i] = conversion::to<float>(data[i]); }
_warp<float, Ops...>(warp_arg, temp_data);
#pragma unroll
for (int i = 0; i < elems; i++) { data[i] = conversion::to<T>(temp_data[i]); }
} else {
_warp<T, Ops...>(warp_arg, data);
}
}
#endif // defined(__HIP_PLATFORM_AMD__)

/*
Implementation for primary block reduction that serves both `block` and
`partitioned_block`.
Expand Down Expand Up @@ -600,7 +626,11 @@ DS_D_INLINE void _block(cg::thread_block& tb,
#endif

// Always perform warp-scope reduction
#ifdef __HIP_PLATFORM_AMD__
_warp_with_type_conversion<hw_warp_size, T, Ops...>(warp_arg, data);
#else
_warp<T, Ops...>(warp_arg, data);
#endif

// If max_warps == 1 let's skip the runtime check
if (total_warps != 1) {
Expand All @@ -624,8 +654,11 @@ DS_D_INLINE void _block(cg::thread_block& tb,
} else {
init<Ops...>(data);
}

#ifdef __HIP_PLATFORM_AMD__
_warp_with_type_conversion<total_warps, T, Ops...>(warp_arg, data);
#else
_warp<T, Ops..., total_warps>(warp_arg, data);
#endif

#pragma unroll
for (int i = 0; i < elems; i++) {
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ static void Kernel_Ex(cudaStream_t stream,
static size_t SHMEM_SZ =
max(TilingConfig::SMEM_SIZE_B_TILE + SMEM_SIZE_A1_TILE + SMEM_SIZE_A2_TILE,
TilingConfig::SMEM_SIZE_C_TILE);
cudaFuncSetAttribute(QUANT_GEMM_Kernel<TilingConfig, OutputDataType>,
auto kernel = QUANT_GEMM_Kernel<TilingConfig, OutputDataType>;
cudaFuncSetAttribute(reinterpret_cast<const void*>(kernel),
cudaFuncAttributeMaxDynamicSharedMemorySize,
SHMEM_SZ);
size_t dimN = (N_Global - 1) / TilingConfig::TILE_N + 1;
Expand Down
4 changes: 3 additions & 1 deletion op_builder/fp_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ class FPQuantizerBuilder(CUDAOpBuilder):
def __init__(self, name=None):
name = self.NAME if name is None else name
super().__init__(name=name)
if self.is_rocm_pytorch():
self.enable_bf16 = True

def absolute_name(self):
return f'deepspeed.ops.fp_quantizer.{self.NAME}_op'
Expand Down Expand Up @@ -90,7 +92,7 @@ def filter_ccs(self, ccs):
def sources(self):
return [
"csrc/fp_quantizer/fp_quantize.cu",
"csrc/fp_quantizer/fp_quantize.cpp",
"csrc/fp_quantizer/fp_quantize_api.cu",
]

def extra_ldflags(self):
Expand Down
4 changes: 3 additions & 1 deletion op_builder/transformer_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ class InferenceBuilder(CUDAOpBuilder):
def __init__(self, name=None):
name = self.NAME if name is None else name
super().__init__(name=name)
if self.is_rocm_pytorch():
self.enable_bf16 = True

def absolute_name(self):
return f'deepspeed.ops.transformer.inference.{self.NAME}_op'
Expand Down Expand Up @@ -55,7 +57,7 @@ def filter_ccs(self, ccs):

def sources(self):
return [
'csrc/transformer/inference/csrc/pt_binding.cpp',
'csrc/transformer/inference/csrc/pt_binding.cu',
'csrc/transformer/inference/csrc/gelu.cu',
'csrc/transformer/inference/csrc/relu.cu',
'csrc/transformer/inference/csrc/layer_norm.cu',
Expand Down