diff --git a/.github/scripts/build-rocm.sh b/.github/scripts/build-rocm.sh new file mode 100644 index 000000000..b508fac69 --- /dev/null +++ b/.github/scripts/build-rocm.sh @@ -0,0 +1,21 @@ +#!/bin/bash +declare build_arch +declare build_os +declare rocm_version + +set -xeuo pipefail +bnb_rocm_arch="gfx90a;gfx942;gfx1100" +if [ "${build_os:0:6}" == ubuntu ]; then + image=rocm/dev-ubuntu-22.04:${rocm_version}-complete + echo "Using image $image" + docker run --rm --platform "linux/$build_arch" -i \ + -w /src -v "$PWD:/src" "$image" sh -c \ + "apt-get update \ + && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends cmake \ + && cmake -DCOMPUTE_BACKEND=hip -DBNB_ROCM_ARCH=\"${bnb_rocm_arch}\" . \ + && cmake --build ." +fi + +output_dir="output/${build_os}/${build_arch}" +mkdir -p "${output_dir}" +(shopt -s nullglob && cp bitsandbytes/*.{so,dylib,dll} "${output_dir}") diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index d191342ab..0ccb10880 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -58,6 +58,7 @@ jobs: # This job matrix builds the CUDA versions of the libraries for platforms that support CUDA (Linux x64/aarch64 + Windows x64) ## build-shared-libs-cuda: + if: github.ref_name != 'multi-backend-refactor' strategy: fail-fast: false matrix: @@ -102,11 +103,55 @@ jobs: name: shared_library_cuda_${{ matrix.os }}_${{ matrix.arch }}_${{ matrix.cuda_version }} path: output/* retention-days: 7 - + build-shared-libs-rocm: + strategy: + matrix: + os: [ubuntu-latest] + arch: [x86_64] + rocm_version: + ["6.1.2", "6.2.4", "6.3.2"] + runs-on: ${{ matrix.os }} # One day, we could run them on native agents. Azure supports this now but it's planned only for Q3 2023 for hosted agents + steps: + - uses: actions/checkout@v4 + - name: Set up Docker multiarch + if: startsWith(matrix.os, 'ubuntu') + uses: docker/setup-qemu-action@v2 + - name: Clean up disk space + run: | + sudo rm -rf \ + /usr/share/dotnet \ + /opt/ghc \ + "/usr/local/share/boost" \ + "$AGENT_TOOLSDIRECTORY" \ + /opt/hostedtoolcache \ + /opt/google/chrome \ + /opt/microsoft/msedge \ + /opt/microsoft/powershell \ + /opt/pipx \ + /usr/lib/mono \ + /usr/local/julia* \ + /usr/local/lib/android \ + /usr/local/lib/node_modules \ + /usr/local/share/chromium \ + /usr/local/share/powershell \ + /usr/share/swift + - name: Build C++ + run: bash .github/scripts/build-rocm.sh + env: + build_os: ${{ matrix.os }} + build_arch: ${{ matrix.arch }} + rocm_version: ${{ matrix.rocm_version }} + - name: Upload build artifact + uses: actions/upload-artifact@v4 + with: + name: shared_library_rocm_${{ matrix.os }}_${{ matrix.arch }}_${{ matrix.rocm_version }} + path: output/* + retention-days: 7 build-wheels: needs: - build-shared-libs - - build-shared-libs-cuda + # - build-shared-libs-cuda reduce the pkg size + build times for the preview release + - build-shared-libs-rocm strategy: matrix: os: [ubuntu-latest, macos-latest, windows-latest] @@ -123,7 +168,16 @@ jobs: runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v4 - - name: Download build artifacts + with: + fetch-depth: 0 # Needed for setuptools_scm. + # with: + # fetch-depth: 1 # shallow clone + # - name: Fetch tags for dynamic versioning in setup.py + # run: | + # git fetch --depth=1 origin --tags + # echo "Available Git tags:" + # git tag -n + - name: Download build artifact uses: actions/download-artifact@v4 with: merge-multiple: true @@ -140,7 +194,7 @@ jobs: python-version: ${{ matrix.python-version }} cache: pip - run: pip install build wheel - - run: python -m build . + - run: python -m build . -w - name: Determine and Set Platform Tag, then Tag Wheel shell: bash run: | @@ -157,7 +211,7 @@ jobs: upload-pre-release-wheels: name: Create release and upload artifacts runs-on: ubuntu-latest - if: github.ref_name == 'main' + if: github.ref_name == 'multi-backend-refactor' permissions: contents: write needs: @@ -188,8 +242,8 @@ jobs: with: files: wheels/*.whl prerelease: true - name: Latest `main` wheel - tag_name: continuous-release_main + name: Multi-Backend Preview + tag_name: continuous-release_multi-backend-refactor make_latest: false draft: false target_commitish: ${{ github.sha }} diff --git a/.gitignore b/.gitignore index aca1983d3..8e0fae35c 100644 --- a/.gitignore +++ b/.gitignore @@ -153,6 +153,8 @@ dmypy.json # vim *.swp +# BNB-specific stuff dependencies cuda_build output/ +bitsandbytes/_version.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8ac37502e..0a6a9e874 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,3 +21,4 @@ repos: rev: v1.26.0 hooks: - id: typos + exclude: ^.*\.hip$ diff --git a/CMakeLists.txt b/CMakeLists.txt index d16e94690..9d0628616 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -3,7 +3,7 @@ # For GCC: `cmake -B build . && cmake --build build` # For MSVC: `cmake -B build . && cmake --build build --config Release` # You can also use the following options and variables -# - COMPUTE_BACKEND: Set to `cpu`, `cuda`, or `mps` to select the backend +# - COMPUTE_BACKEND: Set to `cpu`, `cuda`, `hip`, `mps` or `npu` to select the backend # - CUDA_VERSION: The expected CUDA version, for sanity checking. The actual version # is whatever CMake finds on your path. # - COMPUTE_CAPABILITY: Which GPU Arch/Compute codes to provide to NVCC. @@ -25,13 +25,15 @@ endif() # Define included source files set(CPP_FILES csrc/common.cpp csrc/cpu_ops.cpp csrc/pythonInterface.cpp) set(CUDA_FILES csrc/ops.cu csrc/kernels.cu) +set(HIP_FILES csrc/ops.hip csrc/kernels.hip) set(MPS_FILES csrc/mps_ops.mm) set(METAL_FILES csrc/mps_kernels.metal) +set(NPU_FILES csrc/npu_ops.cpp) # C++ sources are always included list(APPEND SRC_FILES ${CPP_FILES}) -set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, mps)") -set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda mps) +set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, hip, mps, npu)") +set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda hip mps npu) option(PTXAS_VERBOSE "Pass through -v flag to PTX Assembler" OFF) if(APPLE) @@ -47,15 +49,32 @@ if(${COMPUTE_BACKEND} STREQUAL "cuda") message(FATAL_ERROR "CUDA is not supported on macOS" ) endif() set(BUILD_CUDA ON) + set(BUILD_HIP OFF) + set(BUILD_MPS OFF) + message(STATUS "NO_CUBLASLT := ${NO_CUBLASLT}") +elseif(${COMPUTE_BACKEND} STREQUAL "hip") + if(APPLE) + message(FATAL_ERROR "HIP is not supported on macOS" ) + endif() + option(NO_CUBLASLT "Disable HIPBLASLT" OFF) + set(BUILD_CUDA OFF) + set(BUILD_HIP ON) set(BUILD_MPS OFF) elseif(${COMPUTE_BACKEND} STREQUAL "mps") if(NOT APPLE) message(FATAL_ERROR "MPS is only supported on macOS" ) endif() set(BUILD_CUDA OFF) + set(BUILD_HIP OFF) set(BUILD_MPS ON) +elseif(${COMPUTE_BACKEND} STREQUAL "npu") + set(BUILD_CUDA OFF) + set(BUILD_HIP OFF) + set(BUILD_MPS OFF) + set(BUILD_NPU ON) else() set(BUILD_CUDA OFF) + set(BUILD_HIP OFF) set(BUILD_MPS OFF) endif() @@ -175,6 +194,36 @@ if(BUILD_CUDA) string(APPEND BNB_OUTPUT_NAME "_cuda${CUDA_VERSION_SHORT}") add_compile_definitions(BUILD_CUDA) +elseif(BUILD_HIP) + enable_language(HIP) + message(STATUS "HIP Compiler: ${CMAKE_HIP_COMPILER}") + if(DEFINED BNB_ROCM_ARCH) + set(CMAKE_HIP_ARCHITECTURES ${BNB_ROCM_ARCH}) + else() + if (NOT AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES) + set(CMAKE_HIP_ARCHITECTURES "gfx90a;gfx942;gfx1100") + elseif (AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES) + set(CMAKE_HIP_ARCHITECTURES ${AMDGPU_TARGETS}) + endif() + endif() + message(STATUS "HIP Targets: ${CMAKE_HIP_ARCHITECTURES}") + + list(APPEND SRC_FILES ${HIP_FILES}) + + string(APPEND BNB_OUTPUT_NAME "_rocm") + + # get hip version + execute_process(COMMAND hipconfig --version OUTPUT_VARIABLE HIP_CONFIG_VERSION) + string(REGEX MATCH "[0-9]+\\.[0-9]+" HIP_VERSION "${HIP_CONFIG_VERSION}") + string(REPLACE "." "" HIP_VERSION_SHORT "${HIP_VERSION}") + + string(APPEND BNB_OUTPUT_NAME "${HIP_VERSION_SHORT}") + if(NO_CUBLASLT OR HIP_VERSION VERSION_LESS "6.1") + string(APPEND BNB_OUTPUT_NAME "_nohipblaslt") + endif() + add_compile_definitions(__HIP_PLATFORM_AMD__) + add_compile_definitions(__HIP_PLATFORM_HCC__) + add_compile_definitions(BUILD_HIP) elseif(BUILD_MPS) if(NOT APPLE) message(FATAL_ERROR "MPS is only supported on macOS" ) @@ -194,6 +243,33 @@ elseif(BUILD_MPS) COMMENT "Compiling Metal kernels" VERBATIM) add_custom_target(metallib DEPENDS "bitsandbytes/bitsandbytes.metallib") +elseif(BUILD_NPU) + list(APPEND SRC_FILES ${NPU_FILES}) + + set(SOC_VERSION "Ascend910B4" CACHE STRING "system on chip type") + set(ASCEND_CANN_PACKAGE_PATH $ENV{ASCEND_HOME_PATH} CACHE + STRING "ASCEND CAN package installation directory" + ) + + # ${KERNEL_FILES} are used to compile library, push files written by ascendc in ${KERNEL_FILES}. + # ref to cmake/npu.cmake ascendc_library, cmake/cpu.cmake add_library + # file(GLOB KERNEL_FILES ${CMAKE_CURRENT_SOURCE_DIR}/csrc/npu_kernels.cpp) + file(GLOB KERNEL_FILES csrc/npu_kernels.cpp) + + if(EXISTS ${ASCEND_CANN_PACKAGE_PATH}/compiler/tikcpp/ascendc_kernel_cmake) + set(ASCENDC_CMAKE_DIR ${ASCEND_CANN_PACKAGE_PATH}/compiler/tikcpp/ascendc_kernel_cmake) + elseif(EXISTS ${ASCEND_CANN_PACKAGE_PATH}/tools/tikcpp/ascendc_kernel_cmake) + set(ASCENDC_CMAKE_DIR ${ASCEND_CANN_PACKAGE_PATH}/tools/tikcpp/ascendc_kernel_cmake) + else() + message(FATAL_ERROR "ascendc_kernel_cmake does not exist ,please check whether the can package is installed") + endif() + include(${ASCENDC_CMAKE_DIR}/ascendc.cmake) + + # ascendc_library use to add kernel file to generate ascendc library + ascendc_library(ascendc_kernels_npu STATIC ${KERNEL_FILES}) + + string(APPEND BNB_OUTPUT_NAME "_npu") + add_compile_definitions(BUILD_NPU) else() string(APPEND BNB_OUTPUT_NAME "_cpu") set(GPU_SOURCES) @@ -211,7 +287,11 @@ endif() set_source_files_properties(${CPP_FILES} PROPERTIES LANGUAGE CXX) add_library(bitsandbytes SHARED ${SRC_FILES}) -target_compile_features(bitsandbytes PUBLIC cxx_std_14) +if(BUILD_NPU) + target_compile_features(bitsandbytes PUBLIC cxx_std_17) +else() + target_compile_features(bitsandbytes PUBLIC cxx_std_14) +endif() target_include_directories(bitsandbytes PUBLIC csrc include) @@ -223,10 +303,49 @@ if(BUILD_CUDA) CUDA_SEPARABLE_COMPILATION ON ) endif() +if(BUILD_HIP) + if(NOT DEFINED ENV{ROCM_PATH}) + set(ROCM_PATH /opt/rocm) + else() + set(ROCM_PATH $ENV{ROCM_PATH}) + endif() + list(APPEND CMAKE_PREFIX_PATH ${ROCM_PATH}) + macro(find_package_and_print_version PACKAGE_NAME) + find_package("${PACKAGE_NAME}" ${ARGN}) + message("${PACKAGE_NAME} VERSION: ${${PACKAGE_NAME}_VERSION}") + endmacro() + find_package_and_print_version(hipblas REQUIRED) + find_package_and_print_version(hiprand REQUIRED) + find_package_and_print_version(hipsparse REQUIRED) + + ## hacky way of excluding hip::amdhip64 (with it linked many tests unexpectedly fail e.g. adam8bit because of inaccuracies) + set_target_properties(hip::host PROPERTIES INTERFACE_LINK_LIBRARIES "") + set_target_properties(hip-lang::host PROPERTIES INTERFACE_LINK_LIBRARIES "") + set(CMAKE_HIP_IMPLICIT_LINK_LIBRARIES "") + + target_include_directories(bitsandbytes PRIVATE ${CMAKE_SOURCE_DIR} ${CMAKE_SOURCE_DIR}/include ${ROCM_PATH}/include /include) + target_link_directories(bitsandbytes PRIVATE ${ROCM_PATH}/lib /lib) + target_link_libraries(bitsandbytes PUBLIC roc::hipblas hip::hiprand roc::hipsparse) + + target_compile_definitions(bitsandbytes PUBLIC BNB_USE_HIP) + set_source_files_properties(${HIP_FILES} PROPERTIES LANGUAGE HIP) + set_target_properties(bitsandbytes PROPERTIES LINKER_LANGUAGE CXX) + + if(NO_CUBLASLT OR HIP_VERSION VERSION_LESS "6.1") + target_compile_definitions(bitsandbytes PUBLIC NO_HIPBLASLT) + else() + find_package(hipblaslt) + target_link_libraries(bitsandbytes PUBLIC roc::hipblaslt) + endif() +endif() if(BUILD_MPS) add_dependencies(bitsandbytes metallib) target_link_libraries(bitsandbytes objc "-framework Foundation" "-framework Metal" "-framework MetalPerformanceShaders" "-framework MetalPerformanceShadersGraph") endif() +if(BUILD_NPU) + target_compile_options(bitsandbytes PRIVATE -O2 -std=c++17) + target_link_libraries(bitsandbytes PRIVATE $ ascendc_kernels_npu) +endif() if(WIN32) set_target_properties(bitsandbytes PROPERTIES PREFIX "lib") diff --git a/_typos.toml b/_typos.toml index 955c6cb79..2ced2a049 100644 --- a/_typos.toml +++ b/_typos.toml @@ -3,6 +3,7 @@ [default] extend-ignore-re = [ "@Ther-nul", # valid Github user + "CANN", # CANN (Compute Architecture for Neural Networks) is a heterogeneous computing architecture for Ascend NPU ] extend-ignore-identifiers-re = [ ".*arange.*", @@ -11,6 +12,8 @@ extend-ignore-identifiers-re = [ [type.py.extend-words] "BA" = "BA" # used as a commented-out variable in tests +"cann" = "cann" # cann (Compute Architecture for Neural Networks) is a heterogeneous computing architecture for Ascend NPU + [type.cuda.extend-words] "subtile" = "subtile" diff --git a/benchmarking/generation_benchmark.py b/benchmarking/generation_benchmark.py new file mode 100755 index 000000000..a03bf7e83 --- /dev/null +++ b/benchmarking/generation_benchmark.py @@ -0,0 +1,67 @@ +import argparse + +import torch +import torch.utils.benchmark as benchmark +from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig + +parser = argparse.ArgumentParser() + +parser.add_argument( + "--model_name", default="meta-llama/Llama-3.1-8B-Instruct", required=False, type=str, help="model_name" +) +parser.add_argument("--quant_type", default="int8", type=str, help="quant type", choices=["int8", "nf4", "fp4"]) +parser.add_argument("--device_map", default="cpu", type=str, help="device_map", choices=["cpu", "xpu", "cuda"]) +args = parser.parse_args() + +model_name = args.model_name +device_map = args.device_map +if args.quant_type == "int8": + quantization_config = BitsAndBytesConfig(load_in_8bit=True) +else: + quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type=args.quant_type, + bnb_4bit_use_double_quant=True, + bnb_4bit_compute_dtype=torch.bfloat16, + ) +quantized_model = AutoModelForCausalLM.from_pretrained( + model_name, torch_dtype="auto", device_map=device_map, quantization_config=quantization_config +) +tokenizer = AutoTokenizer.from_pretrained(model_name) +input_text = "What are we having for dinner?" +input_ids = tokenizer(input_text, return_tensors="pt").to(quantized_model.device) + +output = quantized_model.generate(**input_ids, max_new_tokens=10) +print(tokenizer.decode(output[0], skip_special_tokens=True)) + + +# benchmark the performance +def benchmark_fn(f, *args, **kwargs): + # Manual warmup + for _ in range(2): + f(*args, **kwargs) + + t0 = benchmark.Timer( + stmt="f(*args, **kwargs)", + globals={"args": args, "kwargs": kwargs, "f": f}, + num_threads=torch.get_num_threads(), + ) + return t0.blocked_autorange().mean + + +MAX_NEW_TOKENS = 100 + +quantized_model_latency = benchmark_fn(quantized_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS) + +bf16_model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device_map, torch_dtype=torch.bfloat16) +bf16_model_latency = benchmark_fn(bf16_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS) + +print(f"bnb model latency: {quantized_model_latency:.3f}") +print(f"bf16 model latency: {bf16_model_latency:.3f}") +print(f"BNB vs. bf16 model speed-up: {(bf16_model_latency / quantized_model_latency):.3f}") + +print(f"BNB model memory: {(quantized_model.get_memory_footprint() / 1024 / 1024 / 1024):.3f} GB") +print(f"bf16 model memory: {(bf16_model.get_memory_footprint() / 1024 / 1024 / 1024):.3f} GB") +print( + f"BNB vs. bf16 model memory ratio: {(bf16_model.get_memory_footprint() / quantized_model.get_memory_footprint()):.3f}" +) diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index c8530a76c..59f881cc9 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -3,6 +3,11 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +# Import the dynamically generated version from _version.py (see setup.py) +from ._version import __version__ # isort: skip # type: ignore + +import torch + from . import research, utils from .autograd._functions import ( MatmulLtState, @@ -12,13 +17,74 @@ matmul_cublas, mm_cublas, ) -from .nn import modules -from .optim import adam +from .backends import backends, register_backend +from .backends.cpu import CPUBackend +from .backends.npu import NPUBackend +from .cextension import lib + +features = {"multi_backend"} +supported_torch_devices = { + "cuda", # includes ROCm + "npu", # Ascend NPU + "xpu", # Intel GPU + "cpu", + "hpu", # Intel Gaudi +} + +# Always register the CPU backend. +register_backend("cpu", CPUBackend()) + +# Register HPU Backend, if available +if hasattr(torch, "hpu") and torch.hpu.is_available(): + from .backends.hpu import HPUBackend + + register_backend("hpu", HPUBackend()) + +# Register either CUDA or ROCm backend, if available. +# Only one of these backends can be used at a time, since the torch.device semantics are +# the same for both torch+rocm and torch+cuda (e.g. device name is "cuda") +if torch.cuda.is_available(): + # TODO: Consider deferring loading of cextension - should backend class implement that? + + if torch.version.cuda: + from .backends.cuda import CUDABackend + + register_backend("cuda", CUDABackend()) + elif torch.version.hip: + from .backends.rocm import ROCmBackend + + register_backend("cuda", ROCmBackend()) + +# Register MPS backend, if available. +if torch.backends.mps.is_available() and torch.backends.mps.is_built(): + from .backends.mps import MPSBackend + + register_backend("mps", MPSBackend()) + +# Register Intel XPU backend, if available. +if hasattr(torch, "xpu") and torch.xpu.is_available(): + from .backends.xpu import XPUBackend + + register_backend("xpu", XPUBackend()) + +# Register Ascend NPU backend, if available. +if hasattr(torch, "npu") and torch.npu.is_available(): + register_backend("npu", NPUBackend()) + + +# import module after decided backends +if backends: + from .nn import modules + +# TODO: Other potential backends: +# XLA - Google TPU / PJRT runtime +# HPU - Habana / Intel Gaudi +# IPU - Graphcore +# Note that we may not map 1:1 with a device type, e.g. SYCL, XLA +# In this case, it will be up to each backend to dispatch as needed __pdoc__ = { "libbitsandbytes": False, "optim.optimizer.Optimizer8bit": False, "optim.optimizer.MockArgs": False, } - -__version__ = "0.45.3.dev0" diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index f66cdf68d..d224cfe1c 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -7,6 +7,7 @@ import torch from typing_extensions import deprecated +from bitsandbytes.cextension import BNB_HIP_VERSION import bitsandbytes.functional as F # The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov: @@ -88,6 +89,9 @@ def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) - :param tile_indices: reverse transformation indices, from get_inverse_transform_indices :return: contiguous row-major tensor """ + # CPU has no change on layout + if permuted_tensor.device.type == "cpu": + return permuted_tensor (rows, cols), (tile_rows, tile_cols) = permuted_tensor.shape, tile_indices.shape assert rows % tile_rows == cols % tile_cols == 0, "tensor must contain a whole number of tiles" tensor = permuted_tensor.reshape(-1, tile_indices.numel()).t() @@ -216,6 +220,10 @@ def backward(ctx, grad_output): @deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) def supports_igemmlt(device: torch.device) -> bool: """check if this device supports the optimized int8 kernel""" + if device == torch.device("cpu") or torch.device("xpu"): + return True + if torch.version.hip: + return False if BNB_HIP_VERSION < 601 else True if torch.cuda.get_device_capability(device=device) < (7, 5): return False device_name = torch.cuda.get_device_name(device=device) @@ -311,8 +319,11 @@ def forward( input_shape = A.shape # Cast A to fp16 - if A.dtype != torch.float16: - warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization") + A_dtype = torch.float16 + if A.device == torch.device("cpu"): + A_dtype = torch.bfloat16 + if A.dtype != A_dtype: + warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to {A_dtype} during quantization") if len(A.shape) == 3: A = A.reshape(-1, A.shape[-1]) @@ -459,7 +470,12 @@ def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState] # 1. Dequantize # 2. MatmulnN - output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias) + if A.device.type == "npu": + output = torch.matmul(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t()) + if bias is not None: + output += bias + else: + output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias) # 3. Save state ctx.state = quant_state @@ -490,11 +506,37 @@ def backward(ctx, grad_output): # not supported by PyTorch. TODO: create work-around # if req_gradB: grad_B = torch.matmul(grad_output.t(), A) if req_gradA: - grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype).t()) + if grad_output.device.type == "npu": + grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype)) + else: + grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype).t()) return grad_A, grad_B, None, grad_bias, None +class MatMul8bitFp(torch.autograd.Function): + # For Intel CPU and XPU, the double quant has many unsafe operations which will breaks the finetune. + # We'd like to use dequant + matmul to run finetune currently. + + @staticmethod + def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): + CB = B.data.to(A.dtype).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)).t() + output = torch.matmul(A, CB).to(A.dtype) + ctx.state = state + ctx.dtype_A = A.dtype + ctx.grad_shape = A.shape + return output + + @staticmethod + def backward(ctx, grad_output): + state = ctx.state + B = state.CxB if state.CxB is not None else state.CB + CB = B.to(ctx.dtype_A).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) + grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) + + return grad_A, None, None, None, None + + def matmul( A: torch.Tensor, B: torch.Tensor, @@ -506,6 +548,8 @@ def matmul( state = state or MatmulLtState() if threshold > 0.0: state.threshold = threshold + if A.device.type in ("cpu", "xpu") and state.is_training: + return MatMul8bitFp.apply(A, B, out, bias, state) return MatMul8bitLt.apply(A, B, out, bias, state) @@ -517,8 +561,16 @@ def matmul_4bit( bias: Optional[torch.Tensor] = None, ): assert quant_state is not None - - if A.numel() == A.shape[-1] and A.requires_grad == False: + if A.device.type in ("cpu", "xpu") and A.requires_grad == False: + if getattr(quant_state, "ipex", False): + B = B.t() if len(B.shape) == 2 else B + out = F.gemv_4bit(A, B, out, state=quant_state) + if bias is not None: + out += bias + return out + else: + return MatMul4Bit.apply(A, B, out, bias, quant_state) + elif A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "npu": if A.shape[-1] % quant_state.blocksize != 0: warn( f"Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}", diff --git a/bitsandbytes/backends/__init__.py b/bitsandbytes/backends/__init__.py new file mode 100644 index 000000000..30f08073a --- /dev/null +++ b/bitsandbytes/backends/__init__.py @@ -0,0 +1,15 @@ +from typing import Dict + +from bitsandbytes.backends.base import Backend + +backends: Dict[str, Backend] = {} + + +def register_backend(backend_name: str, backend_instance: Backend): + backends[backend_name.lower()] = backend_instance + + +def ensure_backend_is_available(device_type: str): + """Check if a backend is available for the given device type.""" + if device_type.lower() not in backends: + raise NotImplementedError(f"Device backend for {device_type} is currently not supported.") diff --git a/bitsandbytes/backends/base.py b/bitsandbytes/backends/base.py new file mode 100644 index 000000000..69fabe1c6 --- /dev/null +++ b/bitsandbytes/backends/base.py @@ -0,0 +1,419 @@ +from abc import ABC, abstractmethod +from typing import Literal, Optional, Tuple + +import torch + +from bitsandbytes.utils import QuantState + + +class Backend(ABC): + """Base class for devices backends that will implement their own 8bits and 4bits functions.""" + + @abstractmethod + def int8_double_quant( + self, + A: torch.Tensor, + col_stats: Optional[torch.Tensor] = None, + row_stats: Optional[torch.Tensor] = None, + out_col: Optional[torch.Tensor] = None, + out_row: Optional[torch.Tensor] = None, + threshold=0.0, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """Determine the quantization statistics for input matrix `A` in accordance to the `LLM.int8()` algorithm. + + The statistics are determined both row-wise and column-wise (transposed). + + For more information, see the [LLM.int8() paper](https://arxiv.org/abs/2208.07339). + + + This function is useful for training, but for inference it is advised to use [`int8_vectorwise_quant`] instead. + This implementation performs additional column-wise transposed calculations which are not optimized. + + + Args: + A (`torch.Tensor` with dtype `torch.float16`): The input matrix. + col_stats (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the column-wise quantization scales. + row_stats (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the row-wise quantization scales. + out_col (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the column-wise quantized data. + out_row (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the row-wise quantized data. + threshold (`float`, *optional*): + An optional threshold for sparse decomposition of outlier features. + + No outliers are held back when 0.0. Defaults to 0.0. + + Returns: + `Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]`: A tuple containing the quantized tensor and relevant statistics. + - `torch.Tensor` with dtype `torch.int8`: The row-wise quantized data. + - `torch.Tensor` with dtype `torch.int8`: The column-wise quantized data. + - `torch.Tensor` with dtype `torch.float32`: The row-wise quantization scales. + - `torch.Tensor` with dtype `torch.float32`: The column-wise quantization scales. + - `torch.Tensor` with dtype `torch.int32`, *optional*: A list of column indices which contain outlier features. + """ + ... + + @abstractmethod + def int8_linear_matmul( + self, + A: torch.Tensor, + B: torch.Tensor, + out: Optional[torch.Tensor] = None, + dtype=torch.int32, + ) -> torch.Tensor: + """Performs an 8-bit integer matrix multiplication. + + A linear transformation is applied such that `out = A @ B.T`. When possible, integer tensor core hardware is + utilized to accelerate the operation. + + Args: + A (`torch.Tensor`): The first matrix operand with the data type `torch.int8`. + B (`torch.Tensor`): The second matrix operand with the data type `torch.int8`. + out (`torch.Tensor`, *optional*): A pre-allocated tensor used to store the result. + dtype (`torch.dtype`, *optional*): The expected data type of the output. Defaults to `torch.int32`. + + Raises: + `NotImplementedError`: The operation is not supported in the current environment. + `RuntimeError`: Raised when the cannot be completed for any other reason. + + Returns: + `torch.Tensor`: The result of the operation. + """ + ... + + @abstractmethod + def int8_mm_dequant( + self, + A: torch.Tensor, + row_stats: torch.Tensor, + col_stats: torch.Tensor, + out: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Performs dequantization on the result of a quantized int8 matrix multiplication. + + Args: + A (`torch.Tensor` with dtype `torch.int32`): The result of a quantized int8 matrix multiplication. + row_stats (`torch.Tensor`): The row-wise quantization statistics for the lhs operand of the matrix multiplication. + col_stats (`torch.Tensor`): The column-wise quantization statistics for the rhs operand of the matrix multiplication. + out (`torch.Tensor`, *optional*): A pre-allocated tensor to store the output of the operation. + bias (`torch.Tensor`, *optional*): An optional bias vector to add to the result. + + Returns: + `torch.Tensor`: The dequantized result with an optional bias, with dtype `torch.float16`. + """ + ... + + @abstractmethod + def int8_vectorwise_dequant(self, A: torch.Tensor, stats: torch.Tensor): + """Dequantizes a tensor with dtype `torch.int8` to `torch.float32`. + + Args: + A (`torch.Tensor` with dtype `torch.int8`): The quantized int8 tensor. + stats (`torch.Tensor` with dtype `torch.float32`): The row-wise quantization statistics. + + Returns: + `torch.Tensor` with dtype `torch.float32`: The dequantized tensor. + """ + # To dequantize we divide by 127, or multiply by the reciprocal. + return A * stats.view(-1, 1) * 7.874015718698502e-3 + + @abstractmethod + def int8_vectorwise_quant(self, A: torch.Tensor, threshold=0.0): + """Quantizes a tensor with dtype `torch.float16` to `torch.int8` in accordance to the `LLM.int8()` algorithm. + + For more information, see the [LLM.int8() paper](https://arxiv.org/abs/2208.07339). + + Args: + A (`torch.Tensor` with dtype `torch.float16`): The input tensor. + threshold (`float`, *optional*): + An optional threshold for sparse decomposition of outlier features. + + No outliers are held back when 0.0. Defaults to 0.0. + + Returns: + `Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]`: A tuple containing the quantized tensor and relevant statistics. + - `torch.Tensor` with dtype `torch.int8`: The quantized data. + - `torch.Tensor` with dtype `torch.float32`: The quantization scales. + - `torch.Tensor` with dtype `torch.int32`, *optional*: A list of column indices which contain outlier features. + """ + ... + + @abstractmethod + def transform( + self, + A: torch.Tensor, + to_order: str, + from_order="row", + out: Optional[torch.Tensor] = None, + transpose=False, + state: Optional[Tuple[torch.Size, str]] = None, + ld=None, + ): + raise NotImplementedError + + @abstractmethod + def extract_outliers( + self, + A: torch.Tensor, + SA: Tuple[torch.Size, str], + idx: torch.Tensor, + ) -> torch.Tensor: + raise NotImplementedError + + @abstractmethod + def quantize_4bit( + self, + A: torch.Tensor, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize=64, + compress_statistics=False, + quant_type: Literal["fp4", "nf4"] = "fp4", + quant_storage=torch.uint8, + ) -> Tuple[torch.Tensor, QuantState]: + """Quantize tensor A in blocks of 4-bit values. + + Quantizes tensor A by dividing it into blocks which are independently quantized. + + Args: + A (`torch.Tensor`): The input tensor. Supports `float16`, `bfloat16`, or `float32` datatypes. + absmax (`torch.Tensor`, *optional*): A tensor to use to store the absmax values. + out (`torch.Tensor`, *optional*): A tensor to use to store the result. + blocksize (`int`, *optional*): + The size of the blocks. Defaults to 64. + Valid values are 64, 128, 256, 512, 1024, 2048, and 4096. + compress_statistics (`bool`, *optional*): Whether to additionally quantize the absmax values. Defaults to False. + quant_type (`str`, *optional*): The data type to use: `nf4` or `fp4`. Defaults to `fp4`. + quant_storage (`torch.dtype`, *optional*): The dtype of the tensor used to store the result. Defaults to `torch.uint8`. + + Raises: + ValueError: Raised when the input data type is not supported. + + Returns: + Tuple[`torch.Tensor`, `QuantState`]: A tuple containing the quantization results. + - `torch.Tensor`: The quantized tensor with packed 4-bit values. + - [`QuantState`]: The state object used to undo the quantization. + """ + ... + + @abstractmethod + def dequantize_4bit( + self, + A: torch.Tensor, + quant_state: Optional[QuantState] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize: int = 64, + quant_type: Literal["fp4", "nf4"] = "fp4", + ) -> torch.Tensor: + """Dequantizes a packed 4-bit quantized tensor. + + The input tensor is dequantized by dividing it into blocks of `blocksize` values. + The the absolute maximum value within these blocks is used for scaling + the non-linear dequantization. + + Args: + A (`torch.Tensor`): The quantized input tensor. + quant_state ([`QuantState`], *optional*): + The quantization state as returned by [`quantize_4bit`]. + Required if `absmax` is not provided. + absmax (`torch.Tensor`, *optional*): + A tensor containing the scaling values. + Required if `quant_state` is not provided and ignored otherwise. + out (`torch.Tensor`, *optional*): A tensor to use to store the result. + blocksize (`int`, *optional*): + The size of the blocks. Defaults to 64. + Valid values are 64, 128, 256, 512, 1024, 2048, and 4096. + quant_type (`str`, *optional*): The data type to use: `nf4` or `fp4`. Defaults to `fp4`. + + Raises: + ValueError: Raised when the input data type or blocksize is not supported. + + Returns: + `torch.Tensor`: The dequantized tensor. + """ + ... + + @abstractmethod + def gemv_4bit( + self, + A: torch.Tensor, + B: torch.Tensor, + out: Optional[torch.Tensor] = None, + transposed_A=False, + transposed_B=False, + state: QuantState = None, + ) -> torch.Tensor: ... + + @abstractmethod + def quantize_blockwise( + self, + A: torch.Tensor, + code: Optional[torch.Tensor] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize=4096, + nested=False, + ) -> Tuple[torch.Tensor, QuantState]: + """Quantize a tensor in blocks of values. + + The input tensor is quantized by dividing it into blocks of `blocksize` values. + The the absolute maximum value within these blocks is calculated for scaling + the non-linear quantization. + + Args: + A (`torch.Tensor`): The input tensor. Supports `float16`, `bfloat16`, or `float32` datatypes. + code (`torch.Tensor`, *optional*): + A mapping describing the low-bit data type. Defaults to a signed 8-bit dynamic type. + For more details, see (8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561]. + absmax (`torch.Tensor`, *optional*): A tensor to use to store the absmax values. + out (`torch.Tensor`, *optional*): A tensor to use to store the result. + blocksize (`int`, *optional*): + The size of the blocks. Defaults to 4096. + Valid values are 64, 128, 256, 512, 1024, 2048, and 4096. + nested (`bool`, *optional*): Whether to additionally quantize the absmax values. Defaults to False. + + Raises: + ValueError: Raised when the input data type is not supported. + + Returns: + `Tuple[torch.Tensor, QuantState]`: A tuple containing the quantization results. + - `torch.Tensor`: The quantized tensor. + - [`QuantState`]: The state object used to undo the quantization. + """ + ... + + @abstractmethod + def dequantize_blockwise( + self, + A: torch.Tensor, + quant_state: Optional[QuantState] = None, + absmax: Optional[torch.Tensor] = None, + code: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize: int = 4096, + nested=False, + ) -> torch.Tensor: + """Dequantize a tensor in blocks of values. + + The input tensor is dequantized by dividing it into blocks of `blocksize` values. + The the absolute maximum value within these blocks is used for scaling + the non-linear dequantization. + + Args: + A (`torch.Tensor`): The quantized input tensor. + quant_state ([`QuantState`], *optional*): + The quantization state as returned by [`quantize_blockwise`]. + Required if `absmax` is not provided. + absmax (`torch.Tensor`, *optional*): + A tensor containing the scaling values. + Required if `quant_state` is not provided and ignored otherwise. + code (`torch.Tensor`, *optional*): + A mapping describing the low-bit data type. Defaults to a signed 8-bit dynamic type. + For more details, see (8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561]. + Ignored when `quant_state` is provided. + out (`torch.Tensor`, *optional*): A tensor to use to store the result. + blocksize (`int`, *optional*): + The size of the blocks. Defaults to 4096. + Valid values are 64, 128, 256, 512, 1024, 2048, and 4096. + Ignored when `quant_state` is provided. + + Raises: + ValueError: Raised when the input data type is not supported. + + Returns: + `torch.Tensor`: + The dequantized tensor. The datatype is indicated by `quant_state.dtype` and defaults to `torch.float32`. + """ + ... + + @abstractmethod + def optimizer_update_8bit_blockwise( + self, + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + state2: Optional[torch.Tensor], + beta1: float, + beta2: float, + beta3: float, + alpha: float, + eps: float, + step: int, + lr: float, + qmap1: torch.Tensor, + qmap2: Optional[torch.Tensor], + absmax1: torch.Tensor, + absmax2: Optional[torch.Tensor], + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + skip_zeros=False, + ) -> None: + """ + Performs an in-place optimizer update with one or two optimizer states. + + Args: + optimizer_name (`str`): The name of the optimizer, e.g. `adam` + g (`torch.Tensor`): Gradient tensor. + p (`torch.Tensor`): Parameter tensor. + state1 (`torch.Tensor`): Optimizer state 1. + state2 (`torch.Tensor`, optional): Optimizer state 2. + beta1 (`float`): Optimizer beta1. + beta2 (`float`): Optimizer beta2. + eps (`float`): Optimizer epsilon. + step (`int`): Current optimizer step. + lr (`float`): The learning rate. + qmap1 (`torch.Tensor`): Quantization map for the first state. + qmap2 (`torch.Tensor`, optional): Quantization map for the second state. + absmax1 (`torch.Tensor`): Max value for the first state update. + absmax2 (`torch.Tensor`, optional): Max value for the second state update. + weight_decay (`float`, optional): Weight decay. Defaults to 0.0. + gnorm_scale (`float`, optional): The factor to rescale the gradient to the max clip value. Defaults to 1.0. + skip_zeros (`bool`, optional): Whether to skip zero-valued gradients or not. Defaults to False. + """ + raise NotImplementedError + + @abstractmethod + def optimizer_update_32bit( + self, + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + beta1: float, + eps: float, + step: int, + lr: float, + state2: Optional[torch.Tensor] = None, + beta2: float = 0.0, + beta3: float = 0.0, + alpha: float = 0.0, + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + unorm_vec: Optional[torch.Tensor] = None, + max_unorm: float = 0.0, + skip_zeros=False, + ) -> None: + """ + Performs an in-place optimizer update with one or two optimizer states. + + Universal optimizer update for 32-bit state and 32/16-bit gradients/weights + + Args: + optimizer_name (`str`): The name of the optimizer, e.g. `adam` + g (`torch.Tensor`): Gradient tensor. + p (`torch.Tensor`): Parameter tensor. + state1 (`torch.Tensor`): Optimizer state 1. + beta1 (`float`): Optimizer beta1. + eps (`float`): Optimizer epsilon. + step (`int`): Current optimizer step. + lr (`float`): The learning rate. + state2 (`torch.Tensor`, optional): Optimizer state 2. Defaults to None. + beta2 (`float`, optional): Optimizer beta2. Defaults to 0.0. + weight_decay (`float`, optional): Defaults to 0.0. + gnorm_scale (`float`, optional): The factor to rescale the gradient to the max clip value. Defaults to 1.0. + unorm_vec (`torch.Tensor`, optional): The tensor for the update norm. Defaults to None. + max_unorm (`float`, optional): The maximum update norm relative to the weight norm.. Defaults to 0.0. + skip_zeros (`bool`, optional): Whether to skip zero-valued gradients or not. Defaults to False. + """ + raise NotImplementedError diff --git a/bitsandbytes/backends/cpu.py b/bitsandbytes/backends/cpu.py new file mode 100644 index 000000000..afe71c080 --- /dev/null +++ b/bitsandbytes/backends/cpu.py @@ -0,0 +1,241 @@ +from typing import Literal, Optional, Tuple + +import torch + +from bitsandbytes.utils import QuantState + +from .base import Backend +from .cpu_xpu_common import ( + dequantize_4bit_impl, + double_quant_impl, + gemm_4bit_impl, + int8_linear_matmul_impl, + int8_mm_dequant_impl, + quantize_4bit_impl, +) + +Tensor = torch.Tensor + + +def assert_on_cpu(tensors): + on_cpu = True + for t in tensors: + if t is None: + continue # NULL pointers are fine + on_cpu &= t.device.type == "cpu" + if not on_cpu: + raise TypeError( + "All input tensors need to be on CPU, but found some tensors to not be on CPU:\n" + f" {[(t.shape, t.device) if isinstance(t, Tensor) else None for t in tensors]}" + ) + return on_cpu + + +class CPUBackend(Backend): + mm_dequant_compute_dtype = torch.bfloat16 + mm_dequant_output_dtype = torch.bfloat16 + + def device_synchronize(self): + pass + + def int8_double_quant( + self, + A: torch.Tensor, + col_stats: Optional[torch.Tensor] = None, + row_stats: Optional[torch.Tensor] = None, + out_col: Optional[torch.Tensor] = None, + out_row: Optional[torch.Tensor] = None, + threshold=0.0, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + assert_on_cpu([A, col_stats, row_stats, out_col, out_row]) + return double_quant_impl(A, col_stats, row_stats, out_col, out_row, threshold) + + def transform( + self, + A: torch.Tensor, + to_order: str, + from_order="row", + out: Optional[torch.Tensor] = None, + transpose=False, + state: Optional[Tuple[torch.Size, str]] = None, + ld=None, + ): + """ + Transform tensor A to to_order. It is originally designed for CUDA. + For CPU, it returns the original tensor if transpose=False. + Otherwise, it returns the transpose of A + """ + assert_on_cpu([A, out]) + if transpose: + if out is not None: + out.copy_(A.T) + else: + out = A.T + else: + if out is not None: + out.copy_(A) + else: + out = A + return out, state + + def int8_linear_matmul( + self, + A: torch.Tensor, + B: torch.Tensor, + out: Optional[torch.Tensor] = None, + dtype=torch.int32, + ) -> torch.Tensor: + assert_on_cpu([A, B]) + return int8_linear_matmul_impl(A, B, out, dtype) + + def int8_mm_dequant( + self, + A: torch.Tensor, + row_stats: torch.Tensor, + col_stats: torch.Tensor, + out: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + assert_on_cpu([A, row_stats, col_stats, out, bias]) + return int8_mm_dequant_impl( + A, + row_stats, + col_stats, + out, + bias, + self.mm_dequant_compute_dtype, + self.mm_dequant_output_dtype, + ) + + def int8_vectorwise_dequant(self, A, stats): + return super().int8_vectorwise_dequant(A, stats) + + def int8_vectorwise_quant(self, A: torch.Tensor, threshold=0.0): + # TODO: We can optimize this as we don't actually need column-wise quant. + out, _, stats, _, outlier_cols = self.int8_double_quant(A, threshold=threshold) + return out, stats, outlier_cols + + def extract_outliers( + self, + A: torch.Tensor, + SA: Tuple[torch.Size, str], + idx: torch.Tensor, + ) -> torch.Tensor: + """ + Extract columns of A by idx + """ + assert_on_cpu([A]) + return A[:, idx].contiguous() + + def quantize_4bit( + self, + A: torch.Tensor, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize=64, + compress_statistics=False, + quant_type: Literal["fp4", "nf4"] = "fp4", + quant_storage=torch.uint8, + ) -> Tuple[torch.Tensor, QuantState]: + if blocksize is None: + blocksize = 64 + assert_on_cpu([A, absmax, out]) + return quantize_4bit_impl(A, absmax, out, blocksize, compress_statistics, quant_type, quant_storage) + + def dequantize_4bit( + self, + A: torch.Tensor, + quant_state: Optional[QuantState] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize: int = 64, + quant_type: Literal["fp4", "nf4"] = "fp4", + ) -> torch.Tensor: + if blocksize is None: + blocksize = 64 + assert_on_cpu([A, absmax, out]) + return dequantize_4bit_impl(A, quant_state, absmax, out, blocksize, quant_type) + + def gemv_4bit( + self, + A: torch.Tensor, + B: torch.Tensor, + out: Optional[torch.Tensor] = None, + transposed_A=False, + transposed_B=False, + state: QuantState = None, + ) -> torch.Tensor: + assert_on_cpu([A, B, out]) + if state is None: + raise ValueError("state cannot be None. gemv_4bit() requires the state from quantize_4bit()") + + return gemm_4bit_impl(A, B, out, transposed_A, transposed_B, state) + + def dequantize_blockwise( + self, + A: torch.Tensor, + quant_state: Optional[QuantState] = None, + absmax: Optional[torch.Tensor] = None, + code: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize: int = 4096, + nested=False, + ) -> torch.Tensor: + raise NotImplementedError("Not yet implemented for CPU backend") + + def quantize_blockwise( + self, + A: torch.Tensor, + code: Optional[torch.Tensor] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize=4096, + nested=False, + ) -> Tuple[torch.Tensor, QuantState]: + raise NotImplementedError("Not yet implemented for CPU backend") + + def optimizer_update_8bit_blockwise( + self, + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + state2: Optional[torch.Tensor], + beta1: float, + beta2: float, + beta3: float, + alpha: float, + eps: float, + step: int, + lr: float, + qmap1: torch.Tensor, + qmap2: Optional[torch.Tensor], + absmax1: torch.Tensor, + absmax2: Optional[torch.Tensor], + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + skip_zeros=False, + ) -> None: + raise NotImplementedError("Not yet implemented for CPU backend") + + def optimizer_update_32bit( + self, + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + beta1: float, + eps: float, + step: int, + lr: float, + state2: Optional[torch.Tensor] = None, + beta2: float = 0.0, + beta3: float = 0.0, + alpha: float = 0.0, + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + unorm_vec: Optional[torch.Tensor] = None, + max_unorm: float = 0.0, + skip_zeros=False, + ) -> None: + raise NotImplementedError("Not yet implemented for CPU backend") diff --git a/bitsandbytes/backends/cpu_xpu_common.py b/bitsandbytes/backends/cpu_xpu_common.py new file mode 100644 index 000000000..22e2563d9 --- /dev/null +++ b/bitsandbytes/backends/cpu_xpu_common.py @@ -0,0 +1,591 @@ +import subprocess +from typing import Optional +import warnings + +import torch +import torch.nn.functional as F + +from bitsandbytes.functional import ( + QuantState, + create_dynamic_map, + get_4bit_type, +) +from bitsandbytes.utils import reverse_4bit_compress_format + +try: + # to support Intel CPU/GPU (XPU) backend + import intel_extension_for_pytorch as ipex + + ipex_cpu = ipex if ipex._C._has_cpu() else None + ipex_xpu = ipex if ipex._C._has_xpu() else None + ipex_cpu_only = ipex._C._has_cpu() and (not ipex._C._has_xpu()) +except BaseException: + ipex_cpu = None + ipex_xpu = None + ipex_cpu_only = None + + +gxx_available = False +try: + subprocess.run(["g++", "--version"], capture_output=True) # hide terminal output + gxx_available = True +except BaseException: + warnings.warn("g++ not found, torch.compile disabled for CPU/XPU.") + + +Tensor = torch.Tensor + + +def _torch_version_prereq(major, minor): + ver_major = int(torch.__version__.split(".")[0]) + ver_minor = int(torch.__version__.split(".")[1]) + return ver_major * 32 + ver_minor >= major * 32 + minor + + +def _ipex_cpu_version_prereq(major, minor): + if ipex_cpu is not None: + ver_major = ipex_cpu.__version__.split(".")[0] + ver_minor = ipex_cpu.__version__.split(".")[1] + return int(ver_major) * 32 + int(ver_minor) >= major * 32 + minor + return False + + +def _ipex_xpu_version_prereq(major, minor): + if ipex_xpu is not None: + ver_major = ipex_xpu.__version__.split(".")[0] + ver_minor = ipex_xpu.__version__.split(".")[1] + return int(ver_major) * 32 + int(ver_minor) >= major * 32 + minor + return False + + +def _maybe_torch_compile(func): + # torch.compile requires g++ and pytorch >= 2.0 + if gxx_available and _torch_version_prereq(2, 0) and ipex_cpu_only: + options = {} + # fx_graph_cache requires pytorch >= 2.2 + if _torch_version_prereq(2, 2): + options.update({"fx_graph_cache": True}) + return torch.compile(func, dynamic=True, options=options) + return func + + +@_maybe_torch_compile +def double_quant_impl(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): + """ + Find absolute max values of each row/column of a tensor, and symmetrically quantize it to int8. + If threshold > 0.0, only values <= threshold are counted. All outliers are zeroed out in + the original tensor and they are kept in COO format: (rows, cols, values) + If threshold == 0.0, there are no outliers. + Args: + A The tensor to be analyzed and quantized. + col_stats Absolute max values of each column of A. If it is not None, use the values directly. + Otherwise, find the values. + row_stats Absolute max values of each row of A. If it is not None, use the values directly. + Otherwise, find the values. + out_col Output buffer for the result quantized per column if it is not None + out_row Output buffer for the result quantized per row if it is not None + threshold The threshold for finding outliers if it is > 0.0. Otherwise it has no effect. + Return: + A tuple of output quantized per row, output quantized per column, absolute max values of + each row of A, absolute max values of each column of A, outliers in COO format + """ + + cols = A.shape[-1] + if len(A.shape) == 3: + rows = A.shape[0] * A.shape[1] + else: + assert A.dim() == 2, f"double_quant: Input tensor should be 2d or 3d but got {A.dim()}d" + rows = A.shape[0] + A = A.reshape(rows, cols) + + def get_row_col_stats(A): + row_stats = torch.max(torch.abs(A), 1).values # absolute max of each row + col_stats = torch.max(torch.abs(A), 0).values # absolute max of each col + return row_stats, col_stats + + def quant_to_int8(A, stats): + return torch.clamp(torch.round(A * (127.0 / stats)), -128, 127).to(torch.int8) + + if threshold == 0.0: + if row_stats is None or col_stats is None: + row_stats, col_stats = get_row_col_stats(A) + outlier_cols = None + else: + outlier_indices = torch.abs(A) >= threshold # find outliers + outlier_cols = torch.argwhere(outlier_indices.any(dim=0)).view(-1) + outlier_values = A[outlier_indices].clone() + + # outlier_indices = torch.abs(A) >= threshold # find outliers + # outlier_coord = outlier_indices.nonzero() # get outlier coordinates + # outlier_rows = outlier_coord[:, 0] # outlier row for COO sparse tensor + # outlier_cols = outlier_coord[:, 1] # outlier column for COO sparse tensor + # outlier_values = A[outlier_indices] # outlier values for COO sparse tensor + # coo_tensor = COOSparseTensor( + # A.shape[0], A.shape[1], outlier_values.numel(), outlier_rows.int(), outlier_cols.int(), outlier_values + # ) + if row_stats is None or col_stats is None: + A[outlier_indices] = 0 # zero out outliers + row_stats, col_stats = get_row_col_stats(A) + + quant_by_row = quant_to_int8(A, row_stats.unsqueeze(-1)) + quant_by_col = quant_to_int8(A, col_stats.unsqueeze(0)) + + if outlier_cols is not None: + A[outlier_indices] = outlier_values # restore outliers for later use + + if rows > 1: + # zero out outlier columns for all rows + quant_by_row[:, outlier_cols] = 0 + + if out_row is not None: + out_row.copy_(quant_by_row) + else: + out_row = quant_by_row + if out_col is not None: + out_col.copy_(quant_by_col) + else: + out_col = quant_by_col + # Return float stats to align with CUDA impl + return out_row, out_col, row_stats.float(), col_stats.float(), outlier_cols + + +def int8_linear_matmul_impl( + A: torch.Tensor, + B: torch.Tensor, + out: Optional[torch.Tensor] = None, + dtype=torch.int32, +) -> torch.Tensor: + """ + Do GEMMM computation. Data type: int8 * int8 -> int32. + Args: + A Activation of linear, data type is int8 + B Weight of linear, data type is int8 + out Specified output tensor if it is not None + dtype Data type of output + Return: + A tuple of GEMM result in dtype and Sout + """ + + assert A.dtype == torch.int8 + assert B.dtype == torch.int8 + if out is not None: + assert out.dtype == dtype + + dimsA = A.ndim + dimsB = B.ndim + shapeA = A.shape + shapeB = B.shape + assert dimsA in [2, 3], "Only two or three dimensional matrices are supported for argument A" + assert dimsB == 2, "Only two dimensional matrices are supported for argument B" + + if dimsA == 2: + m = shapeA[0] + elif dimsA == 3: + m = shapeA[0] * shapeA[1] + n = shapeB[0] + k = shapeA[-1] + assert shapeA[-1] == shapeB[-1], f"Shapes of A and B do not match, got {shapeA} and {shapeB}" + + # if the tensor is empty, return a transformed empty tensor with the right dimensions + if shapeA[0] == 0 and dimsA == 2: + return torch.empty((0, n), device=A.device, dtype=A.dtype) + elif shapeA[1] == 0 and dimsA == 3: + return torch.empty(tuple(shapeA[:2] + [n]), device=A.device, dtype=A.dtype) + + A_reshaped = A.reshape(m, k) + + # torch._int_mm is available on CPU since torch 2.4, XPU since torch 2.6 + if (A.device.type == "cpu" and _torch_version_prereq(2, 4)) or ( + A.device.type == "xpu" and _torch_version_prereq(2, 6) + ): + C = torch._int_mm(A_reshaped, B.T).to(dtype) + else: + C = torch.matmul(A_reshaped.float(), B.t().float()).to(dtype) + if C.ndim != dimsA: + assert dimsA == 3 + shapeOut = (shapeA[0], m // shapeA[0], C.shape[-1]) + C = C.reshape(shapeOut) + if out is not None: + out.copy_(C) + else: + out = C + + return out + + +@_maybe_torch_compile +def int8_mm_dequant_impl( + A: torch.Tensor, + row_stats: torch.Tensor, + col_stats: torch.Tensor, + out: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + compute_dtype=torch.float32, + output_dtype=torch.float32, +) -> torch.Tensor: + """ + Dequant and add bias + out = A_int32 * (abs_max_A * abs_max_B) / 127 * 127 + bias + Args: + A The output of int8 gemm, whose dtype is int32 + row_stats Absolute max value of each row of input (A) of gemm + col_stats Absolute max value of each row of weight (B) of gemm + out Output buffer + bias Bias of linear + compute_dtype Data type for computation + output_dtype Data type for output + Return: + The result + """ + assert A.dtype == torch.int32 + out_shape = A.shape + if len(out_shape) == 3: + out_shape = (out_shape[0] * out_shape[1], out_shape[2]) + + if compute_dtype not in [torch.float32, torch.bfloat16]: + warnings.warn( + f"mm_dequant_{A.device}: compute_dtype {compute_dtype} is not supported, will use bfloat16 instead" + ) + compute_dtype = torch.bfloat16 + A_reshaped = A.reshape(out_shape).to(compute_dtype) + row_stats = row_stats.reshape(-1).unsqueeze(-1).to(compute_dtype) + col_stats = col_stats.reshape(-1).unsqueeze(0).to(compute_dtype) + out = A_reshaped * row_stats * col_stats / (127 * 127) + if bias is not None: + out = out + bias.to(compute_dtype) + out = out.to(output_dtype) + return out + + +NF4_QUANT_TABLE = [ + -1.0 - 1e-2, # 0b0000 + -0.8480964004993439, # 0b0001 + -0.6106329262256622, # 0b0010 + -0.4599952697753906, # 0b0011 + -0.33967943489551544, # 0b0100 + -0.23460740596055984, # 0b0101 + -0.13791173323988914, # 0b0110 + -0.045525018125772476, # 0b0111 + 0.03979014977812767, # 0b1000 + 0.1202552504837513, # 0b1001 + 0.2035212516784668, # 0b1010 + 0.2920137718319893, # 0b1011 + 0.3893125355243683, # 0b1100 + 0.5016634166240692, # 0b1101 + 0.6427869200706482, # 0b1110 + 0.8614784181118011, # 0b1111 +] + + +FP4_QUANT_TABLE = { + 0 - 1e-2: 0, # 0b0000 + 0.00260417: 1, # 0b0001 + 0.0859375: 6, # 0b0110 + 0.20833333: 7, # 0b0111 + 0.29166667: 4, # 0b0100 + 0.4166667: 5, # 0b0101 + 0.583333: 2, # 0b0010 + 0.8333333: 3, # 0b0011 +} + +INT8_QUANT_TABLE = create_dynamic_map().tolist() + + +def quantize_4bit_impl( + A: Tensor, + absmax: Tensor = None, + out: Tensor = None, + blocksize=64, + compress_statistics=False, + quant_type="nf4", + quant_storage=torch.uint8, +) -> Tensor: + """ + Quantize tensor A in blocks of 4-bit values. + + Quantizes tensor A by dividing it into blocks which are independently quantized to FP4. + + Parameters + ---------- + A : torch.Tensor + The input tensor. + absmax : torch.Tensor + The absmax values. + out : torch.Tensor + The output tensor (8-bit). + blocksize : int + The blocksize used in quantization. + quant_type : str + The 4-bit quantization data type {fp4, nf4}, only nf4 is supported now + quant_storage: torch.dtype + We can use bytes to convert storage type. + + Returns + ------- + torch.Tensor: + The 8-bit tensor with packed 4-bit values. + tuple(torch.Tensor, torch.Size, torch.dtype, int): + The quantization state to undo the quantization. + """ + if quant_type not in ["nf4", "fp4", "int8"]: + raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented for CPU/XPU.") + if quant_type == "fp4": + warnings.warn("fp4 quantization is currently slow on CPU/XPU. Please Use nf4 instead for better performance.") + assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] + n = A.numel() + input_shape = A.shape + blocks = n // blocksize + blocks += 1 if n % blocksize > 0 else 0 + + if absmax is None: + absmax = torch.zeros((blocks,), device=A.device, dtype=A.dtype) + + if out is None: + out = torch.zeros(((n + 1) // 2), dtype=torch.uint8, device=A.device) + + rem = n % blocksize + has_rem = rem > 0 + + # Scale tensor to [-1, 1] + A_reshaped = A.reshape(n) + A_com = A_reshaped[: n - rem] + A_com_reshaped = A_com.reshape(n // blocksize, blocksize) + absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0] + scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1) + scaled_A = scaled_A.reshape(-1) + if has_rem: + absmax[-1] = torch.abs(A_reshaped[n - rem :]).max() + scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1) + scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0) + # map [-1, 1] to nf4/fp4 + out_uint8 = torch.empty(scaled_A.shape, dtype=torch.uint8, device=A.device) + if quant_type == "nf4": + for i in range(len(NF4_QUANT_TABLE)): + out_uint8[scaled_A > NF4_QUANT_TABLE[i]] = i + elif quant_type == "fp4": + sign = scaled_A < 0 + abs_scaled_A = torch.abs(scaled_A) + for key, val in FP4_QUANT_TABLE.items(): + out_uint8[abs_scaled_A > key] = val + out_uint8 += sign.to(torch.uint8) * 8 + elif quant_type == "int8": + map = torch.tensor(INT8_QUANT_TABLE, device=scaled_A.device) + diff = torch.abs(scaled_A.unsqueeze(-1) - map) + out_uint8 = torch.argmin(diff, dim=-1).to(torch.uint8).to(scaled_A.device) + + if quant_type == "int8": + out = out_uint8 + code = torch.Tensor(INT8_QUANT_TABLE).to(A.device) + else: + if out_uint8.size(-1) % 2: + out_uint8 = torch.nn.functional.pad(out_uint8, (0, 1), value=0) + out[:] = out_uint8[::2].bitwise_left_shift(4).bitwise_or_(out_uint8[1::2]) + code = get_4bit_type(quant_type, device=A.device) + + if compress_statistics: + offset = absmax.mean() + absmax -= offset + qabsmax, state2 = quantize_4bit_impl(absmax, blocksize=256, quant_type="int8") + del absmax + state = QuantState( + absmax=qabsmax, + shape=input_shape, + dtype=A.dtype, + blocksize=blocksize, + code=code, + quant_type=quant_type, + offset=offset, + state2=state2, + ) + else: + state = QuantState( + absmax=absmax, + shape=input_shape, + dtype=A.dtype, + blocksize=blocksize, + code=code, + quant_type=quant_type, + ) + + if quant_storage != torch.uint8: + bytes_value = out.cpu().numpy().tobytes() + out = torch.frombuffer(bytes_value, dtype=quant_storage).to(A.device) + + return out.reshape(-1, 1), state + + +def dequant_8bit(A, offset, quant_state): + assert A.dtype == torch.uint8 + absmax = quant_state.code[A.reshape(-1).int()] + blocks = absmax.shape[-1] // 256 + res = absmax.shape[-1] % 256 + if res != 0: + absmax = F.pad(absmax, (0, 256 - res), mode="constant", value=0) + absmax = (absmax.view(-1, 256) * quant_state.absmax.view(-1, 1)).to(quant_state.dtype).reshape(-1) + absmax = absmax[: blocks * 256 + res] + absmax = absmax.reshape(A.shape) + absmax += offset + return absmax + + +# Compile will fail in torch.frombuffer +# @_maybe_torch_compile +def dequantize_4bit_impl( + A: Tensor, + quant_state=None, + absmax: Tensor = None, + out: Tensor = None, + blocksize: int = 64, + quant_type="nf4", +) -> Tensor: + """ + Dequantizes 4-bit blockwise quantized values. + Dequantizes the tensor A with maximum absolute values absmax in blocks of size blocksize. + + Parameters + ---------- + A : torch.Tensor + The input 8-bit tensor (packed 4-bit values). + quant_state : QuantState + object with quantisation stats, incl. absmax values, original tensor shape and original dtype. + absmax : torch.Tensor + The absmax values. + out : torch.Tensor + Dequantized output tensor. + blocksize : int + The blocksize used in quantization. + quant_type : str + The 4-bit quantization data type {fp4, nf4} + + Returns + ------- + torch.Tensor: + Dequantized tensor. + """ + transpose = True if A.shape[0] == 1 else False + A = A.reshape(-1) + device = A.device + if A.dtype != torch.uint8: + bytes_value = A.cpu().numpy().tobytes() + A = torch.frombuffer(bytes_value, dtype=torch.uint8).to(device) + + if quant_state is None: + assert absmax is not None and out is not None + + quant_state = QuantState( + absmax=absmax, + shape=out.shape, + dtype=out.dtype, + blocksize=blocksize, + quant_type=quant_type, + ) + + else: + absmax = quant_state.absmax + + if quant_type not in ["nf4", "fp4"]: + raise NotImplementedError( + f"4-bit quantization data type {quant_state.quant_type} is not implemented for CPU/XPU." + ) + + if quant_state.nested: + absmax = dequant_8bit(absmax, quant_state.offset, quant_state.state2) + + if ipex_cpu_only and _ipex_cpu_version_prereq(2, 5) and getattr(quant_state, "ipex", False): + ipex_weight = torch.ops.ipex_prepack.woq_linear_unpack_weight(A, "nf4", quant_state.shape, 2) + A = reverse_4bit_compress_format(ipex_weight) + quant_state.ipex = False + + # Map nf4 to [-1, 1] + out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device) + n = out_dq.numel() + out_dq[1::2] = A & 0xF + out_dq[::2] = A >> 4 + # quant_state.code is fp32, cast to quant_state dtype to avoid the mismatch issue + quant_state.code = quant_state.code.to(quant_state.dtype) + out_dq = quant_state.code[out_dq] + + # Apply scales + if out_dq.numel() != n: + assert out_dq.numel() == n + 1 + out_dq = torch.narrow(out_dq, 0, 0, n) + blocks = n // blocksize + blocks += 1 if n % blocksize > 0 else 0 + rem = n % blocksize + has_rem = rem > 0 + + if has_rem: + if out is None: + out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device) + out_reshaped = out.reshape(-1) + out_reshaped[: n - rem] = ( + out_dq[: n - rem].view(-1, blocksize) * absmax[: blocks - has_rem].view(-1, 1) + ).reshape(-1) + out_reshaped[n - rem :] = out_dq[n - rem :] * absmax[-1] + else: + out = (out_dq.view(-1, blocksize) * absmax.view(-1, 1)).reshape(quant_state.shape).to(quant_state.dtype) + + # take transpose here because weight is transposed (again) for computation + if transpose: + out = out.t() + + return out + + +# Do not need torch.compile here as we are calling torch/ipex kernel +def gemm_4bit_impl( + A: torch.Tensor, + B: torch.Tensor, + out: Optional[torch.Tensor] = None, + transposed_A=False, + transposed_B=False, + state: QuantState = None, +) -> torch.Tensor: + """ + Matrix-matrix multiplication with 4-bit quantization. + + Parameters + ---------- + A : torch.Tensor + The first input tensor. Usually the activation tensor. + B : torch.Tensor + The second input tensor. Usually the weight tensor. + out : torch.Tensor + The output tensor. + transposed_A : bool + Whether A is transposed + transposed_B : bool + Whether B is transposed + state : QuantState + Contains quantization info, such as blocksize and dtype + + Returns + ------- + torch.Tensor: + GEMM output tensor. + """ + if getattr(state, "ipex", False): + # compute_dtype: 1 indicates fp16, 2 indicates bf16 + compute_dtype = 2 if A.dtype == torch.bfloat16 else 1 + output = torch.ops.torch_ipex.woq_linear( + A, + B, + "nf4", + state.shape, + state.new_scales, + state.new_zeros, + None, + None, + state.blocksize, + compute_dtype, + 1, + state.compensation, + ) + else: + dqB = dequantize_4bit_impl(B, state, blocksize=state.blocksize) + output = torch.matmul(A, dqB.to(A.dtype)) + if out is not None: + out.copy_(output) + else: + out = output + return out diff --git a/bitsandbytes/backends/cuda.py b/bitsandbytes/backends/cuda.py new file mode 100644 index 000000000..a3a610580 --- /dev/null +++ b/bitsandbytes/backends/cuda.py @@ -0,0 +1,812 @@ +import ctypes as ct +from typing import Literal, Optional, Tuple + +import torch + +from bitsandbytes.cextension import HIP_ENVIRONMENT, lib +from bitsandbytes.functional import ( + CUBLAS_Context, + _cuda_device_of, + _get_tensor_stream, + dequantize_blockwise, + dtype2bytes, + get_4bit_type, + get_colrow_absmax, + get_ptr, + get_transform_buffer, + is_on_gpu, + nvidia_transform, + post_call, + pre_call, + prod, + quantize_blockwise, +) +from bitsandbytes.utils import QuantState + +from .base import Backend + +if lib and lib.compiled_with_cuda: + """C FUNCTIONS FOR OPTIMIZERS""" + str2optimizer32bit = { + "adam": ( + lib.cadam32bit_grad_fp32, + lib.cadam32bit_grad_fp16, + lib.cadam32bit_grad_bf16, + ), + "momentum": ( + lib.cmomentum32bit_grad_32, + lib.cmomentum32bit_grad_16, + ), + "rmsprop": ( + lib.crmsprop32bit_grad_32, + lib.crmsprop32bit_grad_16, + ), + "lion": ( + lib.clion32bit_grad_fp32, + lib.clion32bit_grad_fp16, + lib.clion32bit_grad_bf16, + ), + "adagrad": ( + lib.cadagrad32bit_grad_32, + lib.cadagrad32bit_grad_16, + ), + "lamb": ( + lib.cadam32bit_grad_fp32, + lib.cadam32bit_grad_fp16, + lib.cadam32bit_grad_bf16, + ), + "ademamix": ( + lib.cademamix32bit_grad_fp32, + lib.cademamix32bit_grad_fp16, + lib.cademamix32bit_grad_bf16, + ), + } + + str2optimizer8bit_blockwise = { + "adam": ( + lib.cadam_8bit_blockwise_grad_fp32, + lib.cadam_8bit_blockwise_grad_fp16, + lib.cadam_8bit_blockwise_grad_bf16, + ), + "momentum": ( + lib.cmomentum_8bit_blockwise_grad_fp32, + lib.cmomentum_8bit_blockwise_grad_fp16, + lib.cmomentum_8bit_blockwise_grad_bf16, + ), + "rmsprop": ( + lib.crmsprop_8bit_blockwise_grad_fp32, + lib.crmsprop_8bit_blockwise_grad_fp16, + lib.crmsprop_8bit_blockwise_grad_bf16, + ), + "lion": ( + lib.clion_8bit_blockwise_grad_fp32, + lib.clion_8bit_blockwise_grad_fp16, + lib.clion_8bit_blockwise_grad_bf16, + ), + "adagrad": ( + lib.cadagrad_8bit_blockwise_grad_fp32, + lib.cadagrad_8bit_blockwise_grad_fp16, + lib.cadagrad_8bit_blockwise_grad_bf16, + ), + "ademamix": ( + lib.cademamix_8bit_blockwise_grad_fp32, + lib.cademamix_8bit_blockwise_grad_fp16, + lib.cademamix_8bit_blockwise_grad_bf16, + ), + } + + +class CUDABackend(Backend): + def device_synchronize(self): + torch.cuda.synchronize() + + def transform( + self, + A: torch.Tensor, + to_order: str, + from_order="row", + out: Optional[torch.Tensor] = None, + transpose=False, + state: Optional[Tuple[torch.Size, str]] = None, + ld=None, + ): + if HIP_ENVIRONMENT: + # transform kernel formats (col32/col_turing/col_ampere) are not applicable to ROCm + # Use nvidia_transform instead + return nvidia_transform(A, to_order, from_order, out, transpose, state, ld) + + prev_device = pre_call(A.device) + if state is None: + state = (A.shape, from_order) + else: + from_order = state[1] + + if out is None: + out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose) + else: + new_state = (state[0], to_order) # (shape, order) + + shape = state[0] + if len(shape) == 2: + dim1 = ct.c_int32(shape[0]) + dim2 = ct.c_int32(shape[1]) + else: + dim1 = ct.c_int32(shape[0] * shape[1]) + dim2 = ct.c_int32(shape[2]) + + is_on_gpu([A, out]) + if to_order == "col32": + if transpose: + lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2) + else: + lib.ctransform_row2col32(get_ptr(A), get_ptr(out), dim1, dim2) + + elif to_order == "col_turing": + if transpose: + lib.ctransform_row2turingT(get_ptr(A), get_ptr(out), dim1, dim2) + else: + lib.ctransform_row2turing(get_ptr(A), get_ptr(out), dim1, dim2) + + elif to_order == "col_ampere": + if transpose: + lib.ctransform_row2ampereT(get_ptr(A), get_ptr(out), dim1, dim2) + else: + lib.ctransform_row2ampere(get_ptr(A), get_ptr(out), dim1, dim2) + + elif to_order == "row": + if from_order == "col_turing": + lib.ctransform_turing2row(get_ptr(A), get_ptr(out), dim1, dim2) + elif from_order == "col_ampere": + lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2) + + else: + raise NotImplementedError(f"Transform function not implemented: From {from_order} to {to_order}") + + post_call(prev_device) + + return out, new_state + + def int8_double_quant( + self, + A: torch.Tensor, + col_stats: Optional[torch.Tensor] = None, + row_stats: Optional[torch.Tensor] = None, + out_col: Optional[torch.Tensor] = None, + out_row: Optional[torch.Tensor] = None, + threshold=0.0, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + # TODO: Optimize/write CUDA kernel for this? + + # Use CUDA kernel for rowwise and COO tensor + quant_row, row_stats, outlier_cols = self.int8_vectorwise_quant(A, threshold=threshold) + + # PyTorch impl for colwise + _, col_stats, outlier_mask = get_colrow_absmax(A, threshold=threshold) + if threshold > 0.0 and outlier_mask is not None: + A = A.masked_fill(outlier_mask, 0.0) + quant_col = torch.round(A.mul(127) / col_stats.unsqueeze(0)).to(torch.int8) + + if out_row is not None: + quant_row = out_row.copy_(quant_row) + if out_col is not None: + quant_col = out_col.copy_(quant_col) + + return quant_row, quant_col, row_stats, col_stats.flatten().float(), outlier_cols + + def int8_linear_matmul( + self, + A: torch.Tensor, + B: torch.Tensor, + out: Optional[torch.Tensor] = None, + dtype=torch.int32, + ) -> torch.Tensor: + # + # To use the IMMA tensor core kernels without special Turing/Ampere layouts, + # cublasLt has some rules, namely: A must be transposed, B must not be transposed. + # The C++ API will calculate `C = A.T @ B` in with A, B, C in col-major. + # This will typically be used with row-major tensors to efficiently + # calculate the linear layer with `C = B @ A.T` without any transformations. + # We will swap A and B in the API invocation, so that we get `C = A @ B.T`. + # + # Quick explanation: + # With row-major A and B tensors, `C = A.T.T @ B.T = A @ B.T`. + # To get row-major output, `C.T = (A @ B.T).T = B @ A.T`. + # + A, B = B, A + + shapeA = A.shape + shapeB = B.shape + + assert A.dtype == torch.int8 + assert B.dtype == torch.int8 + assert A.ndim == 2, "Only two dimensional matrices are supported for argument B" + assert B.ndim in [2, 3], "Only two or three dimensional matrices are supported for argument A" + assert prod(shapeB) > 0, f"Input tensor dimensions need to be > 0: {shapeB}" + assert out is None or out.dtype == dtype + + shapeC = (*shapeB[:-1], shapeA[0]) + + k, m = shapeA + n = prod(shapeB[:-1]) + lda = shapeA[-1] # Weights (outputs, inputs) + ldb = shapeB[-1] # Activations (batch, tokens, inputs) + ldc = shapeC[-1] # Output (batch, tokens, outputs) + + assert ( + lda == ldb + ), f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}" + + # cuBLASLt does not support int8 matmul with inner dimensions that are not divisible by 4. + # We'll fall back to a slower fp32 calculation in this circumstance. + # Fortunately, this should not be very common. + if lda % 4 != 0: + result = torch.matmul(B.float(), A.float().t()).to(torch.int32) + if out is not None: + result = out.copy_(result) + return result + + if out is None: + out = torch.empty(shapeC, device=A.device, dtype=dtype) + + is_on_gpu([A, B, out]) + + with _cuda_device_of(A): + ctx = CUBLAS_Context.get_instance().get_context(A.device) + ptrA = get_ptr(A) + ptrB = get_ptr(B) + ptrC = get_ptr(out) + ptrRowScale = None + m = ct.c_int32(m) + n = ct.c_int32(n) + k = ct.c_int32(k) + lda = ct.c_int32(lda) + ldb = ct.c_int32(ldb) + ldc = ct.c_int32(ldc) + stream = _get_tensor_stream(A) + + if dtype == torch.int32: + has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream) + else: + has_error = lib.cigemmlt_8(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream) + + if has_error == 100: # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu` + raise NotImplementedError("int8_linear_matmul not implemented!") + + if has_error: + raise RuntimeError( + f"cublasLt ran into an error!\n" + f"\t{shapeA=}, {shapeB=}, {shapeC=}\n" + f"\t{(lda, ldb, ldc)=}\n" + f"\t{(m, n, k)=}" + ) + + return out + + def int8_mm_dequant( + self, + A: torch.Tensor, + row_stats: torch.Tensor, + col_stats: torch.Tensor, + out: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + assert A.dtype == torch.int32 + + if bias is not None: + assert bias.dtype == torch.float16 + + if out is None: + out = torch.empty_like(A, dtype=torch.float16) + + ptrA = get_ptr(A) + ptrOut = get_ptr(out) + ptrRowStats = get_ptr(row_stats) + ptrColStats = get_ptr(col_stats) + ptrBias = get_ptr(bias) + numRows = ct.c_int32(prod(A.shape[:-1])) + numCols = ct.c_int32(A.shape[-1]) + + is_on_gpu([A, row_stats, col_stats, out, bias]) + + with _cuda_device_of(A): + lib.cdequant_mm_int32_fp16( + ptrA, ptrRowStats, ptrColStats, ptrOut, ptrBias, numRows, numCols, _get_tensor_stream(A) + ) + + return out + + def int8_vectorwise_dequant( + self, + A: torch.Tensor, + stats: torch.Tensor, + ) -> torch.Tensor: + # To dequantize we divide by 127, or multiply by the reciprocal. + return A * stats.view(-1, 1) * 7.874015718698502e-3 + + def int8_vectorwise_quant( + self, + A: torch.Tensor, + threshold=0.0, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + assert A.dtype == torch.half + is_on_gpu([A]) + + rows = prod(A.shape[:-1]) + cols = A.shape[-1] + + row_stats = torch.empty(rows, device=A.device, dtype=torch.float32) + out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8) + + outlier_cols = None + + if threshold > 0.0: + # TODO we could improve perf of this + outliers = A.abs() >= threshold + + if outliers.any(): + outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1) + + with _cuda_device_of(A): + lib.cint8_vector_quant( + get_ptr(A), + get_ptr(out_row), + get_ptr(row_stats), + ct.c_float(threshold), + ct.c_int32(rows), + ct.c_int32(cols), + _get_tensor_stream(A), + ) + + # Zero out values from outlier columns across all rows. + # The kernel will handle this for outliers themselves, so we can optimize for rows=1. + if rows > 1 and outlier_cols is not None: + out_row[:, outlier_cols] = 0 + + return out_row, row_stats, outlier_cols + + def extract_outliers(self, A: torch.Tensor, SA: Tuple[torch.Size, str], idx: torch.Tensor): + shapeA = SA[0] + formatA = SA[1] + if not HIP_ENVIRONMENT: + assert formatA in ["col_turing", "col_ampere"] + else: + # HIP uses col format + assert formatA in ["col"] + assert A.device.type == "cuda" + + out = torch.zeros((shapeA[0], idx.numel()), dtype=torch.int8, device=A.device) + + idx_size = ct.c_int32(idx.numel()) + rows = ct.c_int32(shapeA[0]) + cols = ct.c_int32(shapeA[1]) + ptrA = get_ptr(A) + ptrIdx = get_ptr(idx) + ptrOut = get_ptr(out) + + prev_device = pre_call(A.device) + + if formatA == "col_turing" or HIP_ENVIRONMENT: + lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) + elif formatA == "col_ampere": + lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) + + post_call(prev_device) + + return out + + def quantize_4bit( + self, + A: torch.Tensor, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize: Optional[int] = None, + compress_statistics=False, + quant_type: Literal["fp4", "nf4"] = "fp4", + quant_storage=torch.uint8, + ) -> Tuple[torch.Tensor, QuantState]: + if blocksize is None: + # Some AMD GPUs have warpsize 64 + # Set default blocksize to 128 (~warpsize 64 in kernel) for HIP + blocksize = 64 if not HIP_ENVIRONMENT else 128 + if A.device.type != "cuda": + raise NotImplementedError(f"Device type not supported for FP4 quantization: {A.device.type}") + if quant_type not in ["fp4", "nf4"]: + raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented.") + + n = A.numel() + input_shape = A.shape + + if absmax is None: + blocks = n // blocksize + blocks += 1 if n % blocksize > 0 else 0 + absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) + + if out is None: + mod = dtype2bytes[quant_storage] * 2 + out = torch.zeros(((n + 1) // mod, 1), dtype=quant_storage, device=A.device) + + # Some AMD GPUs have warpsize 64 + # Set min blocksize to 128 (~warpsize 64 in kernel) for HIP + if not HIP_ENVIRONMENT: + assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] + else: + assert blocksize in [4096, 2048, 1024, 512, 256, 128] + + is_on_gpu([A, out, absmax]) + + with _cuda_device_of(A): + args = ( + None, + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) + + if A.dtype == torch.bfloat16: + if quant_type == "fp4": + lib.cquantize_blockwise_bf16_fp4(*args) + else: + lib.cquantize_blockwise_bf16_nf4(*args) + elif A.dtype == torch.float16: + if quant_type == "fp4": + lib.cquantize_blockwise_fp16_fp4(*args) + else: + lib.cquantize_blockwise_fp16_nf4(*args) + elif A.dtype == torch.float32: + if quant_type == "fp4": + lib.cquantize_blockwise_fp32_fp4(*args) + else: + lib.cquantize_blockwise_fp32_nf4(*args) + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + + code = get_4bit_type(quant_type, device=A.device) + + if compress_statistics: + offset = absmax.mean() + absmax -= offset + qabsmax, state2 = quantize_blockwise(absmax, blocksize=256) + del absmax + state = QuantState( + absmax=qabsmax, + shape=input_shape, + dtype=A.dtype, + blocksize=blocksize, + code=code, + quant_type=quant_type, + offset=offset, + state2=state2, + ) + else: + state = QuantState( + absmax=absmax, + shape=input_shape, + dtype=A.dtype, + blocksize=blocksize, + code=code, + quant_type=quant_type, + ) + + return out, state + + def dequantize_4bit( + self, + A: torch.Tensor, + quant_state: Optional[QuantState] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize: Optional[int] = None, + quant_type: Literal["fp4", "nf4"] = "fp4", + ) -> torch.Tensor: + # Some AMD GPUs have warpsize 64 + # Set default blocksize to 128 (~warpsize 64 in kernel) for HIP + if blocksize is None: + blocksize = 64 if not HIP_ENVIRONMENT else 128 + supported_blocksizes = [2048, 4096, 1024, 512, 256, 128, 64] + if HIP_ENVIRONMENT: + supported_blocksizes = supported_blocksizes[:-1] + if blocksize not in supported_blocksizes: + raise ValueError( + f"The blockwise of {blocksize} is not supported. Supported values: {supported_blocksizes}" + ) + + if quant_type not in ["fp4", "nf4"]: + raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented.") + + if quant_state is None: + assert absmax is not None and out is not None + + quant_state = QuantState( + absmax=absmax, shape=out.shape, dtype=out.dtype, blocksize=blocksize, quant_type=quant_type + ) + else: + absmax = quant_state.absmax + + if quant_state.nested: + absmax = self.dequantize_blockwise(quant_state.absmax, quant_state.state2) + absmax += quant_state.offset + if absmax.dtype != torch.float32: + absmax = absmax.float() + + if out is None: + out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device) + + n = out.numel() + + is_on_gpu([A, absmax, out]) + stream = _get_tensor_stream(A) + + with _cuda_device_of(A): + args = ( + None, + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + stream, + ) + + if out.dtype == torch.bfloat16: + if quant_state.quant_type == "fp4": + lib.cdequantize_blockwise_bf16_fp4(*args) + else: + lib.cdequantize_blockwise_bf16_nf4(*args) + elif out.dtype == torch.float16: + if quant_state.quant_type == "fp4": + lib.cdequantize_blockwise_fp16_fp4(*args) + else: + lib.cdequantize_blockwise_fp16_nf4(*args) + elif out.dtype == torch.float32: + if quant_state.quant_type == "fp4": + lib.cdequantize_blockwise_fp32_fp4(*args) + else: + lib.cdequantize_blockwise_fp32_nf4(*args) + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}") + + if A.shape[0] == 1: # is transposed, transpose back + return out.t() + + return out + + def gemv_4bit( + self, + A: torch.Tensor, + B: torch.Tensor, + out: Optional[torch.Tensor] = None, + transposed_A=False, + transposed_B=False, + state: QuantState = None, + ): + if state is None: + raise ValueError("state cannot be None. gemv_4bit() requires the state from quantize_4bit()") + + if A.numel() != A.shape[-1]: + raise ValueError( + 'Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]', + ) + + Bshape = state.shape + bout = Bshape[0] + absmax = state.absmax + if state.nested: + absmax = self.dequantize_blockwise(state.absmax, state.state2) + absmax += state.offset + + if out is None: + if len(A.shape) == 3: + out = torch.empty(size=(A.shape[0], A.shape[1], bout), dtype=A.dtype, device=A.device) + else: + out = torch.empty(size=(A.shape[0], bout), dtype=A.dtype, device=A.device) + + n = 1 + m = Bshape[0] + k = Bshape[1] + lda = Bshape[0] + ldc = Bshape[0] + ldb = (A.shape[-1] + 1) // 2 + is_on_gpu([B, A, out, absmax, state.code]) + m = ct.c_int32(m) + n = ct.c_int32(n) + k = ct.c_int32(k) + lda = ct.c_int32(lda) + ldb = ct.c_int32(ldb) + ldc = ct.c_int32(ldc) + + inference_args = [ + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(state.code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(state.blocksize), + _get_tensor_stream(A), + ] + + if B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32]: + if A.dtype == torch.float16: + lib.cgemm_4bit_inference_naive_fp16(*inference_args) + elif A.dtype == torch.bfloat16: + lib.cgemm_4bit_inference_naive_bf16(*inference_args) + elif A.dtype == torch.float32: + lib.cgemm_4bit_inference_naive_fp32(*inference_args) + else: + raise NotImplementedError(f"Matmul not implemented for data type {A.dtype}") + else: + raise NotImplementedError(f"Matmul not implemented for data type {A.dtype}") + + return out + + def dequantize_blockwise( + self, + A: torch.Tensor, + quant_state: Optional[QuantState] = None, + absmax: Optional[torch.Tensor] = None, + code: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize: int = 4096, + nested=False, + ) -> torch.Tensor: + # TODO: Move from bnb.functional + return dequantize_blockwise( + A, + quant_state=quant_state, + absmax=absmax, + code=code, + out=out, + blocksize=blocksize, + nested=nested, + ) + + def quantize_blockwise( + self, + A: torch.Tensor, + code: Optional[torch.Tensor] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize=4096, + nested=False, + ) -> Tuple[torch.Tensor, QuantState]: + # TODO: Move from bnb.functional + return quantize_blockwise( + A, + absmax=absmax, + code=code, + out=out, + blocksize=blocksize, + nested=nested, + ) + + def optimizer_update_8bit_blockwise( + self, + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + state2: Optional[torch.Tensor], + beta1: float, + beta2: float, + beta3: float, + alpha: float, + eps: float, + step: int, + lr: float, + qmap1: torch.Tensor, + qmap2: Optional[torch.Tensor], + absmax1: torch.Tensor, + absmax2: Optional[torch.Tensor], + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + skip_zeros=False, + ) -> None: + optim_func = None + + is_on_gpu([g, p, state1, state2, qmap1, qmap2, absmax1, absmax2]) + if g.dtype == torch.float32 and state1.dtype == torch.uint8: + optim_func = str2optimizer8bit_blockwise[optimizer_name][0] + elif g.dtype == torch.float16 and state1.dtype == torch.uint8: + optim_func = str2optimizer8bit_blockwise[optimizer_name][1] + elif ( + g.dtype == torch.bfloat16 + and state1.dtype == torch.uint8 + and len(str2optimizer8bit_blockwise[optimizer_name]) == 3 + ): + optim_func = str2optimizer8bit_blockwise[optimizer_name][2] + else: + raise ValueError( + f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}", + ) + + is_on_gpu([p, g, state1, state2, qmap1, qmap2, absmax1, absmax2]) + + with _cuda_device_of(g): + optim_func( + get_ptr(p), + get_ptr(g), + get_ptr(state1), + get_ptr(state2), + ct.c_float(beta1), + ct.c_float(beta2), + ct.c_float(beta3), + ct.c_float(alpha), + ct.c_float(eps), + ct.c_int32(step), + ct.c_float(lr), + get_ptr(qmap1), + get_ptr(qmap2), + get_ptr(absmax1), + get_ptr(absmax2), + ct.c_float(weight_decay), + ct.c_float(gnorm_scale), + ct.c_bool(skip_zeros), + ct.c_int32(g.numel()), + ) + + def optimizer_update_32bit( + self, + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + beta1: float, + eps: float, + step: int, + lr: float, + state2: Optional[torch.Tensor] = None, + beta2: float = 0.0, + beta3: float = 0.0, + alpha: float = 0.0, + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + unorm_vec: Optional[torch.Tensor] = None, + max_unorm: float = 0.0, + skip_zeros=False, + ) -> None: + param_norm = 0.0 + if max_unorm > 0.0: + param_norm = torch.norm(p.data.float()) + + optim_func = None + if g.dtype == torch.float32: + optim_func = str2optimizer32bit[optimizer_name][0] + elif g.dtype == torch.float16: + optim_func = str2optimizer32bit[optimizer_name][1] + elif g.dtype == torch.bfloat16 and len(str2optimizer32bit[optimizer_name]) == 3: + optim_func = str2optimizer32bit[optimizer_name][2] + else: + raise ValueError( + f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}", + ) + + is_on_gpu([g, p, state1, state2, unorm_vec]) + + with _cuda_device_of(g): + optim_func( + get_ptr(g), + get_ptr(p), + get_ptr(state1), + get_ptr(state2), + get_ptr(unorm_vec), + ct.c_float(max_unorm), + ct.c_float(param_norm), + ct.c_float(beta1), + ct.c_float(beta2), + ct.c_float(beta3), + ct.c_float(alpha), + ct.c_float(eps), + ct.c_float(weight_decay), + ct.c_int32(step), + ct.c_float(lr), + ct.c_float(gnorm_scale), + ct.c_bool(skip_zeros), + ct.c_int32(g.numel()), + ) diff --git a/bitsandbytes/backends/hpu.py b/bitsandbytes/backends/hpu.py new file mode 100644 index 000000000..03308cd5d --- /dev/null +++ b/bitsandbytes/backends/hpu.py @@ -0,0 +1,315 @@ +import math +from typing import Literal, Optional, Tuple +import warnings +import torch + +from bitsandbytes.utils import QuantState + +from .base import Backend +from .cpu_xpu_common import ( + double_quant_impl, + dequant_8bit, + NF4_QUANT_TABLE, + INT8_QUANT_TABLE, +) +from bitsandbytes.functional import ( + QuantState, + get_4bit_type, +) + +Tensor = torch.Tensor + +def assert_on_hpu(tensors): + on_hpu = True + for t in tensors: + if t is None: + continue # NULL pointers are fine + on_hpu &= t.device.type == "hpu" + if not on_hpu: + raise TypeError( + "All input tensors need to be on HPU, but found some tensors to not be on HPU:\n" + f" {[(t.shape, t.device) if isinstance(t, Tensor) else None for t in tensors]}" + ) + return on_hpu + +class HPUBackend(Backend): + + def int8_double_quant( + self, + A: torch.Tensor, + col_stats: Optional[torch.Tensor] = None, + row_stats: Optional[torch.Tensor] = None, + out_col: Optional[torch.Tensor] = None, + out_row: Optional[torch.Tensor] = None, + threshold=0.0, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + assert_on_hpu([A, col_stats, row_stats, out_col, out_row]) + return double_quant_impl(A, col_stats, row_stats, out_col, out_row, threshold) + + def transform( + self, + A: torch.Tensor, + to_order: str, + from_order="row", + out: Optional[torch.Tensor] = None, + transpose=False, + state: Optional[Tuple[torch.Size, str]] = None, + ld=None, + ): + raise NotImplementedError("Not yet implemented for HPU backend") + + def int8_linear_matmul( + self, + A: torch.Tensor, + B: torch.Tensor, + out: Optional[torch.Tensor] = None, + dtype=torch.int32, + ) -> torch.Tensor: + raise NotImplementedError("Not yet implemented for HPU backend") + + def int8_mm_dequant( + self, + A: torch.Tensor, + row_stats: torch.Tensor, + col_stats: torch.Tensor, + out: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + raise NotImplementedError("Not yet implemented for HPU backend") + + def extract_outliers( + self, + A: torch.Tensor, + SA: Tuple[torch.Size, str], + idx: torch.Tensor, + ) -> torch.Tensor: + raise NotImplementedError("Not yet implemented for HPU backend") + + def quantize_4bit( + self, + A: torch.Tensor, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize=64, + compress_statistics=False, + quant_type: Literal["nf4"] = "nf4", + quant_storage=torch.uint8, + ) -> Tuple[torch.Tensor, QuantState]: + if blocksize is None: + blocksize = 64 + assert_on_hpu([A, absmax, out]) + assert quant_storage == torch.uint8, "HPU backend only supports uint8 quant_storage" + return self.quantize_4bit_impl(A, absmax, out, blocksize, compress_statistics, quant_type) + + def quantize_4bit_impl( + self, + A: Tensor, + absmax: Tensor = None, + out: Tensor = None, + blocksize=64, + compress_statistics=False, + quant_type="nf4", + ) -> Tensor: + if quant_type not in ["nf4", "int8"]: + raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented for HPU.") + assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] + n = A.numel() + input_shape = A.shape + blocks = n // blocksize + blocks += 1 if n % blocksize > 0 else 0 + + if absmax is None: + absmax = torch.zeros((blocks,), device=A.device, dtype=A.dtype) + + if out is None: + out = torch.zeros(((n + 1) // 2), dtype=torch.uint8, device=A.device) + + rem = n % blocksize + has_rem = rem > 0 + + # Scale tensor to [-1, 1] + A_reshaped = A.reshape(n) + A_com = A_reshaped[: n - rem] + A_com_reshaped = A_com.reshape(n // blocksize, blocksize) + absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0] + scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1) + scaled_A = scaled_A.reshape(-1) + if has_rem: + absmax[-1] = torch.abs(A_reshaped[n - rem :]).max() + scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1) + scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0) + # map [-1, 1] to nf4 + out_uint8 = torch.empty(scaled_A.shape, dtype=torch.uint8, device=A.device) + if quant_type == "nf4": + for i in range(len(NF4_QUANT_TABLE)): + out_uint8[scaled_A > NF4_QUANT_TABLE[i]] = i + elif quant_type == "int8": + map = torch.tensor(INT8_QUANT_TABLE, device=scaled_A.device) + diff = torch.abs(scaled_A.unsqueeze(-1) - map) + out_uint8 = torch.argmin(diff, dim=-1).to(torch.uint8).to(scaled_A.device) + + if quant_type == "int8": + out = out_uint8 + code = torch.Tensor(INT8_QUANT_TABLE).to(A.device) + else: + if out_uint8.size(-1) % 2: + out_uint8 = torch.nn.functional.pad(out_uint8, (0, 1), value=0) + # To align with HPU dequantize operator + out[:] = out_uint8[1::2].bitwise_left_shift(4).bitwise_or_(out_uint8[::2]) + code = get_4bit_type(quant_type, device=A.device) + + if compress_statistics: + raise AssertionError("Double quantization is not supported for HPU backend") + offset = absmax.mean() + absmax -= offset + qabsmax, state2 = self.hpu_quantize_4bit_impl(absmax, blocksize=256, quant_type="int8") + del absmax + state = QuantState( + absmax=qabsmax, + shape=input_shape, + dtype=A.dtype, + blocksize=blocksize, + code=code, + quant_type=quant_type, + offset=offset, + state2=state2, + ) + else: + state = QuantState( + absmax=absmax, + shape=input_shape, + dtype=A.dtype, + blocksize=blocksize, + code=code, + quant_type=quant_type, + ) + return out, state + + def dequantize_nf4_impl( + self, + input: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_state: QuantState, + ) -> torch.Tensor: + """ + HPU dequantization function for NF4 quantized tensors. + """ + assert_on_hpu([input, absmax]) + out_shape = (math.prod(quant_state.shape), ) + out_dq = torch.ops.hpu.dequantize_nf4(input, absmax, blocksize, + out_shape=out_shape, + out_dtype=quant_state.dtype) + output = out_dq.reshape(quant_state.shape).T + return output + + def dequantize_4bit( + self, + A: torch.Tensor, + quant_state: Optional[QuantState] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize: int = 64, + quant_type: Literal["nf4"] = "nf4", + ) -> torch.Tensor: + if blocksize is None: + blocksize = 64 + + assert_on_hpu([A, absmax, out]) + if quant_state.nested: + raise AssertionError("Double quantization is not supported for HPU backend") + absmax = dequant_8bit(absmax, quant_state.offset, quant_state.state2) + return self.dequantize_nf4_impl(A, absmax, blocksize, quant_state) + + def gemv_4bit( + self, + A: torch.Tensor, + B: torch.Tensor, + out: Optional[torch.Tensor] = None, + transposed_A=False, + transposed_B=False, + state: QuantState = None, + ) -> torch.Tensor: + assert_on_hpu([A, B, out]) + if state is None: + raise ValueError( + "state cannot be None. gemv_4bit() requires the state from quantize_4bit()" + ) + dqB = self.dequantize_nf4_impl(B, state.absmax, state.blocksize, state) + output = torch.matmul(A, dqB.to(A.dtype)) + if out is not None: + out.copy_(output) + else: + out = output + return out + + def int8_vectorwise_dequant(self, A: torch.Tensor, stats: torch.Tensor): + raise NotImplementedError("Not yet implemented for HPU backend") + + def int8_vectorwise_quant(self, A: torch.Tensor, threshold=0.0): + raise NotImplementedError("Not yet implemented for HPU backend") + + def dequantize_blockwise( + self, + A: torch.Tensor, + quant_state: Optional[QuantState] = None, + absmax: Optional[torch.Tensor] = None, + code: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize: int = 4096, + nested=False, + ) -> torch.Tensor: + raise NotImplementedError("Not yet implemented for HPU backend") + + def quantize_blockwise( + self, + A: torch.Tensor, + code: Optional[torch.Tensor] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize=4096, + nested=False, + ) -> Tuple[torch.Tensor, QuantState]: + raise NotImplementedError("Not yet implemented for HPU backend") + + def optimizer_update_8bit_blockwise( + self, + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + state2: Optional[torch.Tensor], + beta1: float, + beta2: float, + eps: float, + step: int, + lr: float, + qmap1: torch.Tensor, + qmap2: Optional[torch.Tensor], + absmax1: torch.Tensor, + absmax2: Optional[torch.Tensor], + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + skip_zeros=False, + ) -> None: + raise NotImplementedError("Not yet implemented for HPU backend") + + def optimizer_update_32bit( + self, + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + beta1: float, + eps: float, + step: int, + lr: float, + state2: Optional[torch.Tensor] = None, + beta2: float = 0.0, + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + unorm_vec: Optional[torch.Tensor] = None, + max_unorm: float = 0.0, + skip_zeros=False, + ) -> None: + raise NotImplementedError("Not yet implemented for HPU backend") diff --git a/bitsandbytes/backends/mps.py b/bitsandbytes/backends/mps.py new file mode 100644 index 000000000..9400699a9 --- /dev/null +++ b/bitsandbytes/backends/mps.py @@ -0,0 +1,167 @@ +from typing import Literal, Optional, Tuple, Union + +import torch + +from bitsandbytes.utils import QuantState + +from .base import Backend + + +class MPSBackend(Backend): + def device_synchronize(self): + torch.mps.synchronize() + + def double_quant( + self, + A: torch.Tensor, + col_stats: Optional[torch.Tensor] = None, + row_stats: Optional[torch.Tensor] = None, + out_col: Optional[torch.Tensor] = None, + out_row: Optional[torch.Tensor] = None, + threshold=0.0, + ): + raise NotImplementedError + + def transform( + self, + A: torch.Tensor, + to_order: str, + from_order="row", + out: Optional[torch.Tensor] = None, + transpose=False, + state: Optional[Tuple[torch.Size, str]] = None, + ld=None, + ): + raise NotImplementedError + + def igemmlt( + self, + A: torch.Tensor, + B: torch.Tensor, + SA: Tuple[torch.Size, str], + SB: Tuple[torch.Size, str], + out: Optional[torch.Tensor] = None, + Sout: Optional[Tuple[torch.Size, str]] = None, + dtype=torch.int32, + ) -> Union[torch.Tensor, Tuple[Optional[Tuple[torch.Tensor, Tuple[torch.Size, str]]]]]: + raise NotImplementedError + + def mm_dequant( + self, + A: torch.Tensor, + quant_state: Tuple[torch.Size, str], + row_stats: torch.Tensor, + col_stats: torch.Tensor, + out: Optional[torch.Tensor] = None, + new_row_stats: Optional[torch.Tensor] = None, + new_col_stats: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + raise NotImplementedError + + def extract_outliers( + self, + A: torch.Tensor, + SA: Tuple[torch.Size, str], + idx: torch.Tensor, + ) -> torch.Tensor: + raise NotImplementedError + + def quantize_4bit( + self, + A: torch.Tensor, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize=64, + compress_statistics=False, + quant_type: Literal["fp4", "nf4"] = "fp4", + quant_storage=torch.uint8, + ) -> Tuple[torch.Tensor, QuantState]: + raise NotImplementedError + + def dequantize_4bit( + self, + A: torch.Tensor, + quant_state: Optional[QuantState] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize: int = 64, + quant_type: Literal["fp4", "nf4"] = "fp4", + ) -> torch.Tensor: + raise NotImplementedError + + def gemv_4bit( + self, + A: torch.Tensor, + B: torch.Tensor, + out: Optional[torch.Tensor] = None, + transposed_A=False, + transposed_B=False, + state: QuantState = None, + ) -> torch.Tensor: + raise NotImplementedError + + def dequantize_blockwise( + self, + A: torch.Tensor, + quant_state: Optional[QuantState] = None, + absmax: Optional[torch.Tensor] = None, + code: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize: int = 4096, + nested=False, + ) -> torch.Tensor: + raise NotImplementedError + + def quantize_blockwise( + self, + A: torch.Tensor, + code: Optional[torch.Tensor] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize=4096, + nested=False, + ) -> Tuple[torch.Tensor, QuantState]: + raise NotImplementedError + + def optimizer_update_8bit_blockwise( + self, + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + state2: Optional[torch.Tensor], + beta1: float, + beta2: float, + eps: float, + step: int, + lr: float, + qmap1: torch.Tensor, + qmap2: Optional[torch.Tensor], + absmax1: torch.Tensor, + absmax2: Optional[torch.Tensor], + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + skip_zeros=False, + ) -> None: + raise NotImplementedError + + def optimizer_update_32bit( + self, + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + beta1: float, + eps: float, + step: int, + lr: float, + state2: Optional[torch.Tensor] = None, + beta2: float = 0.0, + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + unorm_vec: Optional[torch.Tensor] = None, + max_unorm: float = 0.0, + skip_zeros=False, + ) -> None: + raise NotImplementedError diff --git a/bitsandbytes/backends/npu.py b/bitsandbytes/backends/npu.py new file mode 100644 index 000000000..cd3933879 --- /dev/null +++ b/bitsandbytes/backends/npu.py @@ -0,0 +1,313 @@ +import ctypes as ct +from typing import Literal, Optional, Tuple + +import torch + +try: + # to support Ascend NPU backend + import torch_npu # noqa: F401 +except ImportError: + pass + +from bitsandbytes.cextension import lib +from bitsandbytes.functional import ( + get_4bit_type, + get_ptr, +) +from bitsandbytes.utils import QuantState + +from .base import Backend + + +def assert_on_npu(tensors): + if not all(t.device.type == "npu" for t in tensors if t is not None): + raise TypeError( + "All input tensors to be on NPU, but found some tensors not be on NPU:\n" + f"{[(t.shape, t.device) if isinstance(t, torch.Tensor) else None for t in tensors]}" + ) + return True + + +class NPUBackend(Backend): + def device_synchronize(self): + torch.npu.synchronize() + + def int8_double_quant( + self, + A: torch.Tensor, + col_stats: Optional[torch.Tensor] = None, + row_stats: Optional[torch.Tensor] = None, + out_col: Optional[torch.Tensor] = None, + out_row: Optional[torch.Tensor] = None, + threshold=0.0, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + raise NotImplementedError + + def int8_vectorwise_dequant(self, A, stats): + return super().int8_vectorwise_dequant(A, stats) + + def int8_vectorwise_quant( + self, + A: torch.Tensor, + threshold=0.0, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + raise NotImplementedError + + def transform( + self, + A: torch.Tensor, + to_order: str, + from_order="row", + out: Optional[torch.Tensor] = None, + transpose=False, + state: Optional[Tuple[torch.Size, str]] = None, + ld=None, + ): + raise NotImplementedError + + def int8_linear_matmul( + self, + A: torch.Tensor, + B: torch.Tensor, + out: Optional[torch.Tensor] = None, + dtype=torch.int32, + ) -> torch.Tensor: + raise NotImplementedError + + def int8_mm_dequant( + self, + A: torch.Tensor, + row_stats: torch.Tensor, + col_stats: torch.Tensor, + out: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + raise NotImplementedError + + def extract_outliers( + self, + A: torch.Tensor, + SA: Tuple[torch.Size, str], + idx: torch.Tensor, + ) -> torch.Tensor: + raise NotImplementedError + + def quantize_4bit( + self, + A: torch.Tensor, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize: Optional[int] = None, + compress_statistics=False, + quant_type: Literal["fp4", "nf4"] = "nf4", + quant_storage=torch.uint8, + ) -> Tuple[torch.Tensor, QuantState]: + if quant_type not in ["nf4"]: + raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented.") + if compress_statistics: + raise NotImplementedError("compress_statistics is not implemented.") + if blocksize is None: + blocksize = 128 + + prev_device = torch.npu.current_device() + torch.npu.set_device(A.device) + if A.dtype in [torch.float32, torch.float16, torch.bfloat16]: + data = [ + -1.0, + -0.6961928009986877, + -0.5250730514526367, + -0.39491748809814453, + -0.28444138169288635, + -0.18477343022823334, + -0.09105003625154495, + 0.0, + 0.07958029955625534, + 0.16093020141124725, + 0.24611230194568634, + 0.33791524171829224, + 0.44070982933044434, + 0.5626170039176941, + 0.7229568362236023, + 1.0, + ] + data = torch.tensor(data, device="npu", dtype=torch.float32).view(1, -1) + absmax = A.view(-1, blocksize).abs().max(dim=1, keepdim=True).values + a = A.view(-1, blocksize) / absmax.float() + diff = torch.abs(a.unsqueeze(-1) - data) + out = (torch.argmin(diff, dim=-1) + 8) % 16 + out = out.reshape(-1, 2) + out = (out[:, 0] + out[:, 1] * 16).to(torch.uint8) + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + assert_on_npu([A, absmax, out]) + torch.npu.set_device(prev_device) + + code = get_4bit_type(quant_type, device=A.device) + state = QuantState( + absmax=absmax, + shape=A.shape, + dtype=A.dtype, + blocksize=blocksize, + code=code, + quant_type=quant_type, + ) + + return out, state + + def dequantize_4bit( + self, + A: torch.Tensor, + quant_state: Optional[QuantState] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize: Optional[int] = None, + quant_type: Literal["fp4", "nf4"] = "nf4", + ) -> torch.Tensor: + if blocksize is None: + blocksize = 128 + supported_blocksizes = [2048, 4096, 1024, 512, 256, 128, 64] + if blocksize not in supported_blocksizes: + raise ValueError( + f"The blockwise of {blocksize} is not supported. Supported values: {supported_blocksizes}" + ) + + if quant_state is None: + assert absmax is not None and out is not None + quant_state = QuantState( + absmax=absmax, shape=out.shape, dtype=out.dtype, blocksize=blocksize, quant_type=quant_type + ) + else: + absmax = quant_state.absmax + + if out is None: + out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device) + + n = out.numel() + + prev_device = torch.npu.current_device() + torch.npu.set_device(A.device) + assert_on_npu([A, absmax, out]) + + if quant_state.quant_type not in ["nf4"]: + raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented.") + + if out.dtype == torch.float32: + lib.cdequantize_blockwise_fp32_nf4( + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + torch.npu.current_stream(), + ) + elif out.dtype == torch.float16: + lib.cdequantize_blockwise_fp16_nf4( + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + torch.npu.current_stream(), + ) + elif out.dtype == torch.bfloat16: + # bf16: bf16 -> fp32 -> op -> fp32 -> bf16 + absmax = absmax.to(torch.float32) + out = out.to(torch.float32) + lib.cdequantize_blockwise_fp32_nf4( + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + torch.npu.current_stream(), + ) + out = out.to(torch.bfloat16) + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + torch.npu.set_device(prev_device) + is_transposed = True if A.shape[0] == 1 else False + + if is_transposed: + return out.t() + else: + return out + + def gemv_4bit( + self, + A: torch.Tensor, + B: torch.Tensor, + out: Optional[torch.Tensor] = None, + transposed_A=False, + transposed_B=False, + state: QuantState = None, + ) -> torch.Tensor: + raise NotImplementedError + + def dequantize_blockwise( + self, + A: torch.Tensor, + quant_state: Optional[QuantState] = None, + absmax: Optional[torch.Tensor] = None, + code: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize: int = 4096, + nested=False, + ) -> torch.Tensor: + raise NotImplementedError + + def quantize_blockwise( + self, + A: torch.Tensor, + code: Optional[torch.Tensor] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize=4096, + nested=False, + ) -> Tuple[torch.Tensor, QuantState]: + raise NotImplementedError + + def optimizer_update_8bit_blockwise( + self, + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + state2: Optional[torch.Tensor], + beta1: float, + beta2: float, + beta3: float, + alpha: float, + eps: float, + step: int, + lr: float, + qmap1: torch.Tensor, + qmap2: Optional[torch.Tensor], + absmax1: torch.Tensor, + absmax2: Optional[torch.Tensor], + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + skip_zeros=False, + ) -> None: + raise NotImplementedError + + def optimizer_update_32bit( + self, + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + beta1: float, + eps: float, + step: int, + lr: float, + state2: Optional[torch.Tensor] = None, + beta2: float = 0.0, + beta3: float = 0.0, + alpha: float = 0.0, + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + unorm_vec: Optional[torch.Tensor] = None, + max_unorm: float = 0.0, + skip_zeros=False, + ) -> None: + raise NotImplementedError diff --git a/bitsandbytes/backends/rocm.py b/bitsandbytes/backends/rocm.py new file mode 100644 index 000000000..d74f10ead --- /dev/null +++ b/bitsandbytes/backends/rocm.py @@ -0,0 +1,12 @@ +from .cuda import CUDABackend + + +class ROCmBackend(CUDABackend): + """ + Backend for AMD ROCm implementation. + + The interface is largely the same as the CUDA implementation, so only any + differences need to be implemented here. + """ + + pass diff --git a/bitsandbytes/backends/xpu.py b/bitsandbytes/backends/xpu.py new file mode 100644 index 000000000..702c3c386 --- /dev/null +++ b/bitsandbytes/backends/xpu.py @@ -0,0 +1,318 @@ +from typing import Literal, Optional, Tuple + +import torch + +from bitsandbytes.utils import QuantState + +from .base import Backend +from .cpu_xpu_common import ( + dequantize_4bit_impl, + double_quant_impl, + gemm_4bit_impl, + int8_linear_matmul_impl, + int8_mm_dequant_impl, + quantize_4bit_impl, + _ipex_xpu_version_prereq +) +try: + import intel_extension_for_pytorch as ipex + ipex_xpu = ipex if ipex._C._has_xpu() else None +except BaseException: + ipex_xpu = None + +Tensor = torch.Tensor + + +str2optimizer8bit_blockwise = {} +if ipex_xpu is not None and _ipex_xpu_version_prereq(2, 7): + str2optimizer8bit_blockwise = { + "adam": ( + ipex.xpu.bitsandbytes.cadam_8bit_blockwise_grad_fp32, + ipex.xpu.bitsandbytes.cadam_8bit_blockwise_grad_fp16, + ipex.xpu.bitsandbytes.cadam_8bit_blockwise_grad_bf16, + ), + } + + +def assert_on_xpu(tensors): + on_xpu = True + for t in tensors: + if t is None: + continue # NULL pointers are fine + on_xpu &= t.device.type == "xpu" + if not on_xpu: + raise TypeError( + "All input tensors need to be on XPU, but found some tensors to not be on XPU:\n" + f" {[(t.shape, t.device) if isinstance(t, Tensor) else None for t in tensors]}" + ) + return on_xpu + + +class XPUBackend(Backend): + mm_dequant_compute_dtype = torch.bfloat16 + mm_dequant_output_dtype = torch.bfloat16 + + def device_synchronize(self): + torch.xpu.synchronize() + + def int8_double_quant( + self, + A: torch.Tensor, + col_stats: Optional[torch.Tensor] = None, + row_stats: Optional[torch.Tensor] = None, + out_col: Optional[torch.Tensor] = None, + out_row: Optional[torch.Tensor] = None, + threshold=0.0, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + assert_on_xpu([A, col_stats, row_stats, out_col, out_row]) + output = double_quant_impl(A, col_stats, row_stats, out_col, out_row, threshold) + return output + + def transform( + self, + A: torch.Tensor, + to_order: str, + from_order="row", + out: Optional[torch.Tensor] = None, + transpose=False, + state: Optional[Tuple[torch.Size, str]] = None, + ld=None, + ): + """ + Transform tensor A to to_order. It is originally designed for CUDA. + For XPU, it returns the original tensor if transpose=False. + Otherwise, it returns the transpose of A + """ + assert_on_xpu([A, out]) + if transpose: + if out is not None: + out.copy_(A.T) + else: + out = A.T + else: + if out is not None: + out.copy_(A) + else: + out = A + return out, state + + def int8_linear_matmul( + self, + A: torch.Tensor, + B: torch.Tensor, + out: Optional[torch.Tensor] = None, + dtype=torch.int32, + ) -> torch.Tensor: + assert_on_xpu([A, B]) + output = int8_linear_matmul_impl(A, B, out, dtype) + return output + + def int8_mm_dequant( + self, + A: torch.Tensor, + row_stats: torch.Tensor, + col_stats: torch.Tensor, + out: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + assert_on_xpu([A, row_stats, col_stats, out, bias]) + output = int8_mm_dequant_impl( + A, + row_stats, + col_stats, + out, + bias, + self.mm_dequant_compute_dtype, + self.mm_dequant_output_dtype, + ) + return output + + def int8_vectorwise_dequant(self, A, stats): + return super().int8_vectorwise_dequant(A, stats) + + def int8_vectorwise_quant(self, A: torch.Tensor, threshold=0.0): + # TODO: We can optimize this as we don't actually need column-wise quant. + out, _, stats, _, outlier_cols = self.int8_double_quant(A, threshold=threshold) + return out, stats, outlier_cols + + def extract_outliers( + self, + A: torch.Tensor, + SA: Tuple[torch.Size, str], + idx: torch.Tensor, + ) -> torch.Tensor: + assert_on_xpu([A]) + output = A[:, idx].contiguous() + return output + + def quantize_4bit( + self, + A: torch.Tensor, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize=64, + compress_statistics=False, + quant_type: Literal["fp4", "nf4"] = "fp4", + quant_storage=torch.uint8, + ) -> Tuple[torch.Tensor, QuantState]: + if blocksize is None: + blocksize = 64 + assert_on_xpu([A, absmax, out]) + output = quantize_4bit_impl(A, absmax, out, blocksize, compress_statistics, quant_type, quant_storage) + return output + + def dequantize_4bit( + self, + A: torch.Tensor, + quant_state: Optional[QuantState] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize: int = 64, + quant_type: Literal["fp4", "nf4"] = "fp4", + ) -> torch.Tensor: + if blocksize is None: + blocksize = 64 + assert_on_xpu([A, absmax, out]) + if quant_type == "nf4" and getattr(quant_state, "ipex", False): + output = torch.ops.torch_ipex.dequantize_4bit(A, "nf4", quant_state.shape, absmax, None, blocksize).t() + else: + output = dequantize_4bit_impl(A, quant_state, absmax, out, blocksize, quant_type) + + return output + + def gemv_4bit( + self, + A: torch.Tensor, + B: torch.Tensor, + out: Optional[torch.Tensor] = None, + transposed_A=False, + transposed_B=False, + state: QuantState = None, + ) -> torch.Tensor: + assert_on_xpu([A, B, out]) + if state is None: + raise ValueError("state cannot be None. gemv_4bit() requires the state from quantize_4bit()") + output = gemm_4bit_impl(A, B, out, transposed_A, transposed_B, state) + return output + + def dequantize_blockwise( + self, + A: torch.Tensor, + quant_state: Optional[QuantState] = None, + absmax: Optional[torch.Tensor] = None, + code: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize: int = 4096, + nested=False, + ) -> torch.Tensor: + if ipex_xpu is None or not _ipex_xpu_version_prereq(2, 7): + raise RuntimeError("Please install intel_extension_for_ipex >= 2.7 for 8bit optimizer backend on XPU device.") + + # void cdequantize_blockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream) + if out.dtype == torch.float16: + ipex.xpu.bitsandbytes.cdequantize_blockwise_fp16(code, A, absmax, out, blocksize, A.numel()) + elif out.dtype == torch.bfloat16: + ipex.xpu.bitsandbytes.cdequantize_blockwise_bf16(code, A, absmax, out, blocksize, A.numel()) + elif out.dtype == torch.float32: + ipex.xpu.bitsandbytes.cdequantize_blockwise_fp32(code, A, absmax, out, blocksize, A.numel()) + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}") + + + def quantize_blockwise( + self, + A: torch.Tensor, + code: Optional[torch.Tensor] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize=4096, + nested=False, + ) -> Tuple[torch.Tensor, QuantState]: + raise NotImplementedError + + def optimizer_update_8bit_blockwise( + self, + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + state2: Optional[torch.Tensor], + beta1: float, + beta2: float, + beta3: float, + alpha: float, + eps: float, + step: int, + lr: float, + qmap1: torch.Tensor, + qmap2: Optional[torch.Tensor], + absmax1: torch.Tensor, + absmax2: Optional[torch.Tensor], + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + skip_zeros=False, + ) -> None: + optim_func = None + if ipex_xpu is None or not _ipex_xpu_version_prereq(2, 7): + raise RuntimeError("Please install intel_extension_for_ipex >= 2.7 for 8bit optimizer backend on XPU device.") + + assert_on_xpu([g, p, state1, state2, qmap1, qmap2, absmax1, absmax2]) + + if g.dtype == torch.float32 and state1.dtype == torch.uint8: + optim_func = str2optimizer8bit_blockwise[optimizer_name][0] + elif g.dtype == torch.float16 and state1.dtype == torch.uint8: + optim_func = str2optimizer8bit_blockwise[optimizer_name][1] + elif ( + g.dtype == torch.bfloat16 + and state1.dtype == torch.uint8 + and len(str2optimizer8bit_blockwise[optimizer_name]) == 3 + ): + optim_func = str2optimizer8bit_blockwise[optimizer_name][2] + else: + raise ValueError( + f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}", + ) + optim_func( + p, + g, + state1, + state2, + beta1, + beta2, + beta3, + alpha, + eps, + step, + lr, + qmap1, + qmap2, + absmax1, + absmax2, + weight_decay, + gnorm_scale, + skip_zeros, + g.numel() + ) + + + def optimizer_update_32bit( + self, + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + beta1: float, + eps: float, + step: int, + lr: float, + state2: Optional[torch.Tensor] = None, + beta2: float = 0.0, + beta3: float = 0.0, + alpha: float = 0.0, + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + unorm_vec: Optional[torch.Tensor] = None, + max_unorm: float = 0.0, + skip_zeros=False, + ) -> None: + raise NotImplementedError diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index ae738363a..007bdbf8e 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -7,7 +7,8 @@ import torch from bitsandbytes.consts import DYNAMIC_LIBRARY_SUFFIX, PACKAGE_DIR -from bitsandbytes.cuda_specs import CUDASpecs, get_cuda_specs +from bitsandbytes.cuda_specs import CUDASpecs, get_cuda_specs, get_rocm_gpu_arch +from bitsandbytes.npu_specs import get_npu_specs logger = logging.getLogger(__name__) @@ -19,9 +20,15 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path: The library is not guaranteed to exist at the returned path. """ - library_name = f"libbitsandbytes_cuda{cuda_specs.cuda_version_string}{DYNAMIC_LIBRARY_SUFFIX}" + if torch.version.hip: + if BNB_HIP_VERSION < 601: + return PACKAGE_DIR / f"libbitsandbytes_rocm{BNB_HIP_VERSION_SHORT}_nohipblaslt{DYNAMIC_LIBRARY_SUFFIX}" + else: + return PACKAGE_DIR / f"libbitsandbytes_rocm{BNB_HIP_VERSION_SHORT}{DYNAMIC_LIBRARY_SUFFIX}" + library_name = f"libbitsandbytes_cuda{cuda_specs.cuda_version_string}{DYNAMIC_LIBRARY_SUFFIX}" override_value = os.environ.get("BNB_CUDA_VERSION") + if override_value: library_name = re.sub(r"cuda\d+", f"cuda{override_value}", library_name, count=1) logger.warning( @@ -55,7 +62,10 @@ class CudaBNBNativeLibrary(BNBNativeLibrary): def __init__(self, lib: ct.CDLL): super().__init__(lib) lib.get_context.restype = ct.c_void_p - lib.get_cusparse.restype = ct.c_void_p + if torch.version.cuda: + lib.get_cusparse.restype = ct.c_void_p + elif torch.version.hip: + lib.get_hipsparse.restype = ct.c_void_p lib.cget_managed_ptr.restype = ct.c_void_p @@ -67,34 +77,58 @@ def get_native_library() -> BNBNativeLibrary: if cuda_binary_path.exists(): binary_path = cuda_binary_path else: - logger.warning("Could not find the bitsandbytes CUDA binary at %r", cuda_binary_path) + logger.warning("Could not find the bitsandbytes %s binary at %r", BNB_BACKEND, cuda_binary_path) + npu_specs = get_npu_specs() + if npu_specs: + binary_path = PACKAGE_DIR / f"libbitsandbytes_npu{DYNAMIC_LIBRARY_SUFFIX}" + logger.debug(f"Loading bitsandbytes native library from: {binary_path}") dll = ct.cdll.LoadLibrary(str(binary_path)) if hasattr(dll, "get_context"): # only a CUDA-built library exposes this return CudaBNBNativeLibrary(dll) - logger.warning( - "The installed version of bitsandbytes was compiled without GPU support. " - "8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.", - ) return BNBNativeLibrary(dll) +ROCM_GPU_ARCH = get_rocm_gpu_arch() + try: + import intel_extension_for_pytorch as ipex + + assert ipex._C._has_cpu() or ipex._C._has_xpu() + is_ipex_available = True +except Exception: + is_ipex_available = False + +try: + if torch.version.hip: + hip_major, hip_minor = map(int, torch.version.hip.split(".")[0:2]) + HIP_ENVIRONMENT, BNB_HIP_VERSION = True, hip_major * 100 + hip_minor + BNB_HIP_VERSION_SHORT = f"{hip_major}{hip_minor}" + BNB_BACKEND = "ROCm" + else: + HIP_ENVIRONMENT, BNB_HIP_VERSION = False, 0 + BNB_HIP_VERSION_SHORT = "" + BNB_BACKEND = "CUDA" + lib = get_native_library() except Exception as e: lib = None - logger.error(f"Could not load bitsandbytes native library: {e}", exc_info=True) - if torch.cuda.is_available(): - logger.warning( - """ -CUDA Setup failed despite CUDA being available. Please run the following command to get more information: - -python -m bitsandbytes - -Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them -to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes -and open an issue at: https://github.com/bitsandbytes-foundation/bitsandbytes/issues -""", + if not is_ipex_available: + logger.error( + f"Could not load bitsandbytes native library: {e}. If you use Intel CPU or XPU, please pip install intel_extension_for_pytorch", + exc_info=True, ) + if torch.cuda.is_available(): + logger.warning( + f""" + {BNB_BACKEND} Setup failed despite {BNB_BACKEND} being available. Please run the following command to get more information: + + python -m bitsandbytes + + Inspect the output of the command and see if you can locate {BNB_BACKEND} libraries. You might need to add them + to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes + and open an issue at: https://github.com/bitsandbytes-foundation/bitsandbytes/issues + """, + ) diff --git a/bitsandbytes/cuda_specs.py b/bitsandbytes/cuda_specs.py index e72d57590..dd6fd9892 100644 --- a/bitsandbytes/cuda_specs.py +++ b/bitsandbytes/cuda_specs.py @@ -1,4 +1,7 @@ import dataclasses +import logging +import re +import subprocess from typing import List, Optional, Tuple import torch @@ -21,7 +24,10 @@ def get_compute_capabilities() -> List[Tuple[int, int]]: def get_cuda_version_tuple() -> Tuple[int, int]: # https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION - major, minor = map(int, torch.version.cuda.split(".")) + if torch.version.cuda: + major, minor = map(int, torch.version.cuda.split(".")) + elif torch.version.hip: + major, minor = map(int, torch.version.hip.split(".")[0:2]) return major, minor @@ -39,3 +45,26 @@ def get_cuda_specs() -> Optional[CUDASpecs]: cuda_version_string=(get_cuda_version_string()), cuda_version_tuple=get_cuda_version_tuple(), ) + + +def get_rocm_gpu_arch() -> str: + logger = logging.getLogger(__name__) + try: + if torch.version.hip: + result = subprocess.run(["rocminfo"], capture_output=True, text=True) + match = re.search(r"Name:\s+gfx([a-zA-Z\d]+)", result.stdout) + if match: + return "gfx" + match.group(1) + else: + return "unknown" + else: + return "unknown" + except Exception as e: + logger.error(f"Could not detect ROCm GPU architecture: {e}") + if torch.cuda.is_available(): + logger.warning( + """ +ROCm GPU architecture detection failed despite ROCm being available. + """, + ) + return "unknown" diff --git a/bitsandbytes/diagnostics/cuda.py b/bitsandbytes/diagnostics/cuda.py index 45dc98dea..ea6dd4f8d 100644 --- a/bitsandbytes/diagnostics/cuda.py +++ b/bitsandbytes/diagnostics/cuda.py @@ -5,7 +5,7 @@ import torch -from bitsandbytes.cextension import get_cuda_bnb_library_path +from bitsandbytes.cextension import HIP_ENVIRONMENT, get_cuda_bnb_library_path from bitsandbytes.consts import NONPYTORCH_DOC_URL from bitsandbytes.cuda_specs import CUDASpecs from bitsandbytes.diagnostics.utils import print_dedented @@ -32,15 +32,20 @@ "_", # current Python interpreter } -CUDA_RUNTIME_LIB_PATTERNS = ( - "cudart64*.dll", # Windows - "libcudart*.so*", # libcudart.so, libcudart.so.11.0, libcudart.so.12.0, libcudart.so.12.1, libcudart.so.12.2 etc. - "nvcuda*.dll", # Windows -) - logger = logging.getLogger(__name__) +def get_runtime_lib_patterns() -> tuple: + if HIP_ENVIRONMENT: + return ("libamdhip64.so*",) + else: + return ( + "cudart64*.dll", # Windows + "libcudart*.so*", # libcudart.so, libcudart.so.11.0, libcudart.so.12.0, libcudart.so.12.1, libcudart.so.12.2 etc. + "nvcuda*.dll", # Windows + ) + + def find_cuda_libraries_in_path_list(paths_list_candidate: str) -> Iterable[Path]: for dir_string in paths_list_candidate.split(os.pathsep): if not dir_string: @@ -55,9 +60,9 @@ def find_cuda_libraries_in_path_list(paths_list_candidate: str) -> Iterable[Path continue except OSError: # Assume an esoteric error trying to poke at the directory pass - for lib_pattern in CUDA_RUNTIME_LIB_PATTERNS: + for lib_pattern in get_runtime_lib_patterns(): for pth in dir.glob(lib_pattern): - if pth.is_file(): + if pth.is_file() and not pth.is_symlink(): yield pth except (OSError, PermissionError): pass @@ -104,7 +109,7 @@ def find_cudart_libraries() -> Iterator[Path]: yield from find_cuda_libraries_in_path_list(value) -def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None: +def _print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None: print( f"PyTorch settings found: CUDA_VERSION={cuda_specs.cuda_version_string}, " f"Highest Compute Capability: {cuda_specs.highest_compute_capability}.", @@ -149,10 +154,40 @@ def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None: # (2) Multiple CUDA versions installed -def print_cuda_runtime_diagnostics() -> None: +def _print_hip_diagnostics(cuda_specs: CUDASpecs) -> None: + print(f"PyTorch settings found: ROCM_VERSION={cuda_specs.cuda_version_string}") + + binary_path = get_cuda_bnb_library_path(cuda_specs) + if not binary_path.exists(): + print_dedented( + f""" + Library not found: {binary_path}. + Maybe you need to compile it from source? If you compiled from source, check that ROCM_VERSION + in PyTorch Settings matches your ROCm install. If not, reinstall PyTorch for your ROCm version + and rebuild bitsandbytes. + """, + ) + + hip_major, hip_minor = cuda_specs.cuda_version_tuple + if (hip_major, hip_minor) < (6, 1): + print_dedented( + """ + WARNING: bitsandbytes is fully supported only from ROCm 6.1. + """, + ) + + +def print_diagnostics(cuda_specs: CUDASpecs) -> None: + if HIP_ENVIRONMENT: + _print_hip_diagnostics(cuda_specs) + else: + _print_cuda_diagnostics(cuda_specs) + + +def _print_cuda_runtime_diagnostics() -> None: cudart_paths = list(find_cudart_libraries()) if not cudart_paths: - print("CUDA SETUP: WARNING! CUDA runtime files not found in any environmental path.") + print("WARNING! CUDA runtime files not found in any environmental path.") elif len(cudart_paths) > 1: print_dedented( f""" @@ -174,3 +209,33 @@ def print_cuda_runtime_diagnostics() -> None: ) for pth in cudart_paths: print(f"* Found CUDA runtime at: {pth}") + + +def _print_hip_runtime_diagnostics() -> None: + cudart_paths = list(find_cudart_libraries()) + if not cudart_paths: + print("WARNING! ROCm runtime files not found in any environmental path.") + elif len(cudart_paths) > 1: + print_dedented( + f""" + Found duplicate ROCm runtime files (see below). + + We select the PyTorch default ROCm runtime, which is {torch.version.hip}, + but this might mismatch with the ROCm version that is needed for bitsandbytes. + + To resolve it, install PyTorch built for the ROCm version you want to use + + and set LD_LIBRARY_PATH to your ROCm install path, e.g. + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm-6.1.2/lib, + """, + ) + + for pth in cudart_paths: + print(f"* Found ROCm runtime at: {pth}") + + +def print_runtime_diagnostics() -> None: + if HIP_ENVIRONMENT: + _print_hip_runtime_diagnostics() + else: + _print_cuda_runtime_diagnostics() diff --git a/bitsandbytes/diagnostics/main.py b/bitsandbytes/diagnostics/main.py index 1ce096f69..8dc43ed2a 100644 --- a/bitsandbytes/diagnostics/main.py +++ b/bitsandbytes/diagnostics/main.py @@ -3,11 +3,12 @@ import torch +from bitsandbytes.cextension import BNB_BACKEND, HIP_ENVIRONMENT from bitsandbytes.consts import PACKAGE_GITHUB_URL from bitsandbytes.cuda_specs import get_cuda_specs from bitsandbytes.diagnostics.cuda import ( - print_cuda_diagnostics, - print_cuda_runtime_diagnostics, + print_diagnostics, + print_runtime_diagnostics, ) from bitsandbytes.diagnostics.utils import print_dedented, print_header @@ -16,12 +17,13 @@ def sanity_check(): from bitsandbytes.cextension import lib if lib is None: + compute_backend = "cuda" if not HIP_ENVIRONMENT else "hip" print_dedented( - """ + f""" Couldn't load the bitsandbytes library, likely due to missing binaries. Please ensure bitsandbytes is properly installed. - For source installations, compile the binaries with `cmake -DCOMPUTE_BACKEND=cuda -S .`. + For source installations, compile the binaries with `cmake -DCOMPUTE_BACKEND={compute_backend} -S .`. See the documentation for more details if needed. Trying a simple check anyway, but this will likely fail... @@ -49,19 +51,24 @@ def main(): print_header("OTHER") cuda_specs = get_cuda_specs() - print("CUDA specs:", cuda_specs) + if HIP_ENVIRONMENT: + rocm_specs = f" rocm_version_string='{cuda_specs.cuda_version_string}'," + rocm_specs += f" rocm_version_tuple={cuda_specs.cuda_version_tuple}" + print(f"{BNB_BACKEND} specs:{rocm_specs}") + else: + print(f"{BNB_BACKEND} specs:{cuda_specs}") if not torch.cuda.is_available(): - print("Torch says CUDA is not available. Possible reasons:") - print("1. CUDA driver not installed") - print("2. CUDA not installed") - print("3. You have multiple conflicting CUDA libraries") + print(f"Torch says {BNB_BACKEND} is not available. Possible reasons:") + print(f"1. {BNB_BACKEND} driver not installed") + print(f"2. {BNB_BACKEND} not installed") + print(f"3. You have multiple conflicting {BNB_BACKEND} libraries") if cuda_specs: - print_cuda_diagnostics(cuda_specs) - print_cuda_runtime_diagnostics() + print_diagnostics(cuda_specs) + print_runtime_diagnostics() print_header("") print_header("DEBUG INFO END") print_header("") - print("Checking that the library is importable and CUDA is callable...") + print(f"Checking that the library is importable and {BNB_BACKEND} is callable...") try: sanity_check() print("SUCCESS!") diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index a5cc4a9f0..d1b3dd581 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -5,16 +5,17 @@ import ctypes as ct import itertools from math import prod -from typing import Any, Dict, Iterable, Optional, Tuple, Union +from typing import Iterable, Optional, Tuple, Union import numpy as np import torch from torch import Tensor from typing_extensions import deprecated -from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict +from bitsandbytes.backends import backends, ensure_backend_is_available +from bitsandbytes.utils import QuantState -from .cextension import lib +from .cextension import HIP_ENVIRONMENT, lib name2qmap = {} @@ -172,7 +173,11 @@ def __init__(self): raise RuntimeError("Call get_instance() instead") def initialize(self): - self.context = ct.c_void_p(lib.get_cusparse()) + # self.context = ct.c_void_p(lib.get_cusparse()) + if torch.version.cuda: + self.context = ct.c_void_p(lib.get_cusparse()) + elif torch.version.hip: + self.context = ct.c_void_p(lib.get_hipsparse()) @classmethod def get_instance(cls): @@ -539,13 +544,13 @@ def get_transform_buffer(shape, dtype, device, to_order, from_order="row", trans rows = shape[0] * shape[1] cols = shape[-1] - state = (shape, to_order) if transpose: # swap dims tmp = rows rows = cols cols = tmp - state = (shape[::-1], to_order) + shape = shape[::-1] + state = (shape, to_order) if to_order == "row" or to_order == "col": return init_func(shape, dtype=dtype, device=device), state @@ -580,12 +585,18 @@ def nvidia_transform( state=None, ld=None, ): + if HIP_ENVIRONMENT: + # col32/col_turing/col_ampere are not applicable to ROCm + # Use col format instead + to_order = "col" if to_order in ["col32", "col_turing", "col_ampere"] else to_order + from_order = "col" if from_order in ["col32", "col_turing", "col_ampere"] else from_order + if state is None: state = (A.shape, from_order) else: from_order = state[1] if out is None: - out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1]) + out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose) else: new_state = (state[1], to_order) func = get_transform_func(A.dtype, from_order, to_order, transpose) @@ -675,180 +686,8 @@ def estimate_quantiles( return out -class QuantState: - """container for quantization state components to work with Params4bit and similar classes""" - - valid_quant_types = ("fp4", "nf4") - valid_qs_type_keys = [f"bitsandbytes__{x}" for x in valid_quant_types] - valid_qs_keys = [ - "absmax", - "quant_map", - "nested_absmax", - "nested_quant_map", - "quant_state", - "quant_type", - "blocksize", - "dtype", - "shape", - "nested_blocksize", - "nested_dtype", - "nested_offset", - ] - - def __init__( - self, - absmax, - shape=None, - code=None, - blocksize=None, - quant_type=None, - dtype=None, - offset=None, - state2=None, - ): - self.absmax = absmax - self.shape = shape - self.code = code - self.dtype = dtype - self.blocksize = blocksize - self.quant_type = quant_type - self.offset = offset - self.state2 = state2 - self.nested = state2 is not None - - def __get_item__(self, idx): - """ - ensures compatibility with older quant state scheme with nested lists. - assumes the following layout: - state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type] - state2 = [absmax, input_shape, A.dtype, blocksize, None, quant_type] - """ - if self.nested: - list_repr = [ - self.absmax, - self.shape, - self.dtype, - self.blocksize, - [self.offset, self.state2], - self.quant_type, - ] - else: - list_repr = [self.absmax, self.shape, self.dtype, self.blocksize, None, self.quant_type] - return list_repr[idx] - - @classmethod - def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> "QuantState": - """ - unpacks components of state_dict into QuantState - where necessary, convert into strings, torch.dtype, ints, etc. - - qs_dict: based on state_dict, with only relevant keys, striped of prefixes. - - item with key `quant_state.bitsandbytes__[nf4/fp4]` may contain minor and non-tensor quant state items. - """ - - # unpacking tensor with non-tensor components - qs_key = [k for k, v in qs_dict.items() if "quant_state" in k and isinstance(v, torch.Tensor)] - if not len(qs_key) and "quant_type" not in qs_dict: - raise ValueError("Expected packed or unpacked quant_state items, found neither") - elif len(qs_key) != 1 or qs_key[0].split(".")[-1] not in cls.valid_qs_type_keys: - raise ValueError( - f"There should be exactly one `quant_state` item with ending from {cls.valid_qs_type_keys}.\nDetected {qs_key}.", - ) - - # unpacking minor and non-tensor quant state items if necessary - if len(qs_key) == 1: - first_qs_key = qs_key[0] - qs_dict.update(unpack_tensor_to_dict(qs_dict.pop(first_qs_key))) - - qs_dict = {k.split(".")[-1]: v for k, v in qs_dict.items()} # strip prefixes - assert set(qs_dict.keys()).issubset(cls.valid_qs_keys) - - if "nested_absmax" in qs_dict: - offset = torch.tensor(float(qs_dict["nested_offset"])).to(device) - state2 = cls( - absmax=qs_dict["nested_absmax"].to(device), - blocksize=qs_dict["nested_blocksize"], - code=qs_dict["nested_quant_map"].to(device), - dtype=getattr(torch, qs_dict["nested_dtype"]), - ) - else: - offset, state2 = None, None - - quant_state = cls( - quant_type=qs_dict["quant_type"], - absmax=qs_dict["absmax"].to(device), - blocksize=qs_dict["blocksize"], - code=qs_dict["quant_map"].to(device), - dtype=getattr(torch, qs_dict["dtype"]), - shape=torch.Size(qs_dict["shape"]) if qs_dict["shape"] is not None else None, - offset=offset, - state2=state2, - ) - return quant_state - - def as_dict(self, packed=False): - """ - returns dict of tensors and strings to use in serialization via _save_to_state_dict() - param: packed -- returns dict[str, torch.Tensor] for state_dict fit for safetensors saving - """ - qs_dict = { - "quant_type": self.quant_type, - "absmax": self.absmax, - "blocksize": self.blocksize, - "quant_map": self.code, - "dtype": str(self.dtype).strip("torch."), - "shape": tuple(self.shape), - } - if self.nested: - qs_dict.update( - { - "nested_absmax": self.state2.absmax, - "nested_blocksize": self.state2.blocksize, - "nested_quant_map": self.state2.code.clone(), # un-shared to avoid restoring it after shared tensors are removed by safetensors - "nested_dtype": str(self.state2.dtype).strip("torch."), - "nested_offset": self.offset.item(), - }, - ) - if not packed: - return qs_dict - - # packed format allows serialization of non-tensor components, critical for saving in safetensors format - qs_packed_dict = {k: v for k, v in qs_dict.items() if isinstance(v, torch.Tensor)} - non_tensor_dict = {k: v for k, v in qs_dict.items() if not isinstance(v, torch.Tensor)} - qs_packed_dict["quant_state." + "bitsandbytes__" + self.quant_type] = pack_dict_to_tensor(non_tensor_dict) - return qs_packed_dict - - def to(self, device): - # make sure the quantization state is on the right device - self.absmax = self.absmax.to(device) - if self.nested: - self.offset = self.offset.to(device) - self.state2.absmax = self.state2.absmax.to(device) - self.state2.code = self.state2.code.to(device) - - def __eq__(self, other): - if not isinstance(other, QuantState): - return False - - return ( - torch.allclose(self.absmax, other.absmax, atol=1e-6) - and self.shape == other.shape - and torch.allclose(self.code, other.code, atol=1e-6) - and self.dtype == other.dtype - and self.blocksize == other.blocksize - and self.quant_type == other.quant_type - and ( - self.offset == other.offset - if self.offset is not None and other.offset is not None - else self.offset is other.offset - ) - and ( - self.state2 == other.state2 - if self.state2 is not None and other.state2 is not None - else self.state2 is other.state2 - ) - ) +# maintain the compatibility as F.QuantState +QuantState = QuantState def quantize_blockwise( @@ -900,7 +739,12 @@ def quantize_blockwise( out = torch.zeros_like(A, dtype=torch.uint8) if A.device.type != "cpu": - assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] + # Some AMD GPUs have warpsize 64 + # Set min blocksize to 128 (~warpsize 64 in kernel) for HIP + if not HIP_ENVIRONMENT: + assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] + else: + assert blocksize in [4096, 2048, 1024, 512, 256, 128] code = code.to(A.device) @@ -1015,11 +859,25 @@ def dequantize_blockwise( if out is None: out = torch.empty(A.shape, dtype=quant_state.dtype, device=A.device) - if A.device.type != "cpu": + if A.device.type == "xpu": + backends[A.device.type].dequantize_blockwise( + A=A, + quant_state=quant_state, + absmax=absmax, + code=quant_state.code, + out=out, + blocksize=blocksize, + nested=quant_state.nested,) + elif A.device.type != "cpu": code = quant_state.code.to(A.device) - if quant_state.blocksize not in [4096, 2048, 1024, 512, 256, 128, 64]: + supported_blocksizes = [2048, 4096, 1024, 512, 256, 128, 64] + # Some AMD GPUs have warpsize 64 + # Set min blocksize to 128 (~warpsize 64 in kernel) for HIP + if HIP_ENVIRONMENT: + supported_blocksizes = supported_blocksizes[:-1] + if quant_state.blocksize not in supported_blocksizes: raise ValueError( - f"The blocksize of {quant_state.blocksize} is not supported. Supported values: [4096, 2048, 1024, 512, 256, 128, 64]", + f"The blocksize of {quant_state.blocksize} is not supported. Supported values: {supported_blocksizes}", ) is_on_gpu([A, absmax, out]) @@ -1143,10 +1001,14 @@ def quantize_fp4( A: torch.Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize=64, + blocksize=None, compress_statistics=False, quant_storage=torch.uint8, ): + if blocksize is None: + # Some AMD GPUs have warpsize 64 + # Set default blocksize to 128 (~warpsize 64 in kernel) for HIP + blocksize = 64 if not HIP_ENVIRONMENT else 128 return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "fp4", quant_storage) @@ -1154,10 +1016,14 @@ def quantize_nf4( A: torch.Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize=64, + blocksize=None, compress_statistics=False, quant_storage=torch.uint8, ): + if blocksize is None: + # Some AMD GPUs have warpsize 64 + # Set default blocksize to 128 (~warpsize 64 in kernel) for HIP + blocksize = 64 if not HIP_ENVIRONMENT else 128 return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "nf4", quant_storage) @@ -1165,7 +1031,7 @@ def quantize_4bit( A: torch.Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize=64, + blocksize=None, compress_statistics=False, quant_type="fp4", quant_storage=torch.uint8, @@ -1193,83 +1059,16 @@ def quantize_4bit( - `torch.Tensor`: The quantized tensor with packed 4-bit values. - [`QuantState`]: The state object used to undo the quantization. """ - - if A.device.type != "cuda": - raise NotImplementedError(f"Device type not supported for FP4 quantization: {A.device.type}") - if quant_type not in ["fp4", "nf4"]: - raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented.") - - n = A.numel() - input_shape = A.shape - - if absmax is None: - blocks = -(n // -blocksize) - absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) - - if out is None: - mod = dtype2bytes[quant_storage] * 2 - out = torch.zeros(((n + 1) // mod, 1), dtype=quant_storage, device=A.device) - - assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] - - is_on_gpu([A, out, absmax]) - - with _cuda_device_of(A): - args = ( - None, - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(n), - ) - - if A.dtype == torch.bfloat16: - if quant_type == "fp4": - lib.cquantize_blockwise_bf16_fp4(*args) - else: - lib.cquantize_blockwise_bf16_nf4(*args) - elif A.dtype == torch.float16: - if quant_type == "fp4": - lib.cquantize_blockwise_fp16_fp4(*args) - else: - lib.cquantize_blockwise_fp16_nf4(*args) - elif A.dtype == torch.float32: - if quant_type == "fp4": - lib.cquantize_blockwise_fp32_fp4(*args) - else: - lib.cquantize_blockwise_fp32_nf4(*args) - else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") - - code = get_4bit_type(quant_type, device=A.device) - - if compress_statistics: - offset = absmax.mean() - absmax -= offset - qabsmax, state2 = quantize_blockwise(absmax, blocksize=256) - del absmax - state = QuantState( - absmax=qabsmax, - shape=input_shape, - dtype=A.dtype, - blocksize=blocksize, - code=code, - quant_type=quant_type, - offset=offset, - state2=state2, - ) - else: - state = QuantState( - absmax=absmax, - shape=input_shape, - dtype=A.dtype, - blocksize=blocksize, - code=code, - quant_type=quant_type, - ) - - return out, state + ensure_backend_is_available(A.device.type) + return backends[A.device.type].quantize_4bit( + A, + absmax=absmax, + out=out, + blocksize=blocksize, + compress_statistics=compress_statistics, + quant_type=quant_type, + quant_storage=quant_storage, + ) def dequantize_fp4( @@ -1277,7 +1076,7 @@ def dequantize_fp4( quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize: int = 64, + blocksize: Optional[int] = None, ) -> torch.Tensor: return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4") @@ -1287,7 +1086,7 @@ def dequantize_nf4( quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize: int = 64, + blocksize: Optional[int] = None, ) -> torch.Tensor: return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4") @@ -1297,8 +1096,8 @@ def dequantize_4bit( quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize: int = 64, - quant_type="fp4", + blocksize: Optional[int] = None, + quant_type: Optional[str] = "fp4", ) -> torch.Tensor: """Dequantizes a packed 4-bit quantized tensor. @@ -1316,9 +1115,9 @@ def dequantize_4bit( Required if `quant_state` is not provided and ignored otherwise. out (`torch.Tensor`, *optional*): A tensor to use to store the result. blocksize (`int`, *optional*): - The size of the blocks. Defaults to 64. + The size of the blocks. Defaults to 64 if not HIP_ENVIRONMENT else 128. Valid values are 64, 128, 256, 512, 1024, 2048, and 4096. - quant_type (`str`, *optional*): The data type to use: `nf4` or `fp4`. Defaults to `fp4`. + quant_type (`str`, *optional*): The data type to use: `nf4` or `fp4`. Defaults to "fp4". Raises: ValueError: Raised when the input data type or blocksize is not supported. @@ -1326,74 +1125,18 @@ def dequantize_4bit( Returns: `torch.Tensor`: The dequantized tensor. """ - - if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]: - raise ValueError( - f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]", - ) - if quant_type not in ["fp4", "nf4"]: - raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented.") - - if quant_state is None: - assert absmax is not None and out is not None - - quant_state = QuantState( - absmax=absmax, - shape=out.shape, - dtype=out.dtype, - blocksize=blocksize, - quant_type=quant_type, - ) - - else: + ensure_backend_is_available(A.device.type) + if quant_state is not None: absmax = quant_state.absmax - - if quant_state.nested: - absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) - absmax += quant_state.offset - if absmax.dtype != torch.float32: - absmax = absmax.float() - - if out is None: - out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device) - - n = out.numel() - - is_on_gpu([A, absmax, out]) - stream = _get_tensor_stream(A) - - with _cuda_device_of(A): - args = ( - None, - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(quant_state.blocksize), - ct.c_int(n), - stream, - ) - - if out.dtype == torch.bfloat16: - if quant_state.quant_type == "fp4": - lib.cdequantize_blockwise_bf16_fp4(*args) - else: - lib.cdequantize_blockwise_bf16_nf4(*args) - elif out.dtype == torch.float16: - if quant_state.quant_type == "fp4": - lib.cdequantize_blockwise_fp16_fp4(*args) - else: - lib.cdequantize_blockwise_fp16_nf4(*args) - elif out.dtype == torch.float32: - if quant_state.quant_type == "fp4": - lib.cdequantize_blockwise_fp32_fp4(*args) - else: - lib.cdequantize_blockwise_fp32_nf4(*args) - else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}") - - if A.shape[0] == 1: # is transposed, transpose back - return out.t() - return out + quant_type = quant_state.quant_type + blocksize = quant_state.blocksize + if blocksize is None: + # Some AMD GPUs have warpsize 64 + # Set default blocksize to 128 (~warpsize 64 in kernel) for HIP + blocksize = 64 if not HIP_ENVIRONMENT else 128 + return backends[A.device.type].dequantize_4bit( + A, quant_state=quant_state, absmax=absmax, out=out, blocksize=blocksize, quant_type=quant_type + ) @deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) @@ -1519,88 +1262,26 @@ def optimizer_update_32bit( max_unorm: float = 0.0, skip_zeros=False, ) -> None: - """ - Performs an inplace optimizer update with one or two optimizer states. - - Universal optimizer update for 32-bit state and 32/16-bit gradients/weights. - - Parameters - ---------- - optimizer_name : str - The name of the optimizer: {adam}. - g : torch.Tensor - Gradient tensor. - p : torch.Tensor - Parameter tensor. - state1 : torch.Tensor - Optimizer state 1. - beta1 : float - Optimizer beta1. - eps : float - Optimizer epsilon. - weight_decay : float - Weight decay. - step : int - Current optimizer step. - lr : float - The learning rate. - state2 : torch.Tensor - Optimizer state 2. - beta2 : float - Optimizer beta2. - beta3 : float - Optimizer beta3. - alpha : float - Optimizer alpha. - gnorm_scale : float - The factor to rescale the gradient to the max clip value. - unorm_vec : torch.Tensor - The tensor for the update norm. - max_unorm : float - The maximum update norm relative to the weight norm. - skip_zeros : bool - Whether to skip zero-valued gradients or not (default: False). - """ - - param_norm = 0.0 - if max_unorm > 0.0: - param_norm = torch.norm(p.data.float()) - - optim_func = None - if g.dtype == torch.float32: - optim_func = str2optimizer32bit[optimizer_name][0] - elif g.dtype == torch.float16: - optim_func = str2optimizer32bit[optimizer_name][1] - elif g.dtype == torch.bfloat16 and len(str2optimizer32bit[optimizer_name]) == 3: - optim_func = str2optimizer32bit[optimizer_name][2] - else: - raise ValueError( - f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}", - ) - - is_on_gpu([g, p, state1, state2, unorm_vec]) - - with _cuda_device_of(g): - optim_func( - get_ptr(g), - get_ptr(p), - get_ptr(state1), - get_ptr(state2), - get_ptr(unorm_vec), - ct.c_float(max_unorm), - ct.c_float(param_norm), - ct.c_float(beta1), - ct.c_float(beta2), - ct.c_float(beta3), - ct.c_float(alpha), - ct.c_float(eps), - ct.c_float(weight_decay), - ct.c_int32(step), - ct.c_float(lr), - ct.c_float(gnorm_scale), - ct.c_bool(skip_zeros), - ct.c_int32(g.numel()), - ) + ensure_backend_is_available(g.device.type) + return backends[g.device.type].optimizer_update_32bit( + optimizer_name=optimizer_name, + g=g, + p=p, + state1=state1, + beta1=beta1, + eps=eps, + step=step, + lr=lr, + state2=state2, + beta2=beta2, + beta3=beta3, + alpha=alpha, + weight_decay=weight_decay, + gnorm_scale=gnorm_scale, + unorm_vec=unorm_vec, + max_unorm=max_unorm, + skip_zeros=skip_zeros, + ) @deprecated( @@ -1762,47 +1443,28 @@ def optimizer_update_8bit_blockwise( gnorm_scale: float = 1.0, skip_zeros=False, ) -> None: - optim_func = None - - if g.dtype == torch.float32 and state1.dtype == torch.uint8: - optim_func = str2optimizer8bit_blockwise[optimizer_name][0] - elif g.dtype == torch.float16 and state1.dtype == torch.uint8: - optim_func = str2optimizer8bit_blockwise[optimizer_name][1] - elif ( - g.dtype == torch.bfloat16 - and state1.dtype == torch.uint8 - and len(str2optimizer8bit_blockwise[optimizer_name]) == 3 - ): - optim_func = str2optimizer8bit_blockwise[optimizer_name][2] - else: - raise ValueError( - f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}", - ) - - is_on_gpu([p, g, state1, state2, qmap1, qmap2, absmax1, absmax2]) - - with _cuda_device_of(g): - optim_func( - get_ptr(p), - get_ptr(g), - get_ptr(state1), - get_ptr(state2), - ct.c_float(beta1), - ct.c_float(beta2), - ct.c_float(beta3), - ct.c_float(alpha), - ct.c_float(eps), - ct.c_int32(step), - ct.c_float(lr), - get_ptr(qmap1), - get_ptr(qmap2), - get_ptr(absmax1), - get_ptr(absmax2), - ct.c_float(weight_decay), - ct.c_float(gnorm_scale), - ct.c_bool(skip_zeros), - ct.c_int32(g.numel()), - ) + ensure_backend_is_available(g.device.type) + return backends[g.device.type].optimizer_update_8bit_blockwise( + optimizer_name=optimizer_name, + g=g, + p=p, + state1=state1, + state2=state2, + beta1=beta1, + beta2=beta2, + beta3=beta3, + alpha=alpha, + eps=eps, + step=step, + lr=lr, + qmap1=qmap1, + qmap2=qmap2, + absmax1=absmax1, + absmax2=absmax2, + weight_decay=weight_decay, + gnorm_scale=gnorm_scale, + skip_zeros=skip_zeros, + ) @deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) @@ -1958,100 +1620,15 @@ def gemv_4bit( transposed_B=False, state=None, ): - # sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype) - if state is None: - raise ValueError("state cannot be None. gemv_4bit() requires the state from quantize_4bit()") - - if A.numel() != A.shape[-1]: - raise ValueError( - 'Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]', - ) - - Bshape = state.shape - bout = Bshape[0] - absmax = state.absmax - if state.nested: - absmax = dequantize_blockwise(state.absmax, state.state2) - absmax += state.offset - - if out is None: - if len(A.shape) == 3: - out = torch.empty(size=(A.shape[0], A.shape[1], bout), dtype=A.dtype, device=A.device) - else: - out = torch.empty(size=(A.shape[0], bout), dtype=A.dtype, device=A.device) - - n = 1 - m = Bshape[0] - k = Bshape[1] - lda = Bshape[0] - ldc = Bshape[0] - ldb = (A.shape[-1] + 1) // 2 - is_on_gpu([B, A, out, absmax, state.code]) - m = ct.c_int32(m) - n = ct.c_int32(n) - k = ct.c_int32(k) - lda = ct.c_int32(lda) - ldb = ct.c_int32(ldb) - ldc = ct.c_int32(ldc) - stream = _get_tensor_stream(A) - - with _cuda_device_of(A): - if B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32]: - if A.dtype == torch.float16: - lib.cgemm_4bit_inference_naive_fp16( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(state.code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(state.blocksize), - stream, - ) - elif A.dtype == torch.bfloat16: - lib.cgemm_4bit_inference_naive_bf16( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(state.code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(state.blocksize), - stream, - ) - elif A.dtype == torch.float32: - lib.cgemm_4bit_inference_naive_fp32( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(state.code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(state.blocksize), - stream, - ) - else: - raise NotImplementedError(f"Matmul not implemented for data type {A.dtype}") - - else: - raise NotImplementedError(f"Matmul not implemented for data type {A.dtype}") - - return out + ensure_backend_is_available(A.device.type) + return backends[A.device.type].gemv_4bit( + A, + B, + out=out, + transposed_A=transposed_A, + transposed_B=transposed_B, + state=state, + ) def igemm( @@ -2272,7 +1849,9 @@ def igemmlt( return result, (result.shape, "row") -def int8_linear_matmul(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None, dtype=torch.int32): +def int8_linear_matmul( + A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None, dtype=torch.int32 +) -> torch.Tensor: """Performs an 8-bit integer matrix multiplication. A linear transformation is applied such that `out = A @ B.T`. When possible, integer tensor core hardware is @@ -2291,88 +1870,8 @@ def int8_linear_matmul(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Ten Returns: `torch.Tensor`: The result of the operation. """ - - # - # To use the IMMA tensor core kernels without special Turing/Ampere layouts, - # cublasLt has some rules, namely: A must be transposed, B must not be transposed. - # The C++ API will calculate `C = A.T @ B` in with A, B, C in col-major. - # This will typically be used with row-major tensors to efficiently - # calculate the linear layer with `C = B @ A.T` without any transformations. - # We will swap A and B in the API invocation, so that we get `C = A @ B.T`. - # - # Quick explanation: - # With row-major A and B tensors, `C = A.T.T @ B.T = A @ B.T`. - # To get row-major output, `C.T = (A @ B.T).T = B @ A.T`. - # - A, B = B, A - - shapeA = A.shape - shapeB = B.shape - - assert A.dtype == torch.int8 - assert B.dtype == torch.int8 - assert A.ndim == 2, "Only two dimensional matrices are supported for argument B" - assert B.ndim in [2, 3], "Only two or three dimensional matrices are supported for argument A" - assert prod(shapeB) > 0, f"Input tensor dimensions need to be > 0: {shapeB}" - assert out is None or out.dtype == dtype - - shapeC = (*shapeB[:-1], shapeA[0]) - - k, m = shapeA - n = prod(shapeB[:-1]) - lda = shapeA[-1] # Weights (outputs, inputs) - ldb = shapeB[-1] # Activations (batch, tokens, inputs) - ldc = shapeC[-1] # Output (batch, tokens, outputs) - - assert ( - lda == ldb - ), f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}" - - # cuBLASLt does not support int8 matmul with inner dimensions that are not divisible by 4. - # We'll fall back to a slower fp32 calculation in this circumstance. - # Fortunately, this should not be very common. - if lda % 4 != 0: - result = torch.matmul(B.float(), A.float().t()).to(torch.int32) - if out is not None: - result = out.copy_(result) - return result - - if out is None: - out = torch.empty(shapeC, device=A.device, dtype=dtype) - - is_on_gpu([A, B, out]) - - with _cuda_device_of(A): - ctx = CUBLAS_Context.get_instance().get_context(A.device) - ptrA = get_ptr(A) - ptrB = get_ptr(B) - ptrC = get_ptr(out) - ptrRowScale = None - m = ct.c_int32(m) - n = ct.c_int32(n) - k = ct.c_int32(k) - lda = ct.c_int32(lda) - ldb = ct.c_int32(ldb) - ldc = ct.c_int32(ldc) - stream = _get_tensor_stream(A) - - if dtype == torch.int32: - has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream) - else: - has_error = lib.cigemmlt_8(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream) - - if has_error == 100: # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu` - raise NotImplementedError("int8_linear_matmul not implemented!") - - if has_error: - raise RuntimeError( - f"cublasLt ran into an error!\n" - f"\t{shapeA=}, {shapeB=}, {shapeC=}\n" - f"\t{(lda, ldb, ldc)=}\n" - f"\t{(m, n, k)=}" - ) - - return out + ensure_backend_is_available(A.device.type) + return backends[A.device.type].int8_linear_matmul(A, B, out=out, dtype=dtype) def int8_mm_dequant( @@ -2381,7 +1880,7 @@ def int8_mm_dequant( col_stats: torch.Tensor, out: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, -): +) -> torch.Tensor: """Performs dequantization on the result of a quantized int8 matrix multiplication. Args: @@ -2394,31 +1893,14 @@ def int8_mm_dequant( Returns: `torch.Tensor`: The dequantized result with an optional bias, with dtype `torch.float16`. """ - - assert A.dtype == torch.int32 - - if bias is not None: - assert bias.dtype == torch.float16 - - if out is None: - out = torch.empty_like(A, dtype=torch.float16) - - ptrA = get_ptr(A) - ptrOut = get_ptr(out) - ptrRowStats = get_ptr(row_stats) - ptrColStats = get_ptr(col_stats) - ptrBias = get_ptr(bias) - numRows = ct.c_int32(prod(A.shape[:-1])) - numCols = ct.c_int32(A.shape[-1]) - - is_on_gpu([A, row_stats, col_stats, out, bias]) - - with _cuda_device_of(A): - lib.cdequant_mm_int32_fp16( - ptrA, ptrRowStats, ptrColStats, ptrOut, ptrBias, numRows, numCols, _get_tensor_stream(A) - ) - - return out + ensure_backend_is_available(A.device.type) + return backends[A.device.type].int8_mm_dequant( + A, + row_stats, + col_stats, + out=out, + bias=bias, + ) @deprecated("mm_dequant is deprecated. Please use int8_mm_dequant() instead.", category=FutureWarning) @@ -2535,7 +2017,10 @@ def __init__( ): assert rowidx.dtype == torch.int32 assert colidx.dtype == torch.int32 - assert values.dtype == torch.float16 + if values.device == torch.device("cpu") or torch.device("xpu"): + assert values.dtype in [torch.bfloat16, torch.half, torch.float] + else: + assert values.dtype == torch.float16 assert values.numel() == nnz assert rowidx.numel() == nnz assert colidx.numel() == nnz @@ -2714,24 +2199,10 @@ def int8_double_quant( - `torch.Tensor` with dtype `torch.float32`: The column-wise quantization scales. - `torch.Tensor` with dtype `torch.int32`, *optional*: A list of column indices which contain outlier features. """ - - # TODO: Optimize/write CUDA kernel for this? - - # Use CUDA kernel for rowwise and COO tensor - quant_row, row_stats, outlier_cols = int8_vectorwise_quant(A, threshold=threshold) - - # PyTorch impl for colwise - _, col_stats, outlier_mask = get_colrow_absmax(A, threshold=threshold) - if threshold > 0.0 and outlier_mask is not None: - A = A.masked_fill(outlier_mask, 0.0) - quant_col = torch.round(A.mul(C) / col_stats.unsqueeze(0)).to(torch.int8) - - if out_row is not None: - quant_row = out_row.copy_(quant_row) - if out_col is not None: - quant_col = out_col.copy_(quant_col) - - return quant_row, quant_col, row_stats, col_stats.flatten().float(), outlier_cols + ensure_backend_is_available(A.device.type) + return backends[A.device.type].int8_double_quant( + A, col_stats=col_stats, row_stats=row_stats, out_col=out_col, out_row=out_row, threshold=threshold + ) def int8_vectorwise_dequant(A: torch.Tensor, stats: torch.Tensor): @@ -2766,42 +2237,8 @@ def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0): - `torch.Tensor` with dtype `torch.float32`: The quantization scales. - `torch.Tensor` with dtype `torch.int32`, *optional*: A list of column indices which contain outlier features. """ - - assert A.dtype == torch.half - is_on_gpu([A]) - - rows = prod(A.shape[:-1]) - cols = A.shape[-1] - - row_stats = torch.empty(rows, device=A.device, dtype=torch.float32) - out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8) - - outlier_cols = None - - if threshold > 0.0: - # TODO we could improve perf of this - outliers = A.abs() >= threshold - - if outliers.any(): - outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1) - - with _cuda_device_of(A): - lib.cint8_vector_quant( - get_ptr(A), - get_ptr(out_row), - get_ptr(row_stats), - ct.c_float(threshold), - ct.c_int32(rows), - ct.c_int32(cols), - _get_tensor_stream(A), - ) - - # Zero out values from outlier columns across all rows. - # The kernel will handle this for outliers themselves, so we can optimize for rows=1. - if rows > 1 and outlier_cols is not None: - out_row[:, outlier_cols] = 0 - - return out_row, row_stats, outlier_cols + ensure_backend_is_available(A.device.type) + return backends[A.device.type].int8_vectorwise_quant(A, threshold) @deprecated( @@ -2809,51 +2246,10 @@ def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0): category=FutureWarning, ) def transform(A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None): - prev_device = pre_call(A.device) - if state is None: - state = (A.shape, from_order) - else: - from_order = state[1] - if out is None: - out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose) - else: - new_state = (state[0], to_order) # (shape, order) - - shape = state[0] - if len(shape) == 2: - dim1 = ct.c_int32(shape[0]) - dim2 = ct.c_int32(shape[1]) - else: - dim1 = ct.c_int32(shape[0] * shape[1]) - dim2 = ct.c_int32(shape[2]) - - is_on_gpu([A, out]) - if to_order == "col32": - if transpose: - lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2) - else: - lib.ctransform_row2col32(get_ptr(A), get_ptr(out), dim1, dim2) - elif to_order == "col_turing": - if transpose: - lib.ctransform_row2turingT(get_ptr(A), get_ptr(out), dim1, dim2) - else: - lib.ctransform_row2turing(get_ptr(A), get_ptr(out), dim1, dim2) - elif to_order == "col_ampere": - if transpose: - lib.ctransform_row2ampereT(get_ptr(A), get_ptr(out), dim1, dim2) - else: - lib.ctransform_row2ampere(get_ptr(A), get_ptr(out), dim1, dim2) - elif to_order == "row": - if from_order == "col_turing": - lib.ctransform_turing2row(get_ptr(A), get_ptr(out), dim1, dim2) - elif from_order == "col_ampere": - lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2) - else: - raise NotImplementedError(f"Transform function not implemented: From {from_order} to {to_order}") - - post_call(prev_device) - - return out, new_state + ensure_backend_is_available(A.device.type) + return backends[A.device.type].transform( + A, to_order, from_order=from_order, out=out, transpose=transpose, state=state, ld=ld + ) def spmm_coo( @@ -3020,7 +2416,7 @@ def vectorwise_quant(x, dim=1, quant_type="vector"): return xq, max1 elif quant_type in ["vector", "row"]: max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True) - xq = torch.round(x * (C / max1)).to(torch.int8) + xq = torch.clamp(torch.round(x * (C / max1)), -128, 127).to(torch.int8) return xq, max1 elif quant_type == "zeropoint": dtype = x.dtype @@ -3149,28 +2545,8 @@ def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half): @deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) def extract_outliers(A, SA, idx): - shapeA = SA[0] - formatA = SA[1] - assert formatA in ["col_turing", "col_ampere"] - assert A.device.type == "cuda" - - out = torch.zeros((shapeA[0], idx.numel()), dtype=torch.int8, device=A.device) - - idx_size = ct.c_int32(idx.numel()) - rows = ct.c_int32(shapeA[0]) - cols = ct.c_int32(shapeA[1]) - ptrA = get_ptr(A) - ptrIdx = get_ptr(idx) - ptrOut = get_ptr(out) - - prev_device = pre_call(A.device) - if formatA == "col_turing": - lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) - elif formatA == "col_ampere": - lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) - post_call(prev_device) - - return out + ensure_backend_is_available(A.device.type) + return backends[A.device.type].extract_outliers(A, SA, idx) @deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) diff --git a/bitsandbytes/nn/__init__.py b/bitsandbytes/nn/__init__.py index 20aff67a3..e27c88c8a 100644 --- a/bitsandbytes/nn/__init__.py +++ b/bitsandbytes/nn/__init__.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from ..backends import backends from .modules import ( Embedding, Embedding4bit, @@ -18,9 +19,12 @@ StableEmbedding, SwitchBackLinearBnb, ) -from .triton_based_modules import ( - StandardLinear, - SwitchBackLinear, - SwitchBackLinearGlobal, - SwitchBackLinearVectorwise, -) + +# CPU and XPU backend do not need triton, and XPU so not support triton for now. +if "xpu" not in backends.keys() and len(backends.keys()) > 1: + from .triton_based_modules import ( + StandardLinear, + SwitchBackLinear, + SwitchBackLinearGlobal, + SwitchBackLinearVectorwise, + ) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py old mode 100644 new mode 100755 index e63cd8db9..cdeaebc27 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -12,11 +12,14 @@ import bitsandbytes as bnb from bitsandbytes.autograd._functions import get_tile_inds, undo_layout +from bitsandbytes.cextension import HIP_ENVIRONMENT from bitsandbytes.functional import QuantState from bitsandbytes.optim import GlobalOptimManager from bitsandbytes.utils import ( INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, OutlierTracer, + enable_ipex_fusion, + reverse_4bit_compress_format, ) T = TypeVar("T", bound="torch.nn.Module") @@ -213,7 +216,7 @@ def __new__( data: Optional[torch.Tensor] = None, requires_grad=False, # quantized weights should be frozen by default quant_state: Optional[QuantState] = None, - blocksize: int = 64, + blocksize: Optional[int] = None, compress_statistics: bool = True, quant_type: str = "fp4", quant_storage: torch.dtype = torch.uint8, @@ -223,6 +226,9 @@ def __new__( if data is None: data = torch.empty(0) + if blocksize is None: + blocksize = 64 if not HIP_ENVIRONMENT else 128 + self = torch.Tensor._make_subclass(cls, data, requires_grad) self.blocksize = blocksize self.compress_statistics = compress_statistics @@ -310,6 +316,18 @@ def _quantize(self, device): def cuda(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False): return self.to(device="cuda" if device is None else device, non_blocking=non_blocking) + def cpu(self, non_blocking: bool = False): + return self.to(device="cpu", non_blocking=non_blocking) + + def npu(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False): + # `torch.Tensor.to()` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)). + if isinstance(device, int): + device = f"npu:{device}" + return self.to(device="npu" if device is None else device, non_blocking=non_blocking) + + def xpu(self, non_blocking: bool = False): + return self.to(device="xpu", non_blocking=non_blocking) + @overload def to( self: T, @@ -327,7 +345,7 @@ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ... def to(self, *args, **kwargs): device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) - if device is not None and device.type == "cuda" and not self.bnb_quantized: + if device is not None and device.type in ["cuda", "cpu", "npu", "xpu", "hpu"] and not self.bnb_quantized: return self._quantize(device) else: if self.quant_state is not None: @@ -430,6 +448,7 @@ def __init__( # self.persistent_buffers = [] # TODO consider as way to save quant state self.compute_dtype = compute_dtype self.compute_type_is_set = False + self.ipex_linear_is_set = False self.quant_state = None self.quant_storage = quant_storage @@ -458,13 +477,40 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): save weight and bias, then fill state_dict with components of quant_state """ + if getattr(self.weight, "quant_state", None) is not None and getattr(self.weight.quant_state, "ipex", False): + if self.weight.device.type == "cpu": + original_weight = torch.ops.ipex_prepack.woq_linear_unpack_weight( + self.weight, "nf4", self.weight.quant_state.shape, 2 + ) + self.weight.data = reverse_4bit_compress_format(original_weight.data) + elif self.weight.device.type == "xpu": + self.weight.data = reverse_4bit_compress_format(self.weight.data.reshape(1, -1)) + + self.weight.quant_state.ipex = False + self.ipex_linear_is_set = False + super()._save_to_state_dict(destination, prefix, keep_vars) # saving weight and bias if getattr(self.weight, "quant_state", None) is not None: for k, v in self.weight.quant_state.as_dict(packed=True).items(): destination[prefix + "weight." + k] = v if keep_vars else v.detach() + def set_ipex_linear(self, x: torch.Tensor): + if ( + not getattr(self.weight.quant_state, "ipex", False) + and self.weight.data.dtype == torch.uint8 + and self.weight.quant_state.shape[1] % self.weight.quant_state.blocksize == 0 + and self.weight.quant_state.quant_type == "nf4" + ): + if x.device.type == "xpu" or (x.device.type == "cpu" and not self.training and x.requires_grad == False): + enable_ipex_fusion(self, x) + def forward(self, x: torch.Tensor): + # Check if ipex fusion can be used + if not self.ipex_linear_is_set: + self.set_ipex_linear(x) + self.ipex_linear_is_set = True + fix_4bit_weight_quant_state_from_module(self) # weights are cast automatically as Int8Params, but the bias has to be cast manually @@ -480,8 +526,8 @@ def forward(self, x: torch.Tensor): x = x.to(self.compute_dtype) bias = None if self.bias is None else self.bias.to(self.compute_dtype) - - return bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state).to(inp_dtype) + weight = self.weight.t() if len(self.weight.shape) == 2 else self.weight + return bnb.matmul_4bit(x, weight, bias=bias, quant_state=self.weight.quant_state).to(inp_dtype) class LinearFP4(Linear4bit): @@ -605,6 +651,32 @@ def __deepcopy__(self, memo): ) return new_instance + def cpu(self): + # we store the 8-bit rows-major weight + B = self.data.contiguous().to(torch.bfloat16).cpu() + CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B) + if CBt is not None: + del CBt + if SCBt is not None: + del SCBt + self.data = CB + self.CB = CB + self.SCB = SCB + return self + + def xpu(self, device): + # we store the 8-bit rows-major weight + B = self.data.contiguous().to(torch.float16).xpu(device) + CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B) + if CBt is not None: + del CBt + if SCBt is not None: + del SCBt + self.data = CB + self.CB = CB + self.SCB = SCB + return self + @overload def to( self: T, @@ -622,18 +694,30 @@ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ... def to(self, *args, **kwargs): device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) - if device is not None and device.type == "cuda" and self.data.device.type == "cpu": - return self.cuda(device) - else: - new_param = Int8Params( - super().to(device=device, dtype=dtype, non_blocking=non_blocking), - requires_grad=self.requires_grad, - has_fp16_weights=self.has_fp16_weights, - ) - new_param.CB = self.CB - new_param.SCB = self.SCB + if device is not None: + if device.type == "cuda" and self.data.device.type == "cpu": + return self.cuda(device) + elif device.type == "cpu": + if self.data.dtype == torch.int8: + self.CB = self.data + else: + return self.cpu() + elif device.type == "xpu": + if self.data.dtype == torch.int8: + self.data = self.data.contiguous() + self.CB = self.data + if self.data.device.type == "cpu": + return self.xpu(device) + + new_param = Int8Params( + super().to(device=device, dtype=dtype, non_blocking=non_blocking), + requires_grad=self.requires_grad, + has_fp16_weights=self.has_fp16_weights, + ) + new_param.CB = self.CB + new_param.SCB = self.SCB - return new_param + return new_param def maybe_rearrange_weight(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): diff --git a/bitsandbytes/npu_specs.py b/bitsandbytes/npu_specs.py new file mode 100644 index 000000000..7c7cd707e --- /dev/null +++ b/bitsandbytes/npu_specs.py @@ -0,0 +1,20 @@ +import dataclasses + +import torch + +try: + import torch_npu # noqa: F401 +except ImportError: + pass + + +@dataclasses.dataclass(frozen=True) +class NPUSpecs: + cann_version_string: str + + +def get_npu_specs(): + if hasattr(torch, "npu") and torch.npu.is_available(): + return NPUSpecs(cann_version_string=torch.version.cann) + else: + return None diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index 03e0e01d7..0a78b4ade 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -10,6 +10,7 @@ import torch import bitsandbytes.functional as F +from bitsandbytes.backends import backends class MockArgs: @@ -289,11 +290,11 @@ def step(self, closure=None): self.prefetch_state(p) self.update_step(group, p, gindex, pindex) - torch.cuda.synchronize() + backends[p.device.type].device_synchronize() if self.is_paged: # all paged operation are asynchronous, we need # to sync to make sure all tensors are in the right state - torch.cuda.synchronize() + backends[p.device.type].device_synchronize() return loss diff --git a/bitsandbytes/research/autograd/_functions.py b/bitsandbytes/research/autograd/_functions.py index d9718382b..e0af9719e 100644 --- a/bitsandbytes/research/autograd/_functions.py +++ b/bitsandbytes/research/autograd/_functions.py @@ -6,6 +6,7 @@ import torch from bitsandbytes.autograd._functions import GlobalOutlierPooler, MatmulLtState +from bitsandbytes.cextension import HIP_ENVIRONMENT import bitsandbytes.functional as F @@ -340,7 +341,10 @@ def backward(ctx, grad_output): def get_block_sizes(input_matrix, weight_matrix): input_features = input_matrix.shape[-1] output_features = weight_matrix.shape[0] if weight_matrix.shape[1] == input_features else weight_matrix.shape[1] - array = [4096, 2048, 1024, 512, 256, 128, 64, 0] + if not HIP_ENVIRONMENT: + array = [4096, 2048, 1024, 512, 256, 128, 64, 0] + else: + array = [4096, 2048, 1024, 512, 256, 128, 0] bsz, bsz2 = 1024, 1024 for i, k in enumerate(array): if input_features > array[i + 1]: diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py index a88ddf5f9..e3748685e 100644 --- a/bitsandbytes/utils.py +++ b/bitsandbytes/utils.py @@ -1,7 +1,7 @@ import json import shlex import subprocess -from typing import Tuple +from typing import Any, Dict, Tuple import torch @@ -200,5 +200,237 @@ def unpack_tensor_to_dict(tensor_data): return unpacked_dict +def reverse_4bit_compress_format(weight): + out_1 = torch.empty(weight.size(0), dtype=torch.int32, device=weight.device) + out_2 = torch.empty(weight.size(0), dtype=torch.int32, device=weight.device) + out_1 = (weight & 0xF0) >> 4 + out_2 = (weight & 0xF) << 4 + out = out_1 | out_2 + return out + + +def enable_ipex_fusion(linear, x): + from bitsandbytes.backends.cpu_xpu_common import ( + _ipex_cpu_version_prereq, + _ipex_xpu_version_prereq, + dequant_8bit, + ipex_cpu, + ipex_xpu, + ) + + quant_state = linear.weight.quant_state + + if quant_state.nested: + quant_state.absmax = dequant_8bit(quant_state.absmax, quant_state.offset, quant_state.state2) + quant_state.nested = False + delattr(quant_state, "state2") + + if x.device.type == "cpu" and ipex_cpu and _ipex_cpu_version_prereq(2, 5): + converted_weight = reverse_4bit_compress_format(linear.weight.data) + new_weight, new_scales, new_zeros, _, compensation = torch.ops.ipex_prepack.woq_linear_pack_weight( + converted_weight.reshape([quant_state.shape[0], quant_state.shape[1] // 2]), + "nf4", + quant_state.shape, # weight shape + quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize), # scales + None, # zero_points + None, # bias + None, # batch_size + quant_state.blocksize, + 2, + ) + elif x.device.type == "xpu" and ipex_xpu and _ipex_xpu_version_prereq(2, 5): + converted_weight = reverse_4bit_compress_format(linear.weight.data) + new_weight = converted_weight.reshape([quant_state.shape[0], quant_state.shape[1] // 2]) + new_scales = quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize) + new_zeros = None + compensation = None + else: + raise ValueError( + "Please check the device and ipex version. The device should be cpu or xpu while ipex version should >= 2.5" + ) + + linear.weight.data = new_weight.data + linear.weight.quant_state.ipex = True + linear.weight.quant_state.new_scales = new_scales + linear.weight.quant_state.new_zeros = new_zeros + linear.weight.quant_state.compensation = compensation + + +class QuantState: + """container for quantization state components to work with Params4bit and similar classes""" + + valid_quant_types = ("fp4", "nf4") + valid_qs_type_keys = [f"bitsandbytes__{x}" for x in valid_quant_types] + valid_qs_keys = [ + "absmax", + "quant_map", + "nested_absmax", + "nested_quant_map", + "quant_state", + "quant_type", + "blocksize", + "dtype", + "shape", + "nested_blocksize", + "nested_dtype", + "nested_offset", + ] + + def __init__( + self, + absmax, + shape=None, + code=None, + blocksize=None, + quant_type=None, + dtype=None, + offset=None, + state2=None, + ): + self.absmax = absmax + self.shape = shape + self.code = code + self.dtype = dtype + self.blocksize = blocksize + self.quant_type = quant_type + self.offset = offset + self.state2 = state2 + self.nested = state2 is not None + + def __get_item__(self, idx): + """ + ensures compatibility with older quant state scheme with nested lists. + assumes the following layout: + state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type] + state2 = [absmax, input_shape, A.dtype, blocksize, None, quant_type] + """ + if self.nested: + list_repr = [ + self.absmax, + self.shape, + self.dtype, + self.blocksize, + [self.offset, self.state2], + self.quant_type, + ] + else: + list_repr = [self.absmax, self.shape, self.dtype, self.blocksize, None, self.quant_type] + return list_repr[idx] + + @classmethod + def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> "QuantState": + """ + unpacks components of state_dict into QuantState + where necessary, convert into strings, torch.dtype, ints, etc. + + qs_dict: based on state_dict, with only relevant keys, striped of prefixes. + + item with key `quant_state.bitsandbytes__[nf4/fp4]` may contain minor and non-tensor quant state items. + """ + + # unpacking tensor with non-tensor components + qs_key = [k for k, v in qs_dict.items() if "quant_state" in k and isinstance(v, torch.Tensor)] + if not len(qs_key) and "quant_type" not in qs_dict: + raise ValueError("Expected packed or unpacked quant_state items, found neither") + elif len(qs_key) != 1 or qs_key[0].split(".")[-1] not in cls.valid_qs_type_keys: + raise ValueError( + f"There should be exactly one `quant_state` item with ending from {cls.valid_qs_type_keys}.\nDetected {qs_key}.", + ) + + # unpacking minor and non-tensor quant state items if necessary + if len(qs_key) == 1: + first_qs_key = qs_key[0] + qs_dict.update(unpack_tensor_to_dict(qs_dict.pop(first_qs_key))) + + qs_dict = {k.split(".")[-1]: v for k, v in qs_dict.items()} # strip prefixes + assert set(qs_dict.keys()).issubset(cls.valid_qs_keys) + + if "nested_absmax" in qs_dict: + offset = torch.tensor(float(qs_dict["nested_offset"])).to(device) + state2 = cls( + absmax=qs_dict["nested_absmax"].to(device), + blocksize=qs_dict["nested_blocksize"], + code=qs_dict["nested_quant_map"].to(device), + dtype=getattr(torch, qs_dict["nested_dtype"]), + ) + else: + offset, state2 = None, None + + quant_state = cls( + quant_type=qs_dict["quant_type"], + absmax=qs_dict["absmax"].to(device), + blocksize=qs_dict["blocksize"], + code=qs_dict["quant_map"].to(device), + dtype=getattr(torch, qs_dict["dtype"]), + shape=torch.Size(qs_dict["shape"]) if qs_dict["shape"] is not None else None, + offset=offset, + state2=state2, + ) + return quant_state + + def as_dict(self, packed=False): + """ + returns dict of tensors and strings to use in serialization via _save_to_state_dict() + param: packed -- returns dict[str, torch.Tensor] for state_dict fit for safetensors saving + """ + qs_dict = { + "quant_type": self.quant_type, + "absmax": self.absmax, + "blocksize": self.blocksize, + "quant_map": self.code, + "dtype": str(self.dtype).strip("torch."), + "shape": tuple(self.shape), + } + if self.nested: + qs_dict.update( + { + "nested_absmax": self.state2.absmax, + "nested_blocksize": self.state2.blocksize, + "nested_quant_map": self.state2.code.clone(), # un-shared to avoid restoring it after shared tensors are removed by safetensors + "nested_dtype": str(self.state2.dtype).strip("torch."), + "nested_offset": self.offset.item(), + }, + ) + if not packed: + return qs_dict + + # packed format allows serialization of non-tensor components, critical for saving in safetensors format + qs_packed_dict = {k: v for k, v in qs_dict.items() if isinstance(v, torch.Tensor)} + non_tensor_dict = {k: v for k, v in qs_dict.items() if not isinstance(v, torch.Tensor)} + qs_packed_dict["quant_state." + "bitsandbytes__" + self.quant_type] = pack_dict_to_tensor(non_tensor_dict) + return qs_packed_dict + + def to(self, device): + # make sure the quantization state is on the right device + self.absmax = self.absmax.to(device) + if self.nested: + self.offset = self.offset.to(device) + self.state2.absmax = self.state2.absmax.to(device) + self.state2.code = self.state2.code.to(device) + + def __eq__(self, other): + if not isinstance(other, QuantState): + return False + + return ( + torch.allclose(self.absmax, other.absmax, atol=1e-6) + and self.shape == other.shape + and torch.allclose(self.code, other.code, atol=1e-6) + and self.dtype == other.dtype + and self.blocksize == other.blocksize + and self.quant_type == other.quant_type + and ( + self.offset == other.offset + if self.offset is not None and other.offset is not None + else self.offset is other.offset + ) + and ( + self.state2 == other.state2 + if self.state2 is not None and other.state2 is not None + else self.state2 is other.state2 + ) + ) + + LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING = {"row": 0, "col32": 1, "col_turing": 2, "col_ampere": 3} INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING = {val: name for (name, val) in LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING.items()} diff --git a/csrc/kernels.hip b/csrc/kernels.hip new file mode 100644 index 000000000..b804d7ccc --- /dev/null +++ b/csrc/kernels.hip @@ -0,0 +1,4101 @@ +// !!! This is a file automatically generated by hipify!!! +#include "hip/hip_runtime.h" +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include + +//#include + + +#define HLF_MAX 65504 +#define TH 1024 +#define NUM 4 +#define NUM_BLOCK 4096 + + +// source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda +// Luckily we have atomicmax and atomicmin in ROCm + +__device__ float dDequantizeFP4(unsigned char val, float absmax) +{ + float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f; + if((val & 0b0110) == 0) + { + // subnormal + if((val & 0b0001) == 0) + return 0.0f; + else + return sign*0.0625f*absmax; + } + else + { + // normal + float exponent = ((val & 0b0100) == 4 ? 2.0f : 8.0f) + ((val & 0b0010) == 2 ? 0.0f : 2.0f); + float fraction = (val & 0b0001) == 1 ? 1.5f : 1.0f; + + return sign*exponent*fraction*absmax; + } +} + +__device__ float d2DequantizeFP4(unsigned char val) +{ + float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f; + if((val & 0b0110) == 0) + { + // subnormal + if((val & 0b0001) == 0) + return 0.0f; + else + return sign*0.0625f; + } + else + { + // normal + float exponent = ((val & 0b0100) == 4 ? 2.0f : 8.0f) + ((val & 0b0010) == 2 ? 0.0f : 2.0f); + float fraction = (val & 0b0001) == 1 ? 1.5f : 1.0f; + + return sign*exponent*fraction; + } +} + +__device__ float dDequantizeFP4Tree(unsigned char val, float absmax) +{ + float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f; + if((val & 0b0100) == 4) // 0 + if((val & 0b0010) == 2) //01 + if((val & 0b0001) == 1) // 111 + return 0.25000000f*absmax*sign; // 1111 + else + return 0.16666667f*absmax*sign; // 1110 + else + if((val & 0b0001) == 1) // 110 + return 0.50000000f*absmax*sign; // 1101 + else + return 0.33333333f*absmax*sign; // 1100 + else + if((val & 0b0010) == 2) //10 + if((val & 0b0001) == 1) // 101 + return 1.00000000f*absmax*sign; // 1011 + else + return 0.66666667f*absmax*sign; // 1010 + else + if((val & 0b0001) == 1) // 100 + return 5.208333333e-03f*absmax*sign; // 1001 + else + return 0.00000000f*absmax*sign; // 1000 +} + +__device__ unsigned char dQuantizeFP4(float x) +{ + // FP4 with bias of 3 + // first bit is a sign + // subnormals + // 0b000 = 0 + // 0b001 = 0.0625 + // 0b110 = 2 + // 0b111 = 3 + // 0b100 = 4 + // 0b101 = 6 + // 0b010 = 8 + // 0b011 = 12 + + + // we do a binary search + // the pivots are divided by 12 (the FP4 absmax) + // since we assume input data is in [-1.0, 1.0] + + // !be careful here, its easy to make a mistake + // that is difficult to notice if you add an extra + // zero somewhere! + + int sign = x < 0 ? 0b1000 : 0b0000; + x = fabsf(x); + if(x > 0.29166667f) + if( x > 0.583333f) + if( x > 0.8333333f) + return 0b0011+sign; + else + return 0b0010+sign; + else + if(x > 0.4166667f) + return 0b101+sign; + else + return 0b100+sign; + else + if(x > 0.0859375f) + if(x > 0.20833333f) + return 0b0111+sign; + else + return 0b0110+sign; + else + if(x > 0.00260417f) + return 0b0001+sign; + else + return 0b0000+sign; +} + +__device__ half dhDequantizeNF4(unsigned char val) +{ + // the values for this tree was generated by test_normal_map_tree + // in the file tests/test_functional.py + if((val & 0b1000) == 8) + if((val & 0b0100) == 4) // 1 + if((val & 0b0010) == 2) // 11 + if((val & 0b0001) == 1) // 111 + return 1.0f; + else + return 0.7229568362236023f; + else + if((val & 0b0001) == 1) // 110 + return 0.5626170039176941f; + else + return 0.44070982933044434f; + else + if((val & 0b0010) == 2) //10 + if((val & 0b0001) == 1) // 101 + return 0.33791524171829224f; + else + return 0.24611230194568634f; + else + if((val & 0b0001) == 1) // 100 + return 0.16093020141124725f; + else + return 0.07958029955625534f; + + else + if((val & 0b0100) == 4) // 0 + if((val & 0b0010) == 2) //01 + if((val & 0b0001) == 1) // 011 + return 0.0f; + else + return -0.09105003625154495f; + else + if((val & 0b0001) == 1) // 010 + return -0.18477343022823334f; + else + return -0.28444138169288635f; + else + if((val & 0b0010) == 2) //00 + if((val & 0b0001) == 1) // 001 + return -0.39491748809814453f; + else + return -0.5250730514526367f; + else + if((val & 0b0001) == 1) // 000 + return -0.6961928009986877f; + else + return -1.0f; + +} + +__device__ float dDequantizeNF4(unsigned char val) +{ + + // the values for this tree was generated by test_normal_map_tree + // in the file tests/test_functional.py + if((val & 0b1000) == 8) + if((val & 0b0100) == 4) // 1 + if((val & 0b0010) == 2) // 11 + if((val & 0b0001) == 1) // 111 + return 1.0f; + else + return 0.7229568362236023f; + else + if((val & 0b0001) == 1) // 110 + return 0.5626170039176941f; + else + return 0.44070982933044434f; + else + if((val & 0b0010) == 2) //10 + if((val & 0b0001) == 1) // 101 + return 0.33791524171829224f; + else + return 0.24611230194568634f; + else + if((val & 0b0001) == 1) // 100 + return 0.16093020141124725f; + else + return 0.07958029955625534f; + + else + if((val & 0b0100) == 4) // 0 + if((val & 0b0010) == 2) //01 + if((val & 0b0001) == 1) // 011 + return 0.0f; + else + return -0.09105003625154495f; + else + if((val & 0b0001) == 1) // 010 + return -0.18477343022823334f; + else + return -0.28444138169288635f; + else + if((val & 0b0010) == 2) //00 + if((val & 0b0001) == 1) // 001 + return -0.39491748809814453f; + else + return -0.5250730514526367f; + else + if((val & 0b0001) == 1) // 000 + return -0.6961928009986877f; + else + return -1.0f; + +} + +__device__ unsigned char dQuantizeNF4(float x) +{ + + // the values for this tree was generated by test_normal_map_tree + // in the file tests/test_functional.py + if(x > 0.03979014977812767f) + if(x > 0.3893125355243683f) // 1 + if(x > 0.6427869200706482f) // 11 + if(x > 0.8614784181118011f) // 111 + return 0b1111; + else + return 0b1110; + else + if(x > 0.5016634166240692f) // 110 + return 0b1101; + else + return 0b1100; + else + if(x > 0.2035212516784668f) // 10 + if(x > 0.2920137718319893f) // 101 + return 0b1011; + else + return 0b1010; + else + if(x > 0.1202552504837513f) // 100 + return 0b1001; + else + return 0b1000; + else + if(x > -0.33967943489551544f) // 0 + if(x > -0.13791173323988914f) // 01 + if(x > -0.045525018125772476f) // 011 + return 0b0111; + else + return 0b0110; + else + if(x > -0.23460740596055984f) // 010 + return 0b0101; + else + return 0b0100; + else + if(x > -0.6106329262256622f) // 00 + if(x > -0.4599952697753906f) // 001 + return 0b0011; + else + return 0b0010; + else + if(x > -0.8480964004993439f) // 000 + return 0b0001; + else + return 0b0000; +} +// sign function for lion +// taken from https://stackoverflow.com/a/4609795, but not sure if there's a proper way to do this in CUDA + +template __device__ int sgn(T val) +{ + return (T(0) < val) - (val < T(0)); +} + +template +__device__ unsigned char dQuantize(float* smem_code, const float rand, float x) +{ + int pivot = 127; + int upper_pivot = 255; + int lower_pivot = 0; + + float lower = -1.0f; + float upper = 1.0f; + + float val = smem_code[pivot]; + // i>>=1 = {32, 16, 8, 4, 2, 1} + for(int i = 64; i > 0; i>>=1) + { + if(x > val) + { + lower_pivot = pivot; + lower = val; + pivot+=i; + } + else + { + upper_pivot = pivot; + upper = val; + pivot-=i; + } + val = smem_code[pivot]; + } + + if(upper_pivot == 255) + upper = smem_code[upper_pivot]; + if(lower_pivot == 0) + lower = smem_code[lower_pivot]; + + if(!STOCHASTIC) + { + if(x > val) + { + float midpoint = (upper+val)*0.5f; + if(x > midpoint) + { + return upper_pivot; + } + else + return pivot; + } + else + { + float midpoint = (lower+val)*0.5f; + if(x < midpoint) + return lower_pivot; + else + return pivot; + } + } + else + { + if(x > val) + { + float dist_to_upper = fabsf(upper-x); + float dist_full = upper-val; + if(rand >= dist_to_upper/dist_full) return upper_pivot; + else return pivot; + } + else + { + float dist_to_lower = fabsf(lower-x); + float dist_full = val-lower; + if(rand >= dist_to_lower/dist_full) return lower_pivot; + else return pivot; + } + } +} + +template +__device__ __forceinline__ unsigned char quantize_2D(float *__restrict__ quadrants, float *__restrict__ const smem_code, float x) +{ + int pivot = 127; + int upper_pivot = 255; + int lower_pivot = 0; + + float lower = SIGNED ? -1.0f : 0.0f; + float upper = 1.0f; + float midpoint; + float val = quadrants[1]; + int local_pivot = 1; + int offset = 1; + + // i>>=1 = {32, 16, 8, 4, 2, 1} + for(int i = 64; i > 0; i>>=1) + { + if(x > val) + { + lower_pivot = pivot; + lower = val; + pivot+=i; + //val = i == 64 ? quadrants[2] : smem_code[pivot]; + local_pivot += offset; + } + else + { + upper_pivot = pivot; + upper = val; + pivot-=i; + //val = i == 64 ? quadrants[0] : smem_code[pivot]; + local_pivot -= offset; + } + val = i >= 64 ? quadrants[local_pivot] : smem_code[pivot]; + offset -= 1; + } + + if(x > val) + { + midpoint = (upper+val)*0.5f; + if(x > midpoint) + return upper_pivot; + else + return pivot; + } + else + { + midpoint = (lower+val)*0.5f; + if(x < midpoint) + return lower_pivot; + else + return pivot; + } +} + +template +__device__ __forceinline__ unsigned char quantize_quadrant(int QUADRANT, float *__restrict__ const smem_code, float x, float lower, float midpoint, float upper) +{ + int lower_pivot = QUADRANT*16-1 - 0; + int pivot = QUADRANT*16-1 + 16; + int upper_pivot = QUADRANT*16-1 + 31; + + float val = midpoint; + + // i>>=1 = {32, 16, 8, 4, 2, 1} + for(int i = 16; i > 0; i>>=1) + { + if(x > val) + { + lower_pivot = pivot; + lower = val; + pivot+=i; + } + else + { + upper_pivot = pivot; + upper = val; + pivot-=i; + } + val = smem_code[pivot]; + } + + if(x > val) + { + midpoint = (upper+val)*0.5f; + if(x > midpoint) + return upper_pivot; + else + return pivot; + } + else + { + midpoint = (lower+val)*0.5f; + if(x < midpoint) + return lower_pivot; + else + return pivot; + } +} + +__global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n) +{ + const int tid = threadIdx.x + (blockDim.x*blockIdx.x); + const int numThreads = blockDim.x*gridDim.x; + + for(int i = tid; i < n; i+=numThreads) + { + int idx = (index1[i]*maxidx1) + index2[i]; + atomicAdd(&histogram[idx], src[i]); + } +} + +template +__global__ void kCompressMax(T * __restrict__ const A, T* out, unsigned char* out_idx, const int n) +{ + typedef hipcub::WarpReduce WarpReduce; + __shared__ typename WarpReduce::TempStorage temp_storage; + typedef hipcub::BlockLoad LoadT; + __shared__ typename LoadT::TempStorage loadt; + + const int warp_idx = threadIdx.x/32; + const int valid_items = n - (blockIdx.x*BLOCK_SIZE) > BLOCK_SIZE ? BLOCK_SIZE : n - (blockIdx.x*BLOCK_SIZE); + + // BLOCK_SIZE/32 == number of warps + __shared__ int smem_max_indices[8*BLOCK_SIZE/32]; + __shared__ float smem_max_values[8*BLOCK_SIZE/32]; + + T values[8]; + T max1 = -64000.0f; + T max2 = -64000.0f; + int max_idx1 = -1; + int max_idx2 = -1; + int sign1 = -1; + int sign2 = -1; + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + + LoadT(loadt).Load(&(A[(blockIdx.x*BLOCK_SIZE)]), values, valid_items, (T)0.0f); + #pragma unroll 8 + for(int i = 0; i < 8; i++) + { + T absval = fabsf(values[i]); + if(absval > max1) + { + max1 = values[i]; + sign1 = signbit(values[i]); + max_idx1 = 8*threadIdx.x + i; + } + else if(absval > max2) + { + max2 = values[i]; + sign2 = signbit(values[i]); + max_idx2 = 8*threadIdx.x + i; + } + } + + float warp_max; + for(int i = 0; i < 8; i++) + { + // 3. do warp reduction + broadcast back + warp_max = WarpReduce(temp_storage).Reduce(max1, hipcub::Max()); + warp_max = hipcub::ShuffleIndex<32>(warp_max, 0, 0xffffffff); + + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + if(warp_max == max1) + { + smem_max_values[warp_idx*8 + i] = sign1 != 0 ? -max1 : max1; + smem_max_indices[warp_idx*8 + i] = max_idx1; + + sign1 = sign2; + max1 = max2; + max_idx1 = max_idx2; + + max2 = -64000.0f; + } + //__syncwarp(); + } + + if(threadIdx.x % 32 < 8) + { + // offset: 8 values per 256 input values + // + int offset = BLOCK_SIZE*blockIdx.x*BLOCK_SIZE/32*8; + } + +} + +#define THREADS_ESTIMATE 512 +#define NUM_ESTIMATE 8 +#define BLOCK_ESTIMATE 4096 + +template +__launch_bounds__(THREADS_ESTIMATE, 1) +__global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n) +{ + const int n_full = (BLOCK_ESTIMATE*(n/BLOCK_ESTIMATE)) + (n % BLOCK_ESTIMATE == 0 ? 0 : BLOCK_ESTIMATE); + int valid_items = (blockIdx.x+1 == gridDim.x) ? n - (blockIdx.x*BLOCK_ESTIMATE) : BLOCK_ESTIMATE; + const int base_idx = (blockIdx.x * BLOCK_ESTIMATE); + const float reciprocal_num_blocks = 1.0f/(n < 4096 ? 1.0f : (n/BLOCK_ESTIMATE)); + + T vals[NUM_ESTIMATE]; + + typedef hipcub::BlockRadixSort BlockRadixSort; + typedef hipcub::BlockLoad LoadFloat; + + __shared__ union { + typename LoadFloat::TempStorage loadf; + typename BlockRadixSort::TempStorage sort; + int smem_qidx[BLOCK_ESTIMATE]; + } temp_storage; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_ESTIMATE) + { + valid_items = n - i > BLOCK_ESTIMATE ? BLOCK_ESTIMATE : n - i; + + // do not process half-blocks + if(valid_items < BLOCK_ESTIMATE && n > BLOCK_ESTIMATE){ continue; } + + #pragma unroll 4 + for(int j = 0; j < NUM_ESTIMATE; j++) + vals[j] = max_val; + + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(A[i]), vals, valid_items); + + #pragma unroll 4 + for(int j = 0; j < NUM_ESTIMATE; j++) + vals[j] = ((float)vals[j]) * reciprocal_num_blocks; + + + __syncthreads(); + // sort into striped pattern to mitigate bank conflicts + // striped pattern index for thread 0 [0, 1024, 2048, 3096] + // striped pattern index for thread 1 [1, 1025, 2049, 3097] + BlockRadixSort(temp_storage.sort).SortBlockedToStriped(vals); + + __syncthreads(); + for(int j = threadIdx.x; j < BLOCK_ESTIMATE; j+=blockDim.x) + temp_storage.smem_qidx[j] = -1; + + __syncthreads(); + + if(threadIdx.x < 256) + { + float q_interval = (1.0f-(2.0f*offset))/255.0f; + int local_idx = round(((offset+(threadIdx.x*q_interval))*(valid_items-1))); + temp_storage.smem_qidx[local_idx] = threadIdx.x; + } + + __syncthreads(); + + for(int i = threadIdx.x; i < BLOCK_ESTIMATE; i+=blockDim.x) + { + if(temp_storage.smem_qidx[i] != -1) + atomicAdd(&code[temp_storage.smem_qidx[i]], vals[i/THREADS_ESTIMATE]); + } + } +} + + +__launch_bounds__(TH, 4) +__global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n) +{ + const int n_full = (NUM_BLOCK*(n/NUM_BLOCK)) + (n % NUM_BLOCK == 0 ? 0 : NUM_BLOCK); + int valid_items = (blockIdx.x+1 == gridDim.x) ? n - (blockIdx.x*NUM_BLOCK) : NUM_BLOCK; + const int base_idx = (blockIdx.x * NUM_BLOCK); + + float vals[NUM]; + unsigned char qvals[NUM]; + //const int lane_id = threadIdx.x % 2; + + typedef hipcub::BlockLoad LoadFloat; + typedef hipcub::BlockStore StoreChar; + + __shared__ typename LoadFloat::TempStorage loadf; + __shared__ typename StoreChar::TempStorage storec; + __shared__ float smem_code[256]; + //__shared__ float smem_code[2][257]; + + if(threadIdx.x < 256) + { + smem_code[threadIdx.x] = code[threadIdx.x]; + //smem_code[0][threadIdx.x] = code[threadIdx.x]; + //smem_code[1][threadIdx.x] = smem_code[0][threadIdx.x]; + } + + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_BLOCK) + { + // number of values already processed in blocks + + // number of values already processed in this block + + // rand_offset % mod value + valid_items = n - i > NUM_BLOCK ? NUM_BLOCK : n - i; + + __syncthreads(); + LoadFloat(loadf).Load(&(A[i]), vals, valid_items); + + + #pragma unroll 4 + for(int j = 0; j < NUM; j++) + qvals[j] = dQuantize<0>(smem_code, 0.0f, vals[j]); + + __syncthreads(); + StoreChar(storec).Store(&(out[i]), qvals, valid_items); + } +} + +template +//__launch_bounds__(TH, 4) +__global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n) +{ + const int n_full = gridDim.x * BLOCK_SIZE; + int valid_items = 0; + const int base_idx = (blockIdx.x * BLOCK_SIZE); + + T vals[NUM_PER_TH]; + float rand_vals[NUM_PER_TH]; + unsigned char qvals[(DATA_TYPE > 0) ? NUM_PER_TH/2 : NUM_PER_TH]; + //float local_abs_max = -FLT_MAX; + float local_abs_max = 0.0f; + int local_rand_idx = 0; + + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockStore 0) ? NUM_PER_TH/2 : NUM_PER_TH, hipcub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar; + typedef hipcub::BlockReduce BlockReduce; + typedef hipcub::BlockLoad LoadFloat; + + __shared__ typename LoadT::TempStorage loadt; + __shared__ typename LoadFloat::TempStorage loadf; + __shared__ typename StoreChar::TempStorage storec; + __shared__ typename BlockReduce::TempStorage reduce; + __shared__ float smem_code[256]; + __shared__ float smem_absmax_value[1]; + + if(DATA_TYPE == General8bit) + for(int i = threadIdx.x; i < 256; i+=blockDim.x) + smem_code[i] = code[i]; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) + { + valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i; + local_abs_max = -FLT_MAX; + + __syncthreads(); + LoadT(loadt).Load(&(A[i]), vals, valid_items, (T)0.0f); + + // 1. compute local max + // 2. broadcast local max + // 3. normalize inputs and quantize + + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j])); + + local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, hipcub::Max(), valid_items); + + if(threadIdx.x == 0) + smem_absmax_value[0] = local_abs_max; + + __syncthreads(); + + if(threadIdx.x == 0) + absmax[i/BLOCK_SIZE] = local_abs_max; + else + local_abs_max = smem_absmax_value[0]; + + //__syncwarp(); + + local_abs_max = 1.0f/local_abs_max; + + if(STOCHASTIC) + { + local_rand_idx = ((blockIdx.x*NUM_BLOCK) + (threadIdx.x*NUM) + rand_offset) % (1024-4); + LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0); + } + + unsigned char packed_4bit = 0; + switch(DATA_TYPE) + { + case General8bit: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + { + if(!STOCHASTIC) + qvals[j] = dQuantize<0>(smem_code, 0.0f, ((float)vals[j])*local_abs_max); + else + qvals[j] = dQuantize<1>(smem_code, rand_vals[j], ((float)vals[j])*local_abs_max); + } + break; + case FP4: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH/2; j++) + { + packed_4bit |= dQuantizeFP4(((float)vals[2*j])*local_abs_max) << 4; + packed_4bit |= dQuantizeFP4(((float)vals[2*j+1])*local_abs_max); + qvals[j] = packed_4bit; + } + break; + case NF4: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH/2; j++) + { + packed_4bit |= dQuantizeNF4(((float)vals[2*j])*local_abs_max) << 4; + packed_4bit |= dQuantizeNF4(((float)vals[2*j+1])*local_abs_max); + qvals[j] = packed_4bit; + } + break; + } + + __syncthreads(); + StoreChar(storec).Store(&(out[(DATA_TYPE > 0) ? i/2 : i]), qvals, (DATA_TYPE > 0) ? (valid_items+1)/2 : valid_items); + } +} + +template +__global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n) +{ + + const int n_load = (gridDim.x * TILE_SIZE); + int valid_items_load = 0; + int valid_items_store = 0; + const int base_idx = (blockIdx.x * TILE_SIZE); + + T vals[NUM_PER_TH*((DATA_TYPE > 0) ? 2 : 1)]; + unsigned char qvals[NUM_PER_TH]; + float local_abs_max = -FLT_MAX; + + typedef hipcub::BlockLoad LoadChar; + typedef hipcub::BlockStore 0) ? 2 : 1), hipcub::BLOCK_STORE_WARP_TRANSPOSE> StoreT; + + __shared__ typename LoadChar::TempStorage loadchar; + __shared__ typename StoreT::TempStorage storet; + + for (unsigned int i = base_idx; i < n_load; i += gridDim.x*TILE_SIZE) + { + if(DATA_TYPE > 0) + { + valid_items_load = (n+1)/2 - i > TILE_SIZE ? TILE_SIZE : (n+1)/2 - i; + valid_items_store = n - i*2 > TILE_SIZE*2 ? TILE_SIZE*2 : n - i*2; + } + else + { + valid_items_load = n - i > TILE_SIZE ? TILE_SIZE : n - i; + valid_items_store = n - i > TILE_SIZE ? TILE_SIZE : n - i; + } + local_abs_max = __ldg(&absmax[(i+threadIdx.x*NUM_PER_TH)/(blocksize)]); + + __syncthreads(); + LoadChar(loadchar).Load(&(A[i]), qvals, valid_items_load, 128); + + switch(DATA_TYPE) + { + case General8bit: + // load code through read-only cache via __ldg + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + vals[j] = __ldg(&code[qvals[j]])*local_abs_max; + break; + case FP4: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + { + vals[j*2] = dDequantizeFP4Tree(qvals[j] >> 4, local_abs_max); + vals[j*2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F, local_abs_max); + } + break; + case NF4: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + { + vals[j*2] = dDequantizeNF4(qvals[j] >> 4)* local_abs_max; + vals[j*2 + 1] = dDequantizeNF4(qvals[j] & 0x0F)* local_abs_max; + } + break; + } + + __syncthreads(); + StoreT(storet).Store(&(out[(DATA_TYPE > 0) ? i*2 : i]), vals, valid_items_store); + } +} + +__global__ void kDequantize(float *code, unsigned char *A, float *out, const int n) +{ + const unsigned int numThreads = blockDim.x * gridDim.x; + const int idx = (blockIdx.x * blockDim.x) + threadIdx.x; + + __shared__ float smem_code[256]; + if(threadIdx.x < 256) + { + smem_code[threadIdx.x] = code[threadIdx.x]; + } + + __syncthreads(); + + for (int i = idx;i < n; i += numThreads) + { + out[i] = smem_code[A[i]]; + } +} + + + +template +__launch_bounds__(BLOCK_SIZE/NUM_VALS, 1) +__global__ void kPreconditionOptimizer32bit2State(T* g, T* p, + float* state1, float* state2, float *unorm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const int n) +{ + + const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); + const int base_idx = (blockIdx.x * blockDim.x * NUM_VALS); + int valid_items = 0; + + T g_vals[NUM_VALS]; + + float s1_vals[NUM_VALS]; + float s2_vals[NUM_VALS]; + + const float correction1 = 1.0f/(1.0f - powf(beta1, step)); + const float correction2 = 1.0f/(1.0f - powf(beta2, step)); + + typedef hipcub::BlockLoad Load; + typedef hipcub::BlockLoad LoadFloat; + typedef hipcub::BlockReduce BlockReduce; + + __shared__ union { + typename Load::TempStorage load; + typename LoadFloat::TempStorage loadf; + typename BlockReduce::TempStorage reduce; + } temp_storage; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) + { + valid_items = n - i >= (BLOCK_SIZE) ? (BLOCK_SIZE) : n - i; + + __syncthreads(); + Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items, 0.0f); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items, 0.0f); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items, 0.0f); + + # pragma unroll NUM_VALS + for(unsigned int j = 0; j < NUM_VALS; j++) + g_vals[j] = gnorm_scale*((float)g_vals[j]); + + # pragma unroll NUM_VALS + for(unsigned int j = 0; j < NUM_VALS; j++) + { + switch(OPTIMIZER) + { + case ADAM: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j])); + s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j]))); + s1_vals[j] *= correction1; + s2_vals[j] *= correction2; + s1_vals[j] = s1_vals[j]/(sqrtf(s2_vals[j])+eps); // update + s1_vals[j] *= s1_vals[j]; // update l2 norm (update*update) + break; + } + } + + # pragma unroll NUM_VALS-1 + for(unsigned int j = 1; j < NUM_VALS; j++) + s1_vals[0] += s1_vals[j]; + + __syncthreads(); + s1_vals[0] = BlockReduce(temp_storage.reduce).Sum(s1_vals[0]); + + if(threadIdx.x == 0) + atomicAdd(&unorm[0], s1_vals[0]); + + //__syncwarp(); + } +} + + + +#define NUM_PER_THREAD 4 + +template +__launch_bounds__(TH, 1) +__global__ void kOptimizer32bit2State(T* g, T* p, + float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n) +{ + + const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD)); + const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD); + int valid_items = 0; + float update_scale = 0.0f; + T g_vals[NUM_PER_THREAD]; + T p_vals[NUM_PER_THREAD]; + + float s1_vals[NUM_PER_THREAD]; + float s2_vals[NUM_PER_THREAD]; + + const float correction1 = 1.0f - powf(beta1, step); + const float correction2 = sqrtf(1.0f - powf(beta2, step)); + const float step_size = -lr*correction2/correction1; + + if(max_unorm > 0.0f) + { + update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; + if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; } + else{ update_scale = 1.0f; } + } + else{ update_scale = 1.0f; } + + typedef hipcub::BlockLoad Load; + typedef hipcub::BlockStore Store; + + typedef hipcub::BlockLoad LoadFloat; + typedef hipcub::BlockStore StoreFloat; + + __shared__ union { + typename Load::TempStorage load; + typename Store::TempStorage store; + typename LoadFloat::TempStorage loadf; + typename StoreFloat::TempStorage storef; + } temp_storage; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*TH*NUM_PER_THREAD) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + + __syncthreads(); + Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items); + __syncthreads(); + Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items); + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD; j++) + g_vals[j] = gnorm_scale*((float)g_vals[j]); + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD; j++) + { + switch(OPTIMIZER) + { + case ADAM: + if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + { + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j])); + s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j]))); + p_vals[j] = ((float)p_vals[j]) + (update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(eps*correction2)))); + + if(weight_decay > 0.0f) + p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)); + } + break; + } + } + + __syncthreads(); + Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items); + __syncthreads(); + StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items); + __syncthreads(); + StoreFloat(temp_storage.storef).Store(&(state2[i]), s2_vals, valid_items); + } +} + +template +__launch_bounds__(BLOCK_SIZE/NUM_VALS, 1) +__global__ void kPreconditionOptimizer32bit1State(T* g, T* p, + float* state1, float *unorm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const int n) +{ + + const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); + const int base_idx = (blockIdx.x * blockDim.x * NUM_VALS); + int valid_items = 0; + + T g_vals[NUM_VALS]; + + float s1_vals[NUM_VALS]; + + typedef hipcub::BlockLoad Load; + typedef hipcub::BlockLoad LoadFloat; + typedef hipcub::BlockReduce BlockReduce; + + __shared__ union { + typename Load::TempStorage load; + typename LoadFloat::TempStorage loadf; + typename BlockReduce::TempStorage reduce; + } temp_storage; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) + { + valid_items = n - i >= (BLOCK_SIZE) ? (BLOCK_SIZE) : n - i; + + __syncthreads(); + Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items, 0.0f); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items, 0.0f); + + # pragma unroll NUM_VALS + for(unsigned int j = 0; j < NUM_VALS; j++) + g_vals[j] = gnorm_scale*((float)g_vals[j]); + + # pragma unroll NUM_VALS + for(unsigned int j = 0; j < NUM_VALS; j++) + { + switch(OPTIMIZER) + { + case MOMENTUM: + if(step == 1) + s1_vals[j] = (float)g_vals[j]; // state update + else + s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); // state update + s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm + break; + case LION: + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*(float)g_vals[j]); // state update + break; + case RMSPROP: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); // state update + s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value + s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm + break; + case ADAGRAD: + s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]); // state update + s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value + s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm + break; + } + } + + # pragma unroll + for(unsigned int j = 1; j < NUM_VALS; j++) + s1_vals[0] += s1_vals[j]; + + __syncthreads(); + s1_vals[0] = BlockReduce(temp_storage.reduce).Sum(s1_vals[0], valid_items); + + if(threadIdx.x == 0) + atomicAdd(&unorm[0], s1_vals[0]); + + //__syncwarp(); + } +} + +template +__launch_bounds__(TH, 1) +__global__ void kOptimizer32bit1State(T *g, T *p, + float *state1, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n) +{ + + const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD)); + const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD); + int valid_items = 0; + float update_scale = 0.0f; + + if(max_unorm > 0.0f) + { + update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; + if(update_scale > max_unorm*param_norm+eps){ update_scale = (max_unorm*param_norm+eps)/update_scale; } + else{ update_scale = 1.0f; } + } + else{ update_scale = 1.0f; } + + T g_vals[NUM_PER_THREAD]; + T p_vals[NUM_PER_THREAD]; + + float s1_vals[NUM_PER_THREAD]; + + typedef hipcub::BlockLoad Load; + typedef hipcub::BlockStore Store; + + typedef hipcub::BlockLoad LoadFloat; + typedef hipcub::BlockStore StoreFloat; + + __shared__ union { + typename Load::TempStorage load; + typename Store::TempStorage store; + typename LoadFloat::TempStorage loadf; + typename StoreFloat::TempStorage storef; + } temp_storage; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*TH*NUM_PER_THREAD) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + + __syncthreads(); + Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items); + __syncthreads(); + Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items); + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD; j++) + { + g_vals[j] = gnorm_scale*((float)g_vals[j]); + if(weight_decay > 0.0f) + g_vals[j] = (float)g_vals[j] + (((float)p_vals[j])*weight_decay); + } + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD; j++) + { + if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + { + switch(OPTIMIZER) + { + case MOMENTUM: + if(step == 1) + s1_vals[j] = (float)g_vals[j]; + else + s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); + + p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j])); + break; + case LION: + p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j])))); + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*((float)g_vals[j])); + break; + case RMSPROP: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); + p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps)); + break; + case ADAGRAD: + s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]); + p_vals[j] = ((float)p_vals[j]) - lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps); + break; + } + } + } + + __syncthreads(); + Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items); + __syncthreads(); + StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items); + } +} + + +#define NUM8BIT 16 +#define NUM_THREADS 256 +#define NUM_PER_BLOCK 4096 + +template +__global__ void +__launch_bounds__(NUM_THREADS, 2) +kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, + float *unorm, + const float beta1, const float beta2, + const float eps, const int step, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + const float gnorm_scale, const int n) +{ + const int n_full = gridDim.x * NUM_PER_BLOCK; + const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD); + int valid_items = n - (blockIdx.x*NUM_PER_BLOCK) > NUM_PER_BLOCK ? NUM_PER_BLOCK : n - (blockIdx.x*NUM_PER_BLOCK); + float g_val = 0.0f; + float local_max_s1 = -FLT_MAX; + float local_max_s2 = -FLT_MAX; + float local_unorm = 0.0f; + + float s2_vals[NUM8BIT]; + float s1_vals[NUM8BIT]; + T g_vals[NUM8BIT]; + unsigned char m_c1[NUM8BIT]; + unsigned char r_c2[NUM8BIT]; + + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockLoad LoadUInt8; + typedef hipcub::BlockReduce BlockReduce; + + + __shared__ union { + typename LoadT::TempStorage loadh; + typename LoadUInt8::TempStorage loadc; + typename BlockReduce::TempStorage reduce; + } temp_storage; + + __shared__ float smem_quantiles1[256]; + __shared__ float smem_quantiles2[256]; + + if(threadIdx.x < 256) + { + smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x]; + smem_quantiles2[threadIdx.x] = quantiles2[threadIdx.x]; + } + + __syncthreads(); + + for (unsigned int i = base_idx; i < n_full; i += NUM_THREADS*gridDim.x*NUM8BIT) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + __syncthreads(); + LoadUInt8(temp_storage.loadc).Load(&(state1[i]), m_c1, valid_items, 128); + __syncthreads(); + LoadUInt8(temp_storage.loadc).Load(&(state2[i]), r_c2, valid_items, 128); + __syncthreads(); + + #pragma unroll 16 + for(int j = 0; j < NUM8BIT; j++) + { + g_val = g_vals[j]; + g_val *= gnorm_scale; + s1_vals[j] = smem_quantiles1[m_c1[j]]*max1[0]*beta1; + s1_vals[j] += (1.0f-beta1)*g_val; + local_max_s1 = fmaxf(local_max_s1, fabsf(s1_vals[j])); + } + + #pragma unroll 16 + for(int j = 0; j < NUM8BIT; j++) + { + g_val = g_vals[j]; + g_val *= gnorm_scale; + s2_vals[j] = smem_quantiles2[r_c2[j]]*max2[0]*beta2; + s2_vals[j] += (1.0f-beta2)*g_val*g_val; + local_max_s2 = fmaxf(local_max_s2, fabsf(s2_vals[j])); + } + + if(unorm != NULL) + { + #pragma unroll 16 + for(int j = 0; j < NUM8BIT; j++) + { + float correction1 = __fdividef(1.0f, 1.0f - powf(beta1, step)); + float correction2 = __fdividef(1.0f, 1.0f - powf(beta2, step)); + s1_vals[j] *= correction1; + s2_vals[j] *= correction2; + float update_val = s1_vals[j]/(sqrtf(s2_vals[j])+eps); // update + local_unorm += update_val*update_val; + } + } + } + + __syncthreads(); + local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, hipcub::Max(), valid_items); + __syncthreads(); + local_max_s2 = BlockReduce(temp_storage.reduce).Reduce(local_max_s2, hipcub::Max(), valid_items); + if(unorm != NULL) + { + __syncthreads(); + local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, hipcub::Sum(), valid_items); + } + + if(threadIdx.x == 0) + { + atomicMax(&new_max1[0], local_max_s1); + atomicMax(&new_max2[0], local_max_s2); + if(unorm != NULL){ atomicAdd(&unorm[0], local_unorm); } + } +} + +#define NUM_PER_THREAD2 4 +#define NUM_THREADS2 1024 +#define NUM_PER_BLOCK2 4096 + +template +__global__ void +__launch_bounds__(NUM_THREADS2, 1) +kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned char* state2, + const float *unorm, const float max_unorm, const float param_norm, \ + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + float weight_decay, + const float gnorm_scale, const int n) +{ + + const int n_full = (blockDim.x * gridDim.x)*NUM_PER_THREAD2; + const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD2); + int valid_items = 0; + float g_val = 0.0f; + float s1_vals[NUM_PER_THREAD2]; + float s2_vals[NUM_PER_THREAD2]; + const float correction1 = 1.0f - powf(beta1, step); + const float correction2 = sqrtf(1.0f - powf(beta2, step)); + const float step_size = -lr*correction2/correction1; + //const float step_size = -lr*correction2/correction1; + float new_max_val1 = 1.0f/new_max1[0]; + float new_max_val2 = 1.0f/new_max2[0]; + float update_scale = 1.0f; + + if(max_unorm > 0.0f) + { + update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; + if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; } + else{ update_scale = 1.0f; } + } + else{ update_scale = 1.0f; } + + unsigned char c1s[NUM_PER_THREAD2]; + unsigned char c2s[NUM_PER_THREAD2]; + T p_vals[NUM_PER_THREAD2]; + T g_vals[NUM_PER_THREAD2]; + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockLoad LoadChar; + + typedef hipcub::BlockStore StoreChar; + typedef hipcub::BlockStore StoreT; + + __shared__ float smem_quantiles1[256]; + __shared__ float smem_quantiles2[256]; + + __shared__ union { + typename LoadT::TempStorage loadh; + typename LoadChar::TempStorage loadc; + typename StoreChar::TempStorage storec; + typename StoreT::TempStorage storeh; + } temp_storage; + + if(threadIdx.x < 512) + { + if(threadIdx.x < 256) + smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x]; + else + smem_quantiles2[threadIdx.x-256] = quantiles2[threadIdx.x-256]; + } + + __syncthreads(); + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_THREADS2*NUM_PER_THREAD2) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state2[i]), c2s, valid_items, 0); + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items); + + if((i + (threadIdx.x*NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue; } + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD2; j++) + { + g_val = float(g_vals[j]); + g_val *= gnorm_scale; + s1_vals[j] = smem_quantiles1[c1s[j]]; + s1_vals[j] = s1_vals[j]*max1[0]; + + s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val)); + + c1s[j] = dQuantize<0>(smem_quantiles1, 0.0f, s1_vals[j]*new_max_val1); + + // make sure state1 term has still the same sign after quantization + // (not needed for state2 term which has only positive values) + if(signbit(smem_quantiles1[c1s[j]]) != signbit(s1_vals[j])) + { + if(s1_vals[j] > 0.0f) + c1s[j] += 1; + else + c1s[j] -= 1; + } + + s2_vals[j] = smem_quantiles2[c2s[j]]; + s2_vals[j] = s2_vals[j]*max2[0]; + s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val)); + c2s[j] = dQuantize<0>(smem_quantiles2, 0.0f, s2_vals[j]*new_max_val2); + } + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD2; j++) + { + p_vals[j] = (T)(((float)p_vals[j]) + ((update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(correction2*eps)))))); + if(weight_decay > 0.0f) + p_vals[j] = update_scale*((float)p_vals[j])*(1.0f-(lr*weight_decay)); + } + + StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state2[i]), c2s, valid_items); + __syncthreads(); + } +} + + +template +__global__ void +__launch_bounds__(NUM_THREADS, 2) +kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, + float *unorm, + const float beta1, const float beta2, + const float eps, const int step, + float* __restrict__ const quantiles1, + float* max1, float* new_max1, + const float weight_decay, + const float gnorm_scale, const int n) +{ + const int n_full = gridDim.x * NUM_PER_BLOCK; + const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD); + int valid_items = n - (blockIdx.x*NUM_PER_BLOCK) > NUM_PER_BLOCK ? NUM_PER_BLOCK : n - (blockIdx.x*NUM_PER_BLOCK); + float g_val = 0.0f; + float local_max_s1 = -FLT_MAX; + float local_unorm = 0.0f; + + float s1_vals[NUM8BIT]; + T g_vals[NUM8BIT]; + unsigned char m_c1[NUM8BIT]; + + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockLoad LoadUInt8; + typedef hipcub::BlockReduce BlockReduce; + + + __shared__ union { + typename LoadT::TempStorage loadh; + typename LoadUInt8::TempStorage loadc; + typename BlockReduce::TempStorage reduce; + } temp_storage; + + __shared__ float smem_quantiles1[256]; + + if(threadIdx.x < 256) + smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x]; + + __syncthreads(); + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_THREADS*NUM8BIT) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + __syncthreads(); + LoadUInt8(temp_storage.loadc).Load(&(state1[i]), m_c1, valid_items, 128); + + #pragma unroll 16 + for(int j = 0; j < NUM8BIT; j++) + { + g_val = g_vals[j]; + g_val *= gnorm_scale; + s1_vals[j] = smem_quantiles1[m_c1[j]]*max1[0]; + switch(OPTIMIZER) + { + case MOMENTUM: + if(step == 1) + s1_vals[j] = (float)g_vals[j]; + else + s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); + if(unorm != NULL) + local_unorm += s1_vals[j]*s1_vals[j]; + break; + case LION: + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); + break; + case RMSPROP: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); + break; + } + + local_max_s1 = fmaxf(local_max_s1, fabsf(s1_vals[j])); + } + } + + __syncthreads(); + local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, hipcub::Max(), valid_items); + if(threadIdx.x == 0){ atomicMax(&new_max1[0], local_max_s1); } + if(unorm != NULL) + { + __syncthreads(); + local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, hipcub::Sum(), valid_items); + if(threadIdx.x == 0){ atomicAdd(&unorm[0], local_unorm); } + } + +} + +template +__global__ void +__launch_bounds__(1024, 1) +kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, + const float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, + float* max1, float* new_max1, + float weight_decay, + const float gnorm_scale, const int n) +{ + + const int n_full = (blockDim.x * gridDim.x)*NUM_PER_THREAD2; + const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD2); + int valid_items = 0; + float g_val = 0.0f; + float s1_vals[NUM_PER_THREAD2]; + float new_max_val1 = 1.0f/new_max1[0]; + float update_scale = 1.0f; + + if(max_unorm > 0.0f) + { + update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; + if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; } + else{ update_scale = 1.0f; } + } + else{ update_scale = 1.0f; } + + unsigned char c1s[NUM_PER_THREAD2]; + T p_vals[NUM_PER_THREAD2]; + T g_vals[NUM_PER_THREAD2]; + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockLoad LoadChar; + + typedef hipcub::BlockStore StoreChar; + typedef hipcub::BlockStore StoreT; + + __shared__ float smem_quantiles1[256]; + + __shared__ union { + typename LoadT::TempStorage loadh; + typename LoadChar::TempStorage loadc; + typename StoreChar::TempStorage storec; + typename StoreT::TempStorage storeh; + } temp_storage; + + if(threadIdx.x < 256) + smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x]; + + __syncthreads(); + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_THREADS2*NUM_PER_THREAD2) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items); + + if((i + (threadIdx.x*NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue; } + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD2; j++) + { + g_val = float(g_vals[j]); + g_val *= gnorm_scale; + + if(weight_decay > 0.0f) { + switch(OPTIMIZER) { + case MOMENTUM: + case RMSPROP: + g_val += ((float)p_vals[j])*weight_decay; + break; + case LION: + p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay); + break; + } + } + + s1_vals[j] = smem_quantiles1[c1s[j]]*max1[0]; + + switch(OPTIMIZER) + { + case MOMENTUM: + if(step == 1) + s1_vals[j] = g_vals[j]; + else + s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); + + p_vals[j] = ((float)p_vals[j]) + (-lr*update_scale*(s1_vals[j])); + break; + case LION: + p_vals[j] = ((float)p_vals[j]) - (lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_val)))); + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); + break; + case RMSPROP: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); + p_vals[j] = ((float)p_vals[j]) - (lr*__fdividef(g_val,sqrtf(s1_vals[j])+eps)); + break; + } + + c1s[j] = dQuantize<0>(smem_quantiles1, 0.0f, s1_vals[j]*new_max_val1); + + // make sure state1 term has still the same sign after quantization + if(signbit(smem_quantiles1[c1s[j]]) != signbit(s1_vals[j])) + { + if(s1_vals[j] > 0.0f) + c1s[j] += 1; + else + c1s[j] -= 1; + } + } + + StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); + __syncthreads(); + } +} + + +template +__global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n) +{ + const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); + int valid_items = 0; + + typedef hipcub::BlockReduce BlockReduce; + typedef hipcub::BlockLoad LoadT; + + __shared__ typename BlockReduce::TempStorage reduce; + + __shared__ typename LoadT::TempStorage loadT; + T vals[NUM_VALS]; + float local_sum = 0.0f; + + for (unsigned int i = (blockIdx.x * BLOCK_SIZE); i < n_full; i += gridDim.x*BLOCK_SIZE) + { + valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i; + local_sum = 0.0f; + + __syncthreads(); + LoadT(loadT).Load(&(g[i]), vals, valid_items, (T)0.0f); + + #pragma unroll NUM_VALS + for(int j = 0; j < NUM_VALS; j++) + local_sum += ((float)vals[j])*((float)vals[j]); + + local_sum = BlockReduce(reduce).Sum(local_sum, valid_items); + if(threadIdx.x == 0) + { + if(step == 1) + { + // initialize with the same norm for all positions + //#pragma unroll 10 + for(int j = 0; j < 100; j++) + atomicAdd(&gnorm_vec[j], local_sum); + } + else + atomicAdd(&gnorm_vec[step % 100], local_sum); + } + + } +} + + +#define LANES 2 +#define QUAD 3 +template +__launch_bounds__(256, 3) +__global__ void +kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, + const float beta1, const float beta2, const float beta3, const float alpha, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, + float* absmax1, float* absmax2, + float weight_decay, + const float gnorm_scale, const bool skip_zeros, const int n) +{ + + //const int n_full = n + (n%BLOCK_SIZE); + const int n_full = gridDim.x * BLOCK_SIZE; + const int base_idx = (blockIdx.x * BLOCK_SIZE); + int valid_items = 0; + float g_val = 0.0f; + float s1_vals[N_PER_TH]; + float s2_vals[N_PER_TH]; + // 2-5% + const float correction1 = 1.0f - __powf(beta1, step); + const float correction2 = sqrtf(1.0f -__powf(beta2, step)); + const float step_size = __fdividef(-lr*correction2,correction1); + const int lane_id = threadIdx.x % LANES; + float new_local_abs_max1 = -FLT_MAX; + float new_local_abs_max2 = -FLT_MAX; + float quadrants1[QUAD]; + float quadrants2[QUAD]; + + unsigned char c1s[N_PER_TH]; + unsigned char c2s[N_PER_TH]; + T g_vals[N_PER_TH]; + T p_vals[N_PER_TH]; + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockLoad LoadChar; + + typedef hipcub::BlockStore StoreChar; + typedef hipcub::BlockStore StoreT; + + __shared__ float smem_quantiles1[LANES][257]; + __shared__ float smem_quantiles2[LANES][257]; + typedef hipcub::BlockReduce BlockReduce1; + typedef hipcub::BlockReduce BlockReduce2; + __shared__ typename BlockReduce1::TempStorage reduce1; + __shared__ typename BlockReduce2::TempStorage reduce2; + __shared__ float smem_exchange1[1]; + __shared__ float smem_exchange2[1]; + + __shared__ union { + typename LoadT::TempStorage loadh; + typename LoadChar::TempStorage loadc; + typename StoreChar::TempStorage storec; + typename StoreT::TempStorage storeh; + } temp_storage; + // init: 0.2 -> 0.23 + + // 0.23 -> 0.23 + smem_quantiles1[0][threadIdx.x] = quantiles1[threadIdx.x]; + smem_quantiles2[0][threadIdx.x] = quantiles2[threadIdx.x]; + # pragma unroll + for(unsigned int j = 1; j < LANES; j++) + { + smem_quantiles1[j][threadIdx.x] = smem_quantiles1[0][threadIdx.x]; + smem_quantiles2[j][threadIdx.x] = smem_quantiles2[0][threadIdx.x]; + } + + __syncthreads(); + + #pragma unroll + for(int k = 0; k < QUAD; k++) + { + quadrants1[k] = smem_quantiles1[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)]; + quadrants2[k] = smem_quantiles2[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)]; + } + + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) + { + // loads: 0.23 -> 0.85/1.44 + valid_items = n - i >= BLOCK_SIZE ? BLOCK_SIZE : n - i; + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state2[i]), c2s, valid_items, 0); + + new_local_abs_max1 = -FLT_MAX; + new_local_abs_max2 = -FLT_MAX; + + // update: 2.48/1.57 -> 2.51/1.60 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + if(!isnan((float)g_vals[j]) && !isinf((float)g_vals[j])) + { + s2_vals[j] = smem_quantiles2[lane_id][c2s[j]]*absmax2[i/BLOCK_SIZE]; + g_val = g_vals[j]; + //float ratio = (g_val*g_val)/fmaxf(s2_vals[j], eps*eps); + //g_val = ratio > 2.0f ? 2.0f*g_val/ratio : g_val; + g_val *= gnorm_scale; + + s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val)); + + s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE]; + s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val)); + } + else + { + s1_vals[j] = 0.0f; + s2_vals[j] = 0.0f; + } + + new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j])); + new_local_abs_max2 = fmaxf(new_local_abs_max2, fabsf(s2_vals[j])); + } + + + // reduce: 2.51/1.60 -> 2.67/1.69 + new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, hipcub::Max()); + new_local_abs_max2 = BlockReduce2(reduce2).Reduce(new_local_abs_max2, hipcub::Max()); + + if(threadIdx.x == 0) + { + smem_exchange1[0] = new_local_abs_max1; + smem_exchange2[0] = new_local_abs_max2; + } + + __syncthreads(); + + if(threadIdx.x == 0) + { + absmax1[i/BLOCK_SIZE] = new_local_abs_max1; + absmax2[i/BLOCK_SIZE] = new_local_abs_max2; + } + else + { + new_local_abs_max1 = smem_exchange1[0]; + new_local_abs_max2 = smem_exchange2[0]; + } + + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f); + // reduce: 2.67/1.69 -> 2.67/1.70 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + //if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + if(!isnan((float)g_vals[j]) && !isinf((float)g_vals[j])) + { + p_vals[j] = (T)(((float)p_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps))))))); + if(weight_decay > 0.0f) + p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)); + } + } + + // store: 0.85/1.44 -> 2.48/1.57 + __syncthreads(); + StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); + + // quantizaztion: 2.67/1.70 -> 3.4/3.3 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s1_vals[j],new_local_abs_max1)); + c2s[j] = quantize_2D<0>(quadrants2, smem_quantiles2[lane_id], __fdividef(s2_vals[j],new_local_abs_max2)); + + // make sure state1 term has still the same sign after quantization + // (not needed for state2 term which has only positive values) + if(signbit(smem_quantiles1[lane_id][c1s[j]]) != signbit(s1_vals[j])) + { + if(s1_vals[j] > 0.0f) + c1s[j] += 1; + else + c1s[j] -= 1; + } + } + + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state2[i]), c2s, valid_items); + } +} + + +#define LANES 2 +#define QUAD 3 +template +__launch_bounds__(256, 3) +__global__ void +kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char* state1, + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, + float* absmax1, + float weight_decay, + const float gnorm_scale, const bool skip_zeros, const int n) +{ + + //const int n_full = n + (n%BLOCK_SIZE); + const int n_full = gridDim.x * BLOCK_SIZE; + const int base_idx = (blockIdx.x * BLOCK_SIZE); + int valid_items = 0; + float g_val = 0.0f; + float s1_vals[N_PER_TH]; + // 2-5% + const int lane_id = threadIdx.x % LANES; + float new_local_abs_max1 = -FLT_MAX; + float quadrants1[QUAD]; + + unsigned char c1s[N_PER_TH]; + T g_vals[N_PER_TH]; + T p_vals[N_PER_TH]; + + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockLoad LoadChar; + + typedef hipcub::BlockStore StoreChar; + typedef hipcub::BlockStore StoreT; + + __shared__ float smem_quantiles1[LANES][257]; + typedef hipcub::BlockReduce BlockReduce1; + __shared__ typename BlockReduce1::TempStorage reduce1; + __shared__ float smem_exchange1[1]; + + __shared__ union { + typename LoadT::TempStorage loadh; + typename LoadChar::TempStorage loadc; + typename StoreChar::TempStorage storec; + typename StoreT::TempStorage storeh; + } temp_storage; + // init: 0.2 -> 0.23 + + // 0.23 -> 0.23 + smem_quantiles1[0][threadIdx.x] = quantiles1[threadIdx.x]; + # pragma unroll + for(unsigned int j = 1; j < LANES; j++) + smem_quantiles1[j][threadIdx.x] = smem_quantiles1[0][threadIdx.x]; + + __syncthreads(); + + #pragma unroll + for(int k = 0; k < QUAD; k++) + quadrants1[k] = smem_quantiles1[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)]; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) + { + // loads: 0.23 -> 0.85/1.44 + valid_items = n - i >= BLOCK_SIZE ? BLOCK_SIZE : n - i; + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f); + + new_local_abs_max1 = -FLT_MAX; + + // update: 2.48/1.57 -> 2.51/1.60 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + g_val = float(g_vals[j]); + g_val *= gnorm_scale; + if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + { + if(weight_decay > 0.0f) { + switch(OPTIMIZER) { + case MOMENTUM: + case ADAGRAD: + case RMSPROP: + g_val += ((float)p_vals[j])*weight_decay; + break; + case LION: + p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay); + break; + } + } + + s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE]; + + switch(OPTIMIZER) + { + case MOMENTUM: + if(step == 1) + s1_vals[j] = g_val; + else + s1_vals[j] = (s1_vals[j]*beta1) + g_val; + break; + case LION: + // here, using gvals[j] to store the gradient smoothed by beta1 for the following parameter update, before the momentum is updated by beta2 + g_vals[j] = lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*g_val)); + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); + break; + case RMSPROP: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); + break; + case ADAGRAD: + s1_vals[j] = s1_vals[j] + (g_val*g_val); + break; + } + } + + new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j])); + } + + + // reduce: 2.51/1.60 -> 2.67/1.69 + new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, hipcub::Max()); + + if(threadIdx.x == 0) + smem_exchange1[0] = new_local_abs_max1; + + __syncthreads(); + + if(threadIdx.x == 0) + absmax1[i/BLOCK_SIZE] = new_local_abs_max1; + else + new_local_abs_max1 = smem_exchange1[0]; + + // reduce: 2.67/1.69 -> 2.67/1.70 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + { + switch(OPTIMIZER) + { + case MOMENTUM: + p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]); + break; + case LION: + p_vals[j] = ((float)p_vals[j]) - ((float)g_vals[j]); + break; + case RMSPROP: + g_val = g_vals[j]; + p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps)); + break; + case ADAGRAD: + g_val = g_vals[j]; + p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps)); + break; + } + } + } + + // store: 0.85/1.44 -> 2.48/1.57 + __syncthreads(); + StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); + + // quantizaztion: 2.67/1.70 -> 3.4/3.3 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s1_vals[j],new_local_abs_max1)); + + // make sure state1 term has still the same sign after quantization + // (not needed for state2 term which has only positive values) + if(signbit(smem_quantiles1[lane_id][c1s[j]]) != signbit(s1_vals[j])) + { + if(s1_vals[j] > 0.0f) + c1s[j] += 1; + else + c1s[j] -= 1; + } + } + + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); + } +} + +template __global__ void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols) +{ + // 0. reset stats to -FLT_MAX + // 1. load row-by-row ITEMS_PER_THREAD (TILE_SIZE==THREADS*ITEMS_PER_THREAD) + // 2. compute col max (per thread); store in smem due to register pressure + // 3. compute row max (per block); store in smem to accumulate full global mem transaction + // 4. store data via atomicMax + + // each block loads TILE_COLs columns and TILE_ROW rows + // after reading a tile the row counter increase by TILE_ROWS + // the col counter reset after reading TILE_COL elements + const int base_row = ((blockIdx.x*TILE_COLS)/tiledCols)*TILE_ROWS; + // col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached + const int base_col = (blockIdx.x*TILE_COLS) % tiledCols; + const int base_idx = (base_row*cols) + base_col; + const int items_per_load = ITEMS_PER_THREAD*THREADS; + + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockReduce BlockRowReduce; + typedef hipcub::BlockReduce BlockRowSum; + typedef hipcub::BlockExchange BlockExchange; + + __shared__ union { + typename BlockExchange::TempStorage exchange; + typename BlockRowReduce::TempStorage rowreduce; + typename BlockRowSum::TempStorage rowsum; + typename LoadT::TempStorage loadt; + } temp_storage; + + __shared__ float smem_row_absmax_values[ITEMS_PER_THREAD*THREADS]; + __shared__ int smem_row_nnz_values[TILE_ROWS]; + + half local_data[ITEMS_PER_THREAD]; + float local_data_fp32[ITEMS_PER_THREAD]; + float local_col_absmax_values[ITEMS_PER_THREAD]; + int local_row_nnz_count = 0; + float row_absmax = -FLT_MAX; + + // 0. reset stats to -FLT_MAX + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + //smem_col_absmax_values[threadIdx.x + (j*THREADS)] = -FLT_MAX; + smem_row_absmax_values[threadIdx.x + (j*THREADS)] = -FLT_MAX; + // smem_row_nnz_values[threadIdx.x + (j*THREADS)] = 0; + } + + #pragma unroll TILE_ROWS + for (int j = 0; j < TILE_ROWS; j++) { + smem_row_nnz_values[j] = 0; + } + + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + local_col_absmax_values[j] = -FLT_MAX; + + __syncthreads(); + + int valid_items = cols - base_col > items_per_load ? items_per_load : cols - base_col; + int i = base_idx; + // we load row after row from the base_position + // 1. load row-by-row ITEMS_PER_THREAD (TILE_SIZE==THREADS*ITEMS_PER_THREAD) + for(int row = 0; row < TILE_ROWS; row++) + { + if(base_row+row >= rows){ break; } + local_row_nnz_count = 0; + i = base_idx + ((row)*cols); + // each thread gets data from the same column + __syncthreads(); + LoadT(temp_storage.loadt).Load(&(A[i]), local_data, valid_items, __float2half(0.0f)); + + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + local_data[j] = fabsf(local_data[j]); + + + if(SPARSE_DECOMP) + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + if((float)local_data[j] >= nnz_threshold) + { + local_row_nnz_count += 1; + local_data[j] = 0.0f; + } + } + + // 2. compute col max (per thread); store in smem due to register pressure + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + // take the col max for this row + // we use shared memory because register pressure is too high if we do this locally + //smem_col_absmax_values[threadIdx.x + (j*THREADS)] = fmaxf(smem_col_absmax_values[threadIdx.x + (j*THREADS)], __half2float(local_data[j])); + local_col_absmax_values[j] = fmaxf(local_col_absmax_values[j], __half2float(local_data[j])); + + // 3. compute row max (per block); store in smem to accumulate full global mem transaction + + // this is slow as it uses extra registers, but we need this to be compatible with Kepler and Maxwell (no fp16 units) + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + local_data_fp32[j] = local_data[j]; + + __syncthreads(); + + row_absmax = (float)BlockRowReduce(temp_storage.rowreduce).Reduce(local_data_fp32, hipcub::Max()); + if(SPARSE_DECOMP) + { + __syncthreads(); + local_row_nnz_count = BlockRowSum(temp_storage.rowsum).Sum(local_row_nnz_count); + } + // we store the data temporarily in shared memory so we + // can execute a full atomic block transaction into global memory later + // we use a striped arrangement [0, 8, 16, 24, ..] for t0 for faster stores + if(threadIdx.x == 0) + { + smem_row_absmax_values[(row % ITEMS_PER_THREAD) + ((row/ITEMS_PER_THREAD)*ITEMS_PER_THREAD)] = row_absmax; + // each blockIdx.x process 16 rows and 64*4=256 columns -> we sum nnz over 256 columns and have 16 values per block + smem_row_nnz_values[row] = local_row_nnz_count; + } + + __syncthreads(); + + } + + // 4. store data via atomicMax + // to store col data efficiently we need to rewrite the smem blocked data [0, 1, 2, 3...] for t0 + // into a striped arrangement: [0, 8, 16, 24, ..] for t0 + __syncthreads(); + BlockExchange(temp_storage.exchange).BlockedToStriped(local_col_absmax_values); + + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + if(base_col+threadIdx.x+(j*THREADS) < cols) + { + float val = colStats[base_col+(threadIdx.x+(j*THREADS))]; + if(val < local_col_absmax_values[j]) + atomicMax(&colStats[base_col+(threadIdx.x+(j*THREADS))], local_col_absmax_values[j]); + } + + for(int j = 0; j < ITEMS_PER_THREAD; j++) + if(base_row+threadIdx.x+(j*THREADS) < rows) + { + float val = rowStats[base_row+(threadIdx.x+(j*THREADS))]; + if(val < smem_row_absmax_values[threadIdx.x+(j*THREADS)]) + atomicMax(&rowStats[base_row+(threadIdx.x+(j*THREADS))], smem_row_absmax_values[threadIdx.x+(j*THREADS)]); + } + + if(SPARSE_DECOMP) + if(threadIdx.x < TILE_ROWS) + nnz_count_row[blockIdx.x*TILE_ROWS+threadIdx.x+1] = smem_row_nnz_values[threadIdx.x]; + +} + +template __global__ void kgetColRowStats(half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols); +template __global__ void kgetColRowStats(half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols); + +#define MM_DEQUANT_CONST 6.200012e-05f //1.0f/(127.0f*127.0f) + +template __global__ void kdequant_mm_int32_fp16(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, half *__restrict__ const bias, const int numRows, const int numCols, const int n) +{ + + // Strategy: To dequantize we need to load col/row statistics. This can be very expensive + // since different row/col stats need to be loaded with each thread. + // (1, bad algorithm) Loading 32 items per thread would only occur 1 row load, but this increases register pressure + // and would lead to low global load utilization. + // (2, bad algorithm) If each thread loads some columns and multiple rows one needs to do lot of row loads + // for each thread and this is duplicated by a factor of 32/num-cols-per-thread. + // (3, good algorithm) Combining (1) and (2) we use sub-tiles of size 32xk in shared memory per threadblock. + // This allows for efficient row/col loading from shared memory within the tile. + // We can run for example 32x128 sub-tiles and warp-strided loads of 4 elements so that each thread has + // the same col statistic but needs to load 4 row stats from shared memory. To prevent bank conflicts + // we use a block-striped shared memory config [1, 31, 63, 95] so no bank conflicts happen during the + // shared memory loads. + + // data is in 32 column-tile major with tile width 32 columns and numRows rows + // L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory. + // L2. Load data in warp-striped arrangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3]) + // C1. Compute val(row_stat*col_stat)/(127*127) (load 1/(127*127 into register)) + // C2. Compute normalization values and store col values in register + // S1. Store C1 into 16-bit output + // S2. Store col/row statistics of new buffer in shared memory + + // We allow for sub-tiles to span multiple col32 tiles. This is okay + // since the items per thread only rely on a single column statistic. + + + const int n_out = numRows*numCols; + + //int num_row_tiles = (numRows/SUBTILE_ROWS) + (numRows % SUBTILE_ROWS == 0 ? 0 : 1); + // we have tiles of size numRows*32, thus col only increases every numRows + // num_row_tiles is the tiles after which the column increases by 32 + // blockIdx.x is the index of the current tile + //int col = ((threadIdx.x % 32) + ((blockIdx.x/num_row_tiles)*32)); + // base_row increases by SUBTILE_ROWS every block. It wraps back to zero once num_row_tiles is reached + //int base_row = (blockIdx.x*SUBTILE_ROWS) % (num_row_tiles*SUBTILE_ROWS); + + int block_offset = blockIdx.x * THREADS * ITEMS_PER_THREAD; + int thread_offset = threadIdx.x * ITEMS_PER_THREAD; + + // SUBTILE_ROWS is independent from ITEMS_PER_THREAD is independent from THREADS + // subtiles have 32*SUBTILE_ROWS elements <= THREADS*ITEMS_PER_THREAD + // Total subtiles should be n/(32*SUBTILE_ROWS) where each subtile has SUBTILE_ROW*32/4 threads. + // For example for a 1024x1024 matrix with 128 SUBTILE_ROWS and 4 ITEMS_PER_THREAD we have + // 1024*1024/(128*32) = 256 tiles + // 256 tiles are 256*128*32/4 = 256*1024 threads + + // 1. Figure out how index relates to the start of the sub-tile + // 2. Each thread < SUBTILE_ROWS calculates row index + // 3. Load striped and store in shared memory + + int local_values[ITEMS_PER_THREAD]; + half local_output[ITEMS_PER_THREAD]; + //float local_rowStats[ITEMS_PER_THREAD]; + //__shared__ float smem_rowStats[SUBTILE_ROWS]; + + typedef hipcub::BlockLoad LoadInt32; + //typedef hipcub::BlockExchange ExchangeInt32; + __shared__ typename LoadInt32::TempStorage loadint32; + //__shared__ typename ExchangeInt32::TempStorage exchangeint32; + + + // L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory. + //float colStat = col >= numCols ? 0.0f : colStats[col]; + //float local_biasValue = ((bias == NULL) || (col >= numCols)) ? 0.0f : __half2float(bias[col]); + int row_idx, col_idx; + float colStat[ITEMS_PER_THREAD]; + float local_biasValue[ITEMS_PER_THREAD]; + float rowStat[ITEMS_PER_THREAD]; + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + row_idx = (block_offset + thread_offset + j) / numCols; + col_idx = (block_offset + thread_offset + j) % numCols; + colStat[j] = col_idx >= numCols ? 0.0f : colStats[col_idx]; + local_biasValue[j] = ((bias == NULL) || (col_idx >= numCols)) ? 0.0f : __half2float(bias[col_idx]); + rowStat[j] = row_idx >= numRows ? 0.0f : rowStats[row_idx]; + } + // no block loads for rows for now -- keep it simple + /*for(int j = threadIdx.x; j < SUBTILE_ROWS; j+=blockDim.x) + { + // todo: is this global mem access slow due to overlaps or does the L1 cache work well here? + int row = (base_row+j) % numRows; // wrap around + // each warp accesses the same element, for four consequitive elements + // todo: update description about striped shared memory, it is not needed + // rowidx: [0, 1, 2, 3...] and each warp reads ITEMS_PER_THREAD consequitive elements + smem_rowStats[j] = rowStats[row]; + }*/ + __syncthreads(); + + int valid_items = block_offset + THREADS * ITEMS_PER_THREAD < n_out ? THREADS * ITEMS_PER_THREAD : n_out - block_offset; + LoadInt32(loadint32).Load(&(A[block_offset]), local_values, valid_items, 0); + + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + local_output[j] = __float2half((local_values[j]*MM_DEQUANT_CONST*rowStat[j]*colStat[j]) + local_biasValue[j]); + + // each block processes SUBTILE_ROWS*32 elements + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + int outIdx = block_offset + thread_offset + j; + if(outIdx< n_out) + out[outIdx] = local_output[j]; + } + /*const int items_per_load = THREADS*ITEMS_PER_THREAD; + const int rows_per_load = items_per_load/32; + + int subtile_base_row = (threadIdx.x / 32)*ITEMS_PER_THREAD; // row within the tile + int row_offset = 0; + // subtile_idx starts at the base_row*32 + the total offset for a full numRow*32 tile is passed + int subtile_start = (blockIdx.x/num_row_tiles)*(numRows*32) + (base_row*32); + for(int subtile_idx = subtile_start; subtile_idx < subtile_start + (SUBTILE_ROWS*32); subtile_idx+=items_per_load) + { + int valid_rows = numRows - (base_row+row_offset) > rows_per_load ? rows_per_load : numRows - (base_row+row_offset); + int valid_items = valid_rows*32; + if(valid_items <= 0) // the sub-tile might have more elements than the tile itself + break; + + // L2. Load data in warp-striped arrangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3]) + LoadInt32(loadint32).Load(&(A[subtile_idx]), local_values, valid_items, 0); + ExchangeInt32(exchangeint32).BlockedToWarpStriped(local_values, local_values); + + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + local_rowStats[j] = smem_rowStats[subtile_base_row+row_offset+j]; + + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + local_output[j] = __float2half((local_values[j]*MM_DEQUANT_CONST*local_rowStats[j]*colStat) + local_biasValue); + //absmax_col = fmax(fabsf(local_output[j]), absmax_col); + + // we store data in row major + // to store data efficiently, we want to use block exchange: [0, 32, 64, 92] -> [0, 1, 2, 3] + // so that each thread holds ITEMS_PER_THREAD consecutive items for each row + // this way throughput into storage is increased by a factor of ~2x + // for now we use a simple store + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + int outIdx = col + ((base_row+subtile_base_row+row_offset+j)*numCols); + if(outIdx< n_out && col < numCols) + out[outIdx] = local_output[j]; + } + + row_offset += rows_per_load; + }*/ +} + + +template __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols) +{ + // assumes TILE_SIZE == THREADS*ITEMS_PER_THREAD + // Each thread reads the same column but multiple rows + // Rows are loaded in shared memory and access is shared across the threadblock (broadcast) + + // 0. Load row stats data into shared memory; load col stat (1 fixed per thread) + // 1. Load data row by row (should be at least with TILE_SIZE = 512) + // 2. quantize data with row/col stats + // 3. Store data (TILE_SIZE = 512 is a bit slow, but should still be close enough to good performance) + + // each block loads TILE_COLs columns and TILE_ROW rows + // after reading a tile the row counter increase by TILE_ROWS + // the col counter reset after reading TILE_COL elements + const int base_row = ((blockIdx.x*TILE_COLS)/tiledCols)*TILE_ROWS; + // col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached + const int base_col = (blockIdx.x*TILE_COLS) % tiledCols; + const int base_idx = (base_row*cols) + base_col; + const int items_per_load = ITEMS_PER_THREAD*THREADS; + + typedef hipcub::BlockLoad LoadHalf; + __shared__ typename LoadHalf::TempStorage loadhalf; + typedef hipcub::BlockStore StoreInt8; + __shared__ typename StoreInt8::TempStorage storeint8; + + __shared__ float smem_row_stats[TILE_ROWS]; + __shared__ unsigned int smem_nnz_row_idx[TILE_ROWS]; + + half local_data[ITEMS_PER_THREAD]; + float local_col_stats[ITEMS_PER_THREAD]; + char local_quantized_data[ITEMS_PER_THREAD]; + + // 0. Load row stats data into shared memory; load col stat (1 fixed per thread) + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + if(base_col+(threadIdx.x*ITEMS_PER_THREAD) + j < cols) + local_col_stats[j] = __fdividef(127.0f, colStats[base_col+(threadIdx.x*ITEMS_PER_THREAD)+j]); + + for(int i = threadIdx.x; i < TILE_ROWS; i+=blockDim.x) + { + if(base_row + i < rows) + smem_row_stats[i] = rowStats[base_row+i]; + + if(SPARSE_DECOMP) + smem_nnz_row_idx[i] = nnz_block_ptr[(TILE_ROWS*blockIdx.x) + i]; + } + __syncthreads(); + + // we load row after row from the base_position + // 1. Load data row by row (should be at least with TILE_SIZE = 512) + for(int row = 0; row < TILE_ROWS; row++) + { + if(base_row + row >= rows){ break; } + int i = base_idx + (row*cols); + int valid_items = cols - base_col > items_per_load ? items_per_load : cols - base_col; + + + LoadHalf(loadhalf).Load(&(A[i]), local_data, valid_items, 0.0f); + float row_stat = __fdividef(127.0f, smem_row_stats[row]); + + // 2. quantize data with row/col stats + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + // we already pre-normalized the col/row stat: + // what this does is float/absmax*127 = int8 + if(SPARSE_DECOMP) + { + if(fabsf((float)local_data[j]) >= threshold) + { + local_quantized_data[j] = 0; + + int old_idx = atomicInc(&smem_nnz_row_idx[row], UINT_MAX); + + rowidx[old_idx] = base_row+row; + colidx[old_idx] = base_col+(threadIdx.x*ITEMS_PER_THREAD)+j; + val[old_idx] = local_data[j]; + } + else + { + local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*row_stat)); + } + } + else + local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*row_stat)); + } + + StoreInt8(storeint8).Store(&(out_row_normed[i]), local_quantized_data, valid_items); + + // 2. quantize data with row/col stats + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + // we already pre-normalized the col/row stat: + // what this does is float/absmax*127 = int8 + local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*local_col_stats[j])); + } + + __syncthreads(); + StoreInt8(storeint8).Store(&(out_col_normed[i]), local_quantized_data, valid_items); + + } +} + +template __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols) +{ + + // 0. Load data into 32*32 shared memory tiles + // 1. transpose / reorder in shared memory + // 2. store + + // COL32 FORMAT: + // rows*32 tiles + + // TURING FORMAT: + // 8*32 tiles with 4*4 subtiles + // the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*4 = 64 elements) + // the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero + // the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32]) + // the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column + // index increases by 32 + + // AMPERE FORMAT: + // 32*32 tiles with 8*32 subtiles. The rows are interleaved in pairs of two rows with offset of 8 between pairs of two rows: + // row idx (each number stands for 32 values): [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]... + // the tiles are column-major ordered, so after 1024*1024 values we process: A[32:64, 0:32] + + + // To have efficient loads and stores if we transpose we need 128 consequitive bytes which at 1 byte are 128 values + // As such we need: + // at least 32*4 shared memory tiles for col32; preferably 32*32 + // at least 32*6 shared memory tiles for col32_ampere: preferably 32*32 + // at least 32*8 shared memory tiles for col4_turing: preferably 32*32 + // for efficient loading of row major we need to load 128 elements and repeat this 32 items + // this would imply a 32x128 shared memory tile -> 4kb + // It is more efficient to have more than 1 warp, so with 64 threads we need 32x128 -> 8 kb + // we have 64k sharded mem per SM in Turing which is 8 blocks per SM which is 2*8 = 32 warps = 100% occupancy + // for turing and 50% for A100 and 75% for RTX 30s / A40 which is probably good enough + // register pressure should be low with: 8 registers from local memoryh per block and 64 registers per SM + // + // to make the shared memory work with that occupancy we might need to union the block loads/stores + + // each block loads TILE_COLs columns and TILE_ROW rows + // after reading a tile the row counter increase by TILE_ROWS + // the col counter reset after reading TILE_COL elements + const int base_row = ((blockIdx.x*TILE_COLS)/tiledCols)*TILE_ROWS; + // col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached + const int base_col = (blockIdx.x*TILE_COLS) % tiledCols; + const int base_idx = (base_row*cols) + base_col; + + // we load 128 bytes per warp with + // 32 rows for transposes that fill col32 types + // so that we can have contiguous stores + __shared__ char smem_data[32*33*ITEMS_PER_THREAD]; + char local_data[ITEMS_PER_THREAD]; + typedef hipcub::BlockExchange BlockExchange; + + // we load row after row from the base_position + // Load data row by row + int warps = blockDim.x/32; + int warp_id = threadIdx.x/32; + int warp_lane = threadIdx.x % 32; + int offset = 0; + + int smem_row = 0; + // each warp loads one row of 128 bytes + for(int row = warp_id; row < TILE_ROWS; row+=warps) + { + int i = base_idx + (row*cols); + // we load up to 128 bytes/items per load + int valid_items = cols - base_col > 32*ITEMS_PER_THREAD ? 32*ITEMS_PER_THREAD : cols - base_col; + + // 0. Load data into 32*32 shared memory tiles + if(base_row + row < rows) + { + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + int col_idx = warp_lane+(j*32); + if(col_idx < valid_items) + local_data[j] = A[i+col_idx]; + else + local_data[j] = 0; + } + } + else + { + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + local_data[j] = 0; + } + + if(TRANSPOSE) + { + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + int local_col = (32*j)+warp_lane; + //int local_row = row; + // store as 256x32 + smem_data[(local_col*33) + row] = local_data[j]; + } + } + else + { + // treat smem as 32x256, that is 32 rows and 256 columns + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + smem_data[row*32*ITEMS_PER_THREAD + (warp_lane) + (j*32)] = local_data[j]; + } + + + + smem_row += warps; + + // 1. transpose / reorder in shared memory + if(smem_row % 32 == 0) + { + smem_row = 0; + __syncthreads(); + + for(int subrow = warp_id; subrow < 32; subrow+=warps) + { + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + + switch(FORMAT) + { + case COL32: + if(TRANSPOSE) + { + // data lies in shared memory in the following way: + // row0 [col0 col1 ... col31] + // row1 [col0 col1 ... col31] + // ... + // + // As such we read consecutive entries with 256 threads (8rows x 32 columns) + // as j increase, the row increase by a factor of 8 + // We load 8 rows per subrow loop, and subrow increase by 8 per loop + // so we have an offset of 8 rows every loop or (subrow/warps)*8 = (subrow/8)*8 + const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j + const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps) + //const int local_row = warp_id; // each warp_id is one row + //const int block_row = base_col; // block offset for row + //const int local_col = warp_lane + //const int global_col = base_row; // block offset for col + if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows)) + { + // each row hae 32 columns and is offset by 1 to prevent bank conflict during storage into smem + char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane]; + + // each 32 columns we have new tile + // each tile has size outRows*32 and base_row is done in increments of 32 + offset = base_row*outRows; + out[offset + (base_col + jrow + subrow_loop_row)*32 + threadIdx.x] = data; + } + } + else + { + if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols)) + { + offset = (base_col/32)*(32*rows); + char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane]; + out[offset+(base_row+subrow)*32 + ((j)*rows*32)+warp_lane] = data; + } + } + break; + case COL_TURING: + // TURING FORMAT: + // 8*32 tiles with 4*4 subtiles + // the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*4 = 64 elements) + // the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero + // the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32]) + // the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column + // index increases by 32 + // + // [0 0 0 0, 2 2 2 2, 4 4 4 4, 6 6 6 6, 0 0 0 0 ...] + if(TRANSPOSE) + { + const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j + const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps) + //const int local_row = warp_id; // each warp_id is one row + //const int block_row = base_col; // block offset for row + //const int local_col = warp_lane + //const int global_col = base_row; // block offset for col + if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows)) + { + // each row hae 32 columns and is offset by 1 to prevent bank conflict during storage into smem + char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane]; + + // each 32 columns we have new tile + // each tile has size 8*32 = 256 elements offset + // for each row offset of 8 we increaes the tile first + // after all rows are exhausted, we increase the col + int row_offset = ((base_col+jrow+subrow_loop_row+warp_id)/8)*256; // global_row+jrow+subrow_loop_row+local_row, increase tile(=256) every 8 rows + + // we increase by row_tile_column every 32 columns + // base_row increase in increments of 32 + //int row_tile_column = 256*outRows/8; // there are outRows/8 row tiles, and each tile is 256 elements + //int col_offset = (base_row/32)*row_tile_column; + // -> we can remove the divisions to speed up compute since outRows is always a multiple of 8 + // 256*outRows/8*base_row/32 = outRows*base_row + int col_offset = outRows*base_row; + + offset = row_offset+col_offset; + + // since we process even number of rows with each j (8) and with each subrow (8j) we can determine + // odd or even rows with the warp_id (each warp processes one row) + // the col is warp_lane (max 32 columns per row) and the row warp_id + if(warp_id % 2 == 1) + // odd + offset += 128 + (warp_lane/4)*16 + (warp_lane%4) + (((warp_id%8)-1)*2); + else + // even + offset += 0 + (warp_lane/4)*16 + (warp_lane%4) + ((warp_id%8)*2); + + out[offset] = data; + } + } + else + { + if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols)) + { + char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane]; + // set offset designates the tile offset among the 8*32 tiles + // we first increase rows and then columns. Since we load 128 columns at once + // we increase the offset by outRows*32 every 32 columns + // additionally, we increase the offset by 8*32=256 every 8 rows + offset = ((base_col+(j*32))/32)*outRows*32 + (((base_row+subrow)/8)*256); // global offset (8x32 tile) + // first 4 rows are reserved for even rows, [0, 2, 4, 6], the next 4 for odd + // each of these has 32 values in total for 32*4 = 128 as offset if odd + // every set of 4 columns increases the total offset by 16 + // each even row increase the offset by 4, for example row 2 is offset by 4, 4 by 6 etc so: subrow/2*4 = subrow*2 + // this happens every 8 rows anew (subrow % 8) + // one writes 4 columns at once that is (col % 4) for the particular index in the subtile + int subcol = warp_lane; + + // add local offset (4x4 sub-tile) + if(subrow % 2 == 1) + // odd + offset += 128 + (subcol/4)*16 + (subcol%4) + (((subrow%8)-1)*2); + else + // even + offset += 0 + (subcol/4)*16 + (subcol%4) + ((subrow%8)*2); + + out[offset] = data; + } + } + break; + case COL_AMPERE: + // AMPERE FORMAT: + // 32*32 tiles with 8*32 subtiles. The rows are interleaved in pairs of two rows with offset of 8 between pairs of two rows: + // row idx (each number stands for 32 values): [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]... + // the tiles are column-major ordered, so after 1024*1024 values we process: A[32:64, 0:32] + if(TRANSPOSE) + { + const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j + const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps) + //const int local_row = warp_id; // each warp_id is one row + //const int block_row = base_col; // block offset for row + //const int local_col = warp_lane + //const int global_col = base_row; // block offset for col + if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows)) + { + // each row hae 32 columns and is offset by 1 to prevent bank conflict during storage into smem + char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane]; + + // each 32 columns we have new tile + // each tile has size 32*32 = 1024 elements offset + // for each row offset of 32 we increaes the tile first + // after all rows are exhausted, we increase the col + int row_offset = ((base_col+jrow+subrow_loop_row+warp_id)/32)*1024; // global_row+jrow+subrow_loop_row+local_row, increase tile(=256) every 8 rows + + // we increase by row_tile_column every 32 columns + // base_row increase in increments of 32 + //int row_tile_column = 1024*outRows/32; // there are outRows/32 row tiles, and each tile is 1024 elements + //int col_offset = (base_row/32)*row_tile_column; + // -> we can remove the divisions to speed up compute since outRows is always a multiple of 8 + // 1024*outRows/32*base_row/32 = outRows*base_row + int col_offset = outRows*base_row; + + offset = row_offset+col_offset; + + + // same as in the non-transpose case (see below) + // the difference is that now rows = cols + // in this case warp_id = subrow + + // [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]... + // subrow % 8 -> [0,1] in tile0, [2, 3] in tile 1 etc + // subrow % 2 -> 0 for 1st row in the pair, 1 for the 2nd row + // every 2 rows, the offset increases by two [0, 1, 8, 9...] + // every 2 rows, the row index increase by 8 [0, 1, 8, 9...] + int local_row = (jrow + warp_id) % 32; // offset for row > 32 is already calculated into row_offset + int ampere_row = ((local_row % 8)/2)*8 + (local_row/8)*2 + (local_row % 2); + + // global offset + row with 32 cols each + 32 cols per j + col_idx=warp_lane + out[offset + (ampere_row*32) + warp_lane] = data; + } + } + else + { + if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols)) + { + char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane]; + + // set offset designates the tile offset among the 32*32 tiles + // we first increase rows and then columns. Since we load 128 columns at once + // we increase the offset by outRows*32 every 32 columns + // additionally, we increase the offset by 32*32=1024 every 32 rows + offset = ((base_col+(j*32))/32)*outRows*32 + (((base_row+subrow)/32)*1024); // global offset (32x32 tile) + + // [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]... + // subrow % 8 -> [0,1] in tile0, [2, 3] in tile 1 etc + // subrow % 2 -> 0 for 1st row in the pair, 1 for the 2nd row + // every 2 rows, the offset increases by two [0, 1, 8, 9...] + // every 2 rows, the row index increase by 8 [0, 1, 8, 9...] + int local_row = ((subrow % 8)/2)*8 + (subrow/8)*2 + (subrow % 2); + + // global offset + row with 32 cols each + 32 cols per j + col_idx + out[offset + (local_row*32) + warp_lane] = data; + } + } + break; + } + } + } + } + } +} + +#define DENORM 1.0f/127.0f +#define MAX_SPARSE_COUNT 32 +#define SMEM_SIZE 8*256 +template +__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB) +{ + + // 0. load balancing: We process rows with most columns first (count_vec)and we process one row per block + // If a block finishes, the next one is scheduled. Since the last blocks like have fewer + // elements they finish faster "fillin up" the gaps left by larger blocks + + // without tensor cores + // 1. use rowidx_length to find what to load (as many blocks as there are rows) + // 2. Load A into registers + // 3. each warp loads all required rows of B but each warp is offset by k + // 4. Do mma operations that accumulate into registers + // 5. Each warp stores its output row into matrix C + + const int count = max_count[blockIdx.x]; + const int local_max_idx = max_idx[blockIdx.x]; + const int offset = local_max_idx == 0 ? 0 : offset_rowidx[local_max_idx-1]; + const int local_row_idx = rowidx[offset]; + + const int warp_id = threadIdx.x / 32; + const int warp_idx = threadIdx.x % 32; + const int warp_offset = (warp_id*32)*SPMM_ITEMS; + const int num_items = BITS == 8 ? 8 : 8; + int idx_col_B = warp_offset; + int local_idx_col_B_offset = 0; + + half local_valA[MAX_SPARSE_COUNT]; + int local_colidxA[MAX_SPARSE_COUNT]; + half local_valC[SPMM_ITEMS]; + T local_valsB[num_items]; + half local_valOut[num_items]; + // 128 byte loads per warp == 4 bytes per thread + + // 2. Load A into registers + for(int j = 0; j < MAX_SPARSE_COUNT; j++) + { + local_valA[j] = j < count ? values[offset+j] : __float2half(0.0f); + local_colidxA[j] = j < count ? colidx[offset+j] : 0; + } + + // each thread processes SPMM_ITEMS=32 per iteration. We have 256 threads. 32*256=x192 + // we expect each warp to be SPMM_ITEMS*32 apart + // we have a total of 128 bytes for the bank with a bank size of 4 bytes + // added 3 bytes = 6 values between warps should reduce bank conflicts + __shared__ half smem_dequant_stats[SMEM_SIZE]; + + + while(idx_col_B < colsB) + { + + if(dequant_stats != NULL) + { + for(int i = threadIdx.x; i < SMEM_SIZE; i+=blockDim.x) + if((idx_col_B+i-local_idx_col_B_offset) < colsB) + smem_dequant_stats[i] = dequant_stats[idx_col_B+i-local_idx_col_B_offset]; + + __syncthreads(); + } + + #pragma unroll SPMM_ITEMS + for(int j = 0; j < SPMM_ITEMS; j++) + local_valC[j] = 0.0f; + + #pragma unroll + for(int i = 0; i < count; i++) + { + // 3. each warp loads all required rows of B but each warp is offset by k + int row_offset = colsB*local_colidxA[i]; + + #pragma unroll SPMM_ITEMS + for(int j = 0; j < SPMM_ITEMS; j+=num_items) + { + // 4. Multiply the tile -> accumulate outputs in shared memory until 128 bytes it reached + int idx = idx_col_B + (warp_idx*SPMM_ITEMS) + j; + if(idx >= colsB){ break; } + if((idx+num_items < colsB)) + { + if(BITS == 8) + reinterpret_cast(local_valsB)[0] = reinterpret_cast(B)[(row_offset+ idx)/num_items]; + else + reinterpret_cast(local_valsB)[0] = reinterpret_cast(B)[(row_offset+ idx)/num_items]; + } + else + { + #pragma unroll num_items + for(int k = 0; k < num_items; k++) + if(idx+k < colsB) + local_valsB[k] = B[row_offset+idx+k]; + else + local_valsB[k] = 0.0f; + } + #pragma unroll num_items + for(int k = 0; k < num_items; k++) + { + if(BITS == 8 && dequant_stats != NULL) + // we do texture cache reads (__ldg) on dequant_stats which should be super fast + { + float valB = local_valsB[k]; + float valA = local_valA[i]; + if(valB != 0.0 && valA != 0.0) + local_valC[j+k] = (float)local_valC[j+k] + ((float)smem_dequant_stats[idx+k-local_idx_col_B_offset])*DENORM*valB*valA; + } + else + local_valC[j+k] = (float)local_valC[j+k] + (float)local_valsB[k]*(float)local_valA[i]; + } + } + } + + int idx_row_C = (colsB*local_row_idx); + + #pragma unroll SPMM_ITEMS + for(int j = 0; j < SPMM_ITEMS; j+=num_items) + { + //int idx_col_C = idx_col_B + (32*j) + warp_idx; + int idx_col_C = idx_col_B + warp_idx*SPMM_ITEMS + j; + int idx_val = idx_col_C + idx_row_C; + + if(idx_col_C +num_items < colsB) + { + + // load outputs to do inplace addition + reinterpret_cast(local_valOut)[0] = reinterpret_cast(out)[idx_val/num_items]; + + #pragma unroll num_items + for(int k = 0; k < num_items; k++) + local_valC[(j/num_items) + k] = (float)local_valC[(j/num_items) + k] + (float)local_valOut[k]; + + reinterpret_cast(out)[idx_val/num_items] = reinterpret_cast(local_valC)[j/num_items]; + } + else + { + #pragma unroll num_items + for(int k = 0; k < num_items; k++) + if(idx_col_C + k < colsB) + out[idx_val+k] = (float)out[idx_val+k]+(float)local_valC[j+k]; + } + } + + idx_col_B += blockDim.x*SPMM_ITEMS; + local_idx_col_B_offset += blockDim.x*SPMM_ITEMS; + } +} + +template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA) +{ + int local_colidx = idx[blockIdx.x]; + + /*if(FORMAT==COL_TURING) + { + // TURING FORMAT: + // 8*32 tiles with 4*4 subtiles + // the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*8 = 128 elements) + // the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero + // the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32]) + // the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column + // index increases by 32 + // columns are grouped in increments of 4, meaning that one has the following rows and columns + // rows: [0 0 0 0, 2 2 2 2, 4 4 4 4, 6 6 6 6, 0 0 0 0 ...] + // cols: [0 1 2 3, 0 1 2 4, 0 1 2 3, 0 1 2 3, 4 5 6 7 ...] + + // each thread reads 1 element = 1 row + for(int row = threadIdx.x; row < rowsA; row+= blockDim.x) + { + int offset_per_col_tile = ((rowsA+7)/8)*32*8; + int tile_offset_rows = (row/8)*32*8; + int tile_offset_cols = (local_colidx/32)*offset_per_col_tile; + int offset = 0; + int subtile_col_idx = local_colidx%32; + int subtile_row_idx = row % 8; + if(row % 2 == 1) + offset += 128 + (subtile_col_idx/4)*16 + (subtile_col_idx%4) + ((subtile_row_idx-1)*2); + else + // even + offset += 0 + (subtile_col_idx/4)*16 + (subtile_col_idx%4) + (subtile_row_idx*2); + + offset += tile_offset_rows + tile_offset_cols; + + char val = A[offset]; + + int out_idx = (row*idx_size) + blockIdx.x; + out[out_idx] = val; + } + } + else if(FORMAT == COL_AMPERE) + { + + for(int row = threadIdx.x; row < rowsA; row+= blockDim.x) + { + // we got 32x32 tiles and we use the magic equation from the cublasLt doc to get the element + // within each tile. + int offset_per_col_tile = ((rowsA+31)/32)*32*32; + int tile_offset_rows = (row/32)*32*32; + int tile_offset_cols = (local_colidx/32)*offset_per_col_tile; + int subtile_col_idx = local_colidx%32; + int subtile_row_idx = row % 32; + // this magic is taken from the cublasLt doc (search for COL32) + int offset = (((subtile_row_idx%8)/2*4+subtile_row_idx/8)*2+subtile_row_idx%2)*32+subtile_col_idx; + offset += tile_offset_cols + tile_offset_rows; + + char val = A[offset]; + int out_idx = (row*idx_size) + blockIdx.x; + out[out_idx] = val; + } + }*/ + + //Only col format is used on ROCm + for(int row = threadIdx.x; row < rowsA; row+= blockDim.x) + { + //col-major offset + int offset = local_colidx * rowsA + row; + + char val = A[offset]; + int out_idx = (row*idx_size) + blockIdx.x; + out[out_idx] = val; + } +} + + +//template __global__ void kMatmul_inference_4bit(INPT *A, unsigned char *B, OUTT *out, int lda, int ldb, int rowsA, int colsA, int colsB) +//{ +//// element-wise kernel +//// 1. Load batch x k into registers +//// 2. Load k x k into registers +//// 3. dequantize and store in second pair of k x k +//// 4. matmul +//// 5. sum with hipcub +//// 6. store outputs +//// TC kernel +//// use k warps per thread block +//// 1. threadblock use read-only cache to read in register tile for A into shared memory +//// 2. each warp loops over shared memory tiles of A of size 8x16 and loads them into fragments +//// 3. each warp reads a segment of values 16x32 from B +//// 4. do dequantization from register of B into second pair of registers +//// 5. store (4) into fragment +//// 6. matmul aggregate into fragment C +//// 7. aggregate files of C into shared memory block C +//// 8. sum (7) +//// 9. write outputs to matmul output matrix +//} + +template __device__ inline void vector_load(T *local, T * __restrict__ const buffer, int idx, int limit_base, int limit, float zero_value = 0.0f) +{ + if(limit_base + ITEMS <= limit) + reinterpret_cast(local)[0] = reinterpret_cast(buffer)[idx/ITEMS]; + else + { + for(int k = 0; k < ITEMS; k++) + { + if(limit_base + k < limit) + local[k] = buffer[idx+k]; + else + local[k] = (T)zero_value; + } + } +} + +#define WARPS 3 +template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc) +{ + +#if __CUDA_ARCH__ >= 750 + using namespace nvcuda; + int col_offset = blockIdx.x *32; + const int warp_id = threadIdx.x / 32; + const int half_warp_id = threadIdx.x / 16; + const int half_warp_lane = threadIdx.x % 16; + const int batch_size_warps = (WARPS-1)*2; + const int val_per_iter = blockDim.x-32; + + T local_A[4]; + T local_B[128]; + + const int a_tile_offset = 16; + const int b_tile_offset = (16*32 + 16); + + __shared__ T smem_A[8*16 + (2*16*(batch_size_warps-1))]; + __shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))]; + //__shared__ T smem_C[8*32]; + + rocwmma::fragment a_frag; + rocwmma::fragment b_frag; + rocwmma::fragment c_frag; + rocwmma::fill_fragment(c_frag, 0.0f); + + int ticktock = 0; + int idx = 0 + threadIdx.x; + int loaded_values = 0; + // prefetch + if(idx < K && warp_id < (WARPS-1)) + { + if(loaded_values == 0) + { + local_A[0] = A[idx]; + local_A[1] = A[idx+(1*val_per_iter)]; + local_A[2] = A[idx+(2*val_per_iter)]; + local_A[3] = A[idx+(3*val_per_iter)]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + { + local_B[col] = B[(col_offset+col)*ldb+idx]; + local_B[col+32] = B[(col_offset+col)*ldb+idx+(1*val_per_iter)]; + local_B[col+64] = B[(col_offset+col)*ldb+idx+(2*val_per_iter)]; + local_B[col+96] = B[(col_offset+col)*ldb+idx+(3*val_per_iter)]; + } + loaded_values = 3; + } + else + { + + if(loaded_values == 3) + { + local_A[0] = local_A[1]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(32)]; + } + else if(loaded_values == 2) + { + local_A[0] = local_A[2]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(64)]; + } + else + { + local_A[0] = local_A[3]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(96)]; + } + loaded_values--; + } + + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; + } + else if(warp_id < (WARPS-1)) + { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; + } + ticktock = ticktock == 0 ? 1 : 0; + + //for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) + for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) + { + idx = base_idx + threadIdx.x; + + __syncthreads(); + if(idx < K && warp_id < (WARPS-1)) + { + //local_A[0] = A[idx]; + + //#pragma unroll 32 + //for(int col = 0; col < 32; col++) + // local_B[col] = B[(col_offset+col)*ldb+idx]; + if(loaded_values == 0) + { + local_A[0] = A[idx]; + local_A[1] = A[idx+(1*val_per_iter)]; + local_A[2] = A[idx+(2*val_per_iter)]; + local_A[3] = A[idx+(3*val_per_iter)]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + { + local_B[col] = B[(col_offset+col)*ldb+idx]; + local_B[col+32] = B[(col_offset+col)*ldb+idx+(1*val_per_iter)]; + local_B[col+64] = B[(col_offset+col)*ldb+idx+(2*val_per_iter)]; + local_B[col+96] = B[(col_offset+col)*ldb+idx+(3*val_per_iter)]; + } + loaded_values = 3; + + } + else + { + + if(loaded_values == 3) + { + local_A[0] = local_A[1]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(32)]; + } + else if(loaded_values == 2) + { + local_A[0] = local_A[2]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(64)]; + } + else + { + local_A[0] = local_A[3]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(96)]; + } + loaded_values--; + } + + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; + } + else if(warp_id < (WARPS-1)) + { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; + } + ticktock = ticktock == 0 ? 1 : 0; + + if(warp_id == (WARPS-1)) + for(int k = 0; k < batch_size_warps; k++) + { + rocwmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + rocwmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + rocwmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + } + } + + __syncthreads(); + if(warp_id != (WARPS-1)){ return; } + // only warp_id == (WARPS-1) from here + int warp_lane = threadIdx.x % 32; + + ticktock = ticktock == 0 ? 1 : 0; + for(int k = 0; k < batch_size_warps; k++) + { + rocwmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + rocwmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + rocwmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + } + + // 129 mu + if(warp_id == (WARPS-1)) + rocwmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, rocwmma::mem_row_major); + + if(col_offset + warp_lane < M) + out[col_offset + warp_lane] = smem_A[warp_lane]; +#endif +} + + +template __device__ void printnonzero(T *A, int num_values, const char * strval) +{ + for(int i = 0; i < num_values; i++) + if((float)A[i] != 0.0) + printf("%s %i %f\n", strval, i, (float)A[i]); +} + +template __device__ void printnonzero(float *A, int num_values, const char*strval); +template __device__ void printnonzero(half *A, int num_values, const char*strval); + +__device__ static float nf4_data[16] = {-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0}; +template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) +{ + +#if __CUDA_ARCH__ >= 750 + using namespace nvcuda; + int col_offset = blockIdx.x *32; + const int warp_id = threadIdx.x / 32; + const int warp_idx = threadIdx.x % 32; + const int half_warp_id = threadIdx.x / 16; + const int half_warp_lane = threadIdx.x % 16; + const int batch_size_warps = (WARPS-1)*2; + + T quant_map[16]; + + #pragma unroll 16 + for(int i = 0; i < 16; i++) + quant_map[i] = nf4_data[i]; + //__shared__ T quant_map[16*160]; + + T local_A[2]; + T local_B[64]; + unsigned char local_B_4bit[32]; + + + const int a_tile_offset = 16; + const int b_tile_offset = (16*32 + 16); + + __shared__ T smem_A[8*16 + (16*(batch_size_warps-1))]; + __shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))]; + __shared__ T smem_C[8*32]; + + rocwmma::fragment a_frag; + rocwmma::fragment b_frag; + rocwmma::fragment c_frag; + rocwmma::fill_fragment(c_frag, 0.0f); + + for(int i = threadIdx.x; i < (8*32); i+=blockDim.x) + smem_C[i] = 0.0f; + + __syncthreads(); + + int ticktock = 0; + int idx = 0 + threadIdx.x; + int loaded_values = 0; + // prefetch + if(idx < K && warp_id < (WARPS-1)) + { + if(loaded_values == 0) + { + local_A[0] = A[idx]; + local_A[1] = A[idx+blockDim.x-32]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B_4bit[col] = B[(col_offset+col)*ldb+idx]; + + loaded_values = 1; + } + else + { + local_A[0] = local_A[1]; + loaded_values--; + + #pragma unroll 64 + for(int col = 0; col < 64; col+=2) + { + //local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(1.0f); + //local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(1.0f); + //local_B[col] = d2DequantizeFP4(local_B_4bit[col/2] >> 4)*(float)(17.0); + //local_B[col+1] = d2DequantizeFP4(local_B_4bit[col/2] & 0x0F)*(float)(17.0); + //local_B[col] = 127*(local_B_4bit[col/2] >> 4)*(float)(17.0); + //local_B[col+1] = 127*(local_B_4bit[col/2] & 0x0F)*(float)(17.0); + + //local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(17.0); + //local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(17.0); + local_B[col] = quant_map[160*(local_B_4bit[col/2] >> 4)+warp_idx]*T(17.0); + local_B[col+1] = quant_map[160*(local_B_4bit[col/2] & 0x0F)+warp_idx]*T(17.0); + } + } + + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; + } + else if(warp_id < (WARPS-1)) + { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; + } + ticktock = ticktock == 0 ? 1 : 0; + //if(threadIdx.x == 0) + //printf("aa %i %i\n", idx, loaded_values); + + //for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) + for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) + { + idx = base_idx + threadIdx.x; + //if(threadIdx.x == 0) + //printf("%i %i\n", idx, loaded_values); + + //__syncthreads(); + if(idx < K && warp_id < (WARPS-1)) + { + if(loaded_values == 0) + { + local_A[0] = A[idx]; + local_A[1] = A[idx+blockDim.x-32]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + { + local_B_4bit[col] = B[(col_offset+col)*ldb+idx]; + local_B_4bit[col+16] = B[(col_offset+col)*ldb+idx]; + } + + loaded_values = 1; + } + else + { + local_A[0] = local_A[1]; + loaded_values--; + + int absidx = (idx + col_offset)/blocksize; + half local_absmax = __ldg(&(absmax[absidx])); + + #pragma unroll 64 + for(int col = 0; col < 64; col+=2) + { + //local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(absidx); + //local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(absidx); + //local_B[col] = T(127)*T(local_B_4bit[col/2] >> 4)*T(absidx); + //local_B[col+1] = T(127)*T(local_B_4bit[col/2] & 0x0F)*T(absidx); + + //local_B[col] = quant_map[160*(local_B_4bit[col/2] >> 4)+warp_idx]*T(local_absmax); + //local_B[col+1] = quant_map[160*(local_B_4bit[col/2] & 0x0F)+warp_idx]*T(local_absmax); + local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(absidx); + local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(absidx); + } + //printnonzero(local_B, 128, ""); + } + + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; + } + else if(warp_id < (WARPS-1)) + { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; + } + ticktock = ticktock == 0 ? 1 : 0; + + if(warp_id == (WARPS-1)) + for(int k = 0; k < batch_size_warps; k++) + { + rocwmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + rocwmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + rocwmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + } + } + + __syncthreads(); + //if(threadIdx.x == 0) + //{ + // printnonzero(smem_A, 8*16 + (2*16*(batch_size_warps-1)), "A: "); + // printnonzero(smem_B, 2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1)), "B: "); + //} + if(warp_id != (WARPS-1)){ return; } + // only warp_id == (WARPS-1) from here + int warp_lane = threadIdx.x % 32; + + ticktock = ticktock == 0 ? 1 : 0; + for(int k = 0; k < batch_size_warps; k++) + { + //if(warp_lane == 0) + //printf("%i %i %i %i\n", (ticktock*batch_size_warps + k)*a_tile_offset, k, ticktock, threadIdx.x); + rocwmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + rocwmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + rocwmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + } + + // 129 mu + if(warp_id == (WARPS-1)) + rocwmma::store_matrix_sync(&(smem_C[0]), c_frag, 32, rocwmma::mem_row_major); + + //printnonzero(smem_C, 32, ""); + + if(col_offset + warp_lane < M) + out[col_offset + warp_lane] = smem_C[warp_lane]; +#endif +} + +#define num_values_4bit 32 +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize) +{ + + // per threadblock: + // load step-by-step in chunks of [32,warps]: 1x32 * [32,warps] -> [1,warps] + // 4 warps -> 4 loads per iter + // 1x32 * 32x4 -> 1x4 outputs per thread block + typedef hipcub::WarpReduce WarpReduce; + __shared__ typename WarpReduce::TempStorage temp_storage[THREADS/32]; + + const int warp_idx = threadIdx.x / 32; + const int warp_lane = threadIdx.x % 32; + const int row_B = (THREADS/32)*blockIdx.x + warp_idx; + const int num_values_8bit = num_values_4bit/2; + float local_C = 0.0f; + + unsigned char local_B_4bit[num_values_8bit]; + T local_B[num_values_4bit/4]; + T local_A[num_values_4bit/4]; + __shared__ T quant_map[16]; + T local_absmax = T(0.0f); + + for(int i = threadIdx.x; i < 16; i++) + quant_map[i] = T(datatype[i]); + __syncthreads(); + + // A: [1, K] + // B: [N, K] + for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += 32*num_values_4bit) + { + int inner_idx_halved = inner_idx/2; + int offset_B = ldb*row_B; + int absidx = ((2*offset_B)+inner_idx)/blocksize; + local_absmax = __ldg(&(absmax[absidx])); + + if(row_B < M) + { + if((inner_idx_halved + num_values_8bit) < (K/2)) + { + // this is the most important for performance considerations + reinterpret_cast(local_B_4bit)[0] = reinterpret_cast(B)[(offset_B+(inner_idx_halved))/(num_values_8bit)]; + } + else + { + #pragma unroll + for(int j = 0; j < (num_values_8bit); j++) + if((inner_idx_halved) + j < (K/2)) + local_B_4bit[j] = B[offset_B+inner_idx_halved + j]; + else + local_B_4bit[j] = 0b01110111; + } + } + else + { + #pragma unroll + for(int j = 0; j < (num_values_8bit); j++) + local_B_4bit[j] = 0b01110111; + } + + for(int i = 0; i < 4; i++) + { + #pragma unroll + for(int k = 0; k < num_values_8bit/4; k++) + { + #if __CUDA_ARCH__ >= 800 + local_B[k*2] = quant_map[local_B_4bit[(i*num_values_8bit/4) + k] >> 4]*local_absmax; + local_B[k*2 + 1] = quant_map[local_B_4bit[(i*num_values_8bit/4) + k] & 0x0F]*local_absmax; + #else + // bf16 multipliation not supported + local_B[k*2] = T((float)quant_map[local_B_4bit[(i*num_values_8bit/4) + k] >> 4]*(float)local_absmax); + local_B[k*2 + 1] = T((float)quant_map[local_B_4bit[(i*num_values_8bit/4) + k] & 0x0F]*(float)local_absmax); + #endif + } + + if(inner_idx+(num_values_4bit/4) + (i*num_values_4bit/4) < K) + { + // this is also relatively important for performance + if(BITS==16) + { + reinterpret_cast(local_A)[0] = reinterpret_cast(A)[inner_idx/(num_values_4bit/4) + i]; + } + else + { + reinterpret_cast(local_A)[0] = reinterpret_cast(A)[inner_idx/(num_values_4bit/8) + (2*i) + 0]; + reinterpret_cast(local_A)[1] = reinterpret_cast(A)[inner_idx/(num_values_4bit/8) + (2*i) + 1]; + } + + } + else + #pragma unroll + for(int k = 0; k < num_values_4bit/4; k++) + if(inner_idx + (i*num_values_4bit/4) + k < K) + local_A[k] = A[inner_idx + k + (i*num_values_4bit/4)]; + else + local_A[k] = T(0.0f); + + + // accumulate in float; small performance hit for Ampere, but lower error for outputs + #pragma unroll + for(int k = 0; k < num_values_4bit/4; k++) + { + #if __CUDA_ARCH__ >= 800 + local_C += (float)(local_A[k]*local_B[k]); + #else + // bf16 multipliation not supported + local_C += ((float)local_A[k]*(float)local_B[k]); + #endif + } + } + } + + local_C = WarpReduce(temp_storage[warp_idx]).Sum(local_C); + + if(row_B < M && warp_lane == 0) + out[row_B] = T(local_C); + +} + + +//#define ROWS 2 +//template __global__ void gemm_device(int M, int N, int K, T const* A, T* B, T * out, int lda, int ldb, int ldc) +//{ +//// 0. We want to fill a 8x128 tile for a thread block so we have 8x16 tile for each warp +//// 1. Load dataB into register +//// 2. Dequantize B +//// 3. Fetch data from A and multiply +// +// typedef hipcub::BlockLoad LoadA; +// //__shared__ typename LoadA::TempStorage loada; +// typedef hipcub::BlockLoad LoadB; +// //__shared__ typename LoadB::TempStorage loadb; +// typedef hipcub::BlockReduce BlockReduce; +// // Allocate shared memory for BlockReduce +// //__shared__ typename BlockReduce::TempStorage reduce; +// +// __shared__ union { +// typename BlockReduce::TempStorage reduce; +// typename LoadB::TempStorage loadb; +// typename LoadA::TempStorage loada; +// } temp_storage; +// +// +// T dataA[ITEMS]; +// T local_B[ITEMS]; +// T local_accC[ROWS]; +// int valid_items = 0; +// const int col_offset = blockIdx.x * 8; +// +// __shared__ T tileA[ROWS*THREADS*ITEMS]; +// __shared__ T accumulatorC[ROWS*8]; +// +// //#pragma unroll 8 +// //for(int i = 0; i < 8; i++) +// // tileA[threadIdx.x + (i*256)] = 0.0f; +// //__syncthreads(); +// if(threadIdx.x < 64) +// accumulatorC[threadIdx.x] = 0.0f; +// __syncthreads(); +// +// +// for(int inner_idx = 0; inner_idx < K; inner_idx+= THREADS*ITEMS) +// { +// valid_items = K - inner_idx > THREADS*ITEMS ? THREADS*ITEMS : K - inner_idx; +// int baserow = 0; +// for(int row = baserow; row < (baserow+ROWS) && row < N; row++) +// { +// LoadA(temp_storage.loada).Load(&(A[(row*K) + inner_idx]), dataA, valid_items, 0.0f); +// +// #pragma unroll ITEMS +// for(int k = 0; k < ITEMS; k++) +// tileA[row*THREADS*ITEMS + threadIdx.x + (k*THREADS)] = dataA[k]; +// +// __syncthreads(); +// } +// baserow += ROWS; +// +// // load 16 columns from B at a time. B is transposed, so its like loading rows +// // each warp loads one row +// // each thread loads 128 byte +// +// // col: inner_idx + warp_lane +// // row: ldb*(offset + warp_id) +// for(int col = 0; col < 8 && (col_offset + col) < M; col++) +// { +// int colB = col_offset + col; +// +// for(int k = 0; k < ROWS; k++) +// local_accC[k] = 0.0f; +// +// int base_idxB = ldb*colB; +// valid_items = K - inner_idx > THREADS*ITEMS ? THREADS*ITEMS : K - inner_idx; +// LoadB(temp_storage.loadb).Load(&(B[base_idxB + inner_idx]), local_B, valid_items, 0.0f); +// __syncthreads(); +// +// for(int row = 0; row < ROWS && row < N; row++) +// { +// #pragma unroll ITEMS +// for(int k = 0; k < ITEMS; k++) +// { +// int idxA = row*THREADS*ITEMS + threadIdx.x + (THREADS*k); +// local_accC[row] += tileA[idxA]*local_B[k]; +// } +// +// local_accC[row] = BlockReduce(temp_storage.reduce).Reduce(local_accC[row], hipcub::Sum()); +// if(threadIdx.x == 0) +// atomicAdd(&accumulatorC[row*8 + col], local_accC[row]); +// } +// } +// } +// +// for(int row = 0; row < ROWS && row < N; row++) +// { +// int out_idx = ldc*row + col_offset; +// +// //if(threadIdx.x < 8) +// // if(accumulatorC[row*8 + threadIdx.x] != 0.0) +// // printf("%i %i %i %i %f idx %i %i %i\n", row, col_offset, threadIdx.x, N, accumulatorC[row*8 + threadIdx.x], ldc, out_idx, blockIdx.x); +// +// if(threadIdx.x < 8 && (col_offset + threadIdx.x) < M) +// { +// //printf("%i %i %i %i %f idx %i %i\n", row, col_offset, threadIdx.x, N, accumulatorC[row*8 + threadIdx.x], ldc, out_idx); +// out[out_idx + threadIdx.x] = accumulatorC[row*8 + threadIdx.x]; +// } +// } +// +// +// +//} + + +template __global__ void kfunc(T *A, T *B, T value, long n) +{ + for(long i = (blockDim.x*blockIdx.x) + threadIdx.x; i < n; i+=(blockDim.x*gridDim.x)) + { + switch(FUNC) + { + case FILL: + A[i] = (T)value; + break; + case ARANGE: + A[i] = (T)i; + break; + case _MUL: + A[i] = A[i]*B[i]; + break; + } + } +} + + +//============================================================== +// TEMPLATE DEFINITIONS +//============================================================== + +template __global__ void kfunc(float *A, float *B, float value, long n); +template __global__ void kfunc(unsigned char *A, unsigned char *B, unsigned char value, long n); +template __global__ void kfunc(float *A, float *B, float value, long n); +template __global__ void kfunc(float *A, float *B, float value, long n); + +// these are not used and make no sense, but the compiler needs them +//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +// these are not used and make no sense, but the compiler needs them + +//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); + +template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); + +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, half * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, hip_bfloat16 * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, hip_bfloat16 * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, float * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, float * out, int lda, int ldb, int ldc, int blocksize); + +template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); +template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); + +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); + +template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); +template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); +template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_TURING>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); +template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_TURING>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); +template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); +template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); + +template __global__ void kdequant_mm_int32_fp16<4, 128, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, half * __restrict__ const bias, const int numRows, const int numCols, const int n); + +template __global__ void kDoubleRowColQuant<64, 4, 16, 64*4, 0>(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols); +template __global__ void kDoubleRowColQuant<64, 4, 16, 64*4, 1>(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols); + +template __device__ unsigned char dQuantize<0>(float* smem_code, const float rand, float x); +template __device__ unsigned char dQuantize<1>(float* smem_code, const float rand, float x); + +template __global__ void kEstimateQuantiles(float *__restrict__ const A, float *code, const float offset, const float max_val, const int n); +template __global__ void kEstimateQuantiles(half *__restrict__ const A, float *code, const float offset, const half max_val, const int n); + +#define MAKE_PreconditionOptimizer32bit1State(oname, gtype) \ +template __global__ void kPreconditionOptimizer32bit1State(gtype* g, gtype* p, \ + float* state1, float *unorm, \ + const float beta1, const float beta2, const float eps, const float weight_decay, \ + const int step, const float lr, const float gnorm_scale, const int n); \ + +MAKE_PreconditionOptimizer32bit1State(MOMENTUM, half) +MAKE_PreconditionOptimizer32bit1State(MOMENTUM, float) +MAKE_PreconditionOptimizer32bit1State(RMSPROP, half) +MAKE_PreconditionOptimizer32bit1State(RMSPROP, float) +MAKE_PreconditionOptimizer32bit1State(LION, half) +MAKE_PreconditionOptimizer32bit1State(LION, float) +MAKE_PreconditionOptimizer32bit1State(LION, hip_bfloat16) +MAKE_PreconditionOptimizer32bit1State(ADAGRAD, half) +MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float) + +#define MAKE_Optimizer32bit1State(oname, gtype) \ +template __global__ void kOptimizer32bit1State(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \ + const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \ + +MAKE_Optimizer32bit1State(MOMENTUM, half) +MAKE_Optimizer32bit1State(MOMENTUM, float) +MAKE_Optimizer32bit1State(RMSPROP, half) +MAKE_Optimizer32bit1State(RMSPROP, float) +MAKE_Optimizer32bit1State(LION, half) +MAKE_Optimizer32bit1State(LION, float) +MAKE_Optimizer32bit1State(LION, hip_bfloat16) +MAKE_Optimizer32bit1State(ADAGRAD, half) +MAKE_Optimizer32bit1State(ADAGRAD, float) + +#define MAKE_PreconditionOptimizer32bit2State(oname, gtype) \ +template __global__ void kPreconditionOptimizer32bit2State(gtype* g, gtype* p, \ + float* state1, float* state2, float *unorm, \ + const float beta1, const float beta2, const float eps, const float weight_decay, \ + const int step, const float lr, const float gnorm_scale, const int n); \ + +MAKE_PreconditionOptimizer32bit2State(ADAM, float) +MAKE_PreconditionOptimizer32bit2State(ADAM, half) +MAKE_PreconditionOptimizer32bit2State(ADAM, hip_bfloat16) + +template __global__ void kOptimizer32bit2State(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); +template __global__ void kOptimizer32bit2State(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); +template __global__ void kOptimizer32bit2State(hip_bfloat16* g, hip_bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); + + +#define MAKE_PreconditionStatic8bit1State(oname, gtype) \ +template __global__ void kPreconditionOptimizerStatic8bit1State(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, \ + float *unorm, \ + const float beta1, \ + const float beta2, \ + const float eps, const int step, \ + float* __restrict__ const quantiles1, \ + float* max1, float* new_max1, \ + const float weight_decay, \ + const float gnorm_scale, \ + const int n); \ + +MAKE_PreconditionStatic8bit1State(MOMENTUM, half) +MAKE_PreconditionStatic8bit1State(MOMENTUM, float) +MAKE_PreconditionStatic8bit1State(RMSPROP, half) +MAKE_PreconditionStatic8bit1State(RMSPROP, float) +MAKE_PreconditionStatic8bit1State(LION, half) +MAKE_PreconditionStatic8bit1State(LION, float) + +#define MAKE_optimizerStatic8bit1State(oname, gtype) \ +template __global__ void kOptimizerStatic8bit1State(gtype* p, gtype* const g, unsigned char* state1, \ + const float *unorm, const float max_unorm, const float param_norm, \ + const float beta1, \ + const float beta2, \ + const float eps, const int step, const float lr, \ + float* __restrict__ const quantiles1, \ + float* max1, float* new_max1, \ + float weight_decay, \ + const float gnorm_scale, \ + const int n); \ + +MAKE_optimizerStatic8bit1State(MOMENTUM, half) +MAKE_optimizerStatic8bit1State(MOMENTUM, float) +MAKE_optimizerStatic8bit1State(RMSPROP, half) +MAKE_optimizerStatic8bit1State(RMSPROP, float) +MAKE_optimizerStatic8bit1State(LION, half) +MAKE_optimizerStatic8bit1State(LION, float) + +#define MAKE_PreconditionStatic8bit2State(oname, gtype) \ +template __global__ void kPreconditionOptimizerStatic8bit2State(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, \ + float *unorm, \ + const float beta1, const float beta2, \ + const float eps, const int step, \ + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \ + float* max1, float* max2, float* new_max1, float* new_max2, \ + const float gnorm_scale, \ + const int n); \ + +MAKE_PreconditionStatic8bit2State(ADAM, half) +MAKE_PreconditionStatic8bit2State(ADAM, float) + +#define MAKE_optimizerStatic8bit2State(oname, gtype) \ +template __global__ void kOptimizerStatic8bit2State(gtype* p, gtype* const g, unsigned char* state1, unsigned char* state2, \ + const float *unorm, const float max_unorm, const float param_norm, \ + const float beta1, const float beta2, \ + const float eps, const int step, const float lr, \ + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \ + float* max1, float* max2, float* new_max1, float* new_max2, \ + float weight_decay, \ + const float gnorm_scale, \ + const int n); \ + +MAKE_optimizerStatic8bit2State(ADAM, half) +MAKE_optimizerStatic8bit2State(ADAM, float) + +template __global__ void kPercentileClipping(float * __restrict__ g, float *gnorm_vec, int step, const int n); +template __global__ void kPercentileClipping(half * __restrict__ g, float *gnorm_vec, int step, const int n); + +#define MAKE_kQuantizeBlockwise(dtype, blocksize, num_per_thread, stochastic, data_type_name) \ +template __global__ void kQuantizeBlockwise(float * code, dtype * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); \ + +MAKE_kQuantizeBlockwise(half, 4096, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 4096, 4, 1, General8bit) +MAKE_kQuantizeBlockwise(half, 2048, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 1024, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 512, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 256, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 128, 2, 0, General8bit) +//MAKE_kQuantizeBlockwise(half, 64, 2, 0, General8bit) + +MAKE_kQuantizeBlockwise(half, 4096, 4, 0, FP4) +MAKE_kQuantizeBlockwise(half, 2048, 4, 0, FP4) +MAKE_kQuantizeBlockwise(half, 1024, 4, 0, FP4) +MAKE_kQuantizeBlockwise(half, 512, 2, 0, FP4) +MAKE_kQuantizeBlockwise(half, 256, 2, 0, FP4) +MAKE_kQuantizeBlockwise(half, 128, 2, 0, FP4) +//MAKE_kQuantizeBlockwise(half, 64, 2, 0, FP4) + +MAKE_kQuantizeBlockwise(half, 4096, 4, 0, NF4) +MAKE_kQuantizeBlockwise(half, 2048, 4, 0, NF4) +MAKE_kQuantizeBlockwise(half, 1024, 4, 0, NF4) +MAKE_kQuantizeBlockwise(half, 512, 2, 0, NF4) +MAKE_kQuantizeBlockwise(half, 256, 2, 0, NF4) +MAKE_kQuantizeBlockwise(half, 128, 2, 0, NF4) +//MAKE_kQuantizeBlockwise(half, 64, 2, 0, NF4) + +MAKE_kQuantizeBlockwise(float, 4096, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 4096, 4, 1, General8bit) +MAKE_kQuantizeBlockwise(float, 2048, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 1024, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 512, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 256, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 128, 2, 0, General8bit) +//MAKE_kQuantizeBlockwise(float, 64, 2, 0, General8bit) + +MAKE_kQuantizeBlockwise(float, 4096, 4, 0, FP4) +MAKE_kQuantizeBlockwise(float, 2048, 4, 0, FP4) +MAKE_kQuantizeBlockwise(float, 1024, 4, 0, FP4) +MAKE_kQuantizeBlockwise(float, 512, 2, 0, FP4) +MAKE_kQuantizeBlockwise(float, 256, 2, 0, FP4) +MAKE_kQuantizeBlockwise(float, 128, 2, 0, FP4) +//MAKE_kQuantizeBlockwise(float, 64, 2, 0, FP4) + +MAKE_kQuantizeBlockwise(float, 4096, 4, 0, NF4) +MAKE_kQuantizeBlockwise(float, 2048, 4, 0, NF4) +MAKE_kQuantizeBlockwise(float, 1024, 4, 0, NF4) +MAKE_kQuantizeBlockwise(float, 512, 2, 0, NF4) +MAKE_kQuantizeBlockwise(float, 256, 2, 0, NF4) +MAKE_kQuantizeBlockwise(float, 128, 2, 0, NF4) +//MAKE_kQuantizeBlockwise(float, 64, 2, 0, NF4) + +MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 1, General8bit) +MAKE_kQuantizeBlockwise(hip_bfloat16, 2048, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, General8bit) +//MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, General8bit) + +MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 0, FP4) +MAKE_kQuantizeBlockwise(hip_bfloat16, 2048, 4, 0, FP4) +MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, FP4) +MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, FP4) +MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, FP4) +MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, FP4) +//MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, FP4) + +MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 0, NF4) +MAKE_kQuantizeBlockwise(hip_bfloat16, 2048, 4, 0, NF4) +MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, NF4) +MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, NF4) +MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, NF4) +MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, NF4) +//MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, NF4) + +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, hip_bfloat16 *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, hip_bfloat16 *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, hip_bfloat16 *out, const int blocksize, const int n); + +#define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \ +template __global__ void kOptimizerStatic8bit2StateBlockwise(gtype* p, gtype* __restrict__ const g, unsigned char* state1, unsigned char* state2, \ + const float beta1, const float beta2, const float beta3, const float alpha, \ + const float eps, const int step, const float lr, \ + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \ + float* absmax1, float* absmax2, \ + float weight_decay, \ + const float gnorm_scale, const bool skip_zeros, const int n); \ + +MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, float, 2048, 8) +MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, half, 2048, 8) +MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, hip_bfloat16, 2048, 8) + + +#define MAKE_OptimizerStatic8bit1StateBlockwise(oname, gtype, block_size, num_per_thread) \ +template __global__ void kOptimizerStatic8bit1StateBlockwise( \ + gtype* p, gtype* __restrict__ const g, unsigned char* state1, \ + const float beta1, const float beta2, \ + const float eps, const int step, const float lr, \ + float* __restrict__ const quantiles1, \ + float* absmax1, \ + float weight_decay, \ + const float gnorm_scale, const bool skip_zeros, const int n); \ + +MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, half, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, half, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(LION, float, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(LION, half, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(LION, hip_bfloat16, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 2048, 8) diff --git a/csrc/kernels_hip.cuh b/csrc/kernels_hip.cuh new file mode 100644 index 000000000..6768302f9 --- /dev/null +++ b/csrc/kernels_hip.cuh @@ -0,0 +1,134 @@ +// !!! This is a file automatically generated by hipify!!! +#include "hip/hip_runtime.h" +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include + +#ifndef kernels +#define kernels + +//template __global__ void kMatmul_inference_4bit(INP_TYPE *A, unsigned char *B, OUT_TYPE *out, int lda, int ldb, int rowsA, int colsA, int colsB); + +template__global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n); + +__global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n); +__global__ void kDequantize(float *code, unsigned char *A, float *out, const int n); + +template __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n); + +template +__global__ void kPreconditionOptimizer32bit2State(T* g, T* p, + float* state1, float* state2, float *unorm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const int n); + +template +__global__ void kOptimizer32bit2State(T* g, T* p, + float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); + +template +__global__ void kPreconditionOptimizer32bit1State(T* g, T* p, + float* state1, float *unorm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const int n); + +template +__global__ void kOptimizer32bit1State(T* g, T* p, + float* state1, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); + +template +__global__ void +kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, + float *unorm, + const float beta1, const float beta2, + const float eps, const int step, + float* __restrict__ const quantiles1, + float* max1, float* new_max1, + const float weight_decay, + const float gnorm_scale, const int n); + + +template +__global__ void +kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, + const float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, + float* max1, float* new_max1, + float weight_decay, const float gnorm_scale, const int n); + + + +template +__global__ void +kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, + float *unorm, + const float beta1, const float beta2, + const float eps, const int step, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + const float gnorm_scale, const int n); + + +template +__global__ void +kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned char* state2, + const float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + float weight_decay, const float gnorm_scale, const int n); + +template __global__ void kOptimizerStatic8bit2StateBlockwise( + T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, + float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n); + +template __global__ void kOptimizerStatic8bit1StateBlockwise( + T* p, T* __restrict__ const g, unsigned char* state1, + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, + float* absmax1, + float weight_decay, + const float gnorm_scale, const bool skip_zeros, const int n); + + +template __global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n); + + +__global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n); + + +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); + +template __global__ void kdequant_mm_int32_fp16( + int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, + half *out, half * __restrict__ const bias, const int numRows, const int numCols, const int n); + +template __global__ void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols); +template __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols); + +template __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); + +template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); + +template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc); +template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize); + +template __global__ void kfunc(T *A, T *B, T value, long n); + +#endif diff --git a/csrc/npu_kernels.cpp b/csrc/npu_kernels.cpp new file mode 100644 index 000000000..c70e71681 --- /dev/null +++ b/csrc/npu_kernels.cpp @@ -0,0 +1,222 @@ +#include "kernel_operator.h" +#include "npu_ops.h" + +using namespace AscendC; + +constexpr int32_t BUFFER_NUM = 1; + +constexpr half Q_COFF_0 = -0.377685546875; +constexpr half Q_COFF_1 = -3.193359375; +constexpr half Q_COFF_2 = 0.583984375; +constexpr half Q_COFF_3 = 6.02734375; +constexpr half Q_COFF_4 = 1.9560546875; +constexpr half Q_COFF_5 = 7.08984375; + +#define CEIL32(num) (((num) + 32 - 1) / 32 * 32) +#define CEIL_BASE(num, base) (((num) + (base) - 1) / (base) * (base)) + + +template +class KernelDequantizeBlockwiseNf4 { +public: + __aicore__ inline KernelDequantizeBlockwiseNf4() {} + + __aicore__ inline void Init(GM_ADDR A, GM_ADDR absmax, GM_ADDR out, GM_ADDR tilingDevice, TPipe &pipe) + { + ASSERT(GetBlockNum() != 0 && "block dim can not be zero!"); + auto *tiling_data = reinterpret_cast<__gm__ BlockwiseNf4TilingData *>(tilingDevice); + this->blocksize = tiling_data->blocksize; + uint32_t coreNum = tiling_data->coreNum; + uint32_t singleCoreNumel = tiling_data->singleCoreNumel; + uint32_t singleCoreNumelTail = tiling_data->singleCoreNumelTail; + uint32_t numel = tiling_data->numel; + uint32_t ubSize = tiling_data->ubSize; + uint32_t blockIdx = (uint32_t)GetBlockIdx(); + if (coreNum - blockIdx == 1) { + this->CurCoreFP16Num = singleCoreNumelTail; + } else { + this->CurCoreFP16Num = singleCoreNumel; + } + constexpr uint32_t ELEMENT_BYTES = (TypeMode == 1) ? 4 : 2; // FP32: 4bytes, FP16/BF16: 2bytes + uint32_t eachBatchPkgNum = (ubSize - 16 * ELEMENT_BYTES) / + (this->blocksize / 2 * BUFFER_NUM + ELEMENT_BYTES * BUFFER_NUM + this->blocksize * + (ELEMENT_BYTES * BUFFER_NUM + sizeof(half) + sizeof(uint32_t) + ELEMENT_BYTES)); + if (eachBatchPkgNum >= 32 / ELEMENT_BYTES) { + eachBatchPkgNum = (eachBatchPkgNum / (32 / ELEMENT_BYTES)) * (32 / ELEMENT_BYTES); + } else { + eachBatchPkgNum = (eachBatchPkgNum / 2) * 2; + } + this->eachBatchFP16Num = this->blocksize * eachBatchPkgNum; // 64 * 288 + + // gm, 32-byte alignment + uint32_t AOffset = singleCoreNumel / 2 * blockIdx; + uint32_t ABufferSize = singleCoreNumel / 2; + AGm.SetGlobalBuffer((__gm__ int8_t*)A + AOffset, ABufferSize); + uint32_t absmaxOffset = singleCoreNumel / this->blocksize * blockIdx; + uint32_t absmaxBufferSize = singleCoreNumel / this->blocksize; + absmaxGm.SetGlobalBuffer((__gm__ T*)absmax + absmaxOffset, absmaxBufferSize); + uint32_t outOffset = singleCoreNumel * blockIdx; + uint32_t outBufferSize = singleCoreNumel; + outGm.SetGlobalBuffer((__gm__ T*)out + outOffset, outBufferSize); + + // TQue, 32-byte alignment + pipe.InitBuffer(inQueueA, BUFFER_NUM, this->eachBatchFP16Num / 2); + pipe.InitBuffer(inQueueAbsmax, BUFFER_NUM, CEIL32(eachBatchPkgNum * ELEMENT_BYTES)); + pipe.InitBuffer(outQueueOut, BUFFER_NUM, this->eachBatchFP16Num * ELEMENT_BYTES); + + // TBuf, 32-byte alignment + pipe.InitBuffer(calcNf4ToFloat, 16 * ELEMENT_BYTES); + pipe.InitBuffer(calcAFP16, this->eachBatchFP16Num * sizeof(half)); + pipe.InitBuffer(calcAUint32, this->eachBatchFP16Num * sizeof(uint32_t)); + pipe.InitBuffer(calcAbsmaxBuf, this->eachBatchFP16Num * ELEMENT_BYTES); + } + + __aicore__ inline void Process(void) + { + Compute(); + } + +private: + __aicore__ inline void initNf4ToFloat(LocalTensor &nf4ToFloat) + { + if constexpr (TypeMode == 1) { + nf4ToFloat(0) = static_cast(-1.0); + nf4ToFloat(1) = static_cast(-0.6961928009986877); + nf4ToFloat(2) = static_cast(-0.5250730514526367); + nf4ToFloat(3) = static_cast(-0.39491748809814453); + nf4ToFloat(4) = static_cast(-0.28444138169288635); + nf4ToFloat(5) = static_cast(-0.18477343022823334); + nf4ToFloat(6) = static_cast(-0.09105003625154495); + nf4ToFloat(7) = static_cast(0.0); + nf4ToFloat(8) = static_cast(0.07958029955625534); + nf4ToFloat(9) = static_cast(0.16093020141124725); + nf4ToFloat(10) = static_cast(0.24611230194568634); + nf4ToFloat(11) = static_cast(0.33791524171829224); + nf4ToFloat(12) = static_cast(0.44070982933044434); + nf4ToFloat(13) = static_cast(0.5626170039176941); + nf4ToFloat(14) = static_cast(0.7229568362236023); + nf4ToFloat(15) = static_cast(1.0); + } else if constexpr (TypeMode == 2) { + nf4ToFloat(0) = static_cast(-1.0); + nf4ToFloat(1) = static_cast(-0.6962890625); + nf4ToFloat(2) = static_cast(-0.52490234375); + nf4ToFloat(3) = static_cast(-0.39501953125); + nf4ToFloat(4) = static_cast(-0.284423828125); + nf4ToFloat(5) = static_cast(-0.184814453125); + nf4ToFloat(6) = static_cast(-0.091064453125); + nf4ToFloat(7) = static_cast(0.0); + nf4ToFloat(8) = static_cast(0.07958984375); + nf4ToFloat(9) = static_cast(0.160888671875); + nf4ToFloat(10) = static_cast(0.24609375); + nf4ToFloat(11) = static_cast(0.337890625); + nf4ToFloat(12) = static_cast(0.440673828125); + nf4ToFloat(13) = static_cast(0.5625); + nf4ToFloat(14) = static_cast(0.72314453125); + nf4ToFloat(15) = static_cast(1.0); + } + } + + __aicore__ inline void Compute(void) + { + constexpr uint32_t ELEMENT_BYTES = (TypeMode == 1) ? 4 : 2; // FP32: 4bytes, FP16/BF16: 2bytes + LocalTensor ALocal = inQueueA.AllocTensor(); + LocalTensor absmaxLocal = inQueueAbsmax.AllocTensor(); + LocalTensor outLocal = outQueueOut.AllocTensor(); + + LocalTensor AFP16 = calcAFP16.Get(); + LocalTensor AInt32 = calcAUint32.Get(); + LocalTensor absmaxBuf = calcAbsmaxBuf.Get(); + LocalTensor nf4ToFloat = calcNf4ToFloat.Get(); + initNf4ToFloat(nf4ToFloat); + + DataCopyParams dataCopyParams = {1, 0, 0, 0}; + uint32_t curBatchNumel = this->eachBatchFP16Num; + uint32_t curBatchPkgNum = curBatchNumel / this->blocksize; + + uint32_t batchCount = (this->CurCoreFP16Num + this->eachBatchFP16Num - 1) / this->eachBatchFP16Num; + for (uint32_t batchIdx = 0; batchIdx < batchCount; batchIdx++) { + if (batchCount - batchIdx == 1) { + curBatchNumel = this->CurCoreFP16Num - this->eachBatchFP16Num * batchIdx; + curBatchPkgNum = (curBatchNumel + this->blocksize - 1) / this->blocksize; + } + + dataCopyParams.blockLen = curBatchNumel / 2; // Byte + DataCopyPad(ALocal, AGm[this->eachBatchFP16Num / 2 * batchIdx], dataCopyParams, {true, 0, 0, 0}); + dataCopyParams.blockLen = ELEMENT_BYTES * curBatchPkgNum; // Byte + uint32_t gmOffset = this->eachBatchFP16Num / this->blocksize * batchIdx; + DataCopyPad(absmaxLocal, absmaxGm[gmOffset], dataCopyParams, {true, 0, 0, 0}); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + pipe_barrier(PIPE_ALL); + + LocalTensor AInt4 = ALocal.ReinterpretCast(); + Cast(AFP16, AInt4, RoundMode::CAST_NONE, curBatchNumel); + pipe_barrier(PIPE_V); + Adds(AFP16, AFP16, static_cast(8), curBatchNumel); + pipe_barrier(PIPE_V); + if constexpr (TypeMode == 1) { + Muls(AFP16, AFP16, static_cast(4), curBatchNumel); + } else { + Muls(AFP16, AFP16, static_cast(2), curBatchNumel); + } + pipe_barrier(PIPE_V); + Cast(AInt32, AFP16, RoundMode::CAST_ROUND, curBatchNumel); + pipe_barrier(PIPE_V); + LocalTensor AUint32 = AInt32.ReinterpretCast(); + Gather(outLocal, nf4ToFloat, AUint32, 0, curBatchNumel); + pipe_barrier(PIPE_V); + uint32_t dstShape[] = {curBatchPkgNum, this->blocksize}; + uint32_t srcShape[] = {curBatchPkgNum, 1}; + BroadCast(absmaxBuf, absmaxLocal, dstShape, srcShape); + pipe_barrier(PIPE_ALL); + Mul(outLocal, outLocal, absmaxBuf, curBatchNumel); + pipe_barrier(PIPE_ALL); + + dataCopyParams.blockLen = ELEMENT_BYTES * curBatchNumel; // Byte + DataCopyPad(outGm[batchIdx * this->eachBatchFP16Num], outLocal, dataCopyParams); + pipe_barrier(PIPE_MTE3); + } + pipe_barrier(PIPE_ALL); + + inQueueA.FreeTensor(ALocal); + inQueueAbsmax.FreeTensor(absmaxLocal); + outQueueOut.FreeTensor(outLocal); + } + +private: + TQue inQueueA; + TQue inQueueAbsmax; + TQue outQueueOut; + TBuf calcAFP16; + TBuf calcAUint32; + TBuf calcNf4ToFloat; + TBuf calcAbsmaxBuf; + GlobalTensor AGm; + GlobalTensor absmaxGm; + GlobalTensor outGm; + uint32_t blocksize; + uint32_t CurCoreFP16Num; + uint32_t eachBatchFP16Num; +}; + + + +extern "C" { + +__global__ __aicore__ void dequantize_blockwise_fp32_nf4(GM_ADDR A, GM_ADDR absmax, GM_ADDR out, GM_ADDR tiling) +{ + TPipe pipe; + KernelDequantizeBlockwiseNf4 op; + op.Init(A, absmax, out, tiling, pipe); + op.Process(); +} + +__global__ __aicore__ void dequantize_blockwise_fp16_nf4(GM_ADDR A, GM_ADDR absmax, GM_ADDR out, GM_ADDR tiling) +{ + TPipe pipe; + KernelDequantizeBlockwiseNf4 op; + op.Init(A, absmax, out, tiling, pipe); + op.Process(); +} + +} diff --git a/csrc/npu_ops.cpp b/csrc/npu_ops.cpp new file mode 100644 index 000000000..fb5ecef2f --- /dev/null +++ b/csrc/npu_ops.cpp @@ -0,0 +1,51 @@ +#include +#include "acl/acl.h" +#include "tiling/platform/platform_ascendc.h" +#include "npu_ops.h" + +#include "aclrtlaunch_dequantize_blockwise_fp32_nf4.h" +#include "aclrtlaunch_dequantize_blockwise_fp16_nf4.h" + + +extern "C" { + +int32_t get_dequantize_blockwise_nf4_tiling(uint32_t blocksize, uint32_t n, BlockwiseNf4TilingData *tiling) { + tiling->ubSize = 196 * 1024; + uint32_t coreNum = 40; + uint32_t totalPkgNum = (n + blocksize - 1) / blocksize; + uint32_t singleCorePkgNum = (totalPkgNum + coreNum - 1) / coreNum; + coreNum = (totalPkgNum + singleCorePkgNum - 1) / singleCorePkgNum; + uint32_t singleCoreNumel = singleCorePkgNum * blocksize; + uint32_t singleCoreNumelTail = n % singleCoreNumel; + if (singleCoreNumelTail == 0) { + singleCoreNumelTail = singleCoreNumel; + } + tiling->coreNum = coreNum; + tiling->blocksize = blocksize; + tiling->numel = n; + tiling->singleCoreNumel = singleCoreNumel; + tiling->singleCoreNumelTail = singleCoreNumelTail; + return 0; +} + +void dequantizeBlockwiseNf4(uint8_t *A, uint8_t *absmax, uint8_t *out, uint32_t blocksize, uint32_t n, void* stream, const uint32_t type_mode) { + uint32_t blockDim = 40; + size_t tilingSize = sizeof(struct BlockwiseNf4TilingData); + BlockwiseNf4TilingData *tilingHost; + tilingHost = (struct BlockwiseNf4TilingData *)malloc(tilingSize); + uint32_t error = get_dequantize_blockwise_nf4_tiling(blocksize, n, tilingHost); + if (error != 0) { + printf("[!] error\n"); + } + uint8_t *tilingDevice = nullptr; + aclrtMalloc((void **)&tilingDevice, tilingSize, ACL_MEM_MALLOC_NORMAL_ONLY); + aclrtMemcpyAsync((void *)tilingDevice, tilingSize, tilingHost, tilingSize, ACL_MEMCPY_HOST_TO_DEVICE, stream); + if (type_mode == 1) { + ACLRT_LAUNCH_KERNEL(dequantize_blockwise_fp32_nf4)(blockDim, stream, A, absmax, out, tilingDevice); + } else if (type_mode == 2) { + ACLRT_LAUNCH_KERNEL(dequantize_blockwise_fp16_nf4)(blockDim, stream, A, absmax, out, tilingDevice); + } + aclrtFree(tilingDevice); +} + +} diff --git a/csrc/npu_ops.h b/csrc/npu_ops.h new file mode 100644 index 000000000..d7a26cd34 --- /dev/null +++ b/csrc/npu_ops.h @@ -0,0 +1,28 @@ +#ifndef NPU_OPS_H +#define NPU_OPS_H +#include + +#define CHECK_ACL(x) \ + do { \ + aclError __ret = x; \ + if (__ret != ACL_ERROR_NONE) { \ + std::cerr << __FILE__ << ":" << __LINE__ << " aclError:" << __ret << std::endl; \ + } \ + } while (0); + + +struct BlockwiseNf4TilingData { + uint32_t coreNum; + uint32_t blocksize; + uint32_t numel; + uint32_t singleCoreNumel; + uint32_t singleCoreNumelTail; + uint32_t ubSize; +}; + +extern "C" { + +void dequantizeBlockwiseNf4(uint8_t *A, uint8_t *absmax, uint8_t *out, uint32_t blocksize, uint32_t n, void* stream, const uint32_t type_mode); + +} +#endif diff --git a/csrc/ops.hip b/csrc/ops.hip new file mode 100644 index 000000000..8398a6171 --- /dev/null +++ b/csrc/ops.hip @@ -0,0 +1,1025 @@ +// !!! This is a file automatically generated by hipify!!! +#include "hip/hip_runtime.h" +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include +#ifndef NO_HIPBLASLT +#include +#endif +#include +#include +#include +#include + +#define ERR_NOT_IMPLEMENTED 100 + +using namespace BinSearch; +using std::cout; +using std::endl; + +void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n) +{ + int threads = 512; + int num_blocks = n/threads; + num_blocks = n % threads == 0 ? num_blocks : num_blocks + 1; + hipLaunchKernelGGL(( kHistogramScatterAdd2D), dim3(num_blocks), dim3(512), 0, 0, histogram, index1, index2, src, maxidx1, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +template void estimateQuantiles(T *A, float *code, float offset, int n) +{ + int num_blocks = n/4096; + num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; + CUDA_CHECK_RETURN(hipMemset(code, 0, 256*sizeof(float))); + hipLaunchKernelGGL(( kEstimateQuantiles), dim3(num_blocks), dim3(512), 0, 0, A, code, offset, std::numeric_limits::max(), n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +void quantize(float *code, float *A, unsigned char *out, int n) +{ + int num_blocks = n/1024; + num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1; + hipLaunchKernelGGL(( kQuantize), dim3(num_blocks), dim3(1024), 0, 0, code, A, out, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +void dequantize(float *code, unsigned char *A, float *out, int n, hipStream_t stream) +{ + int num_blocks = n/1024; + num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1; + hipLaunchKernelGGL(( kDequantize), dim3(num_blocks), dim3(1024), 0, stream, code, A, out, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n) +{ + int num_blocks = n/blocksize; + num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; + + if(blocksize == 4096) + hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(1024), 0, 0, code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 2048) + hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(512), 0, 0, code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 1024) + hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(256), 0, 0, code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 512) + hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(256), 0, 0, code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 256) + hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(128), 0, 0, code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 128) + hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(64), 0, 0, code, A, absmax, out, rand, rand_offset, n); + //else if(blocksize == 64) + // hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(32), 0, 0, code, A, absmax, out, rand, rand_offset, n); + + + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n, hipStream_t stream) +{ + int num_blocks = n/blocksize; + num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; + int tile_size = (DATA_TYPE > 0) ? 1024 : 512; + + if(DATA_TYPE > 0) + hipLaunchKernelGGL(( kDequantizeBlockwise), dim3((n+tile_size-1)/tile_size), dim3(64), 0, stream, code, A, absmax, out, blocksize/2, n); + else + hipLaunchKernelGGL(( kDequantizeBlockwise), dim3((n+tile_size-1)/tile_size), dim3(64), 0, stream, code, A, absmax, out, blocksize, n); + + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + + +//void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB) +//{ +// int num_blocks = (colsB+32-1)/32; +// kMatmul_inference_4bit<<>>(A, B, out, lda, ldb, rowsA, colsA, colsB); +// CUDA_CHECK_RETURN(hipPeekAtLastError()); +//} + + +template void optimizer32bit(T* g, T* p, + float* state1, float* state2, float *unorm, float max_unorm, float param_norm, + const float beta1, const float beta2, const float beta3, const float alpha, + const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) +{ + int num_blocks = n/4096; + num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; + switch(OPTIMIZER) + { + case ADAM: + if(max_unorm > 0.0f) + { + CUDA_CHECK_RETURN(hipMemset(unorm, 0, 1*sizeof(float))); + hipLaunchKernelGGL(( kPreconditionOptimizer32bit2State), dim3(num_blocks), dim3(512), 0, 0, g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + } + hipLaunchKernelGGL(( kOptimizer32bit2State), dim3(num_blocks), dim3(1024), 0, 0, g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + break; + case ADEMAMIX: + // TODO: Not implemented! + break; + case MOMENTUM: + case RMSPROP: + case ADAGRAD: + if(max_unorm > 0.0f) + { + CUDA_CHECK_RETURN(hipMemset(unorm, 0, 1*sizeof(float))); + hipLaunchKernelGGL(( kPreconditionOptimizer32bit1State), dim3(num_blocks), dim3(512), 0, 0, g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + } + + hipLaunchKernelGGL(( kOptimizer32bit1State), dim3(num_blocks), dim3(1024), 0, 0, g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + break; + case LION: + // in lion, the momentum update after the parameter update + hipLaunchKernelGGL(( kOptimizer32bit1State), dim3(num_blocks), dim3(1024), 0, 0, g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + + if(max_unorm > 0.0f) + { + CUDA_CHECK_RETURN(hipMemset(unorm, 0, 1*sizeof(float))); + hipLaunchKernelGGL(( kPreconditionOptimizer32bit1State), dim3(num_blocks), dim3(512), 0, 0, g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + } + break; + } +} + +template void optimizerStatic8bit(T* p, T* g, + unsigned char* state1, unsigned char* state2, + float *unorm, float max_unorm, float param_norm, + float beta1, float beta2, + float eps, int step, float lr, + float* quantiles1, float* quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + float weight_decay, + const float gnorm_scale, int n) +{ + int num_blocks = n/4096; + num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; + + if(max_unorm > 0.0f){ CUDA_CHECK_RETURN(hipMemset(unorm, 0, 1*sizeof(float))); } + + switch(OPTIMIZER) + { + case ADAM: + CUDA_CHECK_RETURN(hipMemset(new_max1, 0, 1*sizeof(float))); + CUDA_CHECK_RETURN(hipMemset(new_max2, 0, 1*sizeof(float))); + hipLaunchKernelGGL(( kPreconditionOptimizerStatic8bit2State), dim3(num_blocks), dim3(256), 0, 0, p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + hipLaunchKernelGGL(( kOptimizerStatic8bit2State), dim3(num_blocks), dim3(1024), 0, 0, p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, + quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + break; + case MOMENTUM: + case RMSPROP: + case ADAGRAD: + CUDA_CHECK_RETURN(hipMemset(new_max1, 0, 1*sizeof(float))); + hipLaunchKernelGGL(( kPreconditionOptimizerStatic8bit1State), dim3(num_blocks), dim3(256), 0, 0, p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + hipLaunchKernelGGL(( kOptimizerStatic8bit1State), dim3(num_blocks), dim3(1024), 0, 0, p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, + quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + break; + case LION: + // in lion, the momentum update happens after the parameter update + hipLaunchKernelGGL(( kOptimizerStatic8bit1State), dim3(num_blocks), dim3(1024), 0, 0, p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, + quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + + CUDA_CHECK_RETURN(hipMemset(new_max1, 0, 1*sizeof(float))); + hipLaunchKernelGGL(( kPreconditionOptimizerStatic8bit1State), dim3(num_blocks), dim3(256), 0, 0, p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + break; + default: + break; + } +} + +#define BLOCKSIZE_2STATE 2048 +#define NUM_2STATE 8 +#define BLOCKSIZE_1STATE 2048 +#define NUM_1STATE 8 + +template void optimizerStatic8bitBlockwise(T* p, T* g, + unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, + float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) +{ + + int num_blocks = 0; + switch(OPTIMIZER) + { + case ADAM: + num_blocks = n/BLOCKSIZE_2STATE; + num_blocks = n % BLOCKSIZE_2STATE == 0 ? num_blocks : num_blocks + 1; + hipLaunchKernelGGL(( kOptimizerStatic8bit2StateBlockwise), dim3(num_blocks), dim3(BLOCKSIZE_2STATE/NUM_2STATE), 0, 0, p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, + quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + break; + case ADEMAMIX: break; // TODO: Not implemented! + case MOMENTUM: + case RMSPROP: + case ADAGRAD: + case LION: + num_blocks = n/BLOCKSIZE_1STATE; + num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1; + hipLaunchKernelGGL(( kOptimizerStatic8bit1StateBlockwise), dim3(num_blocks), dim3(BLOCKSIZE_1STATE/NUM_1STATE), 0, 0, p, g, state1, beta1, beta2, eps, step, lr, + quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + break; + } +} + + + +template void percentileClipping(T * g, float *gnorm_vec, int step, const int n) +{ + int num_blocks = n/2048; + num_blocks = n % 2048 == 0 ? num_blocks : num_blocks + 1; + CUDA_CHECK_RETURN(hipMemset(&gnorm_vec[step % 100], 0, 1*sizeof(float))); + hipLaunchKernelGGL(( kPercentileClipping), dim3(num_blocks), dim3(512), 0, 0, g, gnorm_vec, step, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +void gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc) +{ + const int falpha = 1; + const int fbeta = 0; + const void * alpha = &falpha; + const void * beta = &fbeta; + hipblasStatus_t status; + + status = hipblasGemmEx(context->m_handle, + transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N, + transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N, + m, n, k, + alpha, A, HIPBLAS_R_8I, lda, B, HIPBLAS_R_8I, ldb, beta, + C, HIPBLAS_R_32I, ldc, + HIPBLAS_R_32I, HIPBLAS_GEMM_DEFAULT); + + if (status != HIPBLAS_STATUS_SUCCESS) + { + std::cout << "HIPBLAS ERROR: Status " << status << std::endl; + } + +} + +void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc, + long long int strideA, long long int strideB, long long int strideC, int batchCount) +{ + const int falpha = 1; + const int fbeta = 0; + const void * alpha = &falpha; + const void * beta = &fbeta; + hipblasStatus_t status; + + //cout << transposeA << transposeB << endl; + //printf("%i %i %i\n", m,n,k); + //printf("%i %i %i\n", lda,ldb,ldc); + //printf("%i %i %i\n", strideA, strideB, strideC); + //printf("%i\n", batchCount); + + status = hipblasGemmStridedBatchedEx(context->m_handle, + transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N, + transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N, + m, n, k, + alpha, A, HIPBLAS_R_8I, lda, (long long int)strideA, B, HIPBLAS_R_8I, ldb, (long long int)strideB, beta, + C, HIPBLAS_R_32I, ldc, (long long int)strideC, batchCount, + HIPBLAS_R_32I, HIPBLAS_GEMM_DEFAULT); + + if (status != HIPBLAS_STATUS_SUCCESS) + { + std::cout << "HIPBLAS ERROR: Status " << status << std::endl; + } + +} + +int roundoff(int v, int d) { + return (v + d - 1) / d * d; +} + + +#ifdef NO_HIPBLASLT +#else +template hipblasLtOrder_t get_order() +{ + switch(ORDER) + { + case ROW: + return HIPBLASLT_ORDER_ROW; + break; + case COL: + return HIPBLASLT_ORDER_COL; + break; + case COL32: + //return HIPBLASLT_ORDER_COL32; + return HIPBLASLT_ORDER_COL; + break; + case COL_TURING: + //return HIPBLASLT_ORDER_COL4_4R2_8C; + return HIPBLASLT_ORDER_COL; + break; + case COL_AMPERE: + //return HIPBLASLT_ORDER_COL32_2R_4R4; + return HIPBLASLT_ORDER_COL; + break; + default: + break; + } + + return HIPBLASLT_ORDER_ROW; +} + +template hipblasLtOrder_t get_order(); +template hipblasLtOrder_t get_order(); +template hipblasLtOrder_t get_order(); +//template hipblasLtOrder_t get_order(); +//template hipblasLtOrder_t get_order(); +#endif + + +template int get_leading_dim(int dim1, int dim2) +{ + switch(ORDER) + { + case ROW: + return dim2; + break; + case COL: + return dim1; + break; + default: + return dim1; + break; + /*case COL32: + // 32*row tiles + return dim1*32; + break; + case COL_TURING: + return 32*roundoff(dim1, 8); + break; + case COL_AMPERE: + // 32*32 tiles + return 32*roundoff(dim1, 32); + break; + default: + return 0; + break; + */ + } +} + +template int get_leading_dim(int dim1, int dim2); +template int get_leading_dim(int dim1, int dim2); +template int get_leading_dim(int dim1, int dim2); + +template void transform(hipblasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2) +{ +#ifdef NO_HIPBLASLT +#else + hipblasLtOrder_t orderA = get_order(); + hipblasLtOrder_t orderOut = get_order(); + int ldA = get_leading_dim(dim1, dim2); + int ldOut; + if (TARGET==COL && transpose) { + ldOut = dim2; + } else { + ldOut = get_leading_dim(dim1, dim2); + } + + hipblasLtMatrixLayout_t A_desc = NULL, out_desc = NULL, B_desc = NULL; + T B = T(0); + hipblasLtMatrixTransformDesc_t A2Out_desc = NULL; + hipblasOperation_t opTranspose = HIPBLAS_OP_T; + float transformAlpha = 1.0f, transformBeta = 0.0f; + + if(DTYPE == 8) + { + checkHipblasStatus(hipblasLtMatrixLayoutCreate(&A_desc, HIP_R_8I, dim1, dim2, ldA)); + checkHipblasStatus(hipblasLtMatrixLayoutCreate(&B_desc, HIP_R_8I, 0, 0, 0)); + if (TARGET==COL && transpose) { + checkHipblasStatus(hipblasLtMatrixLayoutCreate(&out_desc, HIP_R_8I, dim2, dim1, ldOut)); + } else { + checkHipblasStatus(hipblasLtMatrixLayoutCreate(&out_desc, HIP_R_8I, dim1, dim2, ldOut)); + } + } + else if(DTYPE == 32) + { + checkHipblasStatus(hipblasLtMatrixLayoutCreate(&A_desc, HIP_R_32I, dim1, dim2, ldA)); + checkHipblasStatus(hipblasLtMatrixLayoutCreate(&B_desc, HIP_R_32I, 0, 0, 0)); + if (TARGET==COL && transpose) { + checkHipblasStatus(hipblasLtMatrixLayoutCreate(&out_desc, HIP_R_32I, dim2, dim1, ldOut)); + } else { + checkHipblasStatus(hipblasLtMatrixLayoutCreate(&out_desc, HIP_R_32I, dim1, dim2, ldOut)); + } + } + else + { + printf("ERROR WRONG TYPE FOR TRANSFORM: %i\n", DTYPE); + } + + checkHipblasStatus(hipblasLtMatrixLayoutSetAttribute(A_desc, HIPBLASLT_MATRIX_LAYOUT_ORDER, &orderA, sizeof(orderA))); + checkHipblasStatus(hipblasLtMatrixLayoutSetAttribute(out_desc, HIPBLASLT_MATRIX_LAYOUT_ORDER, &orderOut, sizeof(orderOut))); + + checkHipblasStatus(hipblasLtMatrixTransformDescCreate(&A2Out_desc, HIP_R_32F)); + + if(transpose){ checkHipblasStatus(hipblasLtMatrixTransformDescSetAttribute(A2Out_desc, HIPBLASLT_MATRIX_TRANSFORM_DESC_TRANSA, &opTranspose, sizeof(opTranspose))); } + + checkHipblasStatus(hipblasLtMatrixTransform(ltHandle, A2Out_desc, &transformAlpha, A, A_desc, &transformBeta, A, B_desc, out, out_desc, 0)); + + if (A_desc) checkHipblasStatus(hipblasLtMatrixLayoutDestroy(A_desc)); + if (B_desc) checkHipblasStatus(hipblasLtMatrixLayoutDestroy(B_desc)); + if (out_desc) checkHipblasStatus(hipblasLtMatrixLayoutDestroy(out_desc)); + if (A2Out_desc) checkHipblasStatus(hipblasLtMatrixTransformDescDestroy(A2Out_desc)); +#endif +} + +template void transform(hipblasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(hipblasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(hipblasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2); +template void transform(hipblasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2); +template void transform(hipblasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(hipblasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(hipblasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2); +template void transform(hipblasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(hipblasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(hipblasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(hipblasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2); +template void transform(hipblasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(hipblasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2); + +static std::string hipError_to_string(const hipError_t ret) +{ + switch(ret) + { + case hipSuccess: + return "hipSuccess"; + case hipErrorInvalidContext: + return "hipErrorInvalidContext"; + case hipErrorInvalidKernelFile: + return "hipErrorInvalidKernelFile"; + case hipErrorMemoryAllocation: + return "hipErrorMemoryAllocation"; + case hipErrorInitializationError: + return "hipErrorInitializationError"; + case hipErrorLaunchFailure: + return "hipErrorLaunchFailure"; + case hipErrorLaunchOutOfResources: + return "hipErrorLaunchOutOfResources"; + case hipErrorInvalidDevice: + return "hipErrorInvalidDevice"; + case hipErrorInvalidValue: + return "hipErrorInvalidValue"; + case hipErrorInvalidDevicePointer: + return "hipErrorInvalidDevicePointer"; + case hipErrorInvalidMemcpyDirection: + return "hipErrorInvalidMemcpyDirection"; + case hipErrorUnknown: + return "hipErrorUnknown"; + case hipErrorInvalidResourceHandle: + return "hipErrorInvalidResourceHandle"; + case hipErrorNotReady: + return "hipErrorNotReady"; + case hipErrorNoDevice: + return "hipErrorNoDevice"; + case hipErrorPeerAccessAlreadyEnabled: + return "hipErrorPeerAccessAlreadyEnabled"; + case hipErrorPeerAccessNotEnabled: + return "hipErrorPeerAccessNotEnabled"; + case hipErrorRuntimeMemory: + return "hipErrorRuntimeMemory"; + case hipErrorRuntimeOther: + return "hipErrorRuntimeOther"; + case hipErrorHostMemoryAlreadyRegistered: + return "hipErrorHostMemoryAlreadyRegistered"; + case hipErrorHostMemoryNotRegistered: + return "hipErrorHostMemoryNotRegistered"; + case hipErrorMapBufferObjectFailed: + return "hipErrorMapBufferObjectFailed"; + case hipErrorTbd: + return "hipErrorTbd"; + default: + throw std::runtime_error("unknown hipError"); + } +} + +template int igemmlt(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, hipStream_t stream) +{ +#ifdef NO_HIPBLASLT + return ERR_NOT_IMPLEMENTED; +#else + + // Calculate C = A^T @ B, in col-major layout. + // + // Use the IMMA kernels requires: + // * A must be transposed and B must be non-transposed. + // * Dimensions m and k must be multiples of 4. + // * All pointers must be 4-byte aligned; 16-byte alignment preferred. + + int has_error = 0; + const int64_t max_workspace_size = 0;//set to 0 to avoid choosing GSU kernel + + hipblasLtMatmulDesc_t matmulDesc; + hipblasLtMatrixLayout_t aDesc, bDesc, cDesc; + hipblasOperation_t opT = HIPBLAS_OP_T; + + hipDataType outType = DTYPE_OUT == 32 ? HIP_R_32I : HIP_R_8I; + hipDataType scaleType = DTYPE_OUT == 32 ? HIP_R_32I : HIP_R_32F; + + hipblasLtPointerMode_t pointerMode = HIPBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST; + + has_error |= checkHipblasStatus(hipblasLtMatrixLayoutCreate(&aDesc, HIP_R_8I, m, k, lda)); + has_error |= checkHipblasStatus(hipblasLtMatrixLayoutCreate(&bDesc, HIP_R_8I, m, n, ldb)); + has_error |= checkHipblasStatus(hipblasLtMatrixLayoutCreate(&cDesc, outType, k, n, ldc)); + + // Default layout order is col major + + has_error |= checkHipblasStatus(hipblasLtMatmulDescCreate(&matmulDesc, HIPBLAS_COMPUTE_32I, scaleType)); + has_error |= checkHipblasStatus(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_TRANSA, &opT, sizeof(opT))); + + if (DTYPE_OUT == 32) { + + /* Algo and workspace TODO: need to rework to not be duplicated */ + // Set User Preference attributes + hipblasLtMatmulPreference_t pref; + checkHipblasStatus(hipblasLtMatmulPreferenceCreate(&pref)); + checkHipblasStatus( + hipblasLtMatmulPreferenceSetAttribute(pref, + HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &max_workspace_size, + sizeof(max_workspace_size))); + + const int request_solutions = 1; + hipblasLtMatmulHeuristicResult_t heuristicResult[request_solutions]; + int returnedAlgoCount = 0; + checkHipblasStatus(hipblasLtMatmulAlgoGetHeuristic(ltHandle, + matmulDesc, + aDesc, + bDesc, + cDesc, + cDesc, + pref, + request_solutions, + heuristicResult, + &returnedAlgoCount)); + + if (returnedAlgoCount == 0) + { + has_error = 1; + fprintf(stderr, "Error: Matmul Algo Heuristic didn't return algorithms\n"); + } else { + int alpha = 1, beta = 0; + has_error |= checkHipblasStatus(hipblasLtMatmul( + ltHandle, matmulDesc, + &alpha, A, aDesc, + B, bDesc, &beta, + (int32_t*)C, cDesc, + (int32_t*)C, cDesc, + &heuristicResult[0].algo, NULL, 0, stream + )); + } + } else { + // This path is unlikely to be used, as 8-bit accumulation can lead to likely overflows. + + if (!SCALE_ROWS) { + float alpha = 1.0f, beta = 0.0f; + has_error |= checkHipblasStatus(hipblasLtMatmul( + ltHandle, matmulDesc, + &alpha, A, aDesc, + B, bDesc, &beta, + (int8_t*)C, cDesc, + (int8_t*)C, cDesc, + NULL, NULL, 0, stream + )); + } else { + hipblasLtPointerMode_t alphaVec = HIPBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST; + float beta = 0.0f; + has_error |= checkHipblasStatus(hipblasLtMatmulDescSetAttribute( + matmulDesc, + HIPBLASLT_MATMUL_DESC_POINTER_MODE, + &pointerMode, + sizeof(alphaVec) + )); + has_error |= checkHipblasStatus(hipblasLtMatmul( + ltHandle, matmulDesc, + row_scale, A, aDesc, + B, bDesc, &beta, + (int8_t*)C, cDesc, + (int8_t*)C, cDesc, + NULL, NULL, 0, stream + )); + } + } + + has_error |= checkHipblasStatus(hipblasLtMatrixLayoutDestroy(cDesc)); + has_error |= checkHipblasStatus(hipblasLtMatrixLayoutDestroy(bDesc)); + has_error |= checkHipblasStatus(hipblasLtMatrixLayoutDestroy(aDesc)); + has_error |= checkHipblasStatus(hipblasLtMatmulDescDestroy(matmulDesc)); + + if(has_error == 1) + printf("error detected"); + + return has_error; +#endif // NO_HIPBLASLT +} + +int fill_up_to_nearest_multiple(int value, int multiple) +{ + return value + (value % multiple == 0 ? 0 : (multiple - (value % multiple))); +} + +void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half *bias, int numRows, int numCols, hipStream_t stream) +{ + const int threads = 512; + const int num_per_thread = 4; + const int num_per_block = threads * num_per_thread; + const int n = numRows*numCols; + const int num_blocks = (n + num_per_block - 1) / num_per_block; + + hipLaunchKernelGGL(( kdequant_mm_int32_fp16<4, 128, 512>), dim3(num_blocks), dim3(threads), 0, stream, A, rowStats, colStats, out, bias, numRows, numCols, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +#define STATS_THREADS 64 +#define STATS_ITEMS 4 +#define STATS_ROWS 16 +void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols) +{ + int tile_cols = STATS_THREADS*STATS_ITEMS; + int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols); + int tiledRows = fill_up_to_nearest_multiple(rows, STATS_ROWS); + int row_tiles = (tiledRows/STATS_ROWS); + int col_tiles = (tiledCols/tile_cols); + row_tiles = row_tiles > 0 ? row_tiles : 1; + col_tiles = col_tiles > 0 ? col_tiles : 1; + int num_blocks = row_tiles * col_tiles; + + if(nnz_threshold == 0.0) + hipLaunchKernelGGL(( kgetColRowStats), dim3(num_blocks), dim3(STATS_THREADS), 0, 0, A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols); + else if(nnz_threshold != 0.0) + hipLaunchKernelGGL(( kgetColRowStats), dim3(num_blocks), dim3(STATS_THREADS), 0, 0, A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + +} + +void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int *nnz_block_ptr, float threshold, int rows, int cols) +{ + int threads = 64; + int items_per_thread = 4; + int tile_cols = threads*items_per_thread; + int tile_rows = 16; + int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols); + int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows); + int row_tiles = (tiledRows/tile_rows); + int col_tiles = (tiledCols/tile_cols); + row_tiles = row_tiles > 0 ? row_tiles : 1; + col_tiles = col_tiles > 0 ? col_tiles : 1; + int num_blocks = row_tiles * col_tiles; + + + if(threshold > 0.0f) + hipLaunchKernelGGL(( kDoubleRowColQuant<64, 4, 16, 64*4, 1>), dim3(num_blocks), dim3(threads), 0, 0, A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols); + else + hipLaunchKernelGGL(( kDoubleRowColQuant<64, 4, 16, 64*4, 0>), dim3(num_blocks), dim3(threads), 0, 0, A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols); + + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +template void transformRowToFormat(char * A, char *out, int rows, int cols) +{ + int threads = 256; + int items_per_thread = 8; + // we load 128 column values per warp + int tile_cols = 32*items_per_thread; + int tile_rows = 32; + int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols); + int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows); + int row_tiles = (tiledRows/tile_rows); + int col_tiles = (tiledCols/tile_cols); + row_tiles = row_tiles > 0 ? row_tiles : 1; + col_tiles = col_tiles > 0 ? col_tiles : 1; + int num_blocks = row_tiles * col_tiles; + + int outCols = fill_up_to_nearest_multiple(cols, 32); + int outRows = fill_up_to_nearest_multiple(rows, 32); + if(FORMAT == COL_TURING) + { + if(TRANSPOSE) + outRows = fill_up_to_nearest_multiple(cols, 8); + else + outRows = fill_up_to_nearest_multiple(rows, 8); + } + else if(FORMAT == COL_AMPERE) + { + if(TRANSPOSE) + outRows = fill_up_to_nearest_multiple(cols, 32); + else + outRows = fill_up_to_nearest_multiple(rows, 32); + } + else + { + if(TRANSPOSE) + { + outCols = fill_up_to_nearest_multiple(rows, 32); + outRows = cols; + } + } + + hipLaunchKernelGGL(( kTransformRowToFormat<256, 8, 32, 32*8, TRANSPOSE, FORMAT>), dim3(num_blocks), dim3(threads), 0, 0, A, out, rows, cols, tiledCols, outRows, outCols); + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +void spmm_coo(hipsparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B) +{ + +#ifdef NO_HIPBLASLT +#else + + hipsparseSpMatDescr_t descA; + hipsparseDnMatDescr_t descB, descC; + + float alpha = 1.0f; + float beta = 0.0f; + void *dBuffer = NULL; + size_t bufferSize = 0; + + CHECK_HIPSPARSE( hipsparseCreateCoo(&descA, A_rows, A_cols, A_nnz, + A_rowidx, A_colidx, A_vals, + HIPSPARSE_INDEX_32I, + HIPSPARSE_INDEX_BASE_ZERO, HIP_R_16F) ); + // Create dense matrix C + CHECK_HIPSPARSE( hipsparseCreateDnMat(&descC, A_rows, B_cols, ldc, C, + HIP_R_16F, HIPSPARSE_ORDER_ROW) ); + // Create dense matrix B + if(transposed_B) + { + int tmp = A_cols; + A_cols = B_cols; + B_cols = tmp; + } + + CHECK_HIPSPARSE( hipsparseCreateDnMat(&descB, A_cols, B_cols, ldb, B, + HIP_R_16F, HIPSPARSE_ORDER_ROW) ); + // allocate an external buffer if needed + CHECK_HIPSPARSE( hipsparseSpMM_bufferSize( + handle, + HIPSPARSE_OPERATION_NON_TRANSPOSE, + transposed_B ? HIPSPARSE_OPERATION_TRANSPOSE : HIPSPARSE_OPERATION_NON_TRANSPOSE, + &alpha, descA, descB, &beta, descC, HIP_R_32F, + HIPSPARSE_SPMM_ALG_DEFAULT, &bufferSize) ); + CUDA_CHECK_RETURN( hipMalloc(&dBuffer, bufferSize) ); + + // execute SpMM + CHECK_HIPSPARSE( hipsparseSpMM(handle, + HIPSPARSE_OPERATION_NON_TRANSPOSE, + transposed_B ? HIPSPARSE_OPERATION_TRANSPOSE : HIPSPARSE_OPERATION_NON_TRANSPOSE, + &alpha, descA, descB, &beta, descC, HIP_R_32F, + HIPSPARSE_SPMM_ALG_DEFAULT, dBuffer)); + + // destroy matrix/vector descriptors + CHECK_HIPSPARSE( hipsparseDestroySpMat(descA) ); + CHECK_HIPSPARSE( hipsparseDestroyDnMat(descB) ); + CHECK_HIPSPARSE( hipsparseDestroyDnMat(descC) ); + CUDA_CHECK_RETURN( hipFree(dBuffer) ); +#endif +} + +template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) +{ + + hipLaunchKernelGGL(( kspmm_coo_very_sparse_naive), dim3(nnz_rows), dim3(256), 0, 0, max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz, rowsA, rowsB, colsB); + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + + +template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols) +{ + int threads = 256; + // we load 128 column values per warp + int tiledCols = tiledCols = fill_up_to_nearest_multiple(cols, 32); + int tiledRows = 0; + + int num_blocks = idx_size; + + /*if(FORMAT == COL_TURING) + { + tiledRows = fill_up_to_nearest_multiple(rows, 8); + } + else if(FORMAT == COL_AMPERE) + { + tiledRows = fill_up_to_nearest_multiple(rows, 32); + }*/ + + //for col format on ROCm + tiledRows = rows; + + hipLaunchKernelGGL(( kExtractOutliers), dim3(num_blocks), dim3(threads), 0, 0, A, idx, out, idx_size, rows, cols, tiledRows, tiledCols); + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + + + + +template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits) +{ + + int num_blocks = (m+31)/32; + + //cout << num_blocks << endl; + //cout << lda << endl; + //cout << ldb << endl; + //cout << ldc << endl; + + //cout << m << endl; + //cout << n << endl; + //cout << k << endl; + //if(bits == 32) + //gemm_device<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + if(bits == 16) + //gemm_device<<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + hipLaunchKernelGGL(( gemm_device), dim3(num_blocks), dim3(160), 0, 0 , m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 96, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 64, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); +} + +template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) +{ + + int num_blocks = (m+31)/32; + + //cout << num_blocks << endl; + //cout << lda << endl; + //cout << ldb << endl; + //cout << ldc << endl; + + //cout << m << endl; + //cout << n << endl; + //cout << k << endl; + hipLaunchKernelGGL(( kgemm_4bit_inference), dim3(num_blocks), dim3(96), 0, 0 , m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); + //kgemm_4bit_inference<<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); + //kgemm_4bit_inference<<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); + //kgemm_4bit_inference<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); +} + +template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize) +{ + + int num_blocks = (m+3)/4; + + hipLaunchKernelGGL(( kgemm_4bit_inference_naive), dim3(num_blocks), dim3(128), 0, 0 , m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +template void func(T *A, T *B, T value, long n) +{ + int threads = 512; + int blocks = n/threads; + blocks = n % threads == 0 ? blocks : blocks + 1; + blocks = blocks > 65535 ? 65535 : blocks; + hipLaunchKernelGGL(( kfunc), dim3(blocks), dim3(512), 0, 0, A, B, value, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +//============================================================== +// TEMPLATE DEFINITIONS +//============================================================== + +template void func(float *A, float *B, float value, long n); +template void func(unsigned char *A, unsigned char *B, unsigned char value, long n); +template void func(float *A, float *B, float value, long n); +template void func(float *A, float *B, float value, long n); + +template void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference_naive(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference_naive(int m, int n, int k, hip_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, hip_bfloat16 * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference_naive(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize); + +//template void gemm_host(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits); +template void gemm_host(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits); +template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); +template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); + +template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); +template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); + +template int igemmlt<32, 0>(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, hipStream_t stream); +template int igemmlt<8, 0>(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, hipStream_t stream); +template int igemmlt<8, 1>(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, hipStream_t stream); + +template void transformRowToFormat(char * A, char *out, int rows, int cols); +template void transformRowToFormat(char * A, char *out, int rows, int cols); +template void transformRowToFormat(char * A, char *out, int rows, int cols); +template void transformRowToFormat(char * A, char *out, int rows, int cols); +template void transformRowToFormat(char * A, char *out, int rows, int cols); +template void transformRowToFormat(char * A, char *out, int rows, int cols); + +template void estimateQuantiles(half *A, float *code, float offset, int n); +template void estimateQuantiles(float *A, float *code, float offset, int n); + +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, hip_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, hip_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, hip_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, hip_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); + +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, hipStream_t stream); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, hipStream_t stream); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, hipStream_t stream); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, hipStream_t stream); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, hipStream_t stream); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, hipStream_t stream); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, hip_bfloat16 *out, int blocksize, const int n, hipStream_t stream); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, hip_bfloat16 *out, int blocksize, const int n, hipStream_t stream); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, hip_bfloat16 *out, int blocksize, const int n, hipStream_t stream); + +#define MAKE_optimizer32bit(name, gtype) \ +template void optimizer32bit(gtype* g, gtype* p, \ + float* state1, float* state2, float* unorm, float max_unorm, float param_norm, \ + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay, \ + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); + +MAKE_optimizer32bit(ADAM, half) +MAKE_optimizer32bit(ADAM, float) +MAKE_optimizer32bit(ADAM, hip_bfloat16) +MAKE_optimizer32bit(MOMENTUM, half) +MAKE_optimizer32bit(MOMENTUM, float) +MAKE_optimizer32bit(MOMENTUM, hip_bfloat16) +MAKE_optimizer32bit(RMSPROP, half) +MAKE_optimizer32bit(RMSPROP, float) +MAKE_optimizer32bit(RMSPROP, hip_bfloat16) +MAKE_optimizer32bit(LION, half) +MAKE_optimizer32bit(LION, float) +MAKE_optimizer32bit(LION, hip_bfloat16) +MAKE_optimizer32bit(ADAGRAD, half) +MAKE_optimizer32bit(ADAGRAD, float) +MAKE_optimizer32bit(ADAGRAD, hip_bfloat16) +MAKE_optimizer32bit(ADEMAMIX, half) +MAKE_optimizer32bit(ADEMAMIX, hip_bfloat16) +MAKE_optimizer32bit(ADEMAMIX, float) + +#define MAKE_optimizerStatic8bit(name, gtype) \ +template void optimizerStatic8bit(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \ + float *unorm, float max_unorm, float param_norm, \ + float beta1, float beta2, \ + float eps, int step, float lr, \ + float* quantiles1, float* quantiles2, \ + float* max1, float* max2, float* new_max1, float* new_max2, \ + float weight_decay, \ + const float gnorm_scale, int n); \ + +MAKE_optimizerStatic8bit(ADAM, half) +MAKE_optimizerStatic8bit(ADAM, float) +MAKE_optimizerStatic8bit(MOMENTUM, half) +MAKE_optimizerStatic8bit(MOMENTUM, float) +MAKE_optimizerStatic8bit(RMSPROP, half) +MAKE_optimizerStatic8bit(RMSPROP, float) +MAKE_optimizerStatic8bit(LION, half) +MAKE_optimizerStatic8bit(LION, float) + +#define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \ +template void optimizerStatic8bitBlockwise(gtype* p, gtype* g, \ + unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, \ + float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n); \ + +MAKE_optimizerStatic8bitBlockwise(half, ADAM); +MAKE_optimizerStatic8bitBlockwise(float, ADAM); +MAKE_optimizerStatic8bitBlockwise(hip_bfloat16, ADAM); +MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM); +MAKE_optimizerStatic8bitBlockwise(float, MOMENTUM); +MAKE_optimizerStatic8bitBlockwise(hip_bfloat16, MOMENTUM); +MAKE_optimizerStatic8bitBlockwise(half, RMSPROP); +MAKE_optimizerStatic8bitBlockwise(float, RMSPROP); +MAKE_optimizerStatic8bitBlockwise(hip_bfloat16, RMSPROP); +MAKE_optimizerStatic8bitBlockwise(half, LION); +MAKE_optimizerStatic8bitBlockwise(float, LION); +MAKE_optimizerStatic8bitBlockwise(hip_bfloat16, LION); +MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD); +MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD); +MAKE_optimizerStatic8bitBlockwise(hip_bfloat16, ADAGRAD); +MAKE_optimizerStatic8bitBlockwise(half, ADEMAMIX); +MAKE_optimizerStatic8bitBlockwise(hip_bfloat16, ADEMAMIX); +MAKE_optimizerStatic8bitBlockwise(float, ADEMAMIX); + +template void percentileClipping(float * g, float *gnorm_vec, int step, const int n); +template void percentileClipping(half * g, float *gnorm_vec, int step, const int n); diff --git a/csrc/ops_hip.cuh b/csrc/ops_hip.cuh new file mode 100644 index 000000000..bc7b61e08 --- /dev/null +++ b/csrc/ops_hip.cuh @@ -0,0 +1,203 @@ +// !!! This is a file automatically generated by hipify!!! +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + + +#ifndef ops_H +#define ops_H + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#define CUDA_CHECK_RETURN(value) { \ + hipError_t _m_cudaStat = value; \ + if (_m_cudaStat != hipSuccess) { \ + fprintf(stderr, "Error %s at line %d in file %s\n", \ + hipGetErrorString(_m_cudaStat), __LINE__, __FILE__); \ + exit(1); \ + } } + +#define THREADS_PER_BLOCKS (512) + +#define CHECK_HIPSPARSE(value) { \ + hipsparseStatus_t _m_hipStat = value; \ + if (_m_hipStat != HIPSPARSE_STATUS_SUCCESS) { \ + fprintf(stderr, "Error at line %d in file %s\n", \ + __LINE__, __FILE__); \ + exit(1); \ + } } + + + +#define THREADS_PER_BLOCKS (512) + + +inline void v(hipError_t status) { + if (status != hipSuccess) { + printf("hip API failed with status %d: %s\n", status, hipGetErrorString(status)); + throw std::logic_error("hip API failed"); + } +} + +inline int checkHipblasStatus(hipblasStatus_t status) { + if (status != HIPBLAS_STATUS_SUCCESS) { + printf("hipBLAS API failed with status %d\n", status); + //throw std::logic_error("cuBLAS API failed"); + return 1; + } + return 0; +} + +typedef enum Operations_t +{ + ksmul = 0, +} Operations_t; + +typedef enum Optimizer_t +{ + ADAM = 0, + MOMENTUM = 1, + RMSPROP = 2, + LARS = 3, + ADAGRAD = 4, + LION = 5, + ADEMAMIX = 6, +} Optimizer_t; + +typedef enum Transform_t +{ + ROW = 0, + COL = 1, + COL32 = 2, + COL_TURING = 3, + COL_AMPERE = 4, +} Transform_t; + +typedef enum DataType_t +{ + General8bit = 0, + FP4 = 1, + NF4 = 2, +} DataType_t; + +typedef enum Funcs_t +{ + FILL = 0, + ARANGE = 1, + _MUL = 2, +} Funcs_t; + +class Context +{ + public: + rocblas_handle m_handle; + + Context() + { + rocblas_handle handle; + rocblas_create_handle(&handle); + m_handle = handle; + } + +}; + +class ContextLt +{ + public: + hipblasLtHandle_t m_handle; + + ContextLt() + { + hipblasLtHandle_t handle; + hipblasLtCreate(&handle); + m_handle = handle; + } +}; + +class ContextHipsparse +{ + public: + hipsparseHandle_t m_handle; + + ContextHipsparse() + { + hipsparseHandle_t handle; + hipsparseCreate(&handle); + m_handle = handle; + } + +}; + + +template void estimateQuantiles(T *A, float *code, float offset, int n); + +void quantize(float *code, float *A, unsigned char *out, int n); +void dequantize(float *code, unsigned char *A, float *out, int n, hipStream_t stream); +template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n, hipStream_t stream); + +template void optimizer32bit(T* g, T* p, + float* state1, float* state2, float *unorm, float max_unorm, float param_norm, + float beta1, float beta2, float beta3, float alpha, float eps, float weight_decay, + int step, float lr, const float gnorm_scale, bool skip_zeros, int n); + +template void optimizerStatic8bit(T* p, T* g, unsigned char* state1, unsigned char* state2, + float *unorm, float max_unorm, float param_norm, + float beta1, float beta2, + float eps, int step, float lr, + float* quantiles1, float* quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + float weight_decay, + const float gnorm_scale, int n); + +template void optimizerStatic8bitBlockwise(T* p, T* g, + unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, + float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, + bool skip_zeros, int n); + +template void percentileClipping(T * g, float *gnorm_vec, int step, const int n); + +void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n); + +void gemmex(Context * context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); +void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc, + long long int strideA, long long int strideB, long long int strideC, int batchCount); + + +template int igemmlt(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, hipStream_t stream); + +template void transform(hipblasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2); +void cutlass_igemm(bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); +void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half* bias, int numRows, int numCols, hipStream_t stream); +void getRowStats(half * A, float *rowStats, float threshold, int rows, int cols, hipStream_t stream); +void int8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols, hipStream_t stream); + +template void transformRowToFormat(char * A, char *out, int rows, int cols); + +void spmm_coo(hipsparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B); + +template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); + +template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); + +void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB); + +template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits); +template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, hipStream_t stream); + +template void func(T *A, T *B, T value, long n); + +#endif diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 0ced0394c..489151a87 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -6,11 +6,24 @@ #if BUILD_CUDA #include #endif +#if BUILD_HIP +#include +#endif #if BUILD_MPS // #include #endif +#if BUILD_NPU +#include +#endif #include +// Compatibility between HIP/CUDA APIs +#if BUILD_HIP +#define cudaStream_t hipStream_t +#define __nv_bfloat16 hip_bfloat16 +#define cublasLtHandle_t hipblasLtHandle_t +#endif + // We cannot call templated code from C, so we wrap the template in a C compatible call here if necessary. // We use macro functions to expand all the different optimizers. Looks ugly, and is ugly, but its better than to // maintain all that boilerplate @@ -18,7 +31,7 @@ // UNMANGLED CALLS //=================================================================================== -#if BUILD_CUDA +#if defined(BUILD_CUDA) || defined(BUILD_HIP) void estimateQuantiles_fp32(float *A, float *code, float offset, int n){ estimateQuantiles(A, code, offset, n); } void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles(A, code, offset, n); } @@ -34,8 +47,13 @@ void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, floa void gemm_4bit_inference_naive_fp16(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream) { gemm_4bit_inference_naive(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); } +#if defined(BUILD_CUDA) void gemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream) { gemm_4bit_inference_naive<__nv_bfloat16, 16>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); } +#elif defined(BUILD_HIP) +void gemm_4bit_inference_naive_bf16(int m, int n, int k, hip_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, hip_bfloat16 * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream) +{ gemm_4bit_inference_naive(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); } +#endif void gemm_4bit_inference_naive_fp32(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream) { gemm_4bit_inference_naive(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); } @@ -61,12 +79,20 @@ MAKE_FUNC32(momentum, MOMENTUM, float, 32) MAKE_FUNC32(momentum, MOMENTUM, half, 16) MAKE_FUNC32(adam, ADAM, float, fp32) MAKE_FUNC32(adam, ADAM, half, fp16) +#if defined(BUILD_CUDA) MAKE_FUNC32(adam, ADAM, __nv_bfloat16, bf16) +#elif defined(BUILD_HIP) +MAKE_FUNC32(adam, ADAM, hip_bfloat16, bf16) +#endif MAKE_FUNC32(rmsprop, RMSPROP, float, 32) MAKE_FUNC32(rmsprop, RMSPROP, half, 16) MAKE_FUNC32(lion, LION, float, fp32) MAKE_FUNC32(lion, LION, half, fp16) +#if defined(BUILD_CUDA) MAKE_FUNC32(lion, LION, __nv_bfloat16, bf16) +#elif defined(BUILD_HIP) +MAKE_FUNC32(lion, LION, hip_bfloat16, bf16) +#endif MAKE_FUNC32(adagrad, ADAGRAD, float, 32) MAKE_FUNC32(adagrad, ADAGRAD, half, 16) MAKE_FUNC32(ademamix, ADEMAMIX, float, fp32) @@ -102,24 +128,39 @@ void fname##_8bit_blockwise_grad_##gbits(gtype* p, gtype* g, \ float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n)\ { optimizerStatic8bitBlockwise(p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); }\ +MAKE_BLOCKWISE8(adagrad, ADAGRAD, half, fp16) +MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, fp32) +MAKE_BLOCKWISE8(ademamix, ADEMAMIX, half, fp16) +MAKE_BLOCKWISE8(ademamix, ADEMAMIX, float, fp32) MAKE_BLOCKWISE8(adam, ADAM, half, fp16) -MAKE_BLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16) MAKE_BLOCKWISE8(adam, ADAM, float, fp32) +MAKE_BLOCKWISE8(lion, LION, half, fp16) +MAKE_BLOCKWISE8(lion, LION, float, fp32) MAKE_BLOCKWISE8(momentum, MOMENTUM, half, fp16) -MAKE_BLOCKWISE8(momentum, MOMENTUM, __nv_bfloat16, bf16) MAKE_BLOCKWISE8(momentum, MOMENTUM, float, fp32) MAKE_BLOCKWISE8(rmsprop, RMSPROP, half, fp16) -MAKE_BLOCKWISE8(rmsprop, RMSPROP, __nv_bfloat16, bf16) MAKE_BLOCKWISE8(rmsprop, RMSPROP, float, fp32) -MAKE_BLOCKWISE8(adagrad, ADAGRAD, half, fp16) + + +#if defined(BUILD_CUDA) MAKE_BLOCKWISE8(adagrad, ADAGRAD, __nv_bfloat16, bf16) -MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, fp32) -MAKE_BLOCKWISE8(lion, LION, half, fp16) -MAKE_BLOCKWISE8(lion, LION, __nv_bfloat16, bf16) -MAKE_BLOCKWISE8(lion, LION, float, fp32) -MAKE_BLOCKWISE8(ademamix, ADEMAMIX, half, fp16) +MAKE_BLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16) MAKE_BLOCKWISE8(ademamix, ADEMAMIX, __nv_bfloat16, bf16) -MAKE_BLOCKWISE8(ademamix, ADEMAMIX, float, fp32) +MAKE_BLOCKWISE8(lion, LION, __nv_bfloat16, bf16) +MAKE_BLOCKWISE8(momentum, MOMENTUM, __nv_bfloat16, bf16) +MAKE_BLOCKWISE8(rmsprop, RMSPROP, __nv_bfloat16, bf16) +#elif defined(BUILD_HIP) +MAKE_BLOCKWISE8(adagrad, ADAGRAD, hip_bfloat16, bf16) +MAKE_BLOCKWISE8(adam, ADAM, hip_bfloat16, bf16) +MAKE_BLOCKWISE8(ademamix, ADEMAMIX, hip_bfloat16, bf16) +MAKE_BLOCKWISE8(lion, LION, hip_bfloat16, bf16) +MAKE_BLOCKWISE8(momentum, MOMENTUM, hip_bfloat16, bf16) +MAKE_BLOCKWISE8(rmsprop, RMSPROP, hip_bfloat16, bf16) +#endif + + + + void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping(g, gnorm_vec, step, n); } @@ -129,9 +170,15 @@ void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char void quantizeBlockwise_fp16_fp4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } void quantizeBlockwise_fp16_nf4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } +#if defined(BUILD_CUDA) void quantizeBlockwise_bf16(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<__nv_bfloat16, 0, General8bit>(code, A, absmax, out, NULL, 0, blocksize, n); } void quantizeBlockwise_bf16_fp4(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<__nv_bfloat16, 0, FP4>(NULL, A, absmax, out, NULL, 0, blocksize, n); } void quantizeBlockwise_bf16_nf4(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<__nv_bfloat16, 0, NF4>(NULL, A, absmax, out, NULL, 0, blocksize, n); } +#elif defined(BUILD_HIP) +void quantizeBlockwise_bf16(float * code, hip_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(code, A, absmax, out, NULL, 0, blocksize, n); } +void quantizeBlockwise_bf16_fp4(float * code, hip_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } +void quantizeBlockwise_bf16_nf4(float * code, hip_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } +#endif void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(code, A, absmax, out, NULL, 0, blocksize, n); } void quantizeBlockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } @@ -145,17 +192,35 @@ void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, floa void dequantizeBlockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream); } void dequantizeBlockwise_fp32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream); } + +#if defined(BUILD_CUDA) void dequantizeBlockwise_bf16(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<__nv_bfloat16, General8bit>(code, A, absmax, out, blocksize, n, stream); } void dequantizeBlockwise_bf16_fp4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<__nv_bfloat16, FP4>(NULL, A, absmax, out, blocksize, n, stream); } void dequantizeBlockwise_bf16_nf4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<__nv_bfloat16, NF4>(NULL, A, absmax, out, blocksize, n, stream); } +#elif defined(BUILD_HIP) +void dequantizeBlockwise_bf16(float *code, unsigned char *A, float *absmax, hip_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise(code, A, absmax, out, blocksize, n, stream); } +void dequantizeBlockwise_bf16_fp4(float *code, unsigned char *A, float *absmax, hip_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream); } +void dequantizeBlockwise_bf16_nf4(float *code, unsigned char *A, float *absmax, hip_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream); } +#endif +#if BUILD_CUDA #define MAKE_FUNC_TRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \ void transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(cublasLtHandle_t ltHandle, dtype *A, dtype *out, int dim1, int dim2) \ { \ transform(ltHandle, A, out, dim1, dim2); \ } \ +#endif +#if BUILD_HIP +#define MAKE_FUNC_TRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \ +void transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(hipblasLtHandle_t ltHandle, dtype *A, dtype *out, int dim1, int dim2) \ +{ \ + transform(ltHandle, A, out, dim1, dim2); \ +} \ + +#endif + MAKE_FUNC_TRANSFORM(8, row, col, n, int8_t, ROW, COL, false, 8); MAKE_FUNC_TRANSFORM(8, row, row, n, int8_t, ROW, ROW, false, 8); MAKE_FUNC_TRANSFORM(8, row, col32, n, int8_t, ROW, COL32, false, 8); @@ -165,6 +230,14 @@ MAKE_FUNC_TRANSFORM(8, row, col_ampere, n, int8_t, ROW, COL_AMPERE, false, 8); MAKE_FUNC_TRANSFORM(8, col32, row, n, int8_t, COL32, ROW, false, 8); MAKE_FUNC_TRANSFORM(32, col32, row, n, int32_t, COL32, ROW, false, 32); +#if defined(BUILD_HIP) +MAKE_FUNC_TRANSFORM(8, row, col, t, int8_t, ROW, COL, true, 8); +MAKE_FUNC_TRANSFORM(32, row, col, n, int32_t, ROW, COL, false, 32); +MAKE_FUNC_TRANSFORM(32, row, col, t, int32_t, ROW, COL, true, 32); +MAKE_FUNC_TRANSFORM(8, col, row, n, int8_t, COL, ROW, false, 8); +MAKE_FUNC_TRANSFORM(32, col, row, n, int32_t, COL, ROW, false, 32); +#endif + void transform_row2col32(char * A, char *out, int rows, int cols){ transformRowToFormat(A, out, rows, cols); } void transform_row2col32T(char * A, char *out, int rows, int cols){ transformRowToFormat(A, out, rows, cols); } void transform_row2turing(char * A, char *out, int rows, int cols){ transformRowToFormat(A, out, rows, cols); } @@ -190,11 +263,12 @@ void spmm_coo_very_sparse_naive_fp16(int *max_count, int *max_idx, int *offset_r void spmm_coo_very_sparse_naive_int8(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) { spmm_coo_very_sparse_naive(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); } + #endif extern "C" { -#if BUILD_CUDA +#if defined(BUILD_CUDA) || defined(BUILD_HIP) void cestimate_quantiles_fp32(float *A, float *code, float offset, int n){ estimateQuantiles_fp32(A, code, offset, n); } void cestimate_quantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles_fp16(A, code, offset, n); } void cquantize(float *code, float *A, unsigned char *out, int n){ quantize(code, A, out, n); } @@ -216,6 +290,7 @@ extern "C" void cdequantize_blockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n, stream); } void cdequantize_blockwise_fp32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n, stream); } +#if defined(BUILD_CUDA) void cquantize_blockwise_bf16(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_bf16(code, A, absmax, out, blocksize, n); } void cquantize_blockwise_bf16_fp4(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n); } void cquantize_blockwise_bf16_nf4(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n); } @@ -224,6 +299,15 @@ extern "C" void cdequantize_blockwise_bf16_fp4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n, stream); } void cdequantize_blockwise_bf16_nf4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n, stream); } +#elif defined(BUILD_HIP) + void cquantize_blockwise_bf16(float * code, hip_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_bf16(code, A, absmax, out, blocksize, n); } + void cquantize_blockwise_bf16_fp4(float * code, hip_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n); } + void cquantize_blockwise_bf16_nf4(float * code, hip_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n); } + + void cdequantize_blockwise_bf16(float *code, unsigned char *A, float *absmax, hip_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_bf16(code, A, absmax, out, blocksize, n, stream); } + void cdequantize_blockwise_bf16_fp4(float *code, unsigned char *A, float *absmax, hip_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n, stream); } + void cdequantize_blockwise_bf16_nf4(float *code, unsigned char *A, float *absmax, hip_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n, stream); } +#endif #define MAKE_CFUNC32(name, gtype, gbits) \ void c##name##32bit_grad_##gbits(gtype *g, gtype *p, \ float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \ @@ -234,19 +318,29 @@ extern "C" MAKE_CFUNC32(adam, float, fp32) MAKE_CFUNC32(adam, half, fp16) + #if defined(BUILD_CUDA) MAKE_CFUNC32(adam, __nv_bfloat16, bf16) + MAKE_CFUNC32(ademamix, __nv_bfloat16, bf16) + #elif defined(BUILD_HIP) + MAKE_CFUNC32(adam, hip_bfloat16, bf16) + MAKE_CFUNC32(ademamix, hip_bfloat16, bf16) + #endif MAKE_CFUNC32(momentum, float, 32) MAKE_CFUNC32(momentum, half, 16) MAKE_CFUNC32(rmsprop, float, 32) MAKE_CFUNC32(rmsprop, half, 16) MAKE_CFUNC32(lion, float, fp32) MAKE_CFUNC32(lion, half, fp16) + #if defined(BUILD_CUDA) MAKE_CFUNC32(lion, __nv_bfloat16, bf16) + #elif defined(BUILD_HIP) + MAKE_CFUNC32(lion, hip_bfloat16, bf16) + #endif MAKE_CFUNC32(adagrad, float, 32) MAKE_CFUNC32(adagrad, half, 16) MAKE_CFUNC32(ademamix, float, fp32) MAKE_CFUNC32(ademamix, half, fp16) - MAKE_CFUNC32(ademamix, __nv_bfloat16, bf16) + #define MAKE_CFUNC8(name, gtype, gbits) \ void c##name##_static_8bit_grad_##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \ @@ -276,24 +370,38 @@ extern "C" float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) \ { fname##_8bit_blockwise_grad_##gbits(p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); } \ + MAKE_CBLOCKWISE8(ademamix, ADEMAMIX, half, fp16) + MAKE_CBLOCKWISE8(ademamix, ADEMAMIX, float, fp32) + MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, fp16) + MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, fp32) MAKE_CBLOCKWISE8(adam, ADAM, half, fp16) MAKE_CBLOCKWISE8(adam, ADAM, float, fp32) - MAKE_CBLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16) + MAKE_CBLOCKWISE8(lion, LION, half, fp16) + MAKE_CBLOCKWISE8(lion, LION, float, fp32) MAKE_CBLOCKWISE8(momentum, MOMENTUM, half, fp16) MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, fp32) - MAKE_CBLOCKWISE8(momentum, MOMENTUM, __nv_bfloat16, bf16) + MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, fp16) MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, fp32) - MAKE_CBLOCKWISE8(rmsprop, RMSPROP, __nv_bfloat16, bf16) - MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, fp16) - MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, fp32) + + + #if defined(BUILD_CUDA) MAKE_CBLOCKWISE8(adagrad, ADAGRAD, __nv_bfloat16, bf16) - MAKE_CBLOCKWISE8(lion, LION, half, fp16) - MAKE_CBLOCKWISE8(lion, LION, float, fp32) - MAKE_CBLOCKWISE8(lion, LION, __nv_bfloat16, bf16) - MAKE_CBLOCKWISE8(ademamix, ADEMAMIX, half, fp16) - MAKE_CBLOCKWISE8(ademamix, ADEMAMIX, float, fp32) + MAKE_CBLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16) MAKE_CBLOCKWISE8(ademamix, ADEMAMIX, __nv_bfloat16, bf16) + MAKE_CBLOCKWISE8(lion, LION, __nv_bfloat16, bf16) + MAKE_CBLOCKWISE8(momentum, MOMENTUM, __nv_bfloat16, bf16) + MAKE_CBLOCKWISE8(rmsprop, RMSPROP, __nv_bfloat16, bf16) + #elif defined(BUILD_HIP) + MAKE_CBLOCKWISE8(adagrad, ADAGRAD, hip_bfloat16, bf16) + MAKE_CBLOCKWISE8(adam, ADAM, hip_bfloat16, bf16) + MAKE_CBLOCKWISE8(ademamix, ADEMAMIX, hip_bfloat16, bf16) + MAKE_CBLOCKWISE8(lion, LION, hip_bfloat16, bf16) + MAKE_CBLOCKWISE8(momentum, MOMENTUM, hip_bfloat16, bf16) + MAKE_CBLOCKWISE8(rmsprop, RMSPROP, hip_bfloat16, bf16) + #endif + + void cpercentile_clipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping_g32(g, gnorm_vec, step, n); } void cpercentile_clipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping_g16(g, gnorm_vec, step, n); } @@ -306,24 +414,31 @@ extern "C" { strided_gemmex(context, transposeA, transposeB, m, n, k, A, B, C, lda, ldb, ldc, strideA, strideB, strideC, batchCount); } Context *get_context(){ return new Context(); } +#if BUILD_CUDA ContextCusparse *get_cusparse(){ return new ContextCusparse(); } +#endif - int cigemmlt_32(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) { - return igemmlt_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream); - } - int cigemmlt_8(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) { - return igemmlt_8((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream); - } - int cigemmlt_8_rowscale(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) { - return igemmlt_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream); - } +#if BUILD_HIP + ContextHipsparse *get_hipsparse(){ return new ContextHipsparse(); } +#endif - #define MAKE_FUNC_CTRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \ +int cigemmlt_32(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) { + return igemmlt_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream); +} +int cigemmlt_8(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) { + return igemmlt_8((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream); +} +int cigemmlt_8_rowscale(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) { + return igemmlt_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream); +} +#define MAKE_FUNC_CTRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \ void ctransform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(Context *context, dtype *A, dtype *out, int dim1, int dim2) \ { \ transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose((cublasLtHandle_t) context->m_handle, A, out, dim1, dim2); \ } \ + + #if defined(BUILD_CUDA) MAKE_FUNC_CTRANSFORM(8, row, col, n, int8_t, ROW, COL, false, 8) MAKE_FUNC_CTRANSFORM(8, row, row, n, int8_t, ROW, ROW, false, 8) MAKE_FUNC_CTRANSFORM(8, row, col32, n, int8_t, ROW, COL32, false, 8) @@ -332,6 +447,13 @@ extern "C" MAKE_FUNC_CTRANSFORM(8, row, col_ampere, n, int8_t, ROW, COL_AMPERE, false, 8) MAKE_FUNC_CTRANSFORM(8, col32, row, n, int8_t, COL32, ROW, false, 8) MAKE_FUNC_CTRANSFORM(32, col32, row, n, int32_t, COL32, ROW, false, 32) + #elif defined(BUILD_HIP) + MAKE_FUNC_CTRANSFORM(8, row, col, t, int8_t, ROW, COL, true, 8) + MAKE_FUNC_CTRANSFORM(32, row, col, n, int32_t, ROW, COL, false, 32) + MAKE_FUNC_CTRANSFORM(32, row, col, t, int32_t, ROW, COL, true, 32) + MAKE_FUNC_CTRANSFORM(8, col, row, n, int8_t, COL, ROW, false, 8) + MAKE_FUNC_CTRANSFORM(32, col, row, n, int32_t, COL, ROW, false, 32) + #endif void cdequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half* bias, int numRows, int numCols, cudaStream_t stream) { dequant_mm_int32_fp16(A, rowStats, colStats, out, bias, numRows, numCols, stream); } @@ -360,8 +482,15 @@ extern "C" void ctransform_row2ampereT(char * A, char *out, int rows, int cols) { transform_row2ampereT(A, out, rows, cols); } +#if BUILD_CUDA void cspmm_coo(ContextCusparse *context, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B) { spmm_coo((cusparseHandle_t) context->m_handle, A_rowidx, A_colidx, A_vals, A_nnz, A_rows, A_cols, B_cols, ldb, B, ldc, C, transposed_B); } +#endif + +#if BUILD_HIP + void cspmm_coo(ContextHipsparse *context, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B) + { spmm_coo((hipsparseHandle_t) context->m_handle, A_rowidx, A_colidx, A_vals, A_nnz, A_rows, A_cols, B_cols, ldb, B, ldc, C, transposed_B); } +#endif void cspmm_coo_very_sparse_naive_fp16(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) { spmm_coo_very_sparse_naive_fp16(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); } @@ -381,6 +510,7 @@ extern "C" void cgemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize) { gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } +#if BUILD_CUDA void *cget_managed_ptr(size_t bytes) { void *ptr; @@ -400,6 +530,29 @@ extern "C" CUDA_CHECK_RETURN(cudaMemPrefetchAsync(ptr, bytes, device, 0)); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } +#endif + +#if BUILD_HIP + void *cget_managed_ptr(size_t bytes) + { + void *ptr; + CUDA_CHECK_RETURN(hipMallocManaged(&ptr, bytes, hipMemAttachHost)); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + + return ptr; + } + + void cprefetch(void *ptr, size_t bytes, int device) + { + + int hasPrefetch = 0; + CUDA_CHECK_RETURN(hipDeviceGetAttribute(&hasPrefetch, hipDeviceAttributeConcurrentManagedAccess, device)); // 40ns overhead + if (hasPrefetch == 0) return; + + CUDA_CHECK_RETURN(hipMemPrefetchAsync(ptr, bytes, device, 0)); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + } +#endif #define CMAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \ void c##fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ fname##_##type_name(A, B, value, n); } \ @@ -412,14 +565,27 @@ extern "C" void cgemm_4bit_inference_naive_fp16(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream) { gemm_4bit_inference_naive_fp16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); } + #if defined(BUILD_CUDA) void cgemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream) + #elif defined(BUILD_HIP) + void cgemm_4bit_inference_naive_bf16(int m, int n, int k, hip_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, hip_bfloat16 * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream) + #endif { gemm_4bit_inference_naive_bf16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); } + void cgemm_4bit_inference_naive_fp32(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream) { gemm_4bit_inference_naive_fp32(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); } #endif +#if BUILD_NPU + void cdequantize_blockwise_fp32_nf4(uint8_t *A, uint8_t *absmax, uint8_t *out, uint32_t blocksize, uint32_t n, void* stream) + { dequantizeBlockwiseNf4(A, absmax, out, blocksize, n, stream, 1); } + + void cdequantize_blockwise_fp16_nf4(uint8_t *A, uint8_t *absmax, uint8_t *out, uint32_t blocksize, uint32_t n, void* stream) + { dequantizeBlockwiseNf4(A, absmax, out, blocksize, n, stream, 2); } +#endif + void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n){ dequantize_cpu(code, A, absmax, out, blocksize, n); } } diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx index da4ab3b44..17b2d37d5 100644 --- a/docs/source/installation.mdx +++ b/docs/source/installation.mdx @@ -254,6 +254,13 @@ Compatible hardware and functioning `import intel_extension_for_pytorch as ipex` Please refer to [the official Intel installations instructions](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=cpu&version=v2.4.0%2bcpu&os=linux%2fwsl2) for guidance on how to pip install the necessary `intel_extension_for_pytorch` dependency. + + + +Compatible hardware and functioning `import torch_npu` capable environment with Python `3.10` as the minimum requirement. + +Please refer to [the official Ascend installations instructions](https://www.hiascend.com/document/detail/zh/Pytorch/60RC3/configandinstg/instg/insg_0001.html) for guidance on how to pip install the necessary `torch_npu` dependency. + @@ -286,6 +293,7 @@ pip install --force-reinstall 'https://github.com/bitsandbytes-foundation/bitsan ``` + Compatible hardware and functioning `import torch_npu` capable environment with Python `3.10` as the minimum requirement. @@ -333,10 +341,10 @@ pip install -e . # `-e` for "editable" install, when developing BNB (otherwise -#### Intel CPU +#### Intel CPU / XPU > [!TIP] -> Intel CPU backend only supports building from source; for now, please follow the instructions below. +> Intel CPU / XPU backend only supports building from source; for now, please follow the instructions below. Similar to the CUDA case, you can compile bitsandbytes from source for Linux and Windows systems. diff --git a/docs/source/non_cuda_backends.mdx b/docs/source/non_cuda_backends.mdx index 728606b7b..fda78e589 100644 --- a/docs/source/non_cuda_backends.mdx +++ b/docs/source/non_cuda_backends.mdx @@ -27,18 +27,25 @@ Thank you for your support! ### Intel -The following performance data is collected from Intel 4th Gen Xeon (SPR) platform. The tables show speed-up and memory compared with different data types of [Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf). +The below performance data is collected from the Intel 4th Gen Xeon (SPR) platform. The tables show speed-up and memory compared with different data types of [meta-llama/Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct). + +You may run `benchmarking/generation_benchmark.py` to reproduce the below model memory and inference results. Please note that you need to bind cores if you are using the CPU to benchmark. For example, run `numactl -C 0-55 -m 0 python generation_benchmark.py --quant_type nf4` on Intel 4th Gen Xeon with single socket. + +The finetune results are selected from [peft](https://github.com/huggingface/peft/blob/main/examples/olora_finetuning/olora_finetuning.py). + +#### Model memory (CPU) +| Data Type | BF16 | INT8 | NF4 | FP4 | +|---|---|---|---|---| +| Memory (GB) | 15.0 | 8.5 | 5.2 | 5.2 | #### Inference (CPU) | Data Type | BF16 | INT8 | NF4 | FP4 | |---|---|---|---|---| -| Speed-Up (vs BF16) | 1.0x | 0.6x | 2.3x | 0.03x | -| Memory (GB) | 13.1 | 7.6 | 5.0 | 4.6 | +| Speed-Up (vs BF16) | 1.0x | 0.57x | 2.6x | 0.1x | #### Fine-Tuning (CPU) -| Data Type | AMP BF16 | INT8 | NF4 | FP4 | +| Data Type | BF16 | INT8 | NF4 | FP4 | |---|---|---|---|---| -| Speed-Up (vs AMP BF16) | 1.0x | 0.38x | 0.07x | 0.07x | -| Memory (GB) | 40 | 9 | 6.6 | 6.6 | +| Speed-Up (vs BF16) | 1.0x | 0.91x | 1.0x | 1.0x | diff --git a/pyproject.toml b/pyproject.toml index 515d90385..93d7b1829 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools >= 63.0.0"] +requires = ["setuptools >= 63.0.0", "setuptools_scm >= 8.0"] build-backend = "setuptools.build_meta" [project] @@ -78,7 +78,11 @@ package-data = { "*" = ["libbitsandbytes*.*"] } include = ["bitsandbytes*"] [tool.setuptools.dynamic] -version = {attr = "bitsandbytes.__version__"} +#version = {attr = "bitsandbytes.__version__"} + +[tool.setuptools_scm] +local_scheme = "no-local-version" +version_file = "bitsandbytes/_version.py" [tool.pytest.ini_options] addopts = "-rP" diff --git a/setup.py b/setup.py index 74b9debb9..2263f3bdd 100644 --- a/setup.py +++ b/setup.py @@ -1,15 +1,19 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. from setuptools import find_packages, setup from setuptools.dist import Distribution +VERSION = "1.0.0" -# Tested with wheel v0.29.0 + +# Tested with wheel v0.45.1 class BinaryDistribution(Distribution): def has_ext_modules(self): return True -setup(version="0.45.3.dev0", packages=find_packages(), distclass=BinaryDistribution) +def write_version_file(version, filepath="bitsandbytes/_version.py"): + with open(filepath, "w") as f: + f.write(f'__version__ = "{version}"\n') + return version + + +setup(packages=find_packages(), distclass=BinaryDistribution, version=write_version_file(VERSION)) diff --git a/tests/helpers.py b/tests/helpers.py index 02cb881a3..e93c11b70 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -62,3 +62,10 @@ def id_formatter(label: str): def describe_dtype(dtype: torch.dtype) -> str: return DTYPE_NAMES.get(dtype) or str(dtype).rpartition(".")[2] + + +def get_blocksizes(hip_env: bool) -> List[int]: + if not hip_env: + return [4096, 2048, 1024, 512, 256, 128, 64] + else: + return [4096, 2048, 1024, 512, 256, 128] diff --git a/tests/test_autograd.py b/tests/test_autograd.py index ae2529542..259d9928a 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -4,6 +4,7 @@ import torch import bitsandbytes as bnb +from bitsandbytes.cextension import BNB_HIP_VERSION from tests.helpers import ( BOOLEAN_TRIPLES, BOOLEAN_TUPLES, @@ -199,6 +200,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool assert (idx == 0).sum().item() < n * 0.02 +@pytest.mark.skipif(0 < BNB_HIP_VERSION < 601, reason="this test is supported on ROCm from 6.1") @pytest.mark.parametrize("dim1", [40], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [64, 0], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim3", [32], ids=id_formatter("dim3")) diff --git a/tests/test_cuda_setup_evaluator.py b/tests/test_cuda_setup_evaluator.py index 79406472e..3d8b688ee 100644 --- a/tests/test_cuda_setup_evaluator.py +++ b/tests/test_cuda_setup_evaluator.py @@ -1,6 +1,6 @@ import pytest -from bitsandbytes.cextension import get_cuda_bnb_library_path +from bitsandbytes.cextension import HIP_ENVIRONMENT, get_cuda_bnb_library_path from bitsandbytes.cuda_specs import CUDASpecs @@ -13,11 +13,13 @@ def cuda120_spec() -> CUDASpecs: ) +@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm") def test_get_cuda_bnb_library_path(monkeypatch, cuda120_spec): monkeypatch.delenv("BNB_CUDA_VERSION", raising=False) assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda120" +@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm") def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog): monkeypatch.setenv("BNB_CUDA_VERSION", "110") assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda110" diff --git a/tests/test_functional.py b/tests/test_functional.py index c8ac20896..6e3de18c0 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -11,10 +11,12 @@ import bitsandbytes as bnb from bitsandbytes import functional as F +from bitsandbytes.cextension import BNB_HIP_VERSION, HIP_ENVIRONMENT, ROCM_GPU_ARCH from tests.helpers import ( BOOLEAN_TUPLES, TRUE_FALSE, describe_dtype, + get_blocksizes, get_test_dims, id_formatter, ) @@ -116,7 +118,7 @@ def test_estimate_quantiles(dtype): @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("nested", TRUE_FALSE, ids=id_formatter("nested")) -@pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64]) +@pytest.mark.parametrize("blocksize", get_blocksizes(HIP_ENVIRONMENT)) @pytest.mark.parametrize("signed", TRUE_FALSE, ids=id_formatter("signed")) def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed): diffs = [] @@ -431,13 +433,17 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose): @pytest.mark.parametrize("dim4", [512], ids=id_formatter("dim4")) @pytest.mark.parametrize("dims", (2, 3), ids=id_formatter("dims")) @pytest.mark.parametrize("ldb", (0,), ids=id_formatter("ldb")) -def test_int8_linear_matmul(dim1, dim2, dim3, dim4, dims, ldb): +@pytest.mark.parametrize("device", ("cuda", "cpu"), ids=id_formatter("device")) +def test_int8_linear_matmul(dim1, dim2, dim3, dim4, dims, ldb, device): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("this test is not supported on ROCm yet") + for i in range(k): if dims == 2: - A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to(torch.int8) + A = torch.randint(-128, 127, size=(dim1, dim3), device=device).to(torch.int8) elif dims == 3: - A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(torch.int8) - B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(torch.int8) + A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device=device).to(torch.int8) + B = torch.randint(-128, 127, size=(dim4, dim3), device=device).to(torch.int8) C1 = torch.matmul(A.float(), B.t().float()) C2 = F.int8_linear_matmul(A, B) @@ -510,6 +516,33 @@ def test_dequant_mm(dim1, dim4, dims, has_bias): assert_all_approx_close(C1, C4, atol=0.015, rtol=0.1, count=int(0.01 * n)) +@pytest.mark.parametrize("dim1", [64, 256], ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim4", [64, 1024], ids=id_formatter("dim4")) +@pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims")) +@pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias")) +def test_dequant_mm_cpu(dim1, dim4, dims, has_bias): + inner = torch.randint(1, 128, size=(1,)).item() + bias = None + if has_bias: + bias = torch.randn(dim4, device="cpu", dtype=torch.bfloat16) + for i in range(1): + A = torch.randn(dim1, inner, device="cpu") + B = torch.randn(dim4, inner, device="cpu") + + A1, maxA = F.vectorwise_quant(A, dim=1) + B1, maxB = F.vectorwise_quant(B, dim=1) + + C2 = F.int8_linear_matmul(A1, B1) + + C3 = F.vectorwise_mm_dequant(C2.bfloat16(), maxA, maxB.t()) + if has_bias: + C3 += bias + + C4 = F.int8_mm_dequant(C2, maxA.flatten(), maxB.flatten(), bias=bias) + + torch.testing.assert_close(C3.float(), C4.float(), atol=0.05, rtol=0.1) + + @pytest.mark.parametrize("dim1", [1 * 1024], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [1 * 1024], ids=id_formatter("dim2")) @pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims")) @@ -551,9 +584,14 @@ def test_colrow_absmax(dim1, dim2, dims, threshold): @pytest.mark.parametrize("dim1", [2048, 4096], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [512, 1024], ids=id_formatter("dim2")) -def test_int8_double_quant(dim1, dim2): +@pytest.mark.parametrize("device", ["cuda", "cpu"], ids=id_formatter("device")) +@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16], ids=id_formatter("dtype")) +def test_int8_double_quant(dim1, dim2, device, dtype): + if device == "cuda" and dtype == torch.bfloat16: + pytest.skip("bfloat16 is not implemented for this operation on CUDA backend") + for i in range(k): - A = torch.randn(dim1, dim2, device="cuda").half() + A = torch.randn(dim1, dim2, device=device).to(dtype) out_col1, Scol = F.vectorwise_quant(A, dim=0) out_row1, Srow = F.vectorwise_quant(A, dim=1) @@ -580,6 +618,7 @@ def test_int8_double_quant(dim1, dim2): torch.testing.assert_close(Scol.flatten().float(), statsAt) +@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") @pytest.mark.parametrize( ("dim1", "dim4", "inner"), ( @@ -695,42 +734,91 @@ def test_igemmlt_row_scale(dim1, dim4, inner): print(sum(err3) / len(err3)) -@pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1")) -@pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2")) -def test_coo_double_quant(dim1, dim2): - threshold = 2.00 +@pytest.mark.parametrize("dim1", get_test_dims(2, 1024, n=2), ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", get_test_dims(2, 1024, n=2), ids=id_formatter("dim2")) +@pytest.mark.parametrize("dim3", [0], ids=id_formatter("dim3")) +@pytest.mark.parametrize("dims", [2], ids=id_formatter("dims")) +@pytest.mark.parametrize("dtype", [torch.int8], ids=describe_dtype) +@pytest.mark.parametrize("orderA", ["row"], ids=id_formatter("orderA")) +@pytest.mark.parametrize("orderOut", ["col32", "col_turing", "col_ampere"], ids=id_formatter("orderOut")) +@pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose")) +@pytest.mark.deprecated +def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): for i in range(k): - A = torch.randn(dim1, dim2, device="cuda").half() + if dims == 2: + A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to(dtype) + elif dims == 3: + A = torch.randint(10, 99, size=(dim1, dim2, dim3), device="cuda").to(dtype) - idx = torch.abs(A) >= threshold - CA, statsA, outlier_cols = F.int8_vectorwise_quant(A, threshold=threshold) + A.view(-1)[-1] = -1 + if transpose: + At = A.t().contiguous() + out1, S1 = F.nvidia_transform(At, to_order=orderOut) + else: + out1, S1 = F.nvidia_transform(A, to_order=orderOut) + out2, S2 = F.transform(A, to_order=orderOut, transpose=transpose) - if outlier_cols is not None: - A1 = A * idx - A2 = torch.zeros_like(A) + A1 - torch.testing.assert_close(A1, A2) + assert S1[0][0] == S2[0][0] + assert S1[0][1] == S2[0][1] + # print(out1) + # print(out2) - A[:, outlier_cols] = 0 - A2 = (CA.float() * statsA.unsqueeze(1) / 127).half() - torch.testing.assert_close(A, A2, rtol=0.05, atol=1.5e-2) + torch.testing.assert_close(out1, out2) + + +@pytest.mark.parametrize("dim1", get_test_dims(2, 1024, n=2), ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", get_test_dims(2, 1024, n=2), ids=id_formatter("dim2")) +@pytest.mark.parametrize("dim3", [0], ids=id_formatter("dim3")) +@pytest.mark.parametrize("dims", [2], ids=id_formatter("dims")) +@pytest.mark.parametrize("dtype", [torch.int8], ids=describe_dtype) +@pytest.mark.parametrize("orderA", ["row"], ids=id_formatter("orderA")) +@pytest.mark.parametrize("orderOut", ["col32", "col_turing", "col_ampere"], ids=id_formatter("orderOut")) +@pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose")) +@pytest.mark.deprecated +def test_transform_cpu(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): + for i in range(k): + if dims == 2: + A = torch.randint(10, 99, size=(dim1, dim2), device="cpu").to(dtype) + elif dims == 3: + A = torch.randint(10, 99, size=(dim1, dim2, dim3), device="cpu").to(dtype) + + A.view(-1)[-1] = -1 + if transpose: + out1 = A.t().contiguous() + else: + out1 = A + out2, S2 = F.transform(A, to_order=orderOut, transpose=transpose) + + assert S2 is None + + torch.testing.assert_close(out1, out2) @pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2")) -def test_coo_int8_vectorwise_quant(dim1, dim2): - threshold = 3.00 +@pytest.mark.parametrize("device", ["cuda", "cpu"], ids=id_formatter("device")) +@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16], ids=id_formatter("dtype")) +def test_coo_int8_vectorwise_quant(dim1, dim2, device, dtype): + if device == "cuda" and dtype == torch.bfloat16: + pytest.skip("bfloat16 is not implemented for this operation on CUDA backend") + threshold = 2.00 for i in range(k): - A = torch.randn(dim1, dim2, device="cuda").half() + A = torch.randn(dim1, dim2, device=device).to(dtype) idx = torch.abs(A) >= threshold CA, statsA, outlier_cols = F.int8_vectorwise_quant(A, threshold=threshold) if outlier_cols is not None: - A2 = (CA.float() * statsA.unsqueeze(1) / 127).half() + A1 = A * idx + A2 = torch.zeros_like(A) + A1 + torch.testing.assert_close(A1, A2) + A[:, outlier_cols] = 0 - torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2) + A2 = (CA.float() * statsA.unsqueeze(1) / 127).to(dtype) + torch.testing.assert_close(A, A2, rtol=0.05, atol=1.5e-2) +@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") @pytest.mark.parametrize("dim1", get_test_dims(1, 1 * 1024, n=2), ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", get_test_dims(1, 1 * 1024, n=2), ids=id_formatter("dim2")) @pytest.mark.parametrize("transposed_B", TRUE_FALSE, ids=id_formatter("transposed_B")) @@ -762,6 +850,7 @@ def test_spmm_coo(dim1, dim2, transposed_B): assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=30) +@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") @pytest.mark.benchmark def test_spmm_bench(): batch = 2 @@ -804,6 +893,7 @@ def test_spmm_bench(): print(tsp / t8) +@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") @pytest.mark.parametrize("dim1", [256, 1024], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [256, 1024], ids=id_formatter("dim2")) def test_integrated_sparse_decomp(dim1, dim2): @@ -949,6 +1039,7 @@ def test_coo2csc(): torch.testing.assert_close(A2.t()[idx], cscA.values) +@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") @pytest.mark.parametrize("dim1", [1 * 2048]) @pytest.mark.parametrize("dim2", [2048]) @pytest.mark.parametrize("dtype", [torch.int8]) @@ -1136,13 +1227,14 @@ def quant_zp(x): print(err1, err2, err3, err4, err5, err6) +@pytest.mark.skipif(0 < BNB_HIP_VERSION < 601, reason="this test is supported on ROCm from 6.1") +@pytest.mark.parametrize("device", ["cuda", "cpu"]) @pytest.mark.deprecated -def test_extract_outliers(): +def test_extract_outliers(device): for i in range(k): shapeA = (4096, 4096 * 4) - idx = torch.unique(torch.randint(0, shapeA[1], size=(10,)).int()).cuda() - # idx = torch.Tensor([0]).int().cuda() - A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8) + idx = torch.unique(torch.randint(0, shapeA[1], size=(10,)).int()).to(device=device) + A = torch.randint(-128, 127, size=shapeA, device=device).to(torch.int8) outliers1 = A[:, idx.long()] CA, SA = F.transform(A, "col_turing") @@ -1335,10 +1427,12 @@ def test_bench_dequantization(): # print((time.time()-t0)/1e6) +@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize("blocksize", [64, 128, 256, 512, 1024, 2048, 4096]) -def test_4bit_quant(dtype, quant_type, blocksize): +@pytest.mark.parametrize("device", ["cuda", "cpu"]) +def test_4bit_quant(dtype, quant_type, blocksize, device): vals = list(product([0, 1], repeat=4)) code = {} @@ -1362,9 +1456,11 @@ def test_4bit_quant(dtype, quant_type, blocksize): result = sign * exp * frac code[idx] = result - A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype) + A1 = torch.randn(1024, 1024, device=device, dtype=dtype) qa, SA = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type) A2 = F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type) + # if device == "cpu": + # A2 = A2.t() err = (A1 - A2).abs().float() relerr = (err / (A1.abs().float() + 1e-8)).mean() @@ -1394,7 +1490,8 @@ def test_4bit_quant(dtype, quant_type, blocksize): @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) def test_4bit_compressed_stats(quant_type): - for blocksize in [128, 64]: + blocksizes = [128, 64] if not HIP_ENVIRONMENT else [128] + for blocksize in blocksizes: errs1 = [] errs2 = [] for i in range(10): @@ -1474,6 +1571,9 @@ def test_normal_map_tree(): # print(pivots) +@pytest.mark.skipif( + HIP_ENVIRONMENT, reason="gemv 4bit tests are partially enabled on MI300, others being fixed for warpsize 64" +) @pytest.mark.parametrize("double_quant", TRUE_FALSE, ids=lambda double_quant: f"DQ_{double_quant}") @pytest.mark.parametrize("storage_type", ["nf4", "fp4"]) @pytest.mark.parametrize("kind", ["fc1", "fc2", "attn", "attn_packed"]) @@ -1614,6 +1714,50 @@ def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind): assert maxratio < 1.02 and maxratio > 0.98 +@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") +@pytest.mark.parametrize("kind", ["fc1", "fc2", "attn", "attn_packed"]) +@pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) +def test_gemv_4bit_cpu(dtype, quant_type, kind): + """ + Test 4bit GEMV for CPU. It is simplified a lot from the cuda version, since + the CPU backend does not support double_quant or quant_storage other than uint8. + Also, the CPU backend has different numeric accuracy from that of CUDA + """ + for dim in [128, 256, 512, 1024]: + for i in range(10): + if kind == "fc1": + A = torch.randn(1, dim, dtype=dtype, device="cpu") + B = torch.randn(dim * 4, dim, dtype=dtype, device="cpu") / math.sqrt(dim) + elif kind == "fc2": + A = torch.randn(1, 4 * dim, dtype=dtype, device="cpu") + B = torch.randn(dim, 4 * dim, dtype=dtype, device="cpu") / math.sqrt(dim) + elif kind == "attn": + A = torch.randn(1, dim, dtype=dtype, device="cpu") + B = torch.randn(dim, dim, dtype=dtype, device="cpu") / math.sqrt(dim) + elif kind == "attn_packed": + A = torch.randn(1, dim, dtype=dtype, device="cpu") + B = torch.randn(dim * 3, dim, dtype=dtype, device="cpu") / math.sqrt(dim) + + qB, state = F.quantize_4bit( + B, + quant_type=quant_type, + compress_statistics=False, + quant_storage=torch.uint8, + ) + dqB = F.dequantize_4bit(qB, state) + C3 = torch.matmul(A, dqB.t()) + C2 = F.gemv_4bit(A, qB.t(), state=state) + A.requires_grad = True + C1 = bnb.matmul_4bit(A, qB.t(), state) + + c = int(C1.numel() * 0.0014 * (dim / 256)) + 1 + rtol = 1e-3 if dtype != torch.bfloat16 else 1e-2 + atol = 1e-2 if dtype != torch.bfloat16 else 5e-2 + assert_all_approx_close(C1, C2, rtol, atol, count=c) + assert_all_approx_close(C3, C2, rtol, atol, count=c) + + @pytest.mark.skip("Row scale has some bugs for ampere") def test_managed(): n = 32 * 10 @@ -1640,6 +1784,10 @@ def test_managed(): @pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) @pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"]) +@pytest.mark.skipif( + HIP_ENVIRONMENT and ROCM_GPU_ARCH == "gfx90a", + reason="this test is not supported on ROCm with gfx90a architecture yet", +) def test_gemv_eye_4bit(storage_type, dtype, double_quant): dims = 10 torch.random.manual_seed(np.random.randint(0, 412424242)) @@ -1663,6 +1811,7 @@ def test_gemv_eye_4bit(storage_type, dtype, double_quant): # torch.testing.assert_close(A, C2, rtol=1e-5, atol=0.080) +@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") @pytest.mark.parametrize("dim1", get_test_dims(1, 64, n=1), ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", get_test_dims(32, 128, n=1), ids=id_formatter("dim2")) @pytest.mark.parametrize("dim3", get_test_dims(32, 256, n=1), ids=id_formatter("dim3")) @@ -1750,44 +1899,15 @@ def test_percentile_clipping(gtype): torch.testing.assert_close(gnorm1, gnorm2) -@pytest.mark.parametrize("dim1", get_test_dims(2, 1024, n=2), ids=id_formatter("dim1")) -@pytest.mark.parametrize("dim2", get_test_dims(2, 1024, n=2), ids=id_formatter("dim2")) -@pytest.mark.parametrize("dim3", [0], ids=id_formatter("dim3")) -@pytest.mark.parametrize("dims", [2], ids=id_formatter("dims")) -@pytest.mark.parametrize("dtype", [torch.int8], ids=describe_dtype) -@pytest.mark.parametrize("orderA", ["row"], ids=id_formatter("orderA")) -@pytest.mark.parametrize("orderOut", ["col32", "col_turing", "col_ampere"], ids=id_formatter("orderOut")) -@pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose")) -@pytest.mark.deprecated -def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): - for i in range(k): - if dims == 2: - A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to(dtype) - elif dims == 3: - A = torch.randint(10, 99, size=(dim1, dim2, dim3), device="cuda").to(dtype) - - A.view(-1)[-1] = -1 - if transpose: - At = A.t().contiguous() - out1, S1 = F.nvidia_transform(At, to_order=orderOut) - else: - out1, S1 = F.nvidia_transform(A, to_order=orderOut) - out2, S2 = F.transform(A, to_order=orderOut, transpose=transpose) - - assert S1[0][0] == S2[0][0] - assert S1[0][1] == S2[0][1] - # print(out1) - # print(out2) - - torch.testing.assert_close(out1, out2) - - +@pytest.mark.skipif(0 < BNB_HIP_VERSION < 601, reason="this test is supported on ROCm from 6.1") @pytest.mark.parametrize("dim1", get_test_dims(2, 256, n=2), ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", get_test_dims(2, 256, n=2), ids=id_formatter("dim2")) @pytest.mark.parametrize("dim3", get_test_dims(2, 256, n=2), ids=id_formatter("dim3")) @pytest.mark.parametrize("dtype", [torch.int8, torch.int32], ids=describe_dtype) @pytest.mark.parametrize("orderA", ["row"], ids=id_formatter("orderA")) -@pytest.mark.parametrize("orderOut", ["col", "row", "col32"], ids=id_formatter("orderOut")) +@pytest.mark.parametrize( + "orderOut", ["col", "row"] if HIP_ENVIRONMENT else ["col", "row", "col32"], ids=id_formatter("orderOut") +) @pytest.mark.parametrize("transpose", [False], ids=id_formatter("transpose")) @pytest.mark.parametrize("dims", [2, 3], ids=id_formatter("dims")) @pytest.mark.deprecated diff --git a/tests/test_generation.py b/tests/test_generation.py index 911aa14da..8e689261b 100644 --- a/tests/test_generation.py +++ b/tests/test_generation.py @@ -4,6 +4,7 @@ import pytest import torch +from bitsandbytes.cextension import HIP_ENVIRONMENT from tests.helpers import TRUE_FALSE, describe_dtype, id_formatter transformers = pytest.importorskip("transformers") @@ -71,6 +72,7 @@ def model_and_tokenizer(request): del model +@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") @pytest.mark.parametrize("DQ", TRUE_FALSE, ids=id_formatter("dq")) @pytest.mark.parametrize("inference_kernel", TRUE_FALSE, ids=id_formatter("inference_kernel")) @pytest.mark.parametrize("dtype", [torch.float16], ids=describe_dtype) diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py index bc9e2600f..822633de4 100644 --- a/tests/test_linear8bitlt.py +++ b/tests/test_linear8bitlt.py @@ -10,6 +10,7 @@ import bitsandbytes as bnb from bitsandbytes import functional as F from bitsandbytes.autograd import get_inverse_transform_indices, undo_layout +from bitsandbytes.cextension import HIP_ENVIRONMENT from bitsandbytes.nn.modules import Linear8bitLt from tests.helpers import ( TRUE_FALSE, @@ -22,6 +23,7 @@ # https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py +@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") @pytest.mark.skipif( not torch.cuda.is_available() or torch.cuda.get_device_capability() < (7, 5), reason="this test requires a turing-generation or newer GPU, see bitsandbytes docs", @@ -40,6 +42,7 @@ def test_layout_exact_match(): assert torch.all(torch.eq(restored_x, x)) +@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") def test_linear_no_igemmlt(): linear = torch.nn.Linear(1024, 3072) x = torch.randn(3, 1024, dtype=torch.half) @@ -79,6 +82,7 @@ def test_linear_no_igemmlt(): torch.testing.assert_close(x_ref.grad, x_ours.grad, atol=0.01, rtol=1e-5) +@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") @pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights")) @pytest.mark.parametrize("serialize_before_forward", TRUE_FALSE, ids=id_formatter("serialize_before_forward")) @pytest.mark.parametrize("deserialize_before_cuda", TRUE_FALSE, ids=id_formatter("deserialize_before_cuda")) diff --git a/tests/test_modules.py b/tests/test_modules.py index c2583550d..147f6ec4c 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -7,6 +7,7 @@ from torch import nn import bitsandbytes as bnb +from bitsandbytes.cextension import HIP_ENVIRONMENT from tests.helpers import id_formatter @@ -297,6 +298,7 @@ def forward(self, x): return LinearFunction.apply(x, self.weight, self.bias, self.args) +@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") @pytest.mark.parametrize("threshold", [0.0, 3.0], ids=id_formatter("threshold")) def test_linear8bitlt_inference(threshold): l1 = bnb.nn.Linear8bitLt(32, 64, threshold=threshold).cuda().half() @@ -311,6 +313,7 @@ def test_linear8bitlt_inference(threshold): assert l1.state.CB is not None +@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") def test_linear8bitlt_accumulated_gradient(): l1 = torch.nn.Sequential(*[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)]) l2 = torch.nn.Sequential(*[torch.nn.Linear(32, 32).cuda().half() for i in range(2)]) @@ -519,6 +522,7 @@ def test_linear_kbit_fp32_bias(module): } +@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") @pytest.mark.parametrize("module", module_dict.values(), ids=module_dict.keys()) def test_kbit_backprop(module): b = 16 @@ -538,7 +542,7 @@ def test_kbit_backprop(module): kbit[1].bias.detach().copy_(ref[1].bias) ref = ref.half().cuda() kbit = kbit.half().cuda() - kbit = kbit.half().to("cuda") + # kbit = kbit.half().to("cuda") errs1 = [] errs2 = [] diff --git a/tests/test_triton.py b/tests/test_triton.py index 3624fb5e9..1c5422c0d 100644 --- a/tests/test_triton.py +++ b/tests/test_triton.py @@ -1,12 +1,14 @@ import pytest import torch +from bitsandbytes.cextension import HIP_ENVIRONMENT from bitsandbytes.nn import Linear8bitLt from bitsandbytes.nn.triton_based_modules import SwitchBackLinear from bitsandbytes.triton.triton_utils import is_triton_available from tests.helpers import TRUE_FALSE +@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") @pytest.mark.skipif( not is_triton_available() or not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] >= 8, reason="This test requires triton and a GPU with compute capability 8.0 or higher.",