diff --git a/Makefile b/Makefile index 5f997a122..9b8063972 100644 --- a/Makefile +++ b/Makefile @@ -6,7 +6,11 @@ GPP:= /usr/bin/g++ ifeq ($(CUDA_HOME),) CUDA_HOME:= $(shell which nvcc | rev | cut -d'/' -f3- | rev) endif +ifeq ($(ROCM_HOME),) + ROCM_HOME:= $(shell which hipcc | rev | cut -d'/' -f3- | rev) +endif +ifneq ($(CUDA_HOME),) ifndef CUDA_VERSION ifneq ($(MAKECMDGOALS),clean) $(warning WARNING: CUDA_VERSION not set. Call make with CUDA string, for example: make cuda11x CUDA_VERSION=115 or make cpuonly CUDA_VERSION=CPU) @@ -14,9 +18,20 @@ CUDA_VERSION:= endif endif +else ifneq ($(ROCM_HOME),) +ifndef ROCM_TARGET +$(error ERROR: ROCM_TARGET not set. Call make with ROCM string (see https://www.llvm.org/docs/AMDGPUUsage.html#processors), for example: make hip ROCM_TARGET=gfx1030) +ROCM_TARGET:= +endif +else +$(warning WARNING: Unable to find hipcc in path, fallback to ROCM_HOME /opt/rocm) +ROCM_HOME:=/opt/rocm +endif + NVCC := $(CUDA_HOME)/bin/nvcc +HIPCC:= $(ROCM_HOME)/bin/hipcc ########################################### @@ -28,7 +43,8 @@ FILES_CPP := $(CSRC)/common.cpp $(CSRC)/cpu_ops.cpp $(CSRC)/pythonInterface.c INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/include -I $(ROOT_DIR)/include LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcublas -lcublasLt -lcusparse -L $(CONDA_PREFIX)/lib - +HIP_INCLUDE := -I $(ROCM_HOME)/include -I $(ROOT_DIR)/csrc -I $(ROOT_DIR)/include +HIP_LIB := -L $(ROCM_HOME)/lib -lhipblas -lhiprand -lhipsparse #-lhipblaslt, currently only gfx90a # NVIDIA NVCC compilation flags COMPUTE_CAPABILITY += -gencode arch=compute_50,code=sm_50 # Maxwell COMPUTE_CAPABILITY += -gencode arch=compute_52,code=sm_52 # Maxwell @@ -115,6 +131,12 @@ cuda12x: $(BUILD_DIR) env cpuonly: $(BUILD_DIR) env $(GPP) -std=c++14 -shared -fPIC -I $(ROOT_DIR)/csrc -I $(ROOT_DIR)/include $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cpu.so +hip: $(BUILD_DIR) + $(HIPCC) -std=c++14 -c -fPIC --offload-arch=$(ROCM_TARGET) $(HIP_INCLUDE) -o $(BUILD_DIR)/ops.o -DNO_CUBLASLT -DBNB_USE_HIP $(CSRC)/ops.cu + $(HIPCC) -std=c++14 -c -fPIC --offload-arch=$(ROCM_TARGET) $(HIP_INCLUDE) -o $(BUILD_DIR)/kernels.o -DNO_CUBLASLT -DBNB_USE_HIP $(CSRC)/kernels.cu + # HCC is deprecated, but used by hipBLASlt header. Since blas isn't even used doesn't matter, this is just so that it even compiles + $(GPP) -std=c++14 -D__HIP_PLATFORM_HCC__ -D__HIP_PLATFORM_AMD__ -DBUILD_CUDA -DBNB_USE_HIP -shared -fPIC $(HIP_INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(FILES_CPP) $(HIP_LIB) -o ./bitsandbytes/libbitsandbytes_hip_nohipblaslt.so + env: @echo "ENVIRONMENT" @echo "============================" diff --git a/README.md b/README.md index ebf40909f..fb3850a98 100644 --- a/README.md +++ b/README.md @@ -22,10 +22,14 @@ Python >=3.8. Linux distribution (Ubuntu, MacOS, etc.) + CUDA > 10.0. In some cases it can happen that you need to compile from source. If this happens please consider submitting a bug report with `python -m bitsandbytes` information. What now follows is some short instructions which might work out of the box if `nvcc` is installed. If these do not work see further below. Compilation quickstart: + ```bash git clone https://github.com/timdettmers/bitsandbytes.git cd bitsandbytes +``` +For CUDA +```bash # CUDA_VERSIONS in {110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 120} # make argument in {cuda110, cuda11x, cuda12x} # if you do not know what CUDA you have, try looking at the output of: python -m bitsandbytes @@ -33,6 +37,17 @@ CUDA_VERSION=117 make cuda11x python setup.py install ``` +For ROCm +```bash +# Requiers ROCm 5.6+ +# Check if your GPU supports Wave32 with rocminfo | grep "Wavefront Size" +# If this doesn't output 32 and instead 64 this library won't work + +# Your ROCm target can be found with rocminfo | grep gfx +ROCM_TARGET=gfx1030 make hip +pip install . +``` + **Using Int8 inference with HuggingFace Transformers** ```python diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 19f224391..5d400b0e8 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -224,7 +224,7 @@ def backward(ctx, grad_output): def supports_igemmlt(device: torch.device) -> bool: """check if this device supports the optimized int8 kernel""" - if torch.cuda.get_device_capability(device=device) < (7, 5): + if torch.cuda.get_device_capability(device=device) < (7, 5) or torch.version.hip: return False device_name = torch.cuda.get_device_name(device=device) nvidia16_models = ('GTX 1630', 'GTX 1650', 'GTX 1660') # https://en.wikipedia.org/wiki/GeForce_16_series diff --git a/bitsandbytes/cuda_setup/main.py b/bitsandbytes/cuda_setup/main.py index 34c035425..59931cae2 100644 --- a/bitsandbytes/cuda_setup/main.py +++ b/bitsandbytes/cuda_setup/main.py @@ -338,7 +338,9 @@ def evaluate_cuda_setup(): cuda_setup.add_log_entry(('Welcome to bitsandbytes. For bug reports, please run\n\npython -m bitsandbytes\n\n'), ('and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues')) cuda_setup.add_log_entry('='*80) + if not torch.cuda.is_available(): return 'libbitsandbytes_cpu.so', None, None, None + if torch.version.hip: return 'libbitsandbytes_hip_nohipblaslt.so', None, None, None cudart_path = determine_cuda_runtime_lib_path() ccs = get_compute_capabilities() diff --git a/compile_from_source.md b/compile_from_source.md index c2f97088d..cccc448f0 100644 --- a/compile_from_source.md +++ b/compile_from_source.md @@ -38,3 +38,16 @@ If you have problems compiling the library with these instructions from source, Since 0.39.1 bitsandbytes installed via pip no longer provides Kepler binaries and these need to be compiled from source. Follow the steps above and instead of `cuda11x_nomatmul` etc use `cuda11x_nomatmul_kepler` +## Compilation with ROCm + +Since this library requires hipblasLt this only supports **ROCm 5.6+**. +Works well with these docker images: +- [rocm/pytorch](https://hub.docker.com/r/rocm/pytorch) +- [rocm/pytorch-nightly](https://hub.docker.com/r/rocm/pytorch-nightly). + +For installation do: +```bash +make hip ROCM_TARGET=gfx1030 +pip install . +``` +see https://www.llvm.org/docs/AMDGPUUsage.html#processors for finding ROCM_TARGET (e.g. gfx1030 for 6800XT,6900XT) or do `rocminfo | grep gfx`. \ No newline at end of file diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 9ebe0a69e..77447b6e0 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -4,6 +4,23 @@ // LICENSE file in the root directory of this source tree. #include + +#ifdef BNB_USE_HIP +#include +#include +#include +#include +#include +#include +#include +#include +#include +#define cub hipcub +#define __syncwarp __syncthreads //HIP doesn't have this, so just sync threads + +#else +#include +#include #include #include #include @@ -11,18 +28,17 @@ #include #include #include -#include +#endif + #include #include -#include - #define HLF_MAX 65504 #define TH 1024 #define NUM 4 #define NUM_BLOCK 4096 - +#ifndef BNB_USE_HIP // source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda __device__ float atomicMax(float* address, float val) { int* address_as_i = reinterpret_cast(address); @@ -47,6 +63,7 @@ __device__ float atomicMin(float* address, float val) { } while (assumed != old); return __int_as_float(old); } +#endif __device__ float dDequantizeFP4(unsigned char val, float absmax) { @@ -723,21 +740,28 @@ template 0) ? NUM_PER_TH/2 : CUB_NUM_PER_TH; + const int n_full = gridDim.x * BLOCK_SIZE; int valid_items = 0; const int base_idx = (blockIdx.x * BLOCK_SIZE); - T vals[NUM_PER_TH]; - float rand_vals[NUM_PER_TH]; - unsigned char qvals[(DATA_TYPE > 0) ? NUM_PER_TH/2 : NUM_PER_TH]; + T vals[CUB_NUM_PER_TH]; + float rand_vals[CUB_NUM_PER_TH]; + unsigned char qvals[DATA_NUM_PER_TH]; //float local_abs_max = -FLT_MAX; float local_abs_max = 0.0f; int local_rand_idx = 0; - typedef cub::BlockLoad LoadT; - typedef cub::BlockStore 0) ? NUM_PER_TH/2 : NUM_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar; - typedef cub::BlockReduce BlockReduce; - typedef cub::BlockLoad LoadFloat; + typedef cub::BlockLoad LoadT; + typedef cub::BlockStore StoreChar; + typedef cub::BlockReduce BlockReduce; + typedef cub::BlockLoad LoadFloat; __shared__ typename LoadT::TempStorage loadt; __shared__ typename LoadFloat::TempStorage loadf; @@ -762,8 +786,8 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float // 2. broadcast local max // 3. normalize inputs and quantize - #pragma unroll NUM_PER_TH - for(int j = 0; j < NUM_PER_TH; j++) + #pragma unroll CUB_NUM_PER_TH + for(int j = 0; j < CUB_NUM_PER_TH; j++) local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j])); local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, cub::Max(), valid_items); @@ -792,8 +816,8 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float switch(DATA_TYPE) { case General8bit: - #pragma unroll NUM_PER_TH - for(int j = 0; j < NUM_PER_TH; j++) + #pragma unroll CUB_NUM_PER_TH + for(int j = 0; j < CUB_NUM_PER_TH; j++) { if(!STOCHASTIC) qvals[j] = dQuantize<0>(smem_code, 0.0f, ((float)vals[j])*local_abs_max); @@ -802,8 +826,8 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float } break; case FP4: - #pragma unroll NUM_PER_TH - for(int j = 0; j < NUM_PER_TH/2; j++) + #pragma unroll CUB_NUM_PER_TH + for(int j = 0; j < DATA_NUM_PER_TH; j++) { packed_4bit |= dQuantizeFP4(((float)vals[2*j])*local_abs_max) << 4; packed_4bit |= dQuantizeFP4(((float)vals[2*j+1])*local_abs_max); @@ -811,8 +835,8 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float } break; case NF4: - #pragma unroll NUM_PER_TH - for(int j = 0; j < NUM_PER_TH/2; j++) + #pragma unroll CUB_NUM_PER_TH + for(int j = 0; j < DATA_NUM_PER_TH; j++) { packed_4bit |= dQuantizeNF4(((float)vals[2*j])*local_abs_max) << 4; packed_4bit |= dQuantizeNF4(((float)vals[2*j+1])*local_abs_max); diff --git a/csrc/ops.cu b/csrc/ops.cu index 97761216c..4a7c80328 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -5,12 +5,17 @@ #include #include -#include #include #include #include #include +#ifdef BNB_USE_HIP +#include +#else +#include +#endif + using namespace BinSearch; using std::cout; diff --git a/csrc/ops.cuh b/csrc/ops.cuh index f37b3b3af..87203edae 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -12,16 +12,62 @@ #include #include + +#ifdef BNB_USE_HIP + +#include +#include +#include +#include //only using header to allow redefines +#include + +#define cudaPeekAtLastError hipPeekAtLastError +#define cudaMemset hipMemset +#define cudaMemAttachHost hipMemAttachHost +#define cudaMemPrefetchAsync hipMemPrefetchAsync +#define cudaMallocManaged hipMallocManaged +#define cudaDevAttrConcurrentManagedAccess hipDeviceAttributeConcurrentManagedAccess +#define cudaDeviceGetAttribute hipDeviceGetAttribute +#define cublasGemmEx hipblasGemmEx +#define cublasStatus_t hipblasStatus_t +#define CUBLAS_OP_T HIPBLAS_OP_T +#define CUBLAS_OP_N HIPBLAS_OP_N +#define CUDA_R_8I HIPBLAS_R_8I +#define CUDA_R_32I HIPBLAS_R_32I +#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS +#define cublasStatus_t hipblasStatus_t +#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx +#define cublasOperation_t hipblasOperation_t +#define cublasLtMatrixLayoutCreate hipblasLtMatrixLayoutCreate +#define cudaError_t hipError_t +#define cudaGetErrorString hipGetErrorString +#define cudaSuccess hipSuccess +#define cusparseStatus_t hipsparseStatus_t +#define CUSPARSE_STATUS_SUCCESS HIPSPARSE_STATUS_SUCCESS +#define cublasStatus_t hipblasStatus_t +#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS +#define cublasHandle_t hipblasHandle_t +#define cublasCreate_v2 hipblasCreate +#define cusparseHandle_t hipsparseHandle_t +#define cusparseCreate hipsparseCreate +#define __nv_bfloat16 hip_bfloat16 +#define cublasLtHandle_t hipblasLtHandle_t +#define cublasLtCreate hipblasLtCreate +#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT +#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT + +#else #include #include #include #include #include -#include -#include #include #include +#endif +#include +#include diff --git a/include/Algo-Direct2.h b/include/Algo-Direct2.h index d5fa58d12..f9387e5cd 100644 --- a/include/Algo-Direct2.h +++ b/include/Algo-Direct2.h @@ -93,8 +93,8 @@ struct AlgoVecBase::val __m128 vxp = _mm_shuffle_ps(xp01, xp23, (1) + (3 << 2) + (1 << 4) + (3 << 6)); #endif IVec i(u.vec); - IVec vlem = vz < vxm; - IVec vlep = vz < vxp; + IVec vlem = operator< (vz,vxm); + IVec vlep = operator< (vz,vxp); i = i + vlem + vlep; i.store(pr); } @@ -123,8 +123,8 @@ struct AlgoVecBase::val __m128d vxp = _mm_shuffle_pd(vx0, vx1, 3); IVec i(b1, b0); - IVec vlem = (vz < vxm); - IVec vlep = (vz < vxp); + IVec vlem = operator< (vz, vxm); + IVec vlep = operator< (vz, vxp); i = i + vlem + vlep; union { @@ -227,8 +227,8 @@ struct AlgoVecBase::val #endif - IVec vlem = vz < vxm; - IVec vlep = vz < vxp; + IVec vlem = operator< (vz, vxm); + IVec vlep = operator< (vz, vxp); ip = ip + vlem + vlep; ip.store(pr); @@ -277,8 +277,8 @@ struct AlgoVecBase::val // FVec vxp = _mm256_insertf128_pd(_mm256_castpd128_pd256(h01p), h23p, 1); IVec i(u.vec); - IVec vlem = vz < vxm; - IVec vlep = vz < vxp; + IVec vlem = operator< (vz,vxm); + IVec vlep = operator< (vz,vxp); i = i + vlem + vlep; i.extractLo32s().store(pr); }