diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 000000000..c37c66736 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,118 @@ +# This CMake config hopefully makes it easier to compile. +# Ensure the CUDA Toolkit is available on your path. Then run: +# For GCC: `cmake -B build . && cmake --build build` +# For MSVC: `cmake -B build . && cmake --build build --config Release` +# You can also use the following options and variables +# - COMPUTE_BACKEND: Set to `cpu`, `cuda`, or `mps` to select the backend +# - NO_CUBLASLT: Default OFF, will skip building/linking CUBLASLT support +# - CUDA_VERSION: The expected CUDA version, for sanity checking. The actual version +# is whatever CMake finds on your path. +# - COMPUTE_CAPABILITY: Which GPU Arch/Compute codes to provide to NVCC. +# Separate by semicolons, i.e. `-DCOMPUTE_CAPABILITY=89;90` +# Check your compute capability here: https://developer.nvidia.com/cuda-gpus +# - PTXAS_VERBOSE: Pass the `-v` option to the PTX Assembler +cmake_minimum_required(VERSION 3.22.1) + +project(bitsandbytes LANGUAGES CXX) + +# If run without specifying a build type, default to using the Release configuration: +# optimizing the generated binaries for performance and also adds the `-DNDEBUG` flag, +# which turns off a bunch of asserts which seem to link to new symbols in libstdc++, +# worsening our many_linux compliance.. +if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE Release) +endif() + +# Define included source files +set(CPP_FILES csrc/common.cpp csrc/cpu_ops.cpp csrc/pythonInterface.c) +set(CUDA_FILES csrc/ops.hip.cpp csrc/kernels.hip.cpp) +# C++ sources are always included +list(APPEND SRC_FILES ${CPP_FILES}) + +set(COMPUTE_BACKEND "hip" CACHE STRING "The compute backend to use (cpu, hip)") +set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda mps hip) +option(PTXAS_VERBOSE "Pass through -v flag to PTX Assembler" OFF) + +if(NOT DEFINED HIP_PATH) + if(NOT DEFINED ENV{HIP_PATH}) + set(HIP_PATH "/opt/rocm/hip" CACHE PATH "Path to which HIP has been installed") + else() + set(HIP_PATH $ENV{HIP_PATH} CACHE PATH "Path to which HIP has been installed") + endif() +endif() +message("HIP_PATH: " ${HIP_PATH}) +set(CMAKE_MODULE_PATH "${HIP_PATH}/cmake" ${CMAKE_MODULE_PATH}) +find_package(HIP REQUIRED) +if (HIP_FOUND) + message(STATUS "Found HIP: " ${HIP_VERSION}) +else() + message(FATAL_ERROR "Could not find HIP") +endif() +find_package(rocthrust REQUIRED) +find_package(hipblas REQUIRED) +find_package(hipsparse REQUIRED) +find_package(rocrand REQUIRED) +find_package(hipblaslt REQUIRED) +# Search for rocm in common locations +list(APPEND CMAKE_PREFIX_PATH /opt/rocm/hip /opt/rocm /opt/rocm) +list(APPEND HIP_PATH /opt/rocm/llvm/bin/) +# Find HIP. +# The user may override AMDGPU_TARGETS defined in the HIP config file +# to select the AMDGPU archs to compile for. +# ex. set(AMDGPU_TARGETS "gfx803;gfx900;gfx906") +# Find OpenMP. +#find_package(OpenMP REQUIRED) +# Set compiler and linker. +if(NOT WIN32) + set(CMAKE_CXX_COMPILER ${HIP_HIPCC_EXECUTABLE}) + set(CMAKE_CXX_LINKER ${HIP_HIPCC_EXECUTABLE}) + set(CMAKE_CXXFLAGS -D__HIP_PLATFORM_AMD__) + set(CMAKE_CFLAGS -D__HIP_PLATFORM_AMD__) +endif() +message("Current CMAKE_CXX_COMPILER (should show hipcc): " ${CMAKE_CXX_COMPILER}) +message("Current CMAKE_CXX_LINKER (should show hipcc): " ${CMAKE_CXX_LINKER}) + +set(BNB_OUTPUT_NAME "bitsandbytes") + +message(STATUS "Configuring ${PROJECT_NAME} (Backend: ${COMPUTE_BACKEND})") + +if(${COMPUTE_BACKEND} STREQUAL "hip") + set(BUILD_HIP on) + set(BUILD_CUDA OFF) + set(BUILD_MPS OFF) + set(NO_CUBLASLT ON) +else() + set(BUILD_CUDA OFF) + set(BUILD_MPS OFF) +endif() + + +if(BUILD_HIP) + list(APPEND SRC_FILES ${CUDA_FILES}) + # real name + string(APPEND BNB_OUTPUT_NAME "_hip_nohipblaslt") + add_compile_definitions(BUILD_HIP) +else() + string(APPEND BNB_OUTPUT_NAME "_cpu") + set(GPU_SOURCES) +endif() + +if (BUILD_HIP) + set_source_files_properties(${CPP_FILES} PROPERTIES LANGUAGE CXX) + message("Working on: " ${CPP_FILES}) + add_library(bitsandbytes SHARED ${SRC_FILES}) + target_include_directories(bitsandbytes PRIVATE ${CMAKE_SOURCE_DIR} ${CMAKE_SOURCE_DIR}/include /opt/rocm/include/rocwmma) + target_compile_features(bitsandbytes PUBLIC cxx_std_14) + target_compile_definitions(bitsandbytes PUBLIC NO_CUBLASLT) + target_include_directories(bitsandbytes PUBLIC csrc include) + target_link_libraries(bitsandbytes PUBLIC hip::device roc::rocthrust roc::hipblas roc::hipsparse roc::rocrand roc::rocprim roc::hipblaslt ) +else() + set_source_files_properties(${CPP_FILES} PROPERTIES LANGUAGE CXX) + add_library(bitsandbytes SHARED ${SRC_FILES}) + target_compile_features(bitsandbytes PUBLIC cxx_std_14) + target_include_directories(bitsandbytes PUBLIC csrc include) + target_link_libraries(bitsandbytes PUBLIC hip::device) +endif() + +set_target_properties(bitsandbytes PROPERTIES OUTPUT_NAME ${BNB_OUTPUT_NAME}) +set_target_properties(bitsandbytes PROPERTIES LIBRARY_OUTPUT_DIRECTORY "${PROJECT_SOURCE_DIR}/bitsandbytes") \ No newline at end of file diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index 01d5527f5..3b83a8d6d 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -3,14 +3,14 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from . import cuda_setup, utils, research +from . import cuda_setup, research, utils from .autograd._functions import ( MatmulLtState, bmm_cublas, matmul, + matmul_4bit, matmul_cublas, mm_cublas, - matmul_4bit ) from .cextension import COMPILED_WITH_CUDA from .nn import modules @@ -24,6 +24,6 @@ "optim.optimizer.MockArgs": False, } -__version__ = "0.42.0" +__version__ = "0.43.0" PACKAGE_GITHUB_URL = "https://github.com/TimDettmers/bitsandbytes" diff --git a/bitsandbytes/__main__.py b/bitsandbytes/__main__.py index ebbf2653e..61b42e78f 100644 --- a/bitsandbytes/__main__.py +++ b/bitsandbytes/__main__.py @@ -1,44 +1,16 @@ +import glob import os import sys -import shlex -import subprocess - from warnings import warn -from typing import Tuple -from os.path import isdir import torch HEADER_WIDTH = 60 -def execute_and_return(command_string: str) -> Tuple[str, str]: - def _decode(subprocess_err_out_tuple): - return tuple( - to_decode.decode("UTF-8").strip() - for to_decode in subprocess_err_out_tuple - ) - - def execute_and_return_decoded_std_streams(command_string): - return _decode( - subprocess.Popen( - shlex.split(command_string), - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ).communicate() - ) - - std_out, std_err = execute_and_return_decoded_std_streams(command_string) - return std_out, std_err - -def find_file_recursive(folder, filename): - folder = shlex.quote(folder) - filename = shlex.quote(filename) - cmd = f'find {folder} -name {filename}' - out, err = execute_and_return(cmd) - if len(err) > 0: - raise RuntimeError('Something when wrong when trying to find file. Maybe you do not have a linux system?') - return out +def find_dynamic_library(folder, filename): + for ext in ("so", "dll", "dylib"): + yield from glob.glob(os.path.join(folder, "**", filename + ext)) def generate_bug_report_information(): @@ -47,38 +19,25 @@ def generate_bug_report_information(): print_header("") print('') - if 'CONDA_PREFIX' in os.environ: - paths = find_file_recursive(os.environ['CONDA_PREFIX'], '*cuda*so') - print_header("ANACONDA CUDA PATHS") - print(paths) - print('') - if isdir('/usr/local/'): - paths = find_file_recursive('/usr/local', '*cuda*so') - print_header("/usr/local CUDA PATHS") - print(paths) - print('') - - if isdir(os.getcwd()): - paths = find_file_recursive(os.getcwd(), '*cuda*so') - print_header("WORKING DIRECTORY CUDA PATHS") - print(paths) - print('') - - print_header("LD_LIBRARY CUDA PATHS") - if 'LD_LIBRARY_PATH' in os.environ: - lib_path = os.environ['LD_LIBRARY_PATH'].strip() - for path in set(lib_path.split(':')): - try: - if isdir(path): - print_header(f"{path} CUDA PATHS") - paths = find_file_recursive(path, '*cuda*so') - print(paths) - except: - print(f'Could not read LD_LIBRARY_PATH: {path}') - print('') - - - + path_sources = [ + ("ANACONDA CUDA PATHS", os.environ.get("CONDA_PREFIX")), + ("/usr/local CUDA PATHS", "/usr/local"), + ("CUDA PATHS", os.environ.get("CUDA_PATH")), + ("WORKING DIRECTORY CUDA PATHS", os.getcwd()), + ] + try: + ld_library_path = os.environ.get("LD_LIBRARY_PATH") + if ld_library_path: + for path in set(ld_library_path.strip().split(os.pathsep)): + path_sources.append((f"LD_LIBRARY_PATH {path} CUDA PATHS", path)) + except Exception as e: + print(f"Could not parse LD_LIBRARY_PATH: {e}") + + for name, path in path_sources: + if path and os.path.isdir(path): + print_header(name) + print(list(find_dynamic_library(path, '*cuda*'))) + print("") def print_header( @@ -89,67 +48,61 @@ def print_header( def print_debug_info() -> None: + from . import PACKAGE_GITHUB_URL print( "\nAbove we output some debug information. Please provide this info when " f"creating an issue via {PACKAGE_GITHUB_URL}/issues/new/choose ...\n" ) -generate_bug_report_information() +def main(): + generate_bug_report_information() + from . import COMPILED_WITH_CUDA + from .cuda_setup.main import get_compute_capabilities -from . import COMPILED_WITH_CUDA, PACKAGE_GITHUB_URL -from .cuda_setup.env_vars import to_be_ignored -from .cuda_setup.main import get_compute_capabilities - + print_header("OTHER") + print(f"COMPILED_WITH_CUDA = {COMPILED_WITH_CUDA}") + print(f"COMPUTE_CAPABILITIES_PER_GPU = {get_compute_capabilities()}") + print_header("") + print_header("DEBUG INFO END") + print_header("") + print("Checking that the library is importable and CUDA is callable...") + print("\nWARNING: Please be sure to sanitize sensitive info from any such env vars!\n") -print_header("OTHER") -print(f"COMPILED_WITH_CUDA = {COMPILED_WITH_CUDA}") -print(f"COMPUTE_CAPABILITIES_PER_GPU = {get_compute_capabilities()}") -print_header("") -print_header("DEBUG INFO END") -print_header("") -print( - """ -Running a quick check that: - + library is importable - + CUDA function is callable -""" -) -print("\nWARNING: Please be sure to sanitize sensible info from any such env vars!\n") + try: + from bitsandbytes.optim import Adam -try: - from bitsandbytes.optim import Adam + p = torch.nn.Parameter(torch.rand(10, 10).cuda()) + a = torch.rand(10, 10).cuda() - p = torch.nn.Parameter(torch.rand(10, 10).cuda()) - a = torch.rand(10, 10).cuda() + p1 = p.data.sum().item() - p1 = p.data.sum().item() + adam = Adam([p]) - adam = Adam([p]) + out = a * p + loss = out.sum() + loss.backward() + adam.step() - out = a * p - loss = out.sum() - loss.backward() - adam.step() + p2 = p.data.sum().item() - p2 = p.data.sum().item() + assert p1 != p2 + print("SUCCESS!") + print("Installation was successful!") + except ImportError: + print() + warn( + f"WARNING: {__package__} is currently running as CPU-only!\n" + "Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n" + f"If you think that this is so erroneously,\nplease report an issue!" + ) + print_debug_info() + except Exception as e: + print(e) + print_debug_info() + sys.exit(1) - assert p1 != p2 - print("SUCCESS!") - print("Installation was successful!") - sys.exit(0) -except ImportError: - print() - warn( - f"WARNING: {__package__} is currently running as CPU-only!\n" - "Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n" - f"If you think that this is so erroneously,\nplease report an issue!" - ) - print_debug_info() - sys.exit(0) -except Exception as e: - print(e) - print_debug_info() - sys.exit(1) +if __name__ == "__main__": + main() diff --git a/bitsandbytes/autograd/__init__.py b/bitsandbytes/autograd/__init__.py index 6b9a7e4d1..f262d89ed 100644 --- a/bitsandbytes/autograd/__init__.py +++ b/bitsandbytes/autograd/__init__.py @@ -1 +1 @@ -from ._functions import undo_layout, get_inverse_transform_indices +from ._functions import get_inverse_transform_indices, undo_layout diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 59b0ac7b2..6709c15a4 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -1,8 +1,8 @@ -import operator -import warnings from dataclasses import dataclass from functools import reduce # Required in Python 3 -from typing import Tuple, Optional, List +import operator +from typing import Optional, Tuple +import warnings from warnings import warn import torch @@ -14,9 +14,6 @@ def prod(iterable): return reduce(operator.mul, iterable, 1) -tensor = torch.Tensor - - # The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov: # https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py @@ -498,7 +495,7 @@ class MatMul4Bit(torch.autograd.Function): # backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None") @staticmethod - def forward(ctx, A, B, out=None, bias=None, quant_state: F.QuantState = None): + def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState] = None): # default of pytorch behavior if inputs are empty ctx.is_empty = False if prod(A.shape) == 0: @@ -551,10 +548,10 @@ def backward(ctx, grad_output): def matmul( - A: tensor, - B: tensor, - out: tensor = None, - state: MatmulLtState = None, + A: torch.Tensor, + B: torch.Tensor, + out: Optional[torch.Tensor] = None, + state: Optional[MatmulLtState] = None, threshold=0.0, bias=None ): @@ -564,7 +561,7 @@ def matmul( return MatMul8bitLt.apply(A, B, out, bias, state) -def matmul_4bit(A: tensor, B: tensor, quant_state: F.QuantState, out: tensor = None, bias=None): +def matmul_4bit(A: torch.Tensor, B: torch.Tensor, quant_state: F.QuantState, out: Optional[torch.Tensor] = None, bias=None): assert quant_state is not None if A.numel() == A.shape[-1] and A.requires_grad == False: if A.shape[-1] % quant_state.blocksize != 0: diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 03a208995..f42585280 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -1,12 +1,9 @@ import ctypes as ct -import os -import torch - -from pathlib import Path from warnings import warn -from bitsandbytes.cuda_setup.main import CUDASetup +import torch +from bitsandbytes.cuda_setup.main import CUDASetup setup = CUDASetup.get_instance() if setup.initialized != True: @@ -14,7 +11,7 @@ lib = setup.lib try: - if lib is None and torch.cuda.is_available() : + if lib is None and torch.cuda.is_available(): CUDASetup.get_instance().generate_instructions() CUDASetup.get_instance().print_log_stack() raise RuntimeError(''' @@ -25,7 +22,6 @@ Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes and open an issue at: https://github.com/TimDettmers/bitsandbytes/issues''') - lib.cadam32bit_grad_fp32 # runs on an error if the library could not be found -> COMPILED_WITH_CUDA=False lib.get_context.restype = ct.c_void_p diff --git a/bitsandbytes/consts.py b/bitsandbytes/consts.py new file mode 100644 index 000000000..8242d104e --- /dev/null +++ b/bitsandbytes/consts.py @@ -0,0 +1,12 @@ +from pathlib import Path +import platform + +DYNAMIC_LIBRARY_SUFFIX = { + "Darwin": ".dylib", + "Linux": ".so", + "Windows": ".dll", +}.get(platform.system(), ".so") + +PACKAGE_DIR = Path(__file__).parent +PACKAGE_GITHUB_URL = "https://github.com/TimDettmers/bitsandbytes" +NONPYTORCH_DOC_URL = "https://github.com/TimDettmers/bitsandbytes/blob/main/docs/source/nonpytorchcuda.mdx" diff --git a/bitsandbytes/cuda_setup/env_vars.py b/bitsandbytes/cuda_setup/env_vars.py index e8268fcaa..4b2549653 100644 --- a/bitsandbytes/cuda_setup/env_vars.py +++ b/bitsandbytes/cuda_setup/env_vars.py @@ -26,7 +26,7 @@ def to_be_ignored(env_var: str, value: str) -> bool: def might_contain_a_path(candidate: str) -> bool: - return "/" in candidate + return os.sep in candidate def is_active_conda_env(env_var: str) -> bool: diff --git a/bitsandbytes/cuda_setup/main.py b/bitsandbytes/cuda_setup/main.py index b4962c1a0..dfc10c3dc 100644 --- a/bitsandbytes/cuda_setup/main.py +++ b/bitsandbytes/cuda_setup/main.py @@ -17,25 +17,27 @@ """ import ctypes as ct -import os import errno -import torch -from warnings import warn -from itertools import product - +import os from pathlib import Path +import platform from typing import Set, Union +from warnings import warn + +import torch + from .env_vars import get_potentially_lib_path_containing_env_vars -# these are the most common libs names -# libcudart.so is missing by default for a conda install with PyTorch 2.0 and instead -# we have libcudart.so.11.0 which causes a lot of errors before -# not sure if libcudart.so.12.0 exists in pytorch installs, but it does not hurt -CUDA_RUNTIME_LIBS: list = ["libcudart.so", 'libcudart.so.11.0', 'libcudart.so.12.0', 'libcudart.so.12.1', 'libcudart.so.12.2'] +DYNAMIC_LIBRARY_SUFFIX = { "Darwin": ".dylib", "Windows": ".dll", "Linux": ".so"}.get(platform.system(), ".so") +if platform.system() == "Windows": # Windows + CUDA_RUNTIME_LIBS = ["cudart64_110.dll", "cudart64_12.dll"] +else: # Linux or other + # these are the most common libs names + # libcudart.so is missing by default for a conda install with PyTorch 2.0 and instead + # we have libcudart.so.11.0 which causes a lot of errors before + # not sure if libcudart.so.12.0 exists in pytorch installs, but it does not hurt + CUDA_RUNTIME_LIBS = ["libcudart.so", "libcudart.so.11.0", "libcudart.so.12.0", "libcudart.so.12.1", "libcudart.so.12.2"] -# this is a order list of backup paths to search CUDA in, if it cannot be found in the main environmental paths -backup_paths = [] -backup_paths.append('$CONDA_PREFIX/lib/libcudart.so.11.0') class CUDASetup: _instance = None @@ -103,18 +105,30 @@ def initialize(self): self.error = False def manual_override(self): - if torch.cuda.is_available(): - if 'BNB_CUDA_VERSION' in os.environ: - if len(os.environ['BNB_CUDA_VERSION']) > 0: - warn((f'\n\n{"="*80}\n' - 'WARNING: Manual override via BNB_CUDA_VERSION env variable detected!\n' - 'BNB_CUDA_VERSION=XXX can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.\n' - 'If this was unintended set the BNB_CUDA_VERSION variable to an empty string: export BNB_CUDA_VERSION=\n' - 'If you use the manual override make sure the right libcudart.so is in your LD_LIBRARY_PATH\n' - 'For example by adding the following to your .bashrc: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH: Set[Path]: - return {Path(ld_path) for ld_path in paths_list_candidate.split(":") if ld_path} + return {Path(ld_path) for ld_path in paths_list_candidate.split(os.pathsep) if ld_path} def remove_non_existent_dirs(candidate_paths: Set[Path]) -> Set[Path]: @@ -201,8 +213,8 @@ def remove_non_existent_dirs(candidate_paths: Set[Path]) -> Set[Path]: try: if path.exists(): existent_directories.add(path) - except PermissionError as pex: - # Handle the PermissionError first as it is a subtype of OSError + except PermissionError: + # Handle the PermissionError first as it is a subtype of OSError # https://docs.python.org/3/library/exceptions.html#exception-hierarchy pass except OSError as exc: @@ -211,8 +223,10 @@ def remove_non_existent_dirs(candidate_paths: Set[Path]) -> Set[Path]: non_existent_directories: Set[Path] = candidate_paths - existent_directories if non_existent_directories: - CUDASetup.get_instance().add_log_entry("The following directories listed in your path were found to " - f"be non-existent: {non_existent_directories}", is_warning=False) + CUDASetup.get_instance().add_log_entry( + f"The following directories listed in your path were found to be non-existent: {non_existent_directories}", + is_warning=False, + ) return existent_directories @@ -248,13 +262,13 @@ def warn_in_case_of_duplicates(results_paths: Set[Path]) -> None: warning_msg = ( f"Found duplicate {CUDA_RUNTIME_LIBS} files: {results_paths}.. " "We select the PyTorch default libcudart.so, which is {torch.version.cuda}," - "but this might missmatch with the CUDA version that is needed for bitsandbytes." + "but this might mismatch with the CUDA version that is needed for bitsandbytes." "To override this behavior set the BNB_CUDA_VERSION= environmental variable" "For example, if you want to use the CUDA version 122" "BNB_CUDA_VERSION=122 python ..." "OR set the environmental variable in your .bashrc: export BNB_CUDA_VERSION=122" "In the case of a manual override, make sure you set the LD_LIBRARY_PATH, e.g." - "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-11.2") + "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-12.2") CUDASetup.get_instance().add_log_entry(warning_msg, is_warning=True) @@ -348,14 +362,14 @@ def evaluate_cuda_setup(): if torch.version.hip: return 'libbitsandbytes_hip_nohipblaslt.so', None, None, None cudart_path = determine_cuda_runtime_lib_path() - ccs = get_compute_capabilities() - ccs.sort() - cc = ccs[-1] # we take the highest capability + cc = get_compute_capabilities()[-1] # we take the highest capability cuda_version_string = get_cuda_version() cuda_setup.add_log_entry(f"CUDA SETUP: PyTorch settings found: CUDA_VERSION={cuda_version_string}, Highest Compute Capability: {cc}.") - cuda_setup.add_log_entry(f"CUDA SETUP: To manually override the PyTorch CUDA version please see:" - "https://github.com/TimDettmers/bitsandbytes/blob/main/how_to_use_nonpytorch_cuda.md") + cuda_setup.add_log_entry( + "CUDA SETUP: To manually override the PyTorch CUDA version please see:" + "https://github.com/TimDettmers/bitsandbytes/blob/main/how_to_use_nonpytorch_cuda.md" + ) # 7.5 is the minimum CC vor cublaslt @@ -368,10 +382,11 @@ def evaluate_cuda_setup(): # we use ls -l instead of nvcc to determine the cuda version # since most installations will have the libcudart.so installed, but not the compiler - if has_cublaslt: - binary_name = f"libbitsandbytes_cuda{cuda_version_string}.so" - else: - "if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt.so" - binary_name = f"libbitsandbytes_cuda{cuda_version_string}_nocublaslt.so" + binary_name = f"libbitsandbytes_cuda{cuda_version_string}" + if not has_cublaslt: + # if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt + binary_name += "_nocublaslt" + + binary_name = f"{binary_name}{DYNAMIC_LIBRARY_SUFFIX}" return binary_name, cudart_path, cc, cuda_version_string diff --git a/bitsandbytes/cuda_specs.py b/bitsandbytes/cuda_specs.py new file mode 100644 index 000000000..ed19795a0 --- /dev/null +++ b/bitsandbytes/cuda_specs.py @@ -0,0 +1,41 @@ +import dataclasses +from typing import List, Optional, Tuple + +import torch + + +@dataclasses.dataclass(frozen=True) +class CUDASpecs: + highest_compute_capability: Tuple[int, int] + cuda_version_string: str + cuda_version_tuple: Tuple[int, int] + + @property + def has_cublaslt(self) -> bool: + return self.highest_compute_capability >= (7, 5) + + +def get_compute_capabilities() -> List[Tuple[int, int]]: + return sorted(torch.cuda.get_device_capability(torch.cuda.device(i)) for i in range(torch.cuda.device_count())) + + +def get_cuda_version_tuple() -> Tuple[int, int]: + # https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION + major, minor = map(int, torch.version.cuda.split(".")) + return major, minor + + +def get_cuda_version_string() -> str: + major, minor = get_cuda_version_tuple() + return f"{major}{minor}" + + +def get_cuda_specs() -> Optional[CUDASpecs]: + if not torch.cuda.is_available(): + return None + + return CUDASpecs( + highest_compute_capability=(get_compute_capabilities()[-1]), + cuda_version_string=(get_cuda_version_string()), + cuda_version_tuple=get_cuda_version_tuple(), + ) diff --git a/bitsandbytes/diagnostics/__init__.py b/bitsandbytes/diagnostics/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bitsandbytes/diagnostics/cuda.py b/bitsandbytes/diagnostics/cuda.py new file mode 100644 index 000000000..f993dff7e --- /dev/null +++ b/bitsandbytes/diagnostics/cuda.py @@ -0,0 +1,176 @@ +import logging +import os +from pathlib import Path +from typing import Dict, Iterable, Iterator + +import torch + +from bitsandbytes.cextension import get_cuda_bnb_library_path +from bitsandbytes.consts import NONPYTORCH_DOC_URL +from bitsandbytes.cuda_specs import CUDASpecs +from bitsandbytes.diagnostics.utils import print_dedented + +CUDART_PATH_PREFERRED_ENVVARS = ("CONDA_PREFIX", "LD_LIBRARY_PATH") + +CUDART_PATH_IGNORED_ENVVARS = { + "DBUS_SESSION_BUS_ADDRESS", # hardware related + "GOOGLE_VM_CONFIG_LOCK_FILE", # GCP: requires elevated permissions, causing problems in VMs and Jupyter notebooks + "HOME", # Linux shell default + "LESSCLOSE", + "LESSOPEN", # related to the `less` command + "MAIL", # something related to emails + "OLDPWD", + "PATH", # this is for finding binaries, not libraries + "PWD", # PWD: this is how the shell keeps track of the current working dir + "SHELL", # binary for currently invoked shell + "SSH_AUTH_SOCK", # SSH stuff, therefore unrelated + "SSH_TTY", + "TMUX", # Terminal Multiplexer + "XDG_DATA_DIRS", # XDG: Desktop environment stuff + "XDG_GREETER_DATA_DIR", # XDG: Desktop environment stuff + "XDG_RUNTIME_DIR", + "_", # current Python interpreter +} + +CUDA_RUNTIME_LIB_PATTERNS = ( + "cudart64*.dll", # Windows + "libcudart*.so*", # libcudart.so, libcudart.so.11.0, libcudart.so.12.0, libcudart.so.12.1, libcudart.so.12.2 etc. + "nvcuda*.dll", # Windows +) + +logger = logging.getLogger(__name__) + + +def find_cuda_libraries_in_path_list(paths_list_candidate: str) -> Iterable[Path]: + for dir_string in paths_list_candidate.split(os.pathsep): + if not dir_string: + continue + if os.sep not in dir_string: + continue + try: + dir = Path(dir_string) + try: + if not dir.exists(): + logger.warning(f"The directory listed in your path is found to be non-existent: {dir}") + continue + except OSError: # Assume an esoteric error trying to poke at the directory + pass + for lib_pattern in CUDA_RUNTIME_LIB_PATTERNS: + for pth in dir.glob(lib_pattern): + if pth.is_file(): + yield pth + except PermissionError: + pass + + +def is_relevant_candidate_env_var(env_var: str, value: str) -> bool: + return ( + env_var in CUDART_PATH_PREFERRED_ENVVARS # is a preferred location + or ( + os.sep in value # might contain a path + and env_var not in CUDART_PATH_IGNORED_ENVVARS # not ignored + and "CONDA" not in env_var # not another conda envvar + and "BASH_FUNC" not in env_var # not a bash function defined via envvar + and "\n" not in value # likely e.g. a script or something? + ) + ) + + +def get_potentially_lib_path_containing_env_vars() -> Dict[str, str]: + return {env_var: value for env_var, value in os.environ.items() if is_relevant_candidate_env_var(env_var, value)} + + +def find_cudart_libraries() -> Iterator[Path]: + """ + Searches for a cuda installations, in the following order of priority: + 1. active conda env + 2. LD_LIBRARY_PATH + 3. any other env vars, while ignoring those that + - are known to be unrelated + - don't contain the path separator `/` + + If multiple libraries are found in part 3, we optimistically try one, + while giving a warning message. + """ + candidate_env_vars = get_potentially_lib_path_containing_env_vars() + + for envvar in CUDART_PATH_PREFERRED_ENVVARS: + if envvar in candidate_env_vars: + directory = candidate_env_vars[envvar] + yield from find_cuda_libraries_in_path_list(directory) + candidate_env_vars.pop(envvar) + + for env_var, value in candidate_env_vars.items(): + yield from find_cuda_libraries_in_path_list(value) + + +def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None: + print( + f"PyTorch settings found: CUDA_VERSION={cuda_specs.cuda_version_string}, " + f"Highest Compute Capability: {cuda_specs.highest_compute_capability}.", + ) + + binary_path = get_cuda_bnb_library_path(cuda_specs) + if not binary_path.exists(): + print_dedented( + f""" + Library not found: {binary_path}. Maybe you need to compile it from source? + If you compiled from source, try again with `make CUDA_VERSION=DETECTED_CUDA_VERSION`, + for example, `make CUDA_VERSION=113`. + + The CUDA version for the compile might depend on your conda install, if using conda. + Inspect CUDA version via `conda list | grep cuda`. + """, + ) + + cuda_major, cuda_minor = cuda_specs.cuda_version_tuple + if cuda_major < 11: + print_dedented( + """ + WARNING: CUDA versions lower than 11 are currently not supported for LLM.int8(). + You will be only to use 8-bit optimizers and quantization routines! + """, + ) + + print(f"To manually override the PyTorch CUDA version please see: {NONPYTORCH_DOC_URL}") + + # 7.5 is the minimum CC for cublaslt + if not cuda_specs.has_cublaslt: + print_dedented( + """ + WARNING: Compute capability < 7.5 detected! Only slow 8-bit matmul is supported for your GPU! + If you run into issues with 8-bit matmul, you can try 4-bit quantization: + https://huggingface.co/blog/4bit-transformers-bitsandbytes + """, + ) + + # TODO: + # (1) CUDA missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible) + # (2) Multiple CUDA versions installed + + +def print_cuda_runtime_diagnostics() -> None: + cudart_paths = list(find_cudart_libraries()) + if not cudart_paths: + print("CUDA SETUP: WARNING! CUDA runtime files not found in any environmental path.") + elif len(cudart_paths) > 1: + print_dedented( + f""" + Found duplicate CUDA runtime files (see below). + + We select the PyTorch default CUDA runtime, which is {torch.version.cuda}, + but this might mismatch with the CUDA version that is needed for bitsandbytes. + To override this behavior set the `BNB_CUDA_VERSION=` environmental variable. + + For example, if you want to use the CUDA version 122, + BNB_CUDA_VERSION=122 python ... + + OR set the environmental variable in your .bashrc: + export BNB_CUDA_VERSION=122 + + In the case of a manual override, make sure you set LD_LIBRARY_PATH, e.g. + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-11.2, + """, + ) + for pth in cudart_paths: + print(f"* Found CUDA runtime at: {pth}") diff --git a/bitsandbytes/diagnostics/main.py b/bitsandbytes/diagnostics/main.py new file mode 100644 index 000000000..1ce096f69 --- /dev/null +++ b/bitsandbytes/diagnostics/main.py @@ -0,0 +1,85 @@ +import sys +import traceback + +import torch + +from bitsandbytes.consts import PACKAGE_GITHUB_URL +from bitsandbytes.cuda_specs import get_cuda_specs +from bitsandbytes.diagnostics.cuda import ( + print_cuda_diagnostics, + print_cuda_runtime_diagnostics, +) +from bitsandbytes.diagnostics.utils import print_dedented, print_header + + +def sanity_check(): + from bitsandbytes.cextension import lib + + if lib is None: + print_dedented( + """ + Couldn't load the bitsandbytes library, likely due to missing binaries. + Please ensure bitsandbytes is properly installed. + + For source installations, compile the binaries with `cmake -DCOMPUTE_BACKEND=cuda -S .`. + See the documentation for more details if needed. + + Trying a simple check anyway, but this will likely fail... + """, + ) + + from bitsandbytes.optim import Adam + + p = torch.nn.Parameter(torch.rand(10, 10).cuda()) + a = torch.rand(10, 10).cuda() + p1 = p.data.sum().item() + adam = Adam([p]) + out = a * p + loss = out.sum() + loss.backward() + adam.step() + p2 = p.data.sum().item() + assert p1 != p2 + + +def main(): + print_header("") + print_header("BUG REPORT INFORMATION") + print_header("") + + print_header("OTHER") + cuda_specs = get_cuda_specs() + print("CUDA specs:", cuda_specs) + if not torch.cuda.is_available(): + print("Torch says CUDA is not available. Possible reasons:") + print("1. CUDA driver not installed") + print("2. CUDA not installed") + print("3. You have multiple conflicting CUDA libraries") + if cuda_specs: + print_cuda_diagnostics(cuda_specs) + print_cuda_runtime_diagnostics() + print_header("") + print_header("DEBUG INFO END") + print_header("") + print("Checking that the library is importable and CUDA is callable...") + try: + sanity_check() + print("SUCCESS!") + print("Installation was successful!") + return + except ImportError: + print( + f"WARNING: {__package__} is currently running as CPU-only!\n" + "Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n" + f"If you think that this is so erroneously,\nplease report an issue!", + ) + except Exception: + traceback.print_exc() + print_dedented( + f""" + Above we output some debug information. + Please provide this info when creating an issue via {PACKAGE_GITHUB_URL}/issues/new/choose + WARNING: Please be sure to sanitize sensitive info from the output before posting it. + """, + ) + sys.exit(1) diff --git a/bitsandbytes/diagnostics/utils.py b/bitsandbytes/diagnostics/utils.py new file mode 100644 index 000000000..770209b9d --- /dev/null +++ b/bitsandbytes/diagnostics/utils.py @@ -0,0 +1,12 @@ +import textwrap + +HEADER_WIDTH = 60 + + +def print_header(txt: str, width: int = HEADER_WIDTH, filler: str = "+") -> None: + txt = f" {txt} " if txt else "" + print(txt.center(width, filler)) + + +def print_dedented(text): + print("\n".join(textwrap.dedent(text).strip().split("\n"))) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index a4e93bf37..6ef7c3261 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -3,20 +3,18 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import ctypes as ct +from functools import reduce # Required in Python 3 import itertools import operator -import random -import torch -import itertools -import math -import numpy as np +from typing import Any, Dict, Optional, Tuple -from functools import reduce # Required in Python 3 -from typing import Tuple, Any, Dict +import numpy as np +import torch from torch import Tensor + from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict -from .cextension import COMPILED_WITH_CUDA, lib, HIP_ENVIRONMENT +from .cextension import COMPILED_WITH_CUDA, HIP_ENVIRONMENT, lib # Remark: for AMD GPU we need to disable blocksize == 64 @@ -270,7 +268,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8) # the exponent is biased to 2^(e-1) -1 == 0 evalues = [] pvalues = [] - for i, val in enumerate(range(-((2**(exponent_bits-has_sign))), 2**(exponent_bits-has_sign), 1)): + for i, val in enumerate(range(-(2**(exponent_bits-has_sign)), 2**(exponent_bits-has_sign), 1)): evalues.append(2**val) @@ -332,7 +330,7 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8): non_sign_bits = total_bits - (1 if signed else 1) additional_items = 2 ** (non_sign_bits - max_exponent_bits) - 1 for i in range(max_exponent_bits): - fraction_items = int((2 ** (i + non_sign_bits - max_exponent_bits) + 1 if signed else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1)) + fraction_items = int(2 ** (i + non_sign_bits - max_exponent_bits) + 1 if signed else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1) boundaries = torch.linspace(0.1, 1, fraction_items) means = (boundaries[:-1] + boundaries[1:]) / 2.0 data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() @@ -452,13 +450,13 @@ def get_transform_buffer( rows = shape[0] * shape[1] cols = shape[-1] + state = (shape, to_order) if transpose: # swap dims tmp = rows rows = cols cols = tmp - shape = shape[::-1] - state = (shape, to_order) + state = (shape[::-1], to_order) if to_order == "row" or to_order == "col": return init_func(shape, dtype=dtype, device=device), state @@ -499,7 +497,7 @@ def nvidia_transform( from_order = state[1] if out is None: out, new_state = get_transform_buffer( - state[0], A.dtype, A.device, to_order, state[1], transpose + state[0], A.dtype, A.device, to_order, state[1] ) else: new_state = (state[1], to_order) @@ -524,7 +522,7 @@ def nvidia_transform( return out, new_state -def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, num_quantiles=256) -> Tensor: +def estimate_quantiles(A: Tensor, out: Optional[torch.Tensor] = None, offset: float = 1 / 512, num_quantiles=256) -> Tensor: ''' Estimates 256 equidistant quantiles on the input tensor eCDF. @@ -629,8 +627,8 @@ def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> 'QuantState # unpacking minor and non-tensor quant state items if necessary if len(qs_key) == 1: - qs_key = qs_key[0] - qs_dict.update(unpack_tensor_to_dict(qs_dict.pop(qs_key))) + first_qs_key = qs_key[0] + qs_dict.update(unpack_tensor_to_dict(qs_dict.pop(first_qs_key))) qs_dict = {k.split('.')[-1]: v for k, v in qs_dict.items()} # strip prefixes assert set(qs_dict.keys()).issubset(cls.valid_qs_keys) @@ -696,8 +694,30 @@ def to(self, device): self.state2.absmax = self.state2.absmax.to(device) self.state2.code = self.state2.code.to(device) + def __eq__(self, other): + if not isinstance(other, QuantState): + return False + + return ( + torch.allclose(self.absmax, other.absmax, atol=1e-6) and + self.shape == other.shape and + torch.allclose(self.code, other.code, atol=1e-6) and + self.dtype == other.dtype and + self.blocksize == other.blocksize and + self.quant_type == other.quant_type and + (self.offset == other.offset if self.offset is not None and other.offset is not None else self.offset is other.offset) and + (self.state2 == other.state2 if self.state2 is not None and other.state2 is not None else self.state2 is other.state2) + ) + -def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, out: Tensor = None, blocksize=4096, nested=False) -> Tensor: +def quantize_blockwise( + A: Tensor, + code: Optional[torch.Tensor] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize=4096, + nested=False, +) -> Tuple[Tensor, QuantState]: """ Quantize tensor A in blocks of size 4096 values. @@ -775,10 +795,10 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ou def dequantize_blockwise( A: Tensor, - quant_state: QuantState = None, - absmax: Tensor = None, - code: Tensor = None, - out: Tensor = None, + quant_state: Optional[QuantState] = None, + absmax: Optional[torch.Tensor] = None, + code: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, blocksize: int = 4096, nested=False ) -> Tensor: @@ -815,7 +835,7 @@ def dequantize_blockwise( if quant_state is None: quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=torch.float32) - + absmax = quant_state.absmax if quant_state.nested: absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) @@ -888,7 +908,7 @@ def get_4bit_type(typename, device=None, blocksize=64): -0.04934812, 0., 0.04273164, 0.12934483, 0.21961274, 0.31675666, 0.42563882, 0.55496234, 0.72424863, 1.][::-1] else: - raise NotImplementedError(f'4-bit AbnormalFloats currently only support blocksize 64.') + raise NotImplementedError('4-bit AbnormalFloats currently only support blocksize 64.') if data is None: raise NotImplementedError(f'Typename {typename} not supported') @@ -900,17 +920,26 @@ def get_4bit_type(typename, device=None, blocksize=64): return data.to(device) -def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=None, compress_statistics=False, quant_storage=torch.uint8): +def quantize_fp4(A: Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize=None, compress_statistics=False, quant_storage=torch.uint8): if blocksize is None: blocksize = 64 if not HIP_ENVIRONMENT else 128 return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'fp4', quant_storage) -def quantize_nf4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=None, compress_statistics=False, quant_storage=torch.uint8): +def quantize_nf4(A: Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize=None, compress_statistics=False, quant_storage=torch.uint8): if blocksize is None: blocksize = 64 if not HIP_ENVIRONMENT else 128 return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'nf4', quant_storage) -def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=None, compress_statistics=False, quant_type='fp4', quant_storage=torch.uint8) -> Tensor: + +def quantize_4bit( + A: Tensor, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize=None, + compress_statistics=False, + quant_type='fp4', + quant_storage=torch.uint8, +) -> Tuple[Tensor, QuantState]: """ Quantize tensor A in blocks of 4-bit values. @@ -996,19 +1025,19 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz return out, state -def dequantize_fp4(A: Tensor, quant_state: QuantState = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = None) -> Tensor: +def dequantize_fp4(A: Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = None) -> Tensor: if blocksize is None: blocksize = 64 if not HIP_ENVIRONMENT else 128 return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'fp4') -def dequantize_nf4(A: Tensor, quant_state: QuantState = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = None) -> Tensor: +def dequantize_nf4(A: Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = None) -> Tensor: if blocksize is None: blocksize = 64 if not HIP_ENVIRONMENT else 128 return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'nf4') -def dequantize_4bit(A: Tensor, quant_state: QuantState = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = None, quant_type='fp4') -> Tensor: +def dequantize_4bit(A: Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = None, quant_type='fp4') -> Tensor: """ Dequantizes FP4 blockwise quantized values. @@ -1092,7 +1121,11 @@ def dequantize_4bit(A: Tensor, quant_state: QuantState = None, absmax: Tensor = else: return out -def quantize(A: Tensor, code: Tensor = None, out: Tensor = None) -> Tensor: +def quantize( + A: Tensor, + code: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, +) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: if code is None: if "dynamic" not in name2qmap: name2qmap["dynamic"] = create_dynamic_map().to(A.device) @@ -1108,10 +1141,10 @@ def quantize(A: Tensor, code: Tensor = None, out: Tensor = None) -> Tensor: def dequantize( A: Tensor, - state: Tuple[Tensor, Tensor] = None, - absmax: Tensor = None, - code: Tensor = None, - out: Tensor = None, + state: Optional[Tuple[Tensor, Tensor]] = None, + absmax: Optional[torch.Tensor] = None, + code: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, ) -> Tensor: assert state is not None or absmax is not None if code is None and state is None: @@ -1126,7 +1159,7 @@ def dequantize( return out * state[0] -def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: +def quantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = None) -> Tensor: ''' Quantizes input tensor to 8-bit. @@ -1155,7 +1188,7 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: return out -def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: +def dequantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = None) -> Tensor: ''' Dequantizes the 8-bit tensor to 32-bit. @@ -1193,11 +1226,11 @@ def optimizer_update_32bit( eps: float, step: int, lr: float, - state2: Tensor = None, + state2: Optional[torch.Tensor] = None, beta2: float = 0.0, weight_decay: float = 0.0, gnorm_scale: float = 1.0, - unorm_vec: Tensor = None, + unorm_vec: Optional[torch.Tensor] = None, max_unorm: float = 0.0, skip_zeros=False, ) -> None: @@ -1296,7 +1329,7 @@ def optimizer_update_8bit( new_max2: Tensor, weight_decay: float = 0.0, gnorm_scale: float = 1.0, - unorm_vec: Tensor = None, + unorm_vec: Optional[torch.Tensor] = None, max_unorm: float = 0.0, ) -> None: """ @@ -1625,7 +1658,7 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8 def gemv_4bit( A: Tensor, B: Tensor, - out: Tensor = None, + out: Optional[torch.Tensor] = None, transposed_A=False, transposed_B=False, state=None @@ -1633,10 +1666,10 @@ def gemv_4bit( prev_device = pre_call(A.device) #sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype) if state is None: - raise ValueError(f'state cannot None. gem_4bit( ) requires the state from quantize_4bit( )') + raise ValueError('state cannot None. gem_4bit( ) requires the state from quantize_4bit( )') if A.numel() != A.shape[-1]: - raise ValueError(f'Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]') + raise ValueError('Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]') Bshape = state.shape bout = Bshape[0] @@ -1685,7 +1718,7 @@ def gemv_4bit( def igemm( A: Tensor, B: Tensor, - out: Tensor = None, + out: Optional[torch.Tensor] = None, transposed_A=False, transposed_B=False, ): @@ -1774,7 +1807,7 @@ def igemm( def batched_igemm( A: Tensor, B: Tensor, - out: Tensor = None, + out: Optional[torch.Tensor] = None, transposed_A=False, transposed_B=False, ): @@ -2554,10 +2587,7 @@ def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half): def extract_outliers(A, SA, idx): shapeA = SA[0] formatA = SA[1] - if not HIP_ENVIRONMENT: - assert formatA in ["col_turing", "col_ampere"] - else: - assert formatA in ["col"] + assert formatA in ["col_turing", "col_ampere"] assert A.device.type == "cuda" out = torch.zeros( diff --git a/bitsandbytes/nn/__init__.py b/bitsandbytes/nn/__init__.py index 6fa6d1183..96f4359bf 100644 --- a/bitsandbytes/nn/__init__.py +++ b/bitsandbytes/nn/__init__.py @@ -2,5 +2,21 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .modules import Int8Params, Linear8bitLt, StableEmbedding, Linear4bit, LinearNF4, LinearFP4, Params4bit, OutlierAwareLinear, SwitchBackLinearBnb, Embedding -from .triton_based_modules import SwitchBackLinear, SwitchBackLinearGlobal, SwitchBackLinearVectorwise, StandardLinear +from .modules import ( + Embedding, + Int8Params, + Linear4bit, + Linear8bitLt, + LinearFP4, + LinearNF4, + OutlierAwareLinear, + Params4bit, + StableEmbedding, + SwitchBackLinearBnb, +) +from .triton_based_modules import ( + StandardLinear, + SwitchBackLinear, + SwitchBackLinearGlobal, + SwitchBackLinearVectorwise, +) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index e0d94d861..6c0ee6ffb 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -2,23 +2,48 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import copy from typing import Any, Dict, Optional, TypeVar, Union, overload - import warnings + import torch -import torch.nn.functional as F from torch import Tensor, device, dtype, nn +import torch.nn.functional as F import bitsandbytes as bnb +from bitsandbytes.autograd._functions import get_tile_inds, undo_layout from bitsandbytes.functional import QuantState -from bitsandbytes.autograd._functions import undo_layout, get_tile_inds from bitsandbytes.optim import GlobalOptimManager -from bitsandbytes.utils import OutlierTracer, find_outlier_dims +from bitsandbytes.utils import OutlierTracer T = TypeVar("T", bound="torch.nn.Module") class StableEmbedding(torch.nn.Embedding): + """ + Custom embedding layer designed to improve stability during training for NLP tasks by using 32-bit optimizer states. It is designed to reduce gradient variations that can result from quantization. This embedding layer is initialized with Xavier uniform initialization followed by layer normalization. + + Example: + + ``` + # Initialize StableEmbedding layer with vocabulary size 1000, embedding dimension 300 + embedding_layer = StableEmbedding(num_embeddings=1000, embedding_dim=300) + + # Reset embedding parameters + embedding_layer.reset_parameters() + + # Perform a forward pass with input tensor + input_tensor = torch.tensor([1, 2, 3]) + output_embedding = embedding_layer(input_tensor) + ``` + + Attributes: + norm (`torch.nn.LayerNorm`): Layer normalization applied after the embedding. + + Methods: + reset_parameters(): Reset embedding parameters using Xavier uniform initialization. + forward(input: Tensor) -> Tensor: Forward pass through the stable embedding layer. + """ def __init__( self, num_embeddings: int, @@ -32,6 +57,25 @@ def __init__( device=None, dtype=None, ) -> None: + """ + Args: + num_embeddings (`int`): + The number of unique embeddings (vocabulary size). + embedding_dim (`int`): + The dimensionality of the embedding. + padding_idx (`Optional[int]`): + Pads the output with zeros at the given index. + max_norm (`Optional[float]`): + Renormalizes embeddings to have a maximum L2 norm. + norm_type (`float`, defaults to `2.0`): + The p-norm to compute for the `max_norm` option. + scale_grad_by_freq (`bool`, defaults to `False`): + Scale gradient by frequency during backpropagation. + sparse (`bool`, defaults to `False`): + Computes dense gradients. Set to `True` to compute sparse gradients instead. + _weight (`Optional[Tensor]`): + Pretrained embeddings. + """ super().__init__( num_embeddings, embedding_dim, @@ -83,6 +127,9 @@ def forward(self, input: Tensor) -> Tensor: class Embedding(torch.nn.Embedding): + """ + Embedding class to store and retrieve word embeddings from their indices. + """ def __init__( self, num_embeddings: int, @@ -95,6 +142,25 @@ def __init__( _weight: Optional[Tensor] = None, device: Optional[device] = None, ) -> None: + """ + Args: + num_embeddings (`int`): + The number of unique embeddings (vocabulary size). + embedding_dim (`int`): + The dimensionality of the embedding. + padding_idx (`Optional[int]`): + Pads the output with zeros at the given index. + max_norm (`Optional[float]`): + Renormalizes embeddings to have a maximum L2 norm. + norm_type (`float`, defaults to `2.0`): + The p-norm to compute for the `max_norm` option. + scale_grad_by_freq (`bool`, defaults to `False`): + Scale gradient by frequency during backpropagation. + sparse (`bool`, defaults to `False`): + Computes dense gradients. Set to `True` to compute sparse gradients instead. + _weight (`Optional[Tensor]`): + Pretrained embeddings. + """ super().__init__( num_embeddings, embedding_dim, @@ -141,12 +207,11 @@ def forward(self, input: Tensor) -> Tensor: class Params4bit(torch.nn.Parameter): - # Remark: change blocksize to 128 for AMD gpu def __new__( cls, data: Optional[torch.Tensor] = None, - requires_grad=True, - quant_state: QuantState = None, + requires_grad=False, # quantized weights should be frozen by default + quant_state: Optional[QuantState] = None, blocksize: int = 128, compress_statistics: bool = True, quant_type: str = 'fp4', @@ -168,6 +233,37 @@ def __new__( self.module = module return self + def __getstate__(self): + state = self.__dict__ + state["data"] = self.data + state["requires_grad"] = self.requires_grad + return state + + def __setstate__(self, state): + self.requires_grad = state["requires_grad"] + self.blocksize = state["blocksize"] + self.compress_statistics = state["compress_statistics"] + self.quant_type = state["quant_type"] + self.quant_state = state["quant_state"] + self.data = state["data"] + self.quant_storage = state["quant_storage"] + self.bnb_quantized = state["bnb_quantized"] + self.module = state["module"] + + def __deepcopy__(self,memo): + new_instance = type(self).__new__(type(self)) + state = self.__getstate__() + new_instance.__setstate__(state) + new_instance.quant_state = copy.deepcopy(state["quant_state"]) + new_instance.data = copy.deepcopy(state["data"]) + return new_instance + + def __copy__(self): + new_instance = type(self).__new__(type(self)) + state = self.__getstate__() + new_instance.__setstate__(state) + return new_instance + @classmethod def from_prequantized(cls, data: torch.Tensor, quantized_stats: Dict[str, Any], requires_grad: bool = False, device='cuda', **kwargs) -> "Params4bit": self = torch.Tensor._make_subclass(cls, data.to(device)) @@ -181,8 +277,13 @@ def from_prequantized(cls, data: torch.Tensor, quantized_stats: Dict[str, Any], def _quantize(self, device): w = self.data.contiguous().cuda(device) - w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics, - quant_type=self.quant_type, quant_storage=self.quant_storage) + w_4bit, quant_state = bnb.functional.quantize_4bit( + w, + blocksize=self.blocksize, + compress_statistics=self.compress_statistics, + quant_type=self.quant_type, + quant_storage=self.quant_storage, + ) self.data = w_4bit self.quant_state = quant_state if self.module is not None: @@ -223,8 +324,49 @@ def to(self, *args, **kwargs): class Linear4bit(nn.Linear): + """ + This class is the base module for the 4-bit quantization algorithm presented in [QLoRA](https://arxiv.org/abs/2305.14314). + QLoRA 4-bit linear layers uses blockwise k-bit quantization under the hood, with the possibility of selecting various + compute datatypes such as FP4 and NF4. + + In order to quantize a linear layer one should first load the original fp16 / bf16 weights into + the Linear4bit module, then call `quantized_module.to("cuda")` to quantize the fp16 / bf16 weights. + + Example: + + ```python + import torch + import torch.nn as nn + import bitsandbytes as bnb + from bnb.nn import Linear4bit + + fp16_model = nn.Sequential( + nn.Linear(64, 64), + nn.Linear(64, 64) + ) + + quantized_model = nn.Sequential( + Linear4bit(64, 64), + Linear4bit(64, 64) + ) + + quantized_model.load_state_dict(fp16_model.state_dict()) + quantized_model = quantized_model.to(0) # Quantization happens here + ``` + """ def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4', quant_storage=torch.uint8, device=None): + """ + Initialize Linear4bit class. + + Args: + input_features (`str`): + Number of input features of the linear layer. + output_features (`str`): + Number of output features of the linear layer. + bias (`bool`, defaults to `True`): + Whether the linear class uses the bias term as well. + """ super().__init__(input_features, output_features, bias, device) self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type, quant_storage=quant_storage, module=self) # self.persistent_buffers = [] # TODO consider as way to save quant state @@ -243,10 +385,10 @@ def set_compute_type(self, x): if self.compute_dtype == torch.float32 and (x.numel() == x.shape[-1]): # single batch inference with input torch.float16 and compute_dtype float32 -> slow inference when it could be fast # warn the user about this - warnings.warn(f'Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference.') + warnings.warn('Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference.') warnings.filterwarnings('ignore', message='.*inference.') if self.compute_dtype == torch.float32 and (x.numel() != x.shape[-1]): - warnings.warn(f'Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.') + warnings.warn('Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.') warnings.filterwarnings('ignore', message='.*inference or training') def _save_to_state_dict(self, destination, prefix, keep_vars): @@ -292,7 +434,19 @@ def forward(self, x: torch.Tensor): class LinearFP4(Linear4bit): + """ + Implements the FP4 data type. + """ def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_storage=torch.uint8, device=None): + """ + Args: + input_features (`str`): + Number of input features of the linear layer. + output_features (`str`): + Number of output features of the linear layer. + bias (`bool`, defaults to `True`): + Whether the linear class uses the bias term as well. + """ super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4', quant_storage, device) @@ -308,6 +462,15 @@ class LinearNF4(Linear4bit): the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236. ''' def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_storage=torch.uint8, device=None): + """ + Args: + input_features (`str`): + Number of input features of the linear layer. + output_features (`str`): + Number of output features of the linear layer. + bias (`bool`, defaults to `True`): + Whether the linear class uses the bias term as well. + """ super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4', quant_storage, device) @@ -325,7 +488,9 @@ def __new__( cls.SCB = None if data is None: data = torch.empty(0) - return torch.Tensor._make_subclass(cls, data, requires_grad) + obj = torch.Tensor._make_subclass(cls, data, requires_grad) + obj.CB, obj.SCB = cls.CB, cls.SCB + return obj def cuda(self, device): if self.has_fp16_weights: @@ -338,8 +503,8 @@ def cuda(self, device): del CBt del SCBt self.data = CB - setattr(self, "CB", CB) - setattr(self, "SCB", SCB) + self.CB = CB + self.SCB = SCB return self @@ -398,8 +563,49 @@ def maybe_rearrange_weight(state_dict, prefix, local_metadata, strict, missing_k class Linear8bitLt(nn.Linear): + """ + This class is the base module for the [LLM.int8()](https://arxiv.org/abs/2208.07339) algorithm. + To read more about it, have a look at the paper. + + In order to quantize a linear layer one should first load the original fp16 / bf16 weights into + the Linear8bitLt module, then call `int8_module.to("cuda")` to quantize the fp16 weights. + + Example: + + ```python + import torch + import torch.nn as nn + + import bitsandbytes as bnb + from bnb.nn import Linear8bitLt + + fp16_model = nn.Sequential( + nn.Linear(64, 64), + nn.Linear(64, 64) + ) + + int8_model = nn.Sequential( + Linear8bitLt(64, 64, has_fp16_weights=False), + Linear8bitLt(64, 64, has_fp16_weights=False) + ) + + int8_model.load_state_dict(fp16_model.state_dict()) + int8_model = int8_model.to(0) # Quantization happens here + ``` + """ def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0, index=None, device=None): + """ + Initialize Linear8bitLt class. + + Args: + input_features (`str`): + Number of input features of the linear layer. + output_features (`str`): + Number of output features of the linear layer. + bias (`bool`, defaults to `True`): + Whether the linear class uses the bias term as well. + """ super().__init__(input_features, output_features, bias, device) assert not memory_efficient_backward, "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0" self.state = bnb.MatmulLtState() diff --git a/bitsandbytes/nn/triton_based_modules.py b/bitsandbytes/nn/triton_based_modules.py index de07ac647..9c7738c59 100644 --- a/bitsandbytes/nn/triton_based_modules.py +++ b/bitsandbytes/nn/triton_based_modules.py @@ -1,16 +1,24 @@ -import torch -import torch.nn as nn -import time from functools import partial -from bitsandbytes.triton.triton_utils import is_triton_available +import torch +import torch.nn as nn from bitsandbytes.triton.dequantize_rowwise import dequantize_rowwise +from bitsandbytes.triton.int8_matmul_mixed_dequantize import ( + int8_matmul_mixed_dequantize, +) +from bitsandbytes.triton.int8_matmul_rowwise_dequantize import ( + int8_matmul_rowwise_dequantize, +) +from bitsandbytes.triton.quantize_columnwise_and_transpose import ( + quantize_columnwise_and_transpose, +) +from bitsandbytes.triton.quantize_global import ( + quantize_global, + quantize_global_transpose, +) from bitsandbytes.triton.quantize_rowwise import quantize_rowwise -from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose -from bitsandbytes.triton.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize -from bitsandbytes.triton.quantize_global import quantize_global, quantize_global_transpose -from bitsandbytes.triton.int8_matmul_mixed_dequantize import int8_matmul_mixed_dequantize +from bitsandbytes.triton.triton_utils import is_triton_available class _switchback_global(torch.autograd.Function): @@ -162,7 +170,7 @@ def __init__( ): super().__init__(in_features, out_features, bias, device, dtype) - if not is_triton_available: + if not is_triton_available(): raise ImportError('''Could not import triton. Please install triton to use SwitchBackLinear. Alternatively, you can use bnb.nn.SwitchBackLinearBnb, but it will be slower''') diff --git a/bitsandbytes/optim/__init__.py b/bitsandbytes/optim/__init__.py index 83a57bd9f..6796b8e0e 100644 --- a/bitsandbytes/optim/__init__.py +++ b/bitsandbytes/optim/__init__.py @@ -7,10 +7,17 @@ from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit from .adam import Adam, Adam8bit, Adam32bit, PagedAdam, PagedAdam8bit, PagedAdam32bit -from .adamw import AdamW, AdamW8bit, AdamW32bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit +from .adamw import ( + AdamW, + AdamW8bit, + AdamW32bit, + PagedAdamW, + PagedAdamW8bit, + PagedAdamW32bit, +) from .lamb import LAMB, LAMB8bit, LAMB32bit from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS +from .lion import Lion, Lion8bit, Lion32bit, PagedLion, PagedLion8bit, PagedLion32bit from .optimizer import GlobalOptimManager from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit -from .lion import Lion, Lion8bit, Lion32bit, PagedLion, PagedLion8bit, PagedLion32bit from .sgd import SGD, SGD8bit, SGD32bit diff --git a/bitsandbytes/optim/adagrad.py b/bitsandbytes/optim/adagrad.py index 7d8df58ac..c2ea87ab0 100644 --- a/bitsandbytes/optim/adagrad.py +++ b/bitsandbytes/optim/adagrad.py @@ -20,6 +20,33 @@ def __init__( percentile_clipping=100, block_wise=True, ): + """ + Base Adagrad optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-2): + The learning rate. + lr_decay (`int`, defaults to 0): + The learning rate decay. + weight_decay (`float`, defaults to 0.0): + The weight decay value for the optimizer. + initial_accumulator_value (`int`, defaults to 0): + The initial momemtum values. + eps (`float`, defaults to 1e-10): + The epsilon value prevents division by zero in the optimizer. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + """ if not 0.0 <= lr: raise ValueError(f"Invalid learning rate: {lr}") if not 0.0 <= weight_decay: @@ -62,6 +89,33 @@ def __init__( percentile_clipping=100, block_wise=True, ): + """ + 8-bit Adagrad optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-2): + The learning rate. + lr_decay (`int`, defaults to 0): + The learning rate decay. + weight_decay (`float`, defaults to 0.0): + The weight decay value for the optimizer. + initial_accumulator_value (`int`, defaults to 0): + The initial momemtum values. + eps (`float`, defaults to 1e-10): + The epsilon value prevents division by zero in the optimizer. + optim_bits (`int`, defaults to 8): + The number of bits of the optimizer state. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + """ if not 0.0 <= lr: raise ValueError(f"Invalid learning rate: {lr}") if not 0.0 <= weight_decay: @@ -105,6 +159,33 @@ def __init__( percentile_clipping=100, block_wise=True, ): + """ + 32-bit Adagrad optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-2): + The learning rate. + lr_decay (`int`, defaults to 0): + The learning rate decay. + weight_decay (`float`, defaults to 0.0): + The weight decay value for the optimizer. + initial_accumulator_value (`int`, defaults to 0): + The initial momemtum values. + eps (`float`, defaults to 1e-10): + The epsilon value prevents division by zero in the optimizer. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + """ if not 0.0 <= lr: raise ValueError(f"Invalid learning rate: {lr}") if not 0.0 <= weight_decay: diff --git a/bitsandbytes/optim/adam.py b/bitsandbytes/optim/adam.py index 86981eb86..e534c8b8f 100644 --- a/bitsandbytes/optim/adam.py +++ b/bitsandbytes/optim/adam.py @@ -16,31 +16,205 @@ class Adam(Optimizer2State): def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + """ + Base Adam optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 0.0): + The weight decay value for the optimizer. + amsgrad (`bool`, defaults to `False`): + Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) class Adam8bit(Optimizer2State): def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + """ + 8-bit Adam optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 0.0): + The weight decay value for the optimizer. + amsgrad (`bool`, defaults to `False`): + Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) class Adam32bit(Optimizer2State): def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + """ + 32-bit Adam optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 0.0): + The weight decay value for the optimizer. + amsgrad (`bool`, defaults to `False`): + Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) class PagedAdam(Optimizer2State): def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + """ + Paged Adam optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 0.0): + The weight decay value for the optimizer. + amsgrad (`bool`, defaults to `False`): + Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) class PagedAdam8bit(Optimizer2State): def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + """ + 8-bit paged Adam optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 0.0): + The weight decay value for the optimizer. + amsgrad (`bool`, defaults to `False`): + Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) class PagedAdam32bit(Optimizer2State): def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + """ + Paged 32-bit Adam optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 0.0): + The weight decay value for the optimizer. + amsgrad (`bool`, defaults to `False`): + Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) class AnalysisAdam(torch.optim.Optimizer): diff --git a/bitsandbytes/optim/adamw.py b/bitsandbytes/optim/adamw.py index 21077f1a0..8acd0ba9b 100644 --- a/bitsandbytes/optim/adamw.py +++ b/bitsandbytes/optim/adamw.py @@ -5,35 +5,208 @@ from bitsandbytes.optim.optimizer import Optimizer2State - class AdamW(Optimizer2State): def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + """ + Base AdamW optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 1e-2): + The weight decay value for the optimizer. + amsgrad (`bool`, defaults to `False`): + Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged ) class AdamW8bit(Optimizer2State): def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + """ + 8-bit AdamW optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 1e-2): + The weight decay value for the optimizer. + amsgrad (`bool`, defaults to `False`): + Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged ) class AdamW32bit(Optimizer2State): def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + """ + 32-bit AdamW optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 1e-2): + The weight decay value for the optimizer. + amsgrad (`bool`, defaults to `False`): + Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) class PagedAdamW(Optimizer2State): def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): + """ + Paged AdamW optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 1e-2): + The weight decay value for the optimizer. + amsgrad (`bool`, defaults to `False`): + Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) class PagedAdamW8bit(Optimizer2State): def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): + """ + Paged 8-bit AdamW optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 1e-2): + The weight decay value for the optimizer. + amsgrad (`bool`, defaults to `False`): + Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) class PagedAdamW32bit(Optimizer2State): def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): + """ + Paged 32-bit AdamW optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 1e-2): + The weight decay value for the optimizer. + amsgrad (`bool`, defaults to `False`): + Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) diff --git a/bitsandbytes/optim/lamb.py b/bitsandbytes/optim/lamb.py index 1fbb6fadc..ec829ee85 100644 --- a/bitsandbytes/optim/lamb.py +++ b/bitsandbytes/optim/lamb.py @@ -23,6 +23,39 @@ def __init__( block_wise=False, max_unorm=1.0, ): + """ + Base LAMB optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + bias_correction (`bool`, defaults to `True`): + Whether to apply bias correction to the first and second-order moments. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 1e-2): + The weight decay value for the optimizer. + amsgrad (`bool`, defaults to `False`): + Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + adam_w_mode (`bool`, defaults to `True`): + Whether to use the AdamW variant. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + max_unorm (`float`, defaults to 1.0): + The maximum gradient norm. + """ super().__init__( "lamb", params, @@ -56,6 +89,37 @@ def __init__( block_wise=False, max_unorm=1.0, ): + """ + 8-bit LAMB optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + bias_correction (`bool`, defaults to `True`): + Whether to apply bias correction to the first and second-order moments. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 1e-2): + The weight decay value for the optimizer. + amsgrad (`bool`, defaults to `False`): + Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + adam_w_mode (`bool`, defaults to `True`): + Whether to use the AdamW variant. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + max_unorm (`float`, defaults to 1.0): + The maximum gradient norm. + """ super().__init__( "lamb", params, @@ -89,6 +153,37 @@ def __init__( block_wise=False, max_unorm=1.0, ): + """ + 32-bit LAMB optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + bias_correction (`bool`, defaults to `True`): + Whether to apply bias correction to the first and second-order moments. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 1e-2): + The weight decay value for the optimizer. + amsgrad (`bool`, defaults to `False`): + Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + adam_w_mode (`bool`, defaults to `True`): + Whether to use the AdamW variant. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + max_unorm (`float`, defaults to 1.0): + The maximum gradient norm. + """ super().__init__( "lamb", params, diff --git a/bitsandbytes/optim/lars.py b/bitsandbytes/optim/lars.py index 73554e3cc..7449b805b 100644 --- a/bitsandbytes/optim/lars.py +++ b/bitsandbytes/optim/lars.py @@ -23,6 +23,33 @@ def __init__( percentile_clipping=100, max_unorm=0.02, ): + """ + Base LARS optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`): + The learning rate. + momentum (`float`, defaults to 0): + The momentum value speeds up the optimizer by taking bigger steps. + dampening (`float`, defaults to 0): + The dampening value reduces the momentum of the optimizer. + weight_decay (`float`, defaults to 1e-2): + The weight decay value for the optimizer. + nesterov (`bool`, defaults to `False`): + Whether to use Nesterov momentum. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + max_unorm (`float`, defaults to 0.02): + The maximum gradient norm. + """ if momentum == 0: raise NotImplementedError( "LARS without momentum is not supported!" @@ -57,6 +84,31 @@ def __init__( percentile_clipping=100, max_unorm=0.02, ): + """ + 8-bit LARS optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`): + The learning rate. + momentum (`float`, defaults to 0): + The momentum value speeds up the optimizer by taking bigger steps. + dampening (`float`, defaults to 0): + The dampening value reduces the momentum of the optimizer. + weight_decay (`float`, defaults to 1e-2): + The weight decay value for the optimizer. + nesterov (`bool`, defaults to `False`): + Whether to use Nesterov momentum. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + max_unorm (`float`, defaults to 0.02): + The maximum gradient norm. + """ if momentum == 0: raise NotImplementedError( "LARS without momentum is not supported!" @@ -91,6 +143,31 @@ def __init__( percentile_clipping=100, max_unorm=0.02, ): + """ + 32-bit LARS optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`): + The learning rate. + momentum (`float`, defaults to 0): + The momentum value speeds up the optimizer by taking bigger steps. + dampening (`float`, defaults to 0): + The dampening value reduces the momentum of the optimizer. + weight_decay (`float`, defaults to 1e-2): + The weight decay value for the optimizer. + nesterov (`bool`, defaults to `False`): + Whether to use Nesterov momentum. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + max_unorm (`float`, defaults to 0.02): + The maximum gradient norm. + """ if momentum == 0: raise NotImplementedError( "LARS without momentum is not supported!" diff --git a/bitsandbytes/optim/lion.py b/bitsandbytes/optim/lion.py index 2bde1a447..ce185f863 100644 --- a/bitsandbytes/optim/lion.py +++ b/bitsandbytes/optim/lion.py @@ -4,27 +4,168 @@ # LICENSE file in the root directory of this source tree. from bitsandbytes.optim.optimizer import Optimizer1State + class Lion(Optimizer1State): def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + """ + Base Lion optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-4): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + weight_decay (`float`, defaults to 0): + The weight decay value for the optimizer. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ super().__init__("lion", params, lr, betas, 0., weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) class Lion8bit(Optimizer1State): def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + """ + 8-bit Lion optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-4): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + weight_decay (`float`, defaults to 0): + The weight decay value for the optimizer. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ super().__init__("lion", params, lr, betas, 0., weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) class Lion32bit(Optimizer1State): def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + """ + 32-bit Lion optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-4): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + weight_decay (`float`, defaults to 0): + The weight decay value for the optimizer. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ super().__init__("lion", params, lr, betas, 0., weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) class PagedLion(Optimizer1State): def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): + """ + Paged Lion optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-4): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + weight_decay (`float`, defaults to 0): + The weight decay value for the optimizer. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + """ super().__init__("lion", params, lr, betas, 0., weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) class PagedLion8bit(Optimizer1State): def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): + """ + Paged 8-bit Lion optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-4): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + weight_decay (`float`, defaults to 0): + The weight decay value for the optimizer. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + """ super().__init__("lion", params, lr, betas, 0., weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) class PagedLion32bit(Optimizer1State): def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): + """ + Paged 32-bit Lion optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-4): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + weight_decay (`float`, defaults to 0): + The weight decay value for the optimizer. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + """ super().__init__("lion", params, lr, betas, 0., weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index fb83eddf0..a97afb026 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -2,8 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from collections import abc as container_abcs -from collections import defaultdict +from collections import abc as container_abcs, defaultdict from copy import deepcopy from itertools import chain @@ -19,6 +18,9 @@ def __init__(self, initial_data): class GlobalOptimManager: + """ + A global optimizer manager for enabling custom optimizer configs. + """ _instance = None def __init__(self): @@ -54,22 +56,40 @@ def override_config( self, parameters, key=None, value=None, key_value_dict=None ): """ - Overrides initial optimizer config for specific parameters. + Override initial optimizer config with specific hyperparameters. The key-values of the optimizer config for the input parameters are overridden - This can be both, optimizer parameters like "betas", or "lr" or it can be - 8-bit specific parameters like "optim_bits", "percentile_clipping". - - Parameters - ---------- - parameters : torch.Tensor or list(torch.Tensors) - The input parameters. - key : str - The hyperparamter to override. - value : object - The value for the hyperparamters. - key_value_dict : dict - A dictionary with multiple key-values to override. + This can be both, optimizer parameters like `betas` or `lr`, or it can be + 8-bit specific parameters like `optim_bits` or `percentile_clipping`. + + Arguments: + parameters (`torch.Tensor` or `list(torch.Tensors)`): + The input parameters. + key (`str`): + The hyperparamter to override. + value: + The hyperparameter values. + key_value_dict (`dict`): + A dictionary with multiple key-values to override. + + Example: + + ```py + import torch + import bitsandbytes as bnb + + mng = bnb.optim.GlobalOptimManager.get_instance() + + model = MyModel() + mng.register_parameters(model.parameters()) # 1. register parameters while still on CPU + + model = model.cuda() + # use 8-bit optimizer states for all parameters + adam = bnb.optim.Adam(model.parameters(), lr=0.001, optim_bits=8) + + # 2. override: the parameter model.fc1.weight now uses 32-bit Adam + mng.override_config(model.fc1.weight, 'optim_bits', 32) + ``` """ self.uses_config_override = True if isinstance(parameters, torch.nn.Parameter): @@ -93,6 +113,17 @@ def register_module_override(self, module, param_name, config): class Optimizer8bit(torch.optim.Optimizer): def __init__(self, params, defaults, optim_bits=32, is_paged=False): + """ + Base 8-bit optimizer class. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ super().__init__(params, defaults) self.initialized = False self.name2qmap = {} @@ -126,11 +157,11 @@ def __setstate__(self, state): super().__setstate__(state) def load_state_dict(self, state_dict): - r"""Loads the optimizer state. + """Load an optimizer state. - Args: - state_dict (dict): optimizer state. Should be an object returned - from a call to :meth:`state_dict`. + Arguments: + state_dict (`dict`): + An optimizer state (should be returned from a call to `state_dict`) to load. """ # deepcopy, to be consistent with module API state_dict = deepcopy(state_dict) @@ -238,11 +269,11 @@ def check_overrides(self): @torch.no_grad() def step(self, closure=None): - """Performs a single optimization step. + """Perform a single optimization step. Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. + closure (`Callable`, *optional*, defaults to `None`): + A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: @@ -340,6 +371,39 @@ def __init__( skip_zeros=False, is_paged=False ): + """ + Base 2-state update optimizer class. + + Arguments: + optimizer_name (`str`): + The name of the optimizer. + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + betas (`tuple`, defaults to (0.9, 0.999)): + The beta values for the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value for the optimizer. + weight_decay (`float`, defaults to 0.0): + The weight decay value for the optimizer. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + max_unorm (`float`, defaults to 0.0): + The maximum value to normalize each block with. + skip_zeros (`bool`, defaults to `False`): + Whether to skip zero values for sparse gradients and models to ensure correct updates. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ if not 0.0 <= lr: raise ValueError(f"Invalid learning rate: {lr}") if not 0.0 <= eps: @@ -553,6 +617,39 @@ def __init__( skip_zeros=False, is_paged=False ): + """ + Base 1-state update optimizer class. + + Arguments: + optimizer_name (`str`): + The name of the optimizer. + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + betas (`tuple`, defaults to (0.9, 0.0)): + The beta values for the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value for the optimizer. + weight_decay (`float`, defaults to 0.0): + The weight decay value for the optimizer. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + max_unorm (`float`, defaults to 0.0): + The maximum value to normalize each block with. + skip_zeros (`bool`, defaults to `False`): + Whether to skip zero values for sparse gradients and models to ensure correct updates. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ if not 0.0 <= lr: raise ValueError(f"Invalid learning rate: {lr}") if not 0.0 <= eps: diff --git a/bitsandbytes/optim/rmsprop.py b/bitsandbytes/optim/rmsprop.py index 2853ca723..ac371a66f 100644 --- a/bitsandbytes/optim/rmsprop.py +++ b/bitsandbytes/optim/rmsprop.py @@ -21,6 +21,35 @@ def __init__( percentile_clipping=100, block_wise=True, ): + """ + Base RMSprop optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-2): + The learning rate. + alpha (`float`, defaults to 0.99): + The alpha value is the decay rate of the squared gradients of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 0.0): + The weight decay value for the optimizer. + momentum (`float`, defaults to 0): + The momentum value speeds up the optimizer by taking bigger steps. + centered (`bool`, defaults to `False`): + Whether the gradients are normalized by the variance. If `True`, it can help training at the expense of additional compute. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + """ if alpha == 0: raise NotImplementedError( "RMSprop with alpha==0.0 is not supported!" @@ -57,6 +86,35 @@ def __init__( percentile_clipping=100, block_wise=True, ): + """ + 8-bit RMSprop optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-2): + The learning rate. + alpha (`float`, defaults to 0.99): + The alpha value is the decay rate of the squared gradients of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 0.0): + The weight decay value for the optimizer. + momentum (`float`, defaults to 0): + The momentum value speeds up the optimizer by taking bigger steps. + centered (`bool`, defaults to `False`): + Whether the gradients are normalized by the variance. If `True`, it can help training at the expense of additional compute. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + """ if alpha == 0: raise NotImplementedError( "RMSprop with alpha==0.0 is not supported!" @@ -93,6 +151,35 @@ def __init__( percentile_clipping=100, block_wise=True, ): + """ + 32-bit RMSprop optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-2): + The learning rate. + alpha (`float`, defaults to 0.99): + The alpha value is the decay rate of the squared gradients of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 0.0): + The weight decay value for the optimizer. + momentum (`float`, defaults to 0): + The momentum value speeds up the optimizer by taking bigger steps. + centered (`bool`, defaults to `False`): + Whether the gradients are normalized by the variance. If `True`, it can help training at the expense of additional compute. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + """ if alpha == 0: raise NotImplementedError( diff --git a/bitsandbytes/optim/sgd.py b/bitsandbytes/optim/sgd.py index 3c0fc2b9f..0f0b12e4b 100644 --- a/bitsandbytes/optim/sgd.py +++ b/bitsandbytes/optim/sgd.py @@ -20,6 +20,33 @@ def __init__( percentile_clipping=100, block_wise=True, ): + """ + Base SGD optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`): + The learning rate. + momentum (`float`, defaults to 0): + The momentum value speeds up the optimizer by taking bigger steps. + dampening (`float`, defaults to 0): + The dampening value reduces the momentum of the optimizer. + weight_decay (`float`, defaults to 0.0): + The weight decay value for the optimizer. + nesterov (`bool`, defaults to `False`): + Whether to use Nesterov momentum. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + """ if momentum == 0: raise NotImplementedError("SGD without momentum is not supported!") super().__init__( @@ -51,6 +78,31 @@ def __init__( percentile_clipping=100, block_wise=True, ): + """ + 8-bit SGD optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`): + The learning rate. + momentum (`float`, defaults to 0): + The momentum value speeds up the optimizer by taking bigger steps. + dampening (`float`, defaults to 0): + The dampening value reduces the momentum of the optimizer. + weight_decay (`float`, defaults to 0.0): + The weight decay value for the optimizer. + nesterov (`bool`, defaults to `False`): + Whether to use Nesterov momentum. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + """ if momentum == 0: raise NotImplementedError("SGD without momentum is not supported!") super().__init__( @@ -82,6 +134,31 @@ def __init__( percentile_clipping=100, block_wise=True, ): + """ + 32-bit SGD optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`): + The learning rate. + momentum (`float`, defaults to 0): + The momentum value speeds up the optimizer by taking bigger steps. + dampening (`float`, defaults to 0): + The dampening value reduces the momentum of the optimizer. + weight_decay (`float`, defaults to 0.0): + The weight decay value for the optimizer. + nesterov (`bool`, defaults to `False`): + Whether to use Nesterov momentum. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + """ if momentum == 0: raise NotImplementedError("SGD without momentum is not supported!") super().__init__( diff --git a/bitsandbytes/research/__init__.py b/bitsandbytes/research/__init__.py index 47b720d78..31db4f282 100644 --- a/bitsandbytes/research/__init__.py +++ b/bitsandbytes/research/__init__.py @@ -1,6 +1,6 @@ from . import nn from .autograd._functions import ( - switchback_bnb, matmul_fp8_global, matmul_fp8_mixed, + switchback_bnb, ) diff --git a/bitsandbytes/research/autograd/_functions.py b/bitsandbytes/research/autograd/_functions.py index 883121759..21cbe9d75 100644 --- a/bitsandbytes/research/autograd/_functions.py +++ b/bitsandbytes/research/autograd/_functions.py @@ -1,21 +1,19 @@ +from functools import reduce # Required in Python 3 import operator +from typing import Optional import warnings -from dataclasses import dataclass -from functools import reduce # Required in Python 3 import torch -import bitsandbytes.functional as F - -from bitsandbytes.autograd._functions import MatmulLtState, GlobalOutlierPooler +from bitsandbytes.autograd._functions import GlobalOutlierPooler, MatmulLtState from bitsandbytes.cextension import HIP_ENVIRONMENT +import bitsandbytes.functional as F # math.prod not compatible with python < 3.8 def prod(iterable): return reduce(operator.mul, iterable, 1) -tensor = torch.Tensor class MatMulFP8Mixed(torch.autograd.Function): # forward is the same, but we added the fallback for pre-turing GPUs @@ -187,7 +185,9 @@ def backward(ctx, grad_output): class SwitchBackBnb(torch.autograd.Function): @staticmethod - def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): + # TODO: the B008 on the line below is a likely bug; the current implementation will + # have each SwitchBackBnb instance share a single MatmulLtState instance!!! + def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): # noqa: B008 # default to pytorch behavior if inputs are empty ctx.is_empty = False if prod(A.shape) == 0: @@ -393,19 +393,38 @@ def get_block_sizes(input_matrix, weight_matrix): return bsz, bsz2 -def matmul_fp8_global(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bsz : int = -1, bsz2 : int = -1): + +def matmul_fp8_global( + A: torch.Tensor, + B: torch.Tensor, + fw_code: torch.Tensor, + bw_code: torch.Tensor, + out: Optional[torch.Tensor] = None, + bsz: int = -1, + bsz2: int = -1, +): if bsz == -1 or bsz2 == -1: bsz, bsz2 = get_block_sizes(A, B) return MatMulFP8Global.apply(A, B, out, fw_code, bw_code, bsz, bsz2) -def matmul_fp8_mixed(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bsz : int = -1, bsz2 : int = -1): + +def matmul_fp8_mixed( + A: torch.Tensor, + B: torch.Tensor, + fw_code: torch.Tensor, + bw_code: torch.Tensor, + out: Optional[torch.Tensor] = None, + bsz: int = -1, + bsz2: int = -1, +): if bsz == -1 or bsz2 == -1: bsz, bsz2 = get_block_sizes(A, B) return MatMulFP8Mixed.apply(A, B, out, fw_code, bw_code, bsz, bsz2) + def switchback_bnb( - A: tensor, - B: tensor, - out: tensor = None, - state: MatmulLtState = None, + A: torch.Tensor, + B: torch.Tensor, + out: Optional[torch.Tensor] = None, + state: Optional[MatmulLtState] = None, threshold=0.0, bias=None ): diff --git a/bitsandbytes/research/nn/__init__.py b/bitsandbytes/research/nn/__init__.py index 8faec10bb..417011218 100644 --- a/bitsandbytes/research/nn/__init__.py +++ b/bitsandbytes/research/nn/__init__.py @@ -1 +1 @@ -from .modules import LinearFP8Mixed, LinearFP8Global +from .modules import LinearFP8Global, LinearFP8Mixed diff --git a/bitsandbytes/research/nn/modules.py b/bitsandbytes/research/nn/modules.py index 2a46b40c3..7fca34d23 100644 --- a/bitsandbytes/research/nn/modules.py +++ b/bitsandbytes/research/nn/modules.py @@ -1,12 +1,9 @@ -from typing import Optional, TypeVar, Union, overload +from typing import TypeVar import torch -import torch.nn.functional as F -from torch import Tensor, device, dtype, nn +from torch import nn import bitsandbytes as bnb -from bitsandbytes.optim import GlobalOptimManager -from bitsandbytes.utils import OutlierTracer, find_outlier_dims T = TypeVar("T", bound="torch.nn.Module") diff --git a/bitsandbytes/triton/dequantize_rowwise.py b/bitsandbytes/triton/dequantize_rowwise.py index e092680b8..daa59da9c 100644 --- a/bitsandbytes/triton/dequantize_rowwise.py +++ b/bitsandbytes/triton/dequantize_rowwise.py @@ -1,6 +1,7 @@ import math + import torch -import time + from bitsandbytes.triton.triton_utils import is_triton_available if not is_triton_available(): @@ -9,7 +10,6 @@ def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): return None import triton import triton.language as tl - from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time # rowwise quantize diff --git a/bitsandbytes/triton/int8_matmul_mixed_dequantize.py b/bitsandbytes/triton/int8_matmul_mixed_dequantize.py index b0961f558..1b80ab1a0 100644 --- a/bitsandbytes/triton/int8_matmul_mixed_dequantize.py +++ b/bitsandbytes/triton/int8_matmul_mixed_dequantize.py @@ -1,4 +1,5 @@ import torch + from bitsandbytes.triton.triton_utils import is_triton_available if not is_triton_available(): @@ -57,7 +58,8 @@ def get_configs_io_bound(): triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2), - ] + get_configs_io_bound(), + *get_configs_io_bound(), + ], key=['M', 'N', 'K'], prune_configs_by={ 'early_config_prune': early_config_prune, diff --git a/bitsandbytes/triton/int8_matmul_rowwise_dequantize.py b/bitsandbytes/triton/int8_matmul_rowwise_dequantize.py index 33f4d13f2..4881e1468 100644 --- a/bitsandbytes/triton/int8_matmul_rowwise_dequantize.py +++ b/bitsandbytes/triton/int8_matmul_rowwise_dequantize.py @@ -57,7 +57,8 @@ def get_configs_io_bound(): triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2), - ] + get_configs_io_bound(), + *get_configs_io_bound(), + ], key=['M', 'N', 'K'], prune_configs_by={ 'early_config_prune': early_config_prune, @@ -118,7 +119,7 @@ def _int8_matmul_rowwise_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, acc += tl.dot(a, b) A += BLOCK_K * SPLIT_K * stride_ak B += BLOCK_K * SPLIT_K * stride_bk - + acc = (w_factor * (x_factor * (acc * divfactor))) acc = acc.to(C.dtype.element_ty) diff --git a/bitsandbytes/triton/quantize_columnwise_and_transpose.py b/bitsandbytes/triton/quantize_columnwise_and_transpose.py index 54220d95a..e7961cf53 100644 --- a/bitsandbytes/triton/quantize_columnwise_and_transpose.py +++ b/bitsandbytes/triton/quantize_columnwise_and_transpose.py @@ -1,6 +1,7 @@ import math + import torch -import time + from bitsandbytes.triton.triton_utils import is_triton_available if not is_triton_available(): @@ -9,7 +10,6 @@ def quantize_columnwise_and_transpose(x: torch.Tensor): return None import triton import triton.language as tl - from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time # This kernel does fused columnwise quantization and transpose. @@ -54,7 +54,7 @@ def _quantize_columnwise_and_transpose( max_val = tl.max(tl.where(p2_arange_mask, abs_x, 0), axis=0) output = tl.libdevice.llrint(127. * (x / max_val)) - new_start = pid * M + new_start = pid * M new_offsets = new_start + p2_arange tl.store(output_ptr + new_offsets, output, mask=p2_arange_mask) tl.store(output_maxs + pid, max_val) @@ -71,4 +71,3 @@ def quantize_columnwise_and_transpose(x: torch.Tensor): grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) _quantize_columnwise_and_transpose[grid](x, output, output_maxs, n_elements, M, N, BLOCK_SIZE=M, P2=P2) return output, output_maxs - diff --git a/bitsandbytes/triton/quantize_global.py b/bitsandbytes/triton/quantize_global.py index 845db6ecd..a73a5bbaa 100644 --- a/bitsandbytes/triton/quantize_global.py +++ b/bitsandbytes/triton/quantize_global.py @@ -1,6 +1,6 @@ -import math + import torch -import time + from bitsandbytes.triton.triton_utils import is_triton_available if not is_triton_available(): @@ -10,7 +10,6 @@ def quantize_global(x: torch.Tensor): return None import triton import triton.language as tl - from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time # global quantize @triton.autotune( diff --git a/bitsandbytes/triton/quantize_rowwise.py b/bitsandbytes/triton/quantize_rowwise.py index 26d218321..078f4aa2d 100644 --- a/bitsandbytes/triton/quantize_rowwise.py +++ b/bitsandbytes/triton/quantize_rowwise.py @@ -1,6 +1,6 @@ import math + import torch -import time from bitsandbytes.triton.triton_utils import is_triton_available @@ -10,7 +10,6 @@ def quantize_rowwise(x: torch.Tensor): return None import triton import triton.language as tl - from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time # rowwise quantize @@ -47,7 +46,7 @@ def _quantize_rowwise( offsets = block_start + arange row_mask = arange < BLOCK_SIZE x = tl.load(x_ptr + offsets, mask=row_mask) - + abs_x = tl.abs(x) max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0) output = tl.libdevice.llrint(127. * (x / max_val)) @@ -65,4 +64,3 @@ def quantize_rowwise(x: torch.Tensor): grid = lambda meta: (x.shape[0],) _quantize_rowwise[grid](x, output, output_maxs, n_elements, BLOCK_SIZE=x.shape[1], P2=P2) return output, output_maxs - diff --git a/bitsandbytes/triton/triton_utils.py b/bitsandbytes/triton/triton_utils.py index c74c23962..6bbdbf1c1 100644 --- a/bitsandbytes/triton/triton_utils.py +++ b/bitsandbytes/triton/triton_utils.py @@ -1,4 +1,5 @@ import importlib + def is_triton_available(): return importlib.util.find_spec("triton") is not None diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py index 48373a1fe..82bf65d79 100644 --- a/bitsandbytes/utils.py +++ b/bitsandbytes/utils.py @@ -1,9 +1,11 @@ import json import shlex import subprocess -import torch from typing import Tuple +import torch + + def outlier_hook(module, input): assert isinstance(module, torch.nn.Linear) tracer = OutlierTracer.get_instance() @@ -37,7 +39,7 @@ def outlier_hook(module, input): hook.remove() -class OutlierTracer(object): +class OutlierTracer: _instance = None def __init__(self): diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index e28e7b2c2..270f21333 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -1,6 +1,6 @@ #include -#include #include +#include using namespace BinSearch; @@ -32,7 +32,7 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long { long long valid_chunks = num_blocks - offset >= thread_wave_size ? thread_wave_size : num_blocks - offset; pthread_t *threads = (pthread_t *) malloc(sizeof(pthread_t) * valid_chunks); - + struct quantize_block_args **args = (quantize_block_args **) malloc(valid_chunks * sizeof(quantize_block_args *)); for(long long i = 0; i < valid_chunks; i++) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 9ebe0a69e..f4673359b 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -110,7 +110,7 @@ __device__ float dDequantizeFP4Tree(unsigned char val, float absmax) return 1.00000000f*absmax*sign; // 1011 else return 0.66666667f*absmax*sign; // 1010 - else + else if((val & 0b0001) == 1) // 100 return 5.208333333e-03f*absmax*sign; // 1001 else @@ -134,10 +134,10 @@ __device__ unsigned char dQuantizeFP4(float x) // we do a binary search // the pivots are divided by 12 (the FP4 absmax) - // since we assum input data is in [-1.0, 1.0] + // since we assume input data is in [-1.0, 1.0] // !be careful here, its easy to make a mistake - // that is difficult to noice if you add an extra + // that is difficult to notice if you add an extra // zero somewhere! int sign = x < 0 ? 0b1000 : 0b0000; @@ -174,36 +174,36 @@ __device__ half dhDequantizeNF4(unsigned char val) if((val & 0b0100) == 4) // 1 if((val & 0b0010) == 2) // 11 if((val & 0b0001) == 1) // 111 - return 1.0f; + return 1.0f; else return 0.7229568362236023f; else if((val & 0b0001) == 1) // 110 - return 0.5626170039176941f; + return 0.5626170039176941f; else - return 0.44070982933044434f; + return 0.44070982933044434f; else if((val & 0b0010) == 2) //10 if((val & 0b0001) == 1) // 101 - return 0.33791524171829224f; + return 0.33791524171829224f; else - return 0.24611230194568634f; - else + return 0.24611230194568634f; + else if((val & 0b0001) == 1) // 100 - return 0.16093020141124725f; + return 0.16093020141124725f; else - return 0.07958029955625534f; + return 0.07958029955625534f; else if((val & 0b0100) == 4) // 0 if((val & 0b0010) == 2) //01 if((val & 0b0001) == 1) // 011 - return 0.0f; + return 0.0f; else - return -0.09105003625154495f; + return -0.09105003625154495f; else if((val & 0b0001) == 1) // 010 - return -0.18477343022823334f; + return -0.18477343022823334f; else return -0.28444138169288635f; else @@ -211,12 +211,12 @@ __device__ half dhDequantizeNF4(unsigned char val) if((val & 0b0001) == 1) // 001 return -0.39491748809814453f; else - return -0.5250730514526367f; - else + return -0.5250730514526367f; + else if((val & 0b0001) == 1) // 000 - return -0.6961928009986877f; + return -0.6961928009986877f; else - return -1.0f; + return -1.0f; } @@ -229,36 +229,36 @@ __device__ float dDequantizeNF4(unsigned char val) if((val & 0b0100) == 4) // 1 if((val & 0b0010) == 2) // 11 if((val & 0b0001) == 1) // 111 - return 1.0f; + return 1.0f; else return 0.7229568362236023f; else if((val & 0b0001) == 1) // 110 - return 0.5626170039176941f; + return 0.5626170039176941f; else - return 0.44070982933044434f; + return 0.44070982933044434f; else if((val & 0b0010) == 2) //10 if((val & 0b0001) == 1) // 101 - return 0.33791524171829224f; + return 0.33791524171829224f; else - return 0.24611230194568634f; - else + return 0.24611230194568634f; + else if((val & 0b0001) == 1) // 100 - return 0.16093020141124725f; + return 0.16093020141124725f; else - return 0.07958029955625534f; + return 0.07958029955625534f; else if((val & 0b0100) == 4) // 0 if((val & 0b0010) == 2) //01 if((val & 0b0001) == 1) // 011 - return 0.0f; + return 0.0f; else - return -0.09105003625154495f; + return -0.09105003625154495f; else if((val & 0b0001) == 1) // 010 - return -0.18477343022823334f; + return -0.18477343022823334f; else return -0.28444138169288635f; else @@ -266,12 +266,12 @@ __device__ float dDequantizeNF4(unsigned char val) if((val & 0b0001) == 1) // 001 return -0.39491748809814453f; else - return -0.5250730514526367f; - else + return -0.5250730514526367f; + else if((val & 0b0001) == 1) // 000 - return -0.6961928009986877f; + return -0.6961928009986877f; else - return -1.0f; + return -1.0f; } @@ -654,6 +654,8 @@ __global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const f for(int j = threadIdx.x; j < BLOCK_ESTIMATE; j+=blockDim.x) temp_storage.smem_qidx[j] = -1; + __syncthreads(); + if(threadIdx.x < 256) { float q_interval = (1.0f-(2.0f*offset))/255.0f; @@ -1863,7 +1865,7 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char //float ratio = (g_val*g_val)/fmaxf(s2_vals[j], eps*eps); //g_val = ratio > 2.0f ? 2.0f*g_val/ratio : g_val; g_val *= gnorm_scale; - + s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val)); s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE]; @@ -2259,8 +2261,8 @@ template__global__ void kd // data is in 32 column-tile major with tile width 32 columns and numRows rows // L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory. - // L2. Load data in warp-striped arangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3]) + // L2. Load data in warp-striped arrangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3]) // C1. Compute val(row_stat*col_stat)/(127*127) (load 1/(127*127 into register)) // C2. Compute normalization values and store col values in register // S1. Store C1 into 16-bit output @@ -2383,7 +2385,7 @@ template __global__ void kd if(valid_items <= 0) // the sub-tile might have more elements than the tile itself break; - // L2. Load data in warp-striped arangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3]) + // L2. Load data in warp-striped arrangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3]) LoadInt32(loadint32).Load(&(A[subtile_idx]), local_values, valid_items, 0); ExchangeInt32(exchangeint32).BlockedToWarpStriped(local_values, local_values); @@ -2650,7 +2652,7 @@ template __global__ void kExtractOutliers(char *A, int *idx, char * //// use k warps per thread block //// 1. threadblock use read-only cache to read in register tile for A into shared memory //// 2. each warp loops over shared memory tiles of A of size 8x16 and loads them into fragments -//// 3. each warp reads a segment of values 16x32 from B +//// 3. each warp reads a segment of values 16x32 from B //// 4. do dequantization from register of B into second pair of registers //// 5. store (4) into fragment //// 6. matmul aggregate into fragment C -//// 7. aggreecate files of C into shared memroy block C +//// 7. aggregate files of C into shared memory block C //// 8. sum (7) //// 9. write outputs to matmul output matrix //} @@ -3531,7 +3533,7 @@ template __global__ void kgemm_4bit_inference(int M, i template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize) { - // per threadblock: + // per threadblock: // load step-by-step in chunks of [32,warps]: 1x32 * [32,warps] -> [1,warps] // 4 warps -> 4 loads per iter // 1x32 * 32x4 -> 1x4 outputs per thread block @@ -3764,7 +3766,7 @@ template __global__ void kfunc(T *A, T *B, T value, long { switch(FUNC) { - case FILL: + case FILL: A[i] = (T)value; break; case ARANGE: @@ -3821,12 +3823,12 @@ template __global__ void kgemm_4bit_inference_naive(int M, int N template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); diff --git a/csrc/kernels.hip.cpp b/csrc/kernels.hip.cpp new file mode 100644 index 000000000..c32cfdfc8 --- /dev/null +++ b/csrc/kernels.hip.cpp @@ -0,0 +1,4108 @@ +#include "hip/hip_runtime.h" +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include "kernels.hip.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#define __syncwarp __syncthreads //TODO: HIP doesn't have this so just sync threads + +//#include +#include + +#define HLF_MAX 65504 +#define TH 1024 +#define NUM 4 +#define NUM_BLOCK 4096 + + +// source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda +// Luckily we have atomicmax and atomicmin in ROCm + +__device__ float dDequantizeFP4(unsigned char val, float absmax) +{ + float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f; + if((val & 0b0110) == 0) + { + // subnormal + if((val & 0b0001) == 0) + return 0.0f; + else + return sign*0.0625f*absmax; + } + else + { + // normal + float exponent = ((val & 0b0100) == 4 ? 2.0f : 8.0f) + ((val & 0b0010) == 2 ? 0.0f : 2.0f); + float fraction = (val & 0b0001) == 1 ? 1.5f : 1.0f; + + return sign*exponent*fraction*absmax; + } +} + +__device__ float d2DequantizeFP4(unsigned char val) +{ + float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f; + if((val & 0b0110) == 0) + { + // subnormal + if((val & 0b0001) == 0) + return 0.0f; + else + return sign*0.0625f; + } + else + { + // normal + float exponent = ((val & 0b0100) == 4 ? 2.0f : 8.0f) + ((val & 0b0010) == 2 ? 0.0f : 2.0f); + float fraction = (val & 0b0001) == 1 ? 1.5f : 1.0f; + + return sign*exponent*fraction; + } +} + +__device__ float dDequantizeFP4Tree(unsigned char val, float absmax) +{ + float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f; + if((val & 0b0100) == 4) // 0 + if((val & 0b0010) == 2) //01 + if((val & 0b0001) == 1) // 111 + return 0.25000000f*absmax*sign; // 1111 + else + return 0.16666667f*absmax*sign; // 1110 + else + if((val & 0b0001) == 1) // 110 + return 0.50000000f*absmax*sign; // 1101 + else + return 0.33333333f*absmax*sign; // 1100 + else + if((val & 0b0010) == 2) //10 + if((val & 0b0001) == 1) // 101 + return 1.00000000f*absmax*sign; // 1011 + else + return 0.66666667f*absmax*sign; // 1010 + else + if((val & 0b0001) == 1) // 100 + return 5.208333333e-03f*absmax*sign; // 1001 + else + return 0.00000000f*absmax*sign; // 1000 +} + +__device__ unsigned char dQuantizeFP4(float x) +{ + // FP4 with bias of 3 + // first bit is a sign + // subnormals + // 0b000 = 0 + // 0b001 = 0.0625 + // 0b110 = 2 + // 0b111 = 3 + // 0b100 = 4 + // 0b101 = 6 + // 0b010 = 8 + // 0b011 = 12 + + + // we do a binary search + // the pivots are divided by 12 (the FP4 absmax) + // since we assum input data is in [-1.0, 1.0] + + // !be careful here, its easy to make a mistake + // that is difficult to noice if you add an extra + // zero somewhere! + + int sign = x < 0 ? 0b1000 : 0b0000; + x = fabsf(x); + if(x > 0.29166667f) + if( x > 0.583333f) + if( x > 0.8333333f) + return 0b0011+sign; + else + return 0b0010+sign; + else + if(x > 0.4166667f) + return 0b101+sign; + else + return 0b100+sign; + else + if(x > 0.0859375f) + if(x > 0.20833333f) + return 0b0111+sign; + else + return 0b0110+sign; + else + if(x > 0.00260417f) + return 0b0001+sign; + else + return 0b0000+sign; +} + +__device__ half dhDequantizeNF4(unsigned char val) +{ + // the values for this tree was generated by test_normal_map_tree + // in the file tests/test_functional.py + if((val & 0b1000) == 8) + if((val & 0b0100) == 4) // 1 + if((val & 0b0010) == 2) // 11 + if((val & 0b0001) == 1) // 111 + return 1.0f; + else + return 0.7229568362236023f; + else + if((val & 0b0001) == 1) // 110 + return 0.5626170039176941f; + else + return 0.44070982933044434f; + else + if((val & 0b0010) == 2) //10 + if((val & 0b0001) == 1) // 101 + return 0.33791524171829224f; + else + return 0.24611230194568634f; + else + if((val & 0b0001) == 1) // 100 + return 0.16093020141124725f; + else + return 0.07958029955625534f; + + else + if((val & 0b0100) == 4) // 0 + if((val & 0b0010) == 2) //01 + if((val & 0b0001) == 1) // 011 + return 0.0f; + else + return -0.09105003625154495f; + else + if((val & 0b0001) == 1) // 010 + return -0.18477343022823334f; + else + return -0.28444138169288635f; + else + if((val & 0b0010) == 2) //00 + if((val & 0b0001) == 1) // 001 + return -0.39491748809814453f; + else + return -0.5250730514526367f; + else + if((val & 0b0001) == 1) // 000 + return -0.6961928009986877f; + else + return -1.0f; + +} + +__device__ float dDequantizeNF4(unsigned char val) +{ + + // the values for this tree was generated by test_normal_map_tree + // in the file tests/test_functional.py + if((val & 0b1000) == 8) + if((val & 0b0100) == 4) // 1 + if((val & 0b0010) == 2) // 11 + if((val & 0b0001) == 1) // 111 + return 1.0f; + else + return 0.7229568362236023f; + else + if((val & 0b0001) == 1) // 110 + return 0.5626170039176941f; + else + return 0.44070982933044434f; + else + if((val & 0b0010) == 2) //10 + if((val & 0b0001) == 1) // 101 + return 0.33791524171829224f; + else + return 0.24611230194568634f; + else + if((val & 0b0001) == 1) // 100 + return 0.16093020141124725f; + else + return 0.07958029955625534f; + + else + if((val & 0b0100) == 4) // 0 + if((val & 0b0010) == 2) //01 + if((val & 0b0001) == 1) // 011 + return 0.0f; + else + return -0.09105003625154495f; + else + if((val & 0b0001) == 1) // 010 + return -0.18477343022823334f; + else + return -0.28444138169288635f; + else + if((val & 0b0010) == 2) //00 + if((val & 0b0001) == 1) // 001 + return -0.39491748809814453f; + else + return -0.5250730514526367f; + else + if((val & 0b0001) == 1) // 000 + return -0.6961928009986877f; + else + return -1.0f; + +} + +__device__ unsigned char dQuantizeNF4(float x) +{ + + // the values for this tree was generated by test_normal_map_tree + // in the file tests/test_functional.py + if(x > 0.03979014977812767f) + if(x > 0.3893125355243683f) // 1 + if(x > 0.6427869200706482f) // 11 + if(x > 0.8614784181118011f) // 111 + return 0b1111; + else + return 0b1110; + else + if(x > 0.5016634166240692f) // 110 + return 0b1101; + else + return 0b1100; + else + if(x > 0.2035212516784668f) // 10 + if(x > 0.2920137718319893f) // 101 + return 0b1011; + else + return 0b1010; + else + if(x > 0.1202552504837513f) // 100 + return 0b1001; + else + return 0b1000; + else + if(x > -0.33967943489551544f) // 0 + if(x > -0.13791173323988914f) // 01 + if(x > -0.045525018125772476f) // 011 + return 0b0111; + else + return 0b0110; + else + if(x > -0.23460740596055984f) // 010 + return 0b0101; + else + return 0b0100; + else + if(x > -0.6106329262256622f) // 00 + if(x > -0.4599952697753906f) // 001 + return 0b0011; + else + return 0b0010; + else + if(x > -0.8480964004993439f) // 000 + return 0b0001; + else + return 0b0000; +} +// sign function for lion +// taken from https://stackoverflow.com/a/4609795, but not sure if there's a proper way to do this in CUDA + +template __device__ int sgn(T val) +{ + return (T(0) < val) - (val < T(0)); +} + +template +__device__ unsigned char dQuantize(float* smem_code, const float rand, float x) +{ + int pivot = 127; + int upper_pivot = 255; + int lower_pivot = 0; + + float lower = -1.0f; + float upper = 1.0f; + + float val = smem_code[pivot]; + // i>>=1 = {32, 16, 8, 4, 2, 1} + for(int i = 64; i > 0; i>>=1) + { + if(x > val) + { + lower_pivot = pivot; + lower = val; + pivot+=i; + } + else + { + upper_pivot = pivot; + upper = val; + pivot-=i; + } + val = smem_code[pivot]; + } + + if(upper_pivot == 255) + upper = smem_code[upper_pivot]; + if(lower_pivot == 0) + lower = smem_code[lower_pivot]; + + if(!STOCHASTIC) + { + if(x > val) + { + float midpoint = (upper+val)*0.5f; + if(x > midpoint) + { + return upper_pivot; + } + else + return pivot; + } + else + { + float midpoint = (lower+val)*0.5f; + if(x < midpoint) + return lower_pivot; + else + return pivot; + } + } + else + { + if(x > val) + { + float dist_to_upper = fabsf(upper-x); + float dist_full = upper-val; + if(rand >= dist_to_upper/dist_full) return upper_pivot; + else return pivot; + } + else + { + float dist_to_lower = fabsf(lower-x); + float dist_full = val-lower; + if(rand >= dist_to_lower/dist_full) return lower_pivot; + else return pivot; + } + } +} + +template +__device__ __forceinline__ unsigned char quantize_2D(float *__restrict__ quadrants, float *__restrict__ const smem_code, float x) +{ + int pivot = 127; + int upper_pivot = 255; + int lower_pivot = 0; + + float lower = SIGNED ? -1.0f : 0.0f; + float upper = 1.0f; + float midpoint; + float val = quadrants[1]; + int local_pivot = 1; + int offset = 1; + + // i>>=1 = {32, 16, 8, 4, 2, 1} + for(int i = 64; i > 0; i>>=1) + { + if(x > val) + { + lower_pivot = pivot; + lower = val; + pivot+=i; + //val = i == 64 ? quadrants[2] : smem_code[pivot]; + local_pivot += offset; + } + else + { + upper_pivot = pivot; + upper = val; + pivot-=i; + //val = i == 64 ? quadrants[0] : smem_code[pivot]; + local_pivot -= offset; + } + val = i >= 64 ? quadrants[local_pivot] : smem_code[pivot]; + offset -= 1; + } + + if(x > val) + { + midpoint = (upper+val)*0.5f; + if(x > midpoint) + return upper_pivot; + else + return pivot; + } + else + { + midpoint = (lower+val)*0.5f; + if(x < midpoint) + return lower_pivot; + else + return pivot; + } +} + +template +__device__ __forceinline__ unsigned char quantize_quadrant(int QUADRANT, float *__restrict__ const smem_code, float x, float lower, float midpoint, float upper) +{ + int lower_pivot = QUADRANT*16-1 - 0; + int pivot = QUADRANT*16-1 + 16; + int upper_pivot = QUADRANT*16-1 + 31; + + float val = midpoint; + + // i>>=1 = {32, 16, 8, 4, 2, 1} + for(int i = 16; i > 0; i>>=1) + { + if(x > val) + { + lower_pivot = pivot; + lower = val; + pivot+=i; + } + else + { + upper_pivot = pivot; + upper = val; + pivot-=i; + } + val = smem_code[pivot]; + } + + if(x > val) + { + midpoint = (upper+val)*0.5f; + if(x > midpoint) + return upper_pivot; + else + return pivot; + } + else + { + midpoint = (lower+val)*0.5f; + if(x < midpoint) + return lower_pivot; + else + return pivot; + } +} + +__global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n) +{ + const int tid = threadIdx.x + (blockDim.x*blockIdx.x); + const int numThreads = blockDim.x*gridDim.x; + + for(int i = tid; i < n; i+=numThreads) + { + int idx = (index1[i]*maxidx1) + index2[i]; + atomicAdd(&histogram[idx], src[i]); + } +} + +template +__global__ void kCompressMax(T * __restrict__ const A, T* out, unsigned char* out_idx, const int n) +{ + typedef hipcub::WarpReduce WarpReduce; + __shared__ typename WarpReduce::TempStorage temp_storage; + typedef hipcub::BlockLoad LoadT; + __shared__ typename LoadT::TempStorage loadt; + + const int warp_idx = threadIdx.x/32; + const int valid_items = n - (blockIdx.x*BLOCK_SIZE) > BLOCK_SIZE ? BLOCK_SIZE : n - (blockIdx.x*BLOCK_SIZE); + + // BLOCK_SIZE/32 == number of warps + __shared__ int smem_max_indices[8*BLOCK_SIZE/32]; + __shared__ float smem_max_values[8*BLOCK_SIZE/32]; + + T values[8]; + T max1 = -64000.0f; + T max2 = -64000.0f; + int max_idx1 = -1; + int max_idx2 = -1; + int sign1 = -1; + int sign2 = -1; + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + + LoadT(loadt).Load(&(A[(blockIdx.x*BLOCK_SIZE)]), values, valid_items, (T)0.0f); + #pragma unroll 8 + for(int i = 0; i < 8; i++) + { + T absval = fabsf(values[i]); + if(absval > max1) + { + max1 = values[i]; + sign1 = signbit(values[i]); + max_idx1 = 8*threadIdx.x + i; + } + else if(absval > max2) + { + max2 = values[i]; + sign2 = signbit(values[i]); + max_idx2 = 8*threadIdx.x + i; + } + } + + float warp_max; + for(int i = 0; i < 8; i++) + { + // 3. do warp reduction + broadcast back + warp_max = WarpReduce(temp_storage).Reduce(max1, hipcub::Max()); + warp_max = hipcub::ShuffleIndex<32>(warp_max, 0, 0xffffffff); + + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + if(warp_max == max1) + { + smem_max_values[warp_idx*8 + i] = sign1 != 0 ? -max1 : max1; + smem_max_indices[warp_idx*8 + i] = max_idx1; + + sign1 = sign2; + max1 = max2; + max_idx1 = max_idx2; + + max2 = -64000.0f; + } + //__syncwarp(); + } + + if(threadIdx.x % 32 < 8) + { + // offset: 8 values per 256 input values + // + int offset = BLOCK_SIZE*blockIdx.x*BLOCK_SIZE/32*8; + } + +} + +#define THREADS_ESTIMATE 512 +#define NUM_ESTIMATE 8 +#define BLOCK_ESTIMATE 4096 + +template +__launch_bounds__(THREADS_ESTIMATE, 1) +__global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n) +{ + const int n_full = (BLOCK_ESTIMATE*(n/BLOCK_ESTIMATE)) + (n % BLOCK_ESTIMATE == 0 ? 0 : BLOCK_ESTIMATE); + int valid_items = (blockIdx.x+1 == gridDim.x) ? n - (blockIdx.x*BLOCK_ESTIMATE) : BLOCK_ESTIMATE; + const int base_idx = (blockIdx.x * BLOCK_ESTIMATE); + const float reciprocal_num_blocks = 1.0f/(n < 4096 ? 1.0f : (n/BLOCK_ESTIMATE)); + + T vals[NUM_ESTIMATE]; + + typedef hipcub::BlockRadixSort BlockRadixSort; + typedef hipcub::BlockLoad LoadFloat; + + __shared__ union { + typename LoadFloat::TempStorage loadf; + typename BlockRadixSort::TempStorage sort; + int smem_qidx[BLOCK_ESTIMATE]; + } temp_storage; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_ESTIMATE) + { + valid_items = n - i > BLOCK_ESTIMATE ? BLOCK_ESTIMATE : n - i; + + // do not process half-blocks + if(valid_items < BLOCK_ESTIMATE && n > BLOCK_ESTIMATE){ continue; } + + #pragma unroll 4 + for(int j = 0; j < NUM_ESTIMATE; j++) + vals[j] = max_val; + + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(A[i]), vals, valid_items); + + #pragma unroll 4 + for(int j = 0; j < NUM_ESTIMATE; j++) + vals[j] = ((float)vals[j]) * reciprocal_num_blocks; + + + __syncthreads(); + // sort into striped pattern to mitigate bank conflicts + // striped pattern index for thread 0 [0, 1024, 2048, 3096] + // striped pattern index for thread 1 [1, 1025, 2049, 3097] + BlockRadixSort(temp_storage.sort).SortBlockedToStriped(vals); + + __syncthreads(); + for(int j = threadIdx.x; j < BLOCK_ESTIMATE; j+=blockDim.x) + temp_storage.smem_qidx[j] = -1; + + __syncthreads(); + + if(threadIdx.x < 256) + { + float q_interval = (1.0f-(2.0f*offset))/255.0f; + int local_idx = round(((offset+(threadIdx.x*q_interval))*(valid_items-1))); + temp_storage.smem_qidx[local_idx] = threadIdx.x; + } + + __syncthreads(); + + for(int i = threadIdx.x; i < BLOCK_ESTIMATE; i+=blockDim.x) + { + if(temp_storage.smem_qidx[i] != -1) + atomicAdd(&code[temp_storage.smem_qidx[i]], vals[i/THREADS_ESTIMATE]); + } + } +} + + +__launch_bounds__(TH, 4) +__global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n) +{ + const int n_full = (NUM_BLOCK*(n/NUM_BLOCK)) + (n % NUM_BLOCK == 0 ? 0 : NUM_BLOCK); + int valid_items = (blockIdx.x+1 == gridDim.x) ? n - (blockIdx.x*NUM_BLOCK) : NUM_BLOCK; + const int base_idx = (blockIdx.x * NUM_BLOCK); + + float vals[NUM]; + unsigned char qvals[NUM]; + //const int lane_id = threadIdx.x % 2; + + typedef hipcub::BlockLoad LoadFloat; + typedef hipcub::BlockStore StoreChar; + + __shared__ typename LoadFloat::TempStorage loadf; + __shared__ typename StoreChar::TempStorage storec; + __shared__ float smem_code[256]; + //__shared__ float smem_code[2][257]; + + if(threadIdx.x < 256) + { + smem_code[threadIdx.x] = code[threadIdx.x]; + //smem_code[0][threadIdx.x] = code[threadIdx.x]; + //smem_code[1][threadIdx.x] = smem_code[0][threadIdx.x]; + } + + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_BLOCK) + { + // number of values already processed in blocks + + // number of values already processed in this block + + // rand_offset % mod value + valid_items = n - i > NUM_BLOCK ? NUM_BLOCK : n - i; + + __syncthreads(); + LoadFloat(loadf).Load(&(A[i]), vals, valid_items); + + + #pragma unroll 4 + for(int j = 0; j < NUM; j++) + qvals[j] = dQuantize<0>(smem_code, 0.0f, vals[j]); + + __syncthreads(); + StoreChar(storec).Store(&(out[i]), qvals, valid_items); + } +} + +template +//__launch_bounds__(TH, 4) +__global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n) +{ + const int n_full = gridDim.x * BLOCK_SIZE; + int valid_items = 0; + const int base_idx = (blockIdx.x * BLOCK_SIZE); + + T vals[NUM_PER_TH]; + float rand_vals[NUM_PER_TH]; + unsigned char qvals[(DATA_TYPE > 0) ? NUM_PER_TH/2 : NUM_PER_TH]; + //float local_abs_max = -FLT_MAX; + float local_abs_max = 0.0f; + int local_rand_idx = 0; + + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockStore 0) ? NUM_PER_TH/2 : NUM_PER_TH, hipcub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar; + typedef hipcub::BlockReduce BlockReduce; + typedef hipcub::BlockLoad LoadFloat; + + __shared__ typename LoadT::TempStorage loadt; + __shared__ typename LoadFloat::TempStorage loadf; + __shared__ typename StoreChar::TempStorage storec; + __shared__ typename BlockReduce::TempStorage reduce; + __shared__ float smem_code[256]; + __shared__ float smem_absmax_value[1]; + + if(DATA_TYPE == General8bit) + for(int i = threadIdx.x; i < 256; i+=blockDim.x) + smem_code[i] = code[i]; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) + { + valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i; + local_abs_max = -FLT_MAX; + + __syncthreads(); + LoadT(loadt).Load(&(A[i]), vals, valid_items, (T)0.0f); + + // 1. compute local max + // 2. broadcast local max + // 3. normalize inputs and quantize + + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j])); + + local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, hipcub::Max(), valid_items); + + if(threadIdx.x == 0) + smem_absmax_value[0] = local_abs_max; + + __syncthreads(); + + if(threadIdx.x == 0) + absmax[i/BLOCK_SIZE] = local_abs_max; + else + local_abs_max = smem_absmax_value[0]; + + //__syncwarp(); + + local_abs_max = 1.0f/local_abs_max; + + if(STOCHASTIC) + { + local_rand_idx = ((blockIdx.x*NUM_BLOCK) + (threadIdx.x*NUM) + rand_offset) % (1024-4); + LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0); + } + + unsigned char packed_4bit = 0; + switch(DATA_TYPE) + { + case General8bit: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + { + if(!STOCHASTIC) + qvals[j] = dQuantize<0>(smem_code, 0.0f, ((float)vals[j])*local_abs_max); + else + qvals[j] = dQuantize<1>(smem_code, rand_vals[j], ((float)vals[j])*local_abs_max); + } + break; + case FP4: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH/2; j++) + { + packed_4bit |= dQuantizeFP4(((float)vals[2*j])*local_abs_max) << 4; + packed_4bit |= dQuantizeFP4(((float)vals[2*j+1])*local_abs_max); + qvals[j] = packed_4bit; + } + break; + case NF4: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH/2; j++) + { + packed_4bit |= dQuantizeNF4(((float)vals[2*j])*local_abs_max) << 4; + packed_4bit |= dQuantizeNF4(((float)vals[2*j+1])*local_abs_max); + qvals[j] = packed_4bit; + } + break; + } + + __syncthreads(); + StoreChar(storec).Store(&(out[(DATA_TYPE > 0) ? i/2 : i]), qvals, (DATA_TYPE > 0) ? (valid_items+1)/2 : valid_items); + } +} + +template +__global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n) +{ + + const int n_load = (gridDim.x * TILE_SIZE); + int valid_items_load = 0; + int valid_items_store = 0; + const int base_idx = (blockIdx.x * TILE_SIZE); + + T vals[NUM_PER_TH*((DATA_TYPE > 0) ? 2 : 1)]; + unsigned char qvals[NUM_PER_TH]; + float local_abs_max = -FLT_MAX; + + typedef hipcub::BlockLoad LoadChar; + typedef hipcub::BlockStore 0) ? 2 : 1), hipcub::BLOCK_STORE_WARP_TRANSPOSE> StoreT; + + __shared__ typename LoadChar::TempStorage loadchar; + __shared__ typename StoreT::TempStorage storet; + + for (unsigned int i = base_idx; i < n_load; i += gridDim.x*TILE_SIZE) + { + if(DATA_TYPE > 0) + { + valid_items_load = (n+1)/2 - i > TILE_SIZE ? TILE_SIZE : (n+1)/2 - i; + valid_items_store = n - i*2 > TILE_SIZE*2 ? TILE_SIZE*2 : n - i*2; + } + else + { + valid_items_load = n - i > TILE_SIZE ? TILE_SIZE : n - i; + valid_items_store = n - i > TILE_SIZE ? TILE_SIZE : n - i; + } + local_abs_max = __ldg(&absmax[(i+threadIdx.x*NUM_PER_TH)/(blocksize)]); + + __syncthreads(); + LoadChar(loadchar).Load(&(A[i]), qvals, valid_items_load, 128); + + switch(DATA_TYPE) + { + case General8bit: + // load code through read-only cache via __ldg + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + vals[j] = __ldg(&code[qvals[j]])*local_abs_max; + break; + case FP4: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + { + vals[j*2] = dDequantizeFP4Tree(qvals[j] >> 4, local_abs_max); + vals[j*2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F, local_abs_max); + } + break; + case NF4: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + { + vals[j*2] = dDequantizeNF4(qvals[j] >> 4)* local_abs_max; + vals[j*2 + 1] = dDequantizeNF4(qvals[j] & 0x0F)* local_abs_max; + } + break; + } + + __syncthreads(); + StoreT(storet).Store(&(out[(DATA_TYPE > 0) ? i*2 : i]), vals, valid_items_store); + } +} + +__global__ void kDequantize(float *code, unsigned char *A, float *out, const int n) +{ + const unsigned int numThreads = blockDim.x * gridDim.x; + const int idx = (blockIdx.x * blockDim.x) + threadIdx.x; + + __shared__ float smem_code[256]; + if(threadIdx.x < 256) + { + smem_code[threadIdx.x] = code[threadIdx.x]; + } + + __syncthreads(); + + for (int i = idx;i < n; i += numThreads) + { + out[i] = smem_code[A[i]]; + } +} + + + +template +__launch_bounds__(BLOCK_SIZE/NUM_VALS, 1) +__global__ void kPreconditionOptimizer32bit2State(T* g, T* p, + float* state1, float* state2, float *unorm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const int n) +{ + + const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); + const int base_idx = (blockIdx.x * blockDim.x * NUM_VALS); + int valid_items = 0; + + T g_vals[NUM_VALS]; + + float s1_vals[NUM_VALS]; + float s2_vals[NUM_VALS]; + + const float correction1 = 1.0f/(1.0f - powf(beta1, step)); + const float correction2 = 1.0f/(1.0f - powf(beta2, step)); + + typedef hipcub::BlockLoad Load; + typedef hipcub::BlockLoad LoadFloat; + typedef hipcub::BlockReduce BlockReduce; + + __shared__ union { + typename Load::TempStorage load; + typename LoadFloat::TempStorage loadf; + typename BlockReduce::TempStorage reduce; + } temp_storage; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) + { + valid_items = n - i >= (BLOCK_SIZE) ? (BLOCK_SIZE) : n - i; + + __syncthreads(); + Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items, 0.0f); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items, 0.0f); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items, 0.0f); + + # pragma unroll NUM_VALS + for(unsigned int j = 0; j < NUM_VALS; j++) + g_vals[j] = gnorm_scale*((float)g_vals[j]); + + # pragma unroll NUM_VALS + for(unsigned int j = 0; j < NUM_VALS; j++) + { + switch(OPTIMIZER) + { + case ADAM: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j])); + s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j]))); + s1_vals[j] *= correction1; + s2_vals[j] *= correction2; + s1_vals[j] = s1_vals[j]/(sqrtf(s2_vals[j])+eps); // update + s1_vals[j] *= s1_vals[j]; // update l2 norm (update*update) + break; + } + } + + # pragma unroll NUM_VALS-1 + for(unsigned int j = 1; j < NUM_VALS; j++) + s1_vals[0] += s1_vals[j]; + + __syncthreads(); + s1_vals[0] = BlockReduce(temp_storage.reduce).Sum(s1_vals[0]); + + if(threadIdx.x == 0) + atomicAdd(&unorm[0], s1_vals[0]); + + //__syncwarp(); + } +} + + + +#define NUM_PER_THREAD 4 + +template +__launch_bounds__(TH, 1) +__global__ void kOptimizer32bit2State(T* g, T* p, + float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n) +{ + + const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD)); + const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD); + int valid_items = 0; + float update_scale = 0.0f; + T g_vals[NUM_PER_THREAD]; + T p_vals[NUM_PER_THREAD]; + + float s1_vals[NUM_PER_THREAD]; + float s2_vals[NUM_PER_THREAD]; + + const float correction1 = 1.0f - powf(beta1, step); + const float correction2 = sqrtf(1.0f - powf(beta2, step)); + const float step_size = -lr*correction2/correction1; + + if(max_unorm > 0.0f) + { + update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; + if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; } + else{ update_scale = 1.0f; } + } + else{ update_scale = 1.0f; } + + typedef hipcub::BlockLoad Load; + typedef hipcub::BlockStore Store; + + typedef hipcub::BlockLoad LoadFloat; + typedef hipcub::BlockStore StoreFloat; + + __shared__ union { + typename Load::TempStorage load; + typename Store::TempStorage store; + typename LoadFloat::TempStorage loadf; + typename StoreFloat::TempStorage storef; + } temp_storage; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*TH*NUM_PER_THREAD) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + + __syncthreads(); + Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items); + __syncthreads(); + Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items); + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD; j++) + g_vals[j] = gnorm_scale*((float)g_vals[j]); + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD; j++) + { + switch(OPTIMIZER) + { + case ADAM: + if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + { + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j])); + s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j]))); + p_vals[j] = ((float)p_vals[j]) + (update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(eps*correction2)))); + + if(weight_decay > 0.0f) + p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)); + } + break; + } + } + + __syncthreads(); + Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items); + __syncthreads(); + StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items); + __syncthreads(); + StoreFloat(temp_storage.storef).Store(&(state2[i]), s2_vals, valid_items); + } +} + +template +__launch_bounds__(BLOCK_SIZE/NUM_VALS, 1) +__global__ void kPreconditionOptimizer32bit1State(T* g, T* p, + float* state1, float *unorm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const int n) +{ + + const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); + const int base_idx = (blockIdx.x * blockDim.x * NUM_VALS); + int valid_items = 0; + + T g_vals[NUM_VALS]; + + float s1_vals[NUM_VALS]; + + typedef hipcub::BlockLoad Load; + typedef hipcub::BlockLoad LoadFloat; + typedef hipcub::BlockReduce BlockReduce; + + __shared__ union { + typename Load::TempStorage load; + typename LoadFloat::TempStorage loadf; + typename BlockReduce::TempStorage reduce; + } temp_storage; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) + { + valid_items = n - i >= (BLOCK_SIZE) ? (BLOCK_SIZE) : n - i; + + __syncthreads(); + Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items, 0.0f); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items, 0.0f); + + # pragma unroll NUM_VALS + for(unsigned int j = 0; j < NUM_VALS; j++) + g_vals[j] = gnorm_scale*((float)g_vals[j]); + + # pragma unroll NUM_VALS + for(unsigned int j = 0; j < NUM_VALS; j++) + { + switch(OPTIMIZER) + { + case MOMENTUM: + if(step == 1) + s1_vals[j] = (float)g_vals[j]; // state update + else + s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); // state update + s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm + break; + case LION: + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*(float)g_vals[j]); // state update + break; + case RMSPROP: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); // state update + s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value + s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm + break; + case ADAGRAD: + s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]); // state update + s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value + s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm + break; + } + } + + # pragma unroll + for(unsigned int j = 1; j < NUM_VALS; j++) + s1_vals[0] += s1_vals[j]; + + __syncthreads(); + s1_vals[0] = BlockReduce(temp_storage.reduce).Sum(s1_vals[0], valid_items); + + if(threadIdx.x == 0) + atomicAdd(&unorm[0], s1_vals[0]); + + //__syncwarp(); + } +} + +template +__launch_bounds__(TH, 1) +__global__ void kOptimizer32bit1State(T *g, T *p, + float *state1, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n) +{ + + const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD)); + const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD); + int valid_items = 0; + float update_scale = 0.0f; + + if(max_unorm > 0.0f) + { + update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; + if(update_scale > max_unorm*param_norm+eps){ update_scale = (max_unorm*param_norm+eps)/update_scale; } + else{ update_scale = 1.0f; } + } + else{ update_scale = 1.0f; } + + T g_vals[NUM_PER_THREAD]; + T p_vals[NUM_PER_THREAD]; + + float s1_vals[NUM_PER_THREAD]; + + typedef hipcub::BlockLoad Load; + typedef hipcub::BlockStore Store; + + typedef hipcub::BlockLoad LoadFloat; + typedef hipcub::BlockStore StoreFloat; + + __shared__ union { + typename Load::TempStorage load; + typename Store::TempStorage store; + typename LoadFloat::TempStorage loadf; + typename StoreFloat::TempStorage storef; + } temp_storage; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*TH*NUM_PER_THREAD) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + + __syncthreads(); + Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items); + __syncthreads(); + Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items); + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD; j++) + { + g_vals[j] = gnorm_scale*((float)g_vals[j]); + if(weight_decay > 0.0f) + g_vals[j] = (float)g_vals[j] + (((float)p_vals[j])*weight_decay); + } + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD; j++) + { + if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + { + switch(OPTIMIZER) + { + case MOMENTUM: + if(step == 1) + s1_vals[j] = (float)g_vals[j]; + else + s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); + + p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j])); + break; + case LION: + p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j])))); + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*((float)g_vals[j])); + break; + case RMSPROP: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); + p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps)); + break; + case ADAGRAD: + s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]); + p_vals[j] = ((float)p_vals[j]) - lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps); + break; + } + } + } + + __syncthreads(); + Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items); + __syncthreads(); + StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items); + } +} + + +#define NUM8BIT 16 +#define NUM_THREADS 256 +#define NUM_PER_BLOCK 4096 + +template +__global__ void +__launch_bounds__(NUM_THREADS, 2) +kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, + float *unorm, + const float beta1, const float beta2, + const float eps, const int step, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + const float gnorm_scale, const int n) +{ + const int n_full = gridDim.x * NUM_PER_BLOCK; + const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD); + int valid_items = n - (blockIdx.x*NUM_PER_BLOCK) > NUM_PER_BLOCK ? NUM_PER_BLOCK : n - (blockIdx.x*NUM_PER_BLOCK); + float g_val = 0.0f; + float local_max_s1 = -FLT_MAX; + float local_max_s2 = -FLT_MAX; + float local_unorm = 0.0f; + + float s2_vals[NUM8BIT]; + float s1_vals[NUM8BIT]; + T g_vals[NUM8BIT]; + unsigned char m_c1[NUM8BIT]; + unsigned char r_c2[NUM8BIT]; + + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockLoad LoadUInt8; + typedef hipcub::BlockReduce BlockReduce; + + + __shared__ union { + typename LoadT::TempStorage loadh; + typename LoadUInt8::TempStorage loadc; + typename BlockReduce::TempStorage reduce; + } temp_storage; + + __shared__ float smem_quantiles1[256]; + __shared__ float smem_quantiles2[256]; + + if(threadIdx.x < 256) + { + smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x]; + smem_quantiles2[threadIdx.x] = quantiles2[threadIdx.x]; + } + + __syncthreads(); + + for (unsigned int i = base_idx; i < n_full; i += NUM_THREADS*gridDim.x*NUM8BIT) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + __syncthreads(); + LoadUInt8(temp_storage.loadc).Load(&(state1[i]), m_c1, valid_items, 128); + __syncthreads(); + LoadUInt8(temp_storage.loadc).Load(&(state2[i]), r_c2, valid_items, 128); + __syncthreads(); + + #pragma unroll 16 + for(int j = 0; j < NUM8BIT; j++) + { + g_val = g_vals[j]; + g_val *= gnorm_scale; + s1_vals[j] = smem_quantiles1[m_c1[j]]*max1[0]*beta1; + s1_vals[j] += (1.0f-beta1)*g_val; + local_max_s1 = fmaxf(local_max_s1, fabsf(s1_vals[j])); + } + + #pragma unroll 16 + for(int j = 0; j < NUM8BIT; j++) + { + g_val = g_vals[j]; + g_val *= gnorm_scale; + s2_vals[j] = smem_quantiles2[r_c2[j]]*max2[0]*beta2; + s2_vals[j] += (1.0f-beta2)*g_val*g_val; + local_max_s2 = fmaxf(local_max_s2, fabsf(s2_vals[j])); + } + + if(unorm != NULL) + { + #pragma unroll 16 + for(int j = 0; j < NUM8BIT; j++) + { + float correction1 = __fdividef(1.0f, 1.0f - powf(beta1, step)); + float correction2 = __fdividef(1.0f, 1.0f - powf(beta2, step)); + s1_vals[j] *= correction1; + s2_vals[j] *= correction2; + float update_val = s1_vals[j]/(sqrtf(s2_vals[j])+eps); // update + local_unorm += update_val*update_val; + } + } + } + + __syncthreads(); + local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, hipcub::Max(), valid_items); + __syncthreads(); + local_max_s2 = BlockReduce(temp_storage.reduce).Reduce(local_max_s2, hipcub::Max(), valid_items); + if(unorm != NULL) + { + __syncthreads(); + local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, hipcub::Sum(), valid_items); + } + + if(threadIdx.x == 0) + { + atomicMax(&new_max1[0], local_max_s1); + atomicMax(&new_max2[0], local_max_s2); + if(unorm != NULL){ atomicAdd(&unorm[0], local_unorm); } + } +} + +#define NUM_PER_THREAD2 4 +#define NUM_THREADS2 1024 +#define NUM_PER_BLOCK2 4096 + +template +__global__ void +__launch_bounds__(NUM_THREADS2, 1) +kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned char* state2, + const float *unorm, const float max_unorm, const float param_norm, \ + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + float weight_decay, + const float gnorm_scale, const int n) +{ + + const int n_full = (blockDim.x * gridDim.x)*NUM_PER_THREAD2; + const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD2); + int valid_items = 0; + float g_val = 0.0f; + float s1_vals[NUM_PER_THREAD2]; + float s2_vals[NUM_PER_THREAD2]; + const float correction1 = 1.0f - powf(beta1, step); + const float correction2 = sqrtf(1.0f - powf(beta2, step)); + const float step_size = -lr*correction2/correction1; + //const float step_size = -lr*correction2/correction1; + float new_max_val1 = 1.0f/new_max1[0]; + float new_max_val2 = 1.0f/new_max2[0]; + float update_scale = 1.0f; + + if(max_unorm > 0.0f) + { + update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; + if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; } + else{ update_scale = 1.0f; } + } + else{ update_scale = 1.0f; } + + unsigned char c1s[NUM_PER_THREAD2]; + unsigned char c2s[NUM_PER_THREAD2]; + T p_vals[NUM_PER_THREAD2]; + T g_vals[NUM_PER_THREAD2]; + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockLoad LoadChar; + + typedef hipcub::BlockStore StoreChar; + typedef hipcub::BlockStore StoreT; + + __shared__ float smem_quantiles1[256]; + __shared__ float smem_quantiles2[256]; + + __shared__ union { + typename LoadT::TempStorage loadh; + typename LoadChar::TempStorage loadc; + typename StoreChar::TempStorage storec; + typename StoreT::TempStorage storeh; + } temp_storage; + + if(threadIdx.x < 512) + { + if(threadIdx.x < 256) + smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x]; + else + smem_quantiles2[threadIdx.x-256] = quantiles2[threadIdx.x-256]; + } + + __syncthreads(); + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_THREADS2*NUM_PER_THREAD2) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state2[i]), c2s, valid_items, 0); + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items); + + if((i + (threadIdx.x*NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue; } + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD2; j++) + { + g_val = float(g_vals[j]); + g_val *= gnorm_scale; + s1_vals[j] = smem_quantiles1[c1s[j]]; + s1_vals[j] = s1_vals[j]*max1[0]; + + s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val)); + + c1s[j] = dQuantize<0>(smem_quantiles1, 0.0f, s1_vals[j]*new_max_val1); + + // make sure state1 term has still the same sign after quantization + // (not needed for state2 term which has only positive values) + if(signbit(smem_quantiles1[c1s[j]]) != signbit(s1_vals[j])) + { + if(s1_vals[j] > 0.0f) + c1s[j] += 1; + else + c1s[j] -= 1; + } + + s2_vals[j] = smem_quantiles2[c2s[j]]; + s2_vals[j] = s2_vals[j]*max2[0]; + s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val)); + c2s[j] = dQuantize<0>(smem_quantiles2, 0.0f, s2_vals[j]*new_max_val2); + } + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD2; j++) + { + p_vals[j] = (T)(((float)p_vals[j]) + ((update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(correction2*eps)))))); + if(weight_decay > 0.0f) + p_vals[j] = update_scale*((float)p_vals[j])*(1.0f-(lr*weight_decay)); + } + + StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state2[i]), c2s, valid_items); + __syncthreads(); + } +} + + +template +__global__ void +__launch_bounds__(NUM_THREADS, 2) +kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, + float *unorm, + const float beta1, const float beta2, + const float eps, const int step, + float* __restrict__ const quantiles1, + float* max1, float* new_max1, + const float weight_decay, + const float gnorm_scale, const int n) +{ + const int n_full = gridDim.x * NUM_PER_BLOCK; + const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD); + int valid_items = n - (blockIdx.x*NUM_PER_BLOCK) > NUM_PER_BLOCK ? NUM_PER_BLOCK : n - (blockIdx.x*NUM_PER_BLOCK); + float g_val = 0.0f; + float local_max_s1 = -FLT_MAX; + float local_unorm = 0.0f; + + float s1_vals[NUM8BIT]; + T g_vals[NUM8BIT]; + unsigned char m_c1[NUM8BIT]; + + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockLoad LoadUInt8; + typedef hipcub::BlockReduce BlockReduce; + + + __shared__ union { + typename LoadT::TempStorage loadh; + typename LoadUInt8::TempStorage loadc; + typename BlockReduce::TempStorage reduce; + } temp_storage; + + __shared__ float smem_quantiles1[256]; + + if(threadIdx.x < 256) + smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x]; + + __syncthreads(); + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_THREADS*NUM8BIT) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + __syncthreads(); + LoadUInt8(temp_storage.loadc).Load(&(state1[i]), m_c1, valid_items, 128); + + #pragma unroll 16 + for(int j = 0; j < NUM8BIT; j++) + { + g_val = g_vals[j]; + g_val *= gnorm_scale; + s1_vals[j] = smem_quantiles1[m_c1[j]]*max1[0]; + switch(OPTIMIZER) + { + case MOMENTUM: + if(step == 1) + s1_vals[j] = (float)g_vals[j]; + else + s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); + if(unorm != NULL) + local_unorm += s1_vals[j]*s1_vals[j]; + break; + case LION: + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); + break; + case RMSPROP: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); + break; + } + + local_max_s1 = fmaxf(local_max_s1, fabsf(s1_vals[j])); + } + } + + __syncthreads(); + local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, hipcub::Max(), valid_items); + if(threadIdx.x == 0){ atomicMax(&new_max1[0], local_max_s1); } + if(unorm != NULL) + { + __syncthreads(); + local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, hipcub::Sum(), valid_items); + if(threadIdx.x == 0){ atomicAdd(&unorm[0], local_unorm); } + } + +} + +template +__global__ void +__launch_bounds__(1024, 1) +kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, + const float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, + float* max1, float* new_max1, + float weight_decay, + const float gnorm_scale, const int n) +{ + + const int n_full = (blockDim.x * gridDim.x)*NUM_PER_THREAD2; + const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD2); + int valid_items = 0; + float g_val = 0.0f; + float s1_vals[NUM_PER_THREAD2]; + float new_max_val1 = 1.0f/new_max1[0]; + float update_scale = 1.0f; + + if(max_unorm > 0.0f) + { + update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; + if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; } + else{ update_scale = 1.0f; } + } + else{ update_scale = 1.0f; } + + unsigned char c1s[NUM_PER_THREAD2]; + T p_vals[NUM_PER_THREAD2]; + T g_vals[NUM_PER_THREAD2]; + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockLoad LoadChar; + + typedef hipcub::BlockStore StoreChar; + typedef hipcub::BlockStore StoreT; + + __shared__ float smem_quantiles1[256]; + + __shared__ union { + typename LoadT::TempStorage loadh; + typename LoadChar::TempStorage loadc; + typename StoreChar::TempStorage storec; + typename StoreT::TempStorage storeh; + } temp_storage; + + if(threadIdx.x < 256) + smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x]; + + __syncthreads(); + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_THREADS2*NUM_PER_THREAD2) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items); + + if((i + (threadIdx.x*NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue; } + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD2; j++) + { + g_val = float(g_vals[j]); + g_val *= gnorm_scale; + + if(weight_decay > 0.0f) { + switch(OPTIMIZER) { + case MOMENTUM: + case RMSPROP: + g_val += ((float)p_vals[j])*weight_decay; + break; + case LION: + p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay); + break; + } + } + + s1_vals[j] = smem_quantiles1[c1s[j]]*max1[0]; + + switch(OPTIMIZER) + { + case MOMENTUM: + if(step == 1) + s1_vals[j] = g_vals[j]; + else + s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); + + p_vals[j] = ((float)p_vals[j]) + (-lr*update_scale*(s1_vals[j])); + break; + case LION: + p_vals[j] = ((float)p_vals[j]) - (lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_val)))); + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); + break; + case RMSPROP: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); + p_vals[j] = ((float)p_vals[j]) - (lr*__fdividef(g_val,sqrtf(s1_vals[j])+eps)); + break; + } + + c1s[j] = dQuantize<0>(smem_quantiles1, 0.0f, s1_vals[j]*new_max_val1); + + // make sure state1 term has still the same sign after quantization + if(signbit(smem_quantiles1[c1s[j]]) != signbit(s1_vals[j])) + { + if(s1_vals[j] > 0.0f) + c1s[j] += 1; + else + c1s[j] -= 1; + } + } + + StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); + __syncthreads(); + } +} + + +template +__global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n) +{ + const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); + int valid_items = 0; + + typedef hipcub::BlockReduce BlockReduce; + typedef hipcub::BlockLoad LoadT; + + __shared__ typename BlockReduce::TempStorage reduce; + + __shared__ typename LoadT::TempStorage loadT; + T vals[NUM_VALS]; + float local_sum = 0.0f; + + for (unsigned int i = (blockIdx.x * BLOCK_SIZE); i < n_full; i += gridDim.x*BLOCK_SIZE) + { + valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i; + local_sum = 0.0f; + + __syncthreads(); + LoadT(loadT).Load(&(g[i]), vals, valid_items, (T)0.0f); + + #pragma unroll NUM_VALS + for(int j = 0; j < NUM_VALS; j++) + local_sum += ((float)vals[j])*((float)vals[j]); + + local_sum = BlockReduce(reduce).Sum(local_sum, valid_items); + if(threadIdx.x == 0) + { + if(step == 1) + { + // initialize with the same norm for all positions + //#pragma unroll 10 + for(int j = 0; j < 100; j++) + atomicAdd(&gnorm_vec[j], local_sum); + } + else + atomicAdd(&gnorm_vec[step % 100], local_sum); + } + + } +} + + +#define LANES 2 +#define QUAD 3 +template +__launch_bounds__(256, 3) +__global__ void +kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, + float* absmax1, float* absmax2, + float weight_decay, + const float gnorm_scale, const bool skip_zeros, const int n) +{ + + //const int n_full = n + (n%BLOCK_SIZE); + const int n_full = gridDim.x * BLOCK_SIZE; + const int base_idx = (blockIdx.x * BLOCK_SIZE); + int valid_items = 0; + float g_val = 0.0f; + float s1_vals[N_PER_TH]; + float s2_vals[N_PER_TH]; + // 2-5% + const float correction1 = 1.0f - __powf(beta1, step); + const float correction2 = sqrtf(1.0f -__powf(beta2, step)); + const float step_size = __fdividef(-lr*correction2,correction1); + const int lane_id = threadIdx.x % LANES; + float new_local_abs_max1 = -FLT_MAX; + float new_local_abs_max2 = -FLT_MAX; + float quadrants1[QUAD]; + float quadrants2[QUAD]; + + unsigned char c1s[N_PER_TH]; + unsigned char c2s[N_PER_TH]; + T g_vals[N_PER_TH]; + T p_vals[N_PER_TH]; + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockLoad LoadChar; + + typedef hipcub::BlockStore StoreChar; + typedef hipcub::BlockStore StoreT; + + __shared__ float smem_quantiles1[LANES][257]; + __shared__ float smem_quantiles2[LANES][257]; + typedef hipcub::BlockReduce BlockReduce1; + typedef hipcub::BlockReduce BlockReduce2; + __shared__ typename BlockReduce1::TempStorage reduce1; + __shared__ typename BlockReduce2::TempStorage reduce2; + __shared__ float smem_exchange1[1]; + __shared__ float smem_exchange2[1]; + + __shared__ union { + typename LoadT::TempStorage loadh; + typename LoadChar::TempStorage loadc; + typename StoreChar::TempStorage storec; + typename StoreT::TempStorage storeh; + } temp_storage; + // init: 0.2 -> 0.23 + + // 0.23 -> 0.23 + smem_quantiles1[0][threadIdx.x] = quantiles1[threadIdx.x]; + smem_quantiles2[0][threadIdx.x] = quantiles2[threadIdx.x]; + # pragma unroll + for(unsigned int j = 1; j < LANES; j++) + { + smem_quantiles1[j][threadIdx.x] = smem_quantiles1[0][threadIdx.x]; + smem_quantiles2[j][threadIdx.x] = smem_quantiles2[0][threadIdx.x]; + } + + __syncthreads(); + + #pragma unroll + for(int k = 0; k < QUAD; k++) + { + quadrants1[k] = smem_quantiles1[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)]; + quadrants2[k] = smem_quantiles2[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)]; + } + + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) + { + // loads: 0.23 -> 0.85/1.44 + valid_items = n - i >= BLOCK_SIZE ? BLOCK_SIZE : n - i; + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state2[i]), c2s, valid_items, 0); + + new_local_abs_max1 = -FLT_MAX; + new_local_abs_max2 = -FLT_MAX; + + // update: 2.48/1.57 -> 2.51/1.60 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + if(!isnan((float)g_vals[j]) && !isinf((float)g_vals[j])) + { + s2_vals[j] = smem_quantiles2[lane_id][c2s[j]]*absmax2[i/BLOCK_SIZE]; + g_val = g_vals[j]; + //float ratio = (g_val*g_val)/fmaxf(s2_vals[j], eps*eps); + //g_val = ratio > 2.0f ? 2.0f*g_val/ratio : g_val; + g_val *= gnorm_scale; + + s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val)); + + s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE]; + s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val)); + } + else + { + s1_vals[j] = 0.0f; + s2_vals[j] = 0.0f; + } + + new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j])); + new_local_abs_max2 = fmaxf(new_local_abs_max2, fabsf(s2_vals[j])); + } + + + // reduce: 2.51/1.60 -> 2.67/1.69 + new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, hipcub::Max()); + new_local_abs_max2 = BlockReduce2(reduce2).Reduce(new_local_abs_max2, hipcub::Max()); + + if(threadIdx.x == 0) + { + smem_exchange1[0] = new_local_abs_max1; + smem_exchange2[0] = new_local_abs_max2; + } + + __syncthreads(); + + if(threadIdx.x == 0) + { + absmax1[i/BLOCK_SIZE] = new_local_abs_max1; + absmax2[i/BLOCK_SIZE] = new_local_abs_max2; + } + else + { + new_local_abs_max1 = smem_exchange1[0]; + new_local_abs_max2 = smem_exchange2[0]; + } + + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f); + // reduce: 2.67/1.69 -> 2.67/1.70 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + //if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + if(!isnan((float)g_vals[j]) && !isinf((float)g_vals[j])) + { + p_vals[j] = (T)(((float)p_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps))))))); + if(weight_decay > 0.0f) + p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)); + } + } + + // store: 0.85/1.44 -> 2.48/1.57 + __syncthreads(); + StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); + + // quantizaztion: 2.67/1.70 -> 3.4/3.3 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s1_vals[j],new_local_abs_max1)); + c2s[j] = quantize_2D<0>(quadrants2, smem_quantiles2[lane_id], __fdividef(s2_vals[j],new_local_abs_max2)); + + // make sure state1 term has still the same sign after quantization + // (not needed for state2 term which has only positive values) + if(signbit(smem_quantiles1[lane_id][c1s[j]]) != signbit(s1_vals[j])) + { + if(s1_vals[j] > 0.0f) + c1s[j] += 1; + else + c1s[j] -= 1; + } + } + + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state2[i]), c2s, valid_items); + } +} + + +#define LANES 2 +#define QUAD 3 +template +__launch_bounds__(256, 3) +__global__ void +kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char* state1, + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, + float* absmax1, + float weight_decay, + const float gnorm_scale, const bool skip_zeros, const int n) +{ + + //const int n_full = n + (n%BLOCK_SIZE); + const int n_full = gridDim.x * BLOCK_SIZE; + const int base_idx = (blockIdx.x * BLOCK_SIZE); + int valid_items = 0; + float g_val = 0.0f; + float s1_vals[N_PER_TH]; + // 2-5% + const int lane_id = threadIdx.x % LANES; + float new_local_abs_max1 = -FLT_MAX; + float quadrants1[QUAD]; + + unsigned char c1s[N_PER_TH]; + T g_vals[N_PER_TH]; + T p_vals[N_PER_TH]; + + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockLoad LoadChar; + + typedef hipcub::BlockStore StoreChar; + typedef hipcub::BlockStore StoreT; + + __shared__ float smem_quantiles1[LANES][257]; + typedef hipcub::BlockReduce BlockReduce1; + __shared__ typename BlockReduce1::TempStorage reduce1; + __shared__ float smem_exchange1[1]; + + __shared__ union { + typename LoadT::TempStorage loadh; + typename LoadChar::TempStorage loadc; + typename StoreChar::TempStorage storec; + typename StoreT::TempStorage storeh; + } temp_storage; + // init: 0.2 -> 0.23 + + // 0.23 -> 0.23 + smem_quantiles1[0][threadIdx.x] = quantiles1[threadIdx.x]; + # pragma unroll + for(unsigned int j = 1; j < LANES; j++) + smem_quantiles1[j][threadIdx.x] = smem_quantiles1[0][threadIdx.x]; + + __syncthreads(); + + #pragma unroll + for(int k = 0; k < QUAD; k++) + quadrants1[k] = smem_quantiles1[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)]; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) + { + // loads: 0.23 -> 0.85/1.44 + valid_items = n - i >= BLOCK_SIZE ? BLOCK_SIZE : n - i; + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f); + + new_local_abs_max1 = -FLT_MAX; + + // update: 2.48/1.57 -> 2.51/1.60 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + g_val = float(g_vals[j]); + g_val *= gnorm_scale; + if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + { + if(weight_decay > 0.0f) { + switch(OPTIMIZER) { + case MOMENTUM: + case ADAGRAD: + case RMSPROP: + g_val += ((float)p_vals[j])*weight_decay; + break; + case LION: + p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay); + break; + } + } + + s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE]; + + switch(OPTIMIZER) + { + case MOMENTUM: + if(step == 1) + s1_vals[j] = g_val; + else + s1_vals[j] = (s1_vals[j]*beta1) + g_val; + break; + case LION: + // here, using gvals[j] to store the gradient smoothed by beta1 for the following parameter update, before the momentum is updated by beta2 + g_vals[j] = lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*g_val)); + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); + break; + case RMSPROP: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); + break; + case ADAGRAD: + s1_vals[j] = s1_vals[j] + (g_val*g_val); + break; + } + } + + new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j])); + } + + + // reduce: 2.51/1.60 -> 2.67/1.69 + new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, hipcub::Max()); + + if(threadIdx.x == 0) + smem_exchange1[0] = new_local_abs_max1; + + __syncthreads(); + + if(threadIdx.x == 0) + absmax1[i/BLOCK_SIZE] = new_local_abs_max1; + else + new_local_abs_max1 = smem_exchange1[0]; + + // reduce: 2.67/1.69 -> 2.67/1.70 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + { + switch(OPTIMIZER) + { + case MOMENTUM: + p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]); + break; + case LION: + p_vals[j] = ((float)p_vals[j]) - ((float)g_vals[j]); + break; + case RMSPROP: + g_val = g_vals[j]; + p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps)); + break; + case ADAGRAD: + g_val = g_vals[j]; + p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps)); + break; + } + } + } + + // store: 0.85/1.44 -> 2.48/1.57 + __syncthreads(); + StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); + + // quantizaztion: 2.67/1.70 -> 3.4/3.3 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s1_vals[j],new_local_abs_max1)); + + // make sure state1 term has still the same sign after quantization + // (not needed for state2 term which has only positive values) + if(signbit(smem_quantiles1[lane_id][c1s[j]]) != signbit(s1_vals[j])) + { + if(s1_vals[j] > 0.0f) + c1s[j] += 1; + else + c1s[j] -= 1; + } + } + + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); + } +} + +template __global__ void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols) +{ + // 0. reset stats to -FLT_MAX + // 1. load row-by-row ITEMS_PER_THREAD (TILE_SIZE==THREADS*ITEMS_PER_THREAD) + // 2. compute col max (per thread); store in smem due to register pressure + // 3. compute row max (per block); store in smem to accumulate full global mem transation + // 4. store data via atomicMax + + // each block loads TILE_COLs columns and TILE_ROW rows + // after reading a tile the row counter increase by TILE_ROWS + // the col counter reset after reading TILE_COL elements + const int base_row = ((blockIdx.x*TILE_COLS)/tiledCols)*TILE_ROWS; + // col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached + const int base_col = (blockIdx.x*TILE_COLS) % tiledCols; + const int base_idx = (base_row*cols) + base_col; + const int items_per_load = ITEMS_PER_THREAD*THREADS; + + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockReduce BlockRowReduce; + typedef hipcub::BlockReduce BlockRowSum; + typedef hipcub::BlockExchange BlockExchange; + + __shared__ union { + typename BlockExchange::TempStorage exchange; + typename BlockRowReduce::TempStorage rowreduce; + typename BlockRowSum::TempStorage rowsum; + typename LoadT::TempStorage loadt; + } temp_storage; + + __shared__ float smem_row_absmax_values[ITEMS_PER_THREAD*THREADS]; + __shared__ int smem_row_nnz_values[TILE_ROWS]; + + half local_data[ITEMS_PER_THREAD]; + float local_data_fp32[ITEMS_PER_THREAD]; + float local_col_absmax_values[ITEMS_PER_THREAD]; + int local_row_nnz_count = 0; + float row_absmax = -FLT_MAX; + + // 0. reset stats to -FLT_MAX + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + //smem_col_absmax_values[threadIdx.x + (j*THREADS)] = -FLT_MAX; + smem_row_absmax_values[threadIdx.x + (j*THREADS)] = -FLT_MAX; + // smem_row_nnz_values[threadIdx.x + (j*THREADS)] = 0; + } + + #pragma unroll TILE_ROWS + for (int j = 0; j < TILE_ROWS; j++) { + smem_row_nnz_values[j] = 0; + } + + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + local_col_absmax_values[j] = -FLT_MAX; + + __syncthreads(); + + int valid_items = cols - base_col > items_per_load ? items_per_load : cols - base_col; + int i = base_idx; + // we load row after row from the base_position + // 1. load row-by-row ITEMS_PER_THREAD (TILE_SIZE==THREADS*ITEMS_PER_THREAD) + for(int row = 0; row < TILE_ROWS; row++) + { + if(base_row+row >= rows){ break; } + local_row_nnz_count = 0; + i = base_idx + ((row)*cols); + // each thread gets data from the same column + __syncthreads(); + LoadT(temp_storage.loadt).Load(&(A[i]), local_data, valid_items, __float2half(0.0f)); + + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + local_data[j] = fabsf(local_data[j]); + + + if(SPARSE_DECOMP) + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + if((float)local_data[j] >= nnz_threshold) + { + local_row_nnz_count += 1; + local_data[j] = 0.0f; + } + } + + // 2. compute col max (per thread); store in smem due to register pressure + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + // take the col max for this row + // we use shared memory because register pressure is too high if we do this locally + //smem_col_absmax_values[threadIdx.x + (j*THREADS)] = fmaxf(smem_col_absmax_values[threadIdx.x + (j*THREADS)], __half2float(local_data[j])); + local_col_absmax_values[j] = fmaxf(local_col_absmax_values[j], __half2float(local_data[j])); + + // 3. compute row max (per block); store in smem to accumulate full global mem transation + + // this is slow as it uses extra registers, but we need this to be compatible with Kepler and Maxwell (no fp16 units) + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + local_data_fp32[j] = local_data[j]; + + __syncthreads(); + + row_absmax = (float)BlockRowReduce(temp_storage.rowreduce).Reduce(local_data_fp32, hipcub::Max()); + if(SPARSE_DECOMP) + { + __syncthreads(); + local_row_nnz_count = BlockRowSum(temp_storage.rowsum).Sum(local_row_nnz_count); + } + // we store the data temporarily in shared memory so we + // can execute a full atomic block transaction into global memory later + // we use a striped arrangement [0, 8, 16, 24, ..] for t0 for faster stores + if(threadIdx.x == 0) + { + smem_row_absmax_values[(row % ITEMS_PER_THREAD) + ((row/ITEMS_PER_THREAD)*ITEMS_PER_THREAD)] = row_absmax; + // each blockIdx.x process 16 rows and 64*4=256 columns -> we sum nnz over 256 columns and have 16 values per block + smem_row_nnz_values[row] = local_row_nnz_count; + } + + __syncthreads(); + + } + + // 4. store data via atomicMax + // to store col data efficienctly we need to rewrite the smem blocked data [0, 1, 2, 3...] for t0 + // into a striped arangement: [0, 8, 16, 24, ..] for t0 + __syncthreads(); + BlockExchange(temp_storage.exchange).BlockedToStriped(local_col_absmax_values); + + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + if(base_col+threadIdx.x+(j*THREADS) < cols) + { + float val = colStats[base_col+(threadIdx.x+(j*THREADS))]; + if(val < local_col_absmax_values[j]) + atomicMax(&colStats[base_col+(threadIdx.x+(j*THREADS))], local_col_absmax_values[j]); + } + + for(int j = 0; j < ITEMS_PER_THREAD; j++) + if(base_row+threadIdx.x+(j*THREADS) < rows) + { + float val = rowStats[base_row+(threadIdx.x+(j*THREADS))]; + if(val < smem_row_absmax_values[threadIdx.x+(j*THREADS)]) + atomicMax(&rowStats[base_row+(threadIdx.x+(j*THREADS))], smem_row_absmax_values[threadIdx.x+(j*THREADS)]); + } + + if(SPARSE_DECOMP) + if(threadIdx.x < TILE_ROWS) + nnz_count_row[blockIdx.x*TILE_ROWS+threadIdx.x+1] = smem_row_nnz_values[threadIdx.x]; + +} + +template __global__ void kgetColRowStats(half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols); +template __global__ void kgetColRowStats(half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols); + +#define MM_DEQUANT_CONST 6.200012e-05f //1.0f/(127.0f*127.0f) + +template __global__ void kdequant_mm_int32_fp16(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, float* newRowStats, float* newcolStats, half *__restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n) +{ + + // Strategy: To dequantize we need to load col/row statistics. This can be very expensive + // since different row/col stats need to be loaded with each thread. + // (1, bad algorithm) Loading 32 items per thread would only occur 1 row load, but this increases register pressure + // and would lead to low global load utilization. + // (2, bad algorithm) If each thread loads some columns and multiple rows one needs to do lot of row loads + // for each thread and this is duplicated by a factor of 32/num-cols-per-thread. + // (3, good algorithm) Combining (1) and (2) we use sub-tiles of size 32xk in shared memory per threadblock. + // This allows for efficient row/col loading from shared memory within the tile. + // We can run for example 32x128 sub-tiles and warp-strided loads of 4 elements so that each thread has + // the same col statistic but needs to load 4 row stats from shared memory. To prevent bank conflicts + // we use a block-striped shared memory config [1, 31, 63, 95] so no bank conflicts happen during the + // shared memory loads. + + // data is in 32 column-tile major with tile width 32 columns and numRows rows + // L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory. + // L2. Load data in warp-striped arangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3]) + // C1. Compute val(row_stat*col_stat)/(127*127) (load 1/(127*127 into register)) + // C2. Compute normalization values and store col values in register + // S1. Store C1 into 16-bit output + // S2. Store col/row statistics of new buffer in shared memory + + // We allow for sub-tiles to span multiple col32 tiles. This is okay + // since the items per thread only rely on a single column statistic. + + + const int n_out = numRows*numCols; + + //int num_row_tiles = (numRows/SUBTILE_ROWS) + (numRows % SUBTILE_ROWS == 0 ? 0 : 1); + // we have tiles of size numRows*32, thus col only increases every numRows + // num_row_tiles is the tiles after which the column increases by 32 + // blockIdx.x is the index of the current tile + //int col = ((threadIdx.x % 32) + ((blockIdx.x/num_row_tiles)*32)); + // base_row increases by SUBTILE_ROWS every block. It wraps back to zero once num_row_tiles is reached + //int base_row = (blockIdx.x*SUBTILE_ROWS) % (num_row_tiles*SUBTILE_ROWS); + + int block_offset = blockIdx.x * THREADS * ITEMS_PER_THREAD; + int thread_offset = threadIdx.x * ITEMS_PER_THREAD; + + // SUBTILE_ROWS is independent from ITEMS_PER_THREAD is independent from THREADS + // subtiles have 32*SUBTILE_ROWS elements <= THREADS*ITEMS_PER_THREAD + // Total subtiles should be n/(32*SUBTILE_ROWS) where each subtile has SUBTILE_ROW*32/4 threads. + // For example for a 1024x1024 matrix with 128 SUBTILE_ROWS and 4 ITEMS_PER_THREAD we have + // 1024*1024/(128*32) = 256 tiles + // 256 tiles are 256*128*32/4 = 256*1024 threads + + // 1. Figure out how index relates to the start of the sub-tile + // 2. Each thread < SUBTILE_ROWS calculates row index + // 3. Load striped and store in shared memory + + int local_values[ITEMS_PER_THREAD]; + half local_output[ITEMS_PER_THREAD]; + //float local_rowStats[ITEMS_PER_THREAD]; + //__shared__ float smem_rowStats[SUBTILE_ROWS]; + + typedef hipcub::BlockLoad LoadInt32; + //typedef hipcub::BlockExchange ExchangeInt32; + __shared__ typename LoadInt32::TempStorage loadint32; + //__shared__ typename ExchangeInt32::TempStorage exchangeint32; + + + // L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory. + //float colStat = col >= numCols ? 0.0f : colStats[col]; + //float local_biasValue = ((bias == NULL) || (col >= numCols)) ? 0.0f : __half2float(bias[col]); +int row_idx, col_idx; + float colStat[ITEMS_PER_THREAD]; + float local_biasValue[ITEMS_PER_THREAD]; + float rowStat[ITEMS_PER_THREAD]; + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + row_idx = (block_offset + thread_offset + j) / numCols; + col_idx = (block_offset + thread_offset + j) % numCols; + colStat[j] = col_idx >= numCols ? 0.0f : colStats[col_idx]; + local_biasValue[j] = ((bias == NULL) || (col_idx >= numCols)) ? 0.0f : __half2float(bias[col_idx]); + rowStat[j] = row_idx >= numRows ? 0.0f : rowStats[row_idx]; + } + // no block loads for rows for now -- keep it simple + /*for(int j = threadIdx.x; j < SUBTILE_ROWS; j+=blockDim.x) + { + // todo: is this global mem access slow due to overlaps or does the L1 cache work well here? + int row = (base_row+j) % numRows; // wrap around + // each warp accesses the same element, for four consequitive elements + // todo: update description about striped shared memory, it is not needed + // rowidx: [0, 1, 2, 3...] and each warp reads ITEMS_PER_THREAD consequitive elements + smem_rowStats[j] = rowStats[row]; + }*/ + __syncthreads(); + +int valid_items = block_offset + THREADS * ITEMS_PER_THREAD < n_out ? THREADS * ITEMS_PER_THREAD : n_out - block_offset; + LoadInt32(loadint32).Load(&(A[block_offset]), local_values, valid_items, 0); + + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + local_output[j] = __float2half((local_values[j]*MM_DEQUANT_CONST*rowStat[j]*colStat[j]) + local_biasValue[j]); + + // each block processes SUBTILE_ROWS*32 elements + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + int outIdx = block_offset + thread_offset + j; + if(outIdx< n_out) + out[outIdx] = local_output[j]; + } + /*const int items_per_load = THREADS*ITEMS_PER_THREAD; + const int rows_per_load = items_per_load/32; + + int subtile_base_row = (threadIdx.x / 32)*ITEMS_PER_THREAD; // row within the tile + int row_offset = 0; + // subtile_idx starts at the base_row*32 + the total offset for a full numRow*32 tile is passed + int subtile_start = (blockIdx.x/num_row_tiles)*(numRows*32) + (base_row*32); + for(int subtile_idx = subtile_start; subtile_idx < subtile_start + (SUBTILE_ROWS*32); subtile_idx+=items_per_load) + { + int valid_rows = numRows - (base_row+row_offset) > rows_per_load ? rows_per_load : numRows - (base_row+row_offset); + int valid_items = valid_rows*32; + if(valid_items <= 0) // the sub-tile might have more elements than the tile itself + break; + + // L2. Load data in warp-striped arangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3]) + LoadInt32(loadint32).Load(&(A[subtile_idx]), local_values, valid_items, 0); + ExchangeInt32(exchangeint32).BlockedToWarpStriped(local_values, local_values); + + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + local_rowStats[j] = smem_rowStats[subtile_base_row+row_offset+j]; + + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + local_output[j] = __float2half((local_values[j]*MM_DEQUANT_CONST*local_rowStats[j]*colStat) + local_biasValue); + //absmax_col = fmax(fabsf(local_output[j]), absmax_col); + + // we store data in row major + // to store data efficiently, we want to use block exchange: [0, 32, 64, 92] -> [0, 1, 2, 3] + // so that each thread holds ITEMS_PER_THREAD consecutive items for each row + // this way throughput into storage is increased by a factor of ~2x + // for now we use a simple store + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + int outIdx = col + ((base_row+subtile_base_row+row_offset+j)*numCols); + if(outIdx< n_out && col < numCols) + out[outIdx] = local_output[j]; + } + + row_offset += rows_per_load; + }*/ +} + + +template __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols) +{ + // assumes TILE_SIZE == THREADS*ITEMS_PER_THREAD + // Each thread reads the same column but multiple rows + // Rows are loaded in shared memory and access is shared across the threadblock (broadcast) + + // 0. Load row stats data into shared memory; load col stat (1 fixed per thread) + // 1. Load data row by row (should be at least with TILE_SIZE = 512) + // 2. quantize data with row/col stats + // 3. Store data (TILE_SIZE = 512 is a bit slow, but should still be close enough to good performance) + + // each block loads TILE_COLs columns and TILE_ROW rows + // after reading a tile the row counter increase by TILE_ROWS + // the col counter reset after reading TILE_COL elements + const int base_row = ((blockIdx.x*TILE_COLS)/tiledCols)*TILE_ROWS; + // col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached + const int base_col = (blockIdx.x*TILE_COLS) % tiledCols; + const int base_idx = (base_row*cols) + base_col; + const int items_per_load = ITEMS_PER_THREAD*THREADS; + + typedef hipcub::BlockLoad LoadHalf; + __shared__ typename LoadHalf::TempStorage loadhalf; + typedef hipcub::BlockStore StoreInt8; + __shared__ typename StoreInt8::TempStorage storeint8; + + __shared__ float smem_row_stats[TILE_ROWS]; + __shared__ unsigned int smem_nnz_row_idx[TILE_ROWS]; + + half local_data[ITEMS_PER_THREAD]; + float local_col_stats[ITEMS_PER_THREAD]; + char local_quantized_data[ITEMS_PER_THREAD]; + + // 0. Load row stats data into shared memory; load col stat (1 fixed per thread) + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + if(base_col+(threadIdx.x*ITEMS_PER_THREAD) + j < cols) + local_col_stats[j] = __fdividef(127.0f, colStats[base_col+(threadIdx.x*ITEMS_PER_THREAD)+j]); + + for(int i = threadIdx.x; i < TILE_ROWS; i+=blockDim.x) + { + if(base_row + i < rows) + smem_row_stats[i] = rowStats[base_row+i]; + + if(SPARSE_DECOMP) + smem_nnz_row_idx[i] = nnz_block_ptr[(TILE_ROWS*blockIdx.x) + i]; + } + __syncthreads(); + + // we load row after row from the base_position + // 1. Load data row by row (should be at least with TILE_SIZE = 512) + for(int row = 0; row < TILE_ROWS; row++) + { + if(base_row + row >= rows){ break; } + int i = base_idx + (row*cols); + int valid_items = cols - base_col > items_per_load ? items_per_load : cols - base_col; + + + LoadHalf(loadhalf).Load(&(A[i]), local_data, valid_items, 0.0f); + float row_stat = __fdividef(127.0f, smem_row_stats[row]); + + // 2. quantize data with row/col stats + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + // we already pre-normalized the col/row stat: + // what this does is float/absmax*127 = int8 + if(SPARSE_DECOMP) + { + if(fabsf((float)local_data[j]) >= threshold) + { + local_quantized_data[j] = 0; + + int old_idx = atomicInc(&smem_nnz_row_idx[row], UINT_MAX); + + rowidx[old_idx] = base_row+row; + colidx[old_idx] = base_col+(threadIdx.x*ITEMS_PER_THREAD)+j; + val[old_idx] = local_data[j]; + } + else + { + local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*row_stat)); + } + } + else + local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*row_stat)); + } + + StoreInt8(storeint8).Store(&(out_row_normed[i]), local_quantized_data, valid_items); + + // 2. quantize data with row/col stats + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + // we already pre-normalized the col/row stat: + // what this does is float/absmax*127 = int8 + local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*local_col_stats[j])); + } + + __syncthreads(); + StoreInt8(storeint8).Store(&(out_col_normed[i]), local_quantized_data, valid_items); + + } +} + +template __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols) +{ + + // 0. Load data into 32*32 shared memory tiles + // 1. transpose / reorder in shared memory + // 2. store + + // COL32 FORMAT: + // rows*32 tiles + + // TURING FORMAT: + // 8*32 tiles with 4*4 subtiles + // the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*4 = 64 elements) + // the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero + // the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32]) + // the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column + // index increases by 32 + + // AMPERE FORMAT: + // 32*32 tiles with 8*32 subtiles. The rows are interleaved in pairs of two rows with offset of 8 between pairs of two rows: + // row idx (each number stands for 32 values): [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]... + // the tiles are column-major ordered, so after 1024*1024 values we process: A[32:64, 0:32] + + + // To have efficient loads and stores if we transpose we need 128 consequitive bytes which at 1 byte are 128 values + // As such we need: + // at least 32*4 shared memory tiles for col32; preferably 32*32 + // at least 32*6 shared memory tiles for col32_ampere: preferably 32*32 + // at least 32*8 shared memory tiles for col4_turing: preferably 32*32 + // for efficient loading of row major we need to load 128 elements and repeat this 32 items + // this would imply a 32x128 shared memory tile -> 4kb + // It is more efficient to have more than 1 warp, so with 64 threads we need 32x128 -> 8 kb + // we have 64k sharded mem per SM in Turing which is 8 blocks per SM which is 2*8 = 32 warps = 100% occupancy + // for turing and 50% for A100 and 75% for RTX 30s / A40 which is probably good enough + // register pressure should be low with: 8 registers from local memoryh per block and 64 registers per SM + // + // to make the shared memory work with that occupancy we might need to union the block loads/stores + + // each block loads TILE_COLs columns and TILE_ROW rows + // after reading a tile the row counter increase by TILE_ROWS + // the col counter reset after reading TILE_COL elements + const int base_row = ((blockIdx.x*TILE_COLS)/tiledCols)*TILE_ROWS; + // col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached + const int base_col = (blockIdx.x*TILE_COLS) % tiledCols; + const int base_idx = (base_row*cols) + base_col; + + // we load 128 bytes per warp with + // 32 rows for transposes that fill col32 types + // so that we can have contiguous stores + __shared__ char smem_data[32*33*ITEMS_PER_THREAD]; + char local_data[ITEMS_PER_THREAD]; + typedef hipcub::BlockExchange BlockExchange; + + // we load row after row from the base_position + // Load data row by row + int warps = blockDim.x/32; + int warp_id = threadIdx.x/32; + int warp_lane = threadIdx.x % 32; + int offset = 0; + + int smem_row = 0; + // each warp loads one row of 128 bytes + for(int row = warp_id; row < TILE_ROWS; row+=warps) + { + int i = base_idx + (row*cols); + // we load up to 128 bytes/items per load + int valid_items = cols - base_col > 32*ITEMS_PER_THREAD ? 32*ITEMS_PER_THREAD : cols - base_col; + + // 0. Load data into 32*32 shared memory tiles + if(base_row + row < rows) + { + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + int col_idx = warp_lane+(j*32); + if(col_idx < valid_items) + local_data[j] = A[i+col_idx]; + else + local_data[j] = 0; + } + } + else + { + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + local_data[j] = 0; + } + + if(TRANSPOSE) + { + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + int local_col = (32*j)+warp_lane; + //int local_row = row; + // store as 256x32 + smem_data[(local_col*33) + row] = local_data[j]; + } + } + else + { + // treat smem as 32x256, that is 32 rows and 256 columns + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + smem_data[row*32*ITEMS_PER_THREAD + (warp_lane) + (j*32)] = local_data[j]; + } + + + + smem_row += warps; + + // 1. transpose / reorder in shared memory + if(smem_row % 32 == 0) + { + smem_row = 0; + __syncthreads(); + + for(int subrow = warp_id; subrow < 32; subrow+=warps) + { + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + + switch(FORMAT) + { + case COL32: + if(TRANSPOSE) + { + // data lies in shared memory in the following way: + // row0 [col0 col1 ... col31] + // row1 [col0 col1 ... col31] + // ... + // + // As such we read consequtive entries with 256 threads (8rows x 32 columns) + // as j increase, the row increase by a factor of 8 + // We load 8 rows per subrow loop, and subrow increase by 8 per loop + // so we have an offset of 8 rows every loop or (subrow/warps)*8 = (subrow/8)*8 + const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j + const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps) + //const int local_row = warp_id; // each warp_id is one row + //const int block_row = base_col; // block offset for row + //const int local_col = warp_lane + //const int global_col = base_row; // block offset for col + if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows)) + { + // each row hae 32 columns and is offset by 1 to prevent bank conflict during storage into smem + char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane]; + + // each 32 columns we have new tile + // each tile has size outRows*32 and base_row is done in increments of 32 + offset = base_row*outRows; + out[offset + (base_col + jrow + subrow_loop_row)*32 + threadIdx.x] = data; + } + } + else + { + if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols)) + { + offset = (base_col/32)*(32*rows); + char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane]; + out[offset+(base_row+subrow)*32 + ((j)*rows*32)+warp_lane] = data; + } + } + break; + case COL_TURING: + // TURING FORMAT: + // 8*32 tiles with 4*4 subtiles + // the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*4 = 64 elements) + // the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero + // the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32]) + // the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column + // index increases by 32 + // + // [0 0 0 0, 2 2 2 2, 4 4 4 4, 6 6 6 6, 0 0 0 0 ...] + if(TRANSPOSE) + { + const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j + const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps) + //const int local_row = warp_id; // each warp_id is one row + //const int block_row = base_col; // block offset for row + //const int local_col = warp_lane + //const int global_col = base_row; // block offset for col + if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows)) + { + // each row hae 32 columns and is offset by 1 to prevent bank conflict during storage into smem + char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane]; + + // each 32 columns we have new tile + // each tile has size 8*32 = 256 elements offset + // for each row offset of 8 we increaes the tile first + // after all rows are exhausted, we increase the col + int row_offset = ((base_col+jrow+subrow_loop_row+warp_id)/8)*256; // global_row+jrow+subrow_loop_row+local_row, increase tile(=256) every 8 rows + + // we increase by row_tile_column every 32 columns + // base_row increase in increments of 32 + //int row_tile_column = 256*outRows/8; // there are outRows/8 row tiles, and each tile is 256 elements + //int col_offset = (base_row/32)*row_tile_column; + // -> we can remove the divisions to speed up compute since outRows is always a multiple of 8 + // 256*outRows/8*base_row/32 = outRows*base_row + int col_offset = outRows*base_row; + + offset = row_offset+col_offset; + + // since we process even number of rows with each j (8) and with each subrow (8j) we can determine + // odd or even rows with the warp_id (each warp processes one row) + // the col is warp_lane (max 32 columns per row) and the row warp_id + if(warp_id % 2 == 1) + // odd + offset += 128 + (warp_lane/4)*16 + (warp_lane%4) + (((warp_id%8)-1)*2); + else + // even + offset += 0 + (warp_lane/4)*16 + (warp_lane%4) + ((warp_id%8)*2); + + out[offset] = data; + } + } + else + { + if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols)) + { + char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane]; + // set offset designates the tile offset among the 8*32 tiles + // we first increase rows and then columns. Since we load 128 columns at once + // we increase the offset by outRows*32 every 32 columns + // additionally, we increase the offset by 8*32=256 every 8 rows + offset = ((base_col+(j*32))/32)*outRows*32 + (((base_row+subrow)/8)*256); // global offset (8x32 tile) + // first 4 rows are reserved for even rows, [0, 2, 4, 6], the next 4 for odd + // each of these has 32 values in total for 32*4 = 128 as offset if odd + // every set of 4 columns increases the total offset by 16 + // each even row increase the offset by 4, for example row 2 is offset by 4, 4 by 6 etc so: subrow/2*4 = subrow*2 + // this happends every 8 rows anew (subrow % 8) + // one writes 4 columns at once that is (col % 4) for the particular index in the subtile + int subcol = warp_lane; + + // add local offset (4x4 sub-tile) + if(subrow % 2 == 1) + // odd + offset += 128 + (subcol/4)*16 + (subcol%4) + (((subrow%8)-1)*2); + else + // even + offset += 0 + (subcol/4)*16 + (subcol%4) + ((subrow%8)*2); + + out[offset] = data; + } + } + break; + case COL_AMPERE: + // AMPERE FORMAT: + // 32*32 tiles with 8*32 subtiles. The rows are interleaved in pairs of two rows with offset of 8 between pairs of two rows: + // row idx (each number stands for 32 values): [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]... + // the tiles are column-major ordered, so after 1024*1024 values we process: A[32:64, 0:32] + if(TRANSPOSE) + { + const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j + const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps) + //const int local_row = warp_id; // each warp_id is one row + //const int block_row = base_col; // block offset for row + //const int local_col = warp_lane + //const int global_col = base_row; // block offset for col + if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows)) + { + // each row hae 32 columns and is offset by 1 to prevent bank conflict during storage into smem + char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane]; + + // each 32 columns we have new tile + // each tile has size 32*32 = 1024 elements offset + // for each row offset of 32 we increaes the tile first + // after all rows are exhausted, we increase the col + int row_offset = ((base_col+jrow+subrow_loop_row+warp_id)/32)*1024; // global_row+jrow+subrow_loop_row+local_row, increase tile(=256) every 8 rows + + // we increase by row_tile_column every 32 columns + // base_row increase in increments of 32 + //int row_tile_column = 1024*outRows/32; // there are outRows/32 row tiles, and each tile is 1024 elements + //int col_offset = (base_row/32)*row_tile_column; + // -> we can remove the divisions to speed up compute since outRows is always a multiple of 8 + // 1024*outRows/32*base_row/32 = outRows*base_row + int col_offset = outRows*base_row; + + offset = row_offset+col_offset; + + + // same as in the non-transpose case (see below) + // the difference is that now rows = cols + // in this case warp_id = subrow + + // [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]... + // subrow % 8 -> [0,1] in tile0, [2, 3] in tile 1 etc + // subrow % 2 -> 0 for 1st row in the pair, 1 for the 2nd row + // every 2 rows, the offset increases by two [0, 1, 8, 9...] + // every 2 rows, the row index increase by 8 [0, 1, 8, 9...] + int local_row = (jrow + warp_id) % 32; // offset for row > 32 is already calculated into row_offset + int ampere_row = ((local_row % 8)/2)*8 + (local_row/8)*2 + (local_row % 2); + + // global offset + row with 32 cols each + 32 cols per j + col_idx=warp_lane + out[offset + (ampere_row*32) + warp_lane] = data; + } + } + else + { + if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols)) + { + char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane]; + + // set offset designates the tile offset among the 32*32 tiles + // we first increase rows and then columns. Since we load 128 columns at once + // we increase the offset by outRows*32 every 32 columns + // additionally, we increase the offset by 32*32=1024 every 32 rows + offset = ((base_col+(j*32))/32)*outRows*32 + (((base_row+subrow)/32)*1024); // global offset (32x32 tile) + + // [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]... + // subrow % 8 -> [0,1] in tile0, [2, 3] in tile 1 etc + // subrow % 2 -> 0 for 1st row in the pair, 1 for the 2nd row + // every 2 rows, the offset increases by two [0, 1, 8, 9...] + // every 2 rows, the row index increase by 8 [0, 1, 8, 9...] + int local_row = ((subrow % 8)/2)*8 + (subrow/8)*2 + (subrow % 2); + + // global offset + row with 32 cols each + 32 cols per j + col_idx + out[offset + (local_row*32) + warp_lane] = data; + } + } + break; + } + } + } + } + } +} + +#define DENORM 1.0f/127.0f +#define MAX_SPARSE_COUNT 32 +#define SMEM_SIZE 8*256 +template +__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB) +{ + + // 0. load balancing: We process rows with most columns first (count_vec)and we process one row per block + // If a block finishes, the next one is scheduled. Since the last blocks like have fewer + // elements they finish faster "fillin up" the gaps left by larger blocks + + // without tensor cores + // 1. use rowidx_length to find what to load (as many blocks as there are rows) + // 2. Load A into registers + // 3. each warp loads all required rows of B but each warp is offset by k + // 4. Do mma operations that accumulate into registers + // 5. Each warp stores its output row into matrix C + + const int count = max_count[blockIdx.x]; + const int local_max_idx = max_idx[blockIdx.x]; + const int offset = local_max_idx == 0 ? 0 : offset_rowidx[local_max_idx-1]; + const int local_row_idx = rowidx[offset]; + + const int warp_id = threadIdx.x / 32; + const int warp_idx = threadIdx.x % 32; + const int warp_offset = (warp_id*32)*SPMM_ITEMS; + const int num_items = BITS == 8 ? 8 : 8; + int idx_col_B = warp_offset; + int local_idx_col_B_offset = 0; + + half local_valA[MAX_SPARSE_COUNT]; + int local_colidxA[MAX_SPARSE_COUNT]; + half local_valC[SPMM_ITEMS]; + T local_valsB[num_items]; + half local_valOut[num_items]; + // 128 byte loads per warp == 4 bytes per thread + + // 2. Load A into registers + for(int j = 0; j < MAX_SPARSE_COUNT; j++) + { + local_valA[j] = j < count ? values[offset+j] : __float2half(0.0f); + local_colidxA[j] = j < count ? colidx[offset+j] : 0; + } + + // each thread processes SPMM_ITEMS=32 per iteration. We have 256 threads. 32*256=x192 + // we expect each warp to be SPMM_ITEMS*32 apart + // we have a total of 128 bytes for the bank with a bank size of 4 bytes + // added 3 bytes = 6 values between warps should reduce bank conflicts + __shared__ half smem_dequant_stats[SMEM_SIZE]; + + + while(idx_col_B < colsB) + { + + if(dequant_stats != NULL) + { + for(int i = threadIdx.x; i < SMEM_SIZE; i+=blockDim.x) + if((idx_col_B+i-local_idx_col_B_offset) < colsB) + smem_dequant_stats[i] = dequant_stats[idx_col_B+i-local_idx_col_B_offset]; + + __syncthreads(); + } + + #pragma unroll SPMM_ITEMS + for(int j = 0; j < SPMM_ITEMS; j++) + local_valC[j] = 0.0f; + + #pragma unroll + for(int i = 0; i < count; i++) + { + // 3. each warp loads all required rows of B but each warp is offset by k + int row_offset = colsB*local_colidxA[i]; + + #pragma unroll SPMM_ITEMS + for(int j = 0; j < SPMM_ITEMS; j+=num_items) + { + // 4. Multiply the tile -> accumulate outputs in shared memory until 128 bytes it reached + int idx = idx_col_B + (warp_idx*SPMM_ITEMS) + j; + if(idx >= colsB){ break; } + if((idx+num_items < colsB)) + { + if(BITS == 8) + reinterpret_cast(local_valsB)[0] = reinterpret_cast(B)[(row_offset+ idx)/num_items]; + else + reinterpret_cast(local_valsB)[0] = reinterpret_cast(B)[(row_offset+ idx)/num_items]; + } + else + { + #pragma unroll num_items + for(int k = 0; k < num_items; k++) + if(idx+k < colsB) + local_valsB[k] = B[row_offset+idx+k]; + else + local_valsB[k] = 0.0f; + } + #pragma unroll num_items + for(int k = 0; k < num_items; k++) + { + if(BITS == 8 && dequant_stats != NULL) + // we do texture cache reads (__ldg) on dequant_stats which should be super fast + { + float valB = local_valsB[k]; + float valA = local_valA[i]; + if(valB != 0.0 && valA != 0.0) + local_valC[j+k] = (float)local_valC[j+k] + ((float)smem_dequant_stats[idx+k-local_idx_col_B_offset])*DENORM*valB*valA; + } + else + local_valC[j+k] = (float)local_valC[j+k] + (float)local_valsB[k]*(float)local_valA[i]; + } + } + } + + int idx_row_C = (colsB*local_row_idx); + + #pragma unroll SPMM_ITEMS + for(int j = 0; j < SPMM_ITEMS; j+=num_items) + { + //int idx_col_C = idx_col_B + (32*j) + warp_idx; + int idx_col_C = idx_col_B + warp_idx*SPMM_ITEMS + j; + int idx_val = idx_col_C + idx_row_C; + + if(idx_col_C +num_items < colsB) + { + + // load outputs to do inplace addition + reinterpret_cast(local_valOut)[0] = reinterpret_cast(out)[idx_val/num_items]; + + #pragma unroll num_items + for(int k = 0; k < num_items; k++) + local_valC[(j/num_items) + k] = (float)local_valC[(j/num_items) + k] + (float)local_valOut[k]; + + reinterpret_cast(out)[idx_val/num_items] = reinterpret_cast(local_valC)[j/num_items]; + } + else + { + #pragma unroll num_items + for(int k = 0; k < num_items; k++) + if(idx_col_C + k < colsB) + out[idx_val+k] = (float)out[idx_val+k]+(float)local_valC[j+k]; + } + } + + idx_col_B += blockDim.x*SPMM_ITEMS; + local_idx_col_B_offset += blockDim.x*SPMM_ITEMS; + } +} + +template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA) +{ + int local_colidx = idx[blockIdx.x]; + + /*if(FORMAT==COL_TURING) + { + // TURING FORMAT: + // 8*32 tiles with 4*4 subtiles + // the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*8 = 128 elements) + // the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero + // the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32]) + // the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column + // index increases by 32 + // columns are grouped in increments of 4, meaning that one has the following rows and columns + // rows: [0 0 0 0, 2 2 2 2, 4 4 4 4, 6 6 6 6, 0 0 0 0 ...] + // cols: [0 1 2 3, 0 1 2 4, 0 1 2 3, 0 1 2 3, 4 5 6 7 ...] + + // each thread reads 1 element = 1 row + for(int row = threadIdx.x; row < rowsA; row+= blockDim.x) + { + int offset_per_col_tile = ((rowsA+7)/8)*32*8; + int tile_offset_rows = (row/8)*32*8; + int tile_offset_cols = (local_colidx/32)*offset_per_col_tile; + int offset = 0; + int subtile_col_idx = local_colidx%32; + int subtile_row_idx = row % 8; + if(row % 2 == 1) + offset += 128 + (subtile_col_idx/4)*16 + (subtile_col_idx%4) + ((subtile_row_idx-1)*2); + else + // even + offset += 0 + (subtile_col_idx/4)*16 + (subtile_col_idx%4) + (subtile_row_idx*2); + + offset += tile_offset_rows + tile_offset_cols; + + char val = A[offset]; + + int out_idx = (row*idx_size) + blockIdx.x; + out[out_idx] = val; + } + } + else if(FORMAT == COL_AMPERE) + { + + for(int row = threadIdx.x; row < rowsA; row+= blockDim.x) + { + // we got 32x32 tiles and we use the magic equation from the cublasLt doc to get the element + // within each tile. + int offset_per_col_tile = ((rowsA+31)/32)*32*32; + int tile_offset_rows = (row/32)*32*32; + int tile_offset_cols = (local_colidx/32)*offset_per_col_tile; + int subtile_col_idx = local_colidx%32; + int subtile_row_idx = row % 32; + // this magic is taken from the cublasLt doc (search for COL32) + int offset = (((subtile_row_idx%8)/2*4+subtile_row_idx/8)*2+subtile_row_idx%2)*32+subtile_col_idx; + offset += tile_offset_cols + tile_offset_rows; + + char val = A[offset]; + int out_idx = (row*idx_size) + blockIdx.x; + out[out_idx] = val; + } +}*/ + + //Only col format is used on ROCm + for(int row = threadIdx.x; row < rowsA; row+= blockDim.x) + { + //col-major offset + int offset = local_colidx * rowsA + row; + + char val = A[offset]; + int out_idx = (row*idx_size) + blockIdx.x; + out[out_idx] = val; + } +} + + +//template __global__ void kMatmul_inference_4bit(INPT *A, unsigned char *B, OUTT *out, int lda, int ldb, int rowsA, int colsA, int colsB) +//{ +//// element-wise kernel +//// 1. Load batch x k into registers +//// 2. Load k x k into registers +//// 3. dequantize and store in second pair of k x k +//// 4. matmul +//// 5. sum with hipcub +//// 6. store outputs +//// TC kernel +//// use k warps per thread block +//// 1. threadblock use read-only cache to read in register tile for A into shared memory +//// 2. each warp loops over shared memory tiles of A of size 8x16 and loads them into fragments +//// 3. each warp reads a segment of values 16x32 from B +//// 4. do dequantization from register of B into second pair of registers +//// 5. store (4) into fragment +//// 6. matmul aggregate into fragment C +//// 7. aggreecate files of C into shared memroy block C +//// 8. sum (7) +//// 9. write outputs to matmul output matrix +//} + +template __device__ inline void vector_load(T *local, T * __restrict__ const buffer, int idx, int limit_base, int limit, float zero_value = 0.0f) +{ + if(limit_base + ITEMS <= limit) + reinterpret_cast(local)[0] = reinterpret_cast(buffer)[idx/ITEMS]; + else + { + for(int k = 0; k < ITEMS; k++) + { + if(limit_base + k < limit) + local[k] = buffer[idx+k]; + else + local[k] = (T)zero_value; + } + } +} + +#define WARPS 3 +template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc) +{ + +#if __CUDA_ARCH__ >= 750 + using namespace nvcuda; + int col_offset = blockIdx.x *32; + const int warp_id = threadIdx.x / 32; + const int half_warp_id = threadIdx.x / 16; + const int half_warp_lane = threadIdx.x % 16; + const int batch_size_warps = (WARPS-1)*2; + const int val_per_iter = blockDim.x-32; + + T local_A[4]; + T local_B[128]; + + const int a_tile_offset = 16; + const int b_tile_offset = (16*32 + 16); + + __shared__ T smem_A[8*16 + (2*16*(batch_size_warps-1))]; + __shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))]; + //__shared__ T smem_C[8*32]; + + rocwmma::fragment a_frag; + rocwmma::fragment b_frag; + rocwmma::fragment c_frag; + rocwmma::fill_fragment(c_frag, 0.0f); + + int ticktock = 0; + int idx = 0 + threadIdx.x; + int loaded_values = 0; + // prefetch + if(idx < K && warp_id < (WARPS-1)) + { + if(loaded_values == 0) + { + local_A[0] = A[idx]; + local_A[1] = A[idx+(1*val_per_iter)]; + local_A[2] = A[idx+(2*val_per_iter)]; + local_A[3] = A[idx+(3*val_per_iter)]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + { + local_B[col] = B[(col_offset+col)*ldb+idx]; + local_B[col+32] = B[(col_offset+col)*ldb+idx+(1*val_per_iter)]; + local_B[col+64] = B[(col_offset+col)*ldb+idx+(2*val_per_iter)]; + local_B[col+96] = B[(col_offset+col)*ldb+idx+(3*val_per_iter)]; + } + loaded_values = 3; + } + else + { + + if(loaded_values == 3) + { + local_A[0] = local_A[1]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(32)]; + } + else if(loaded_values == 2) + { + local_A[0] = local_A[2]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(64)]; + } + else + { + local_A[0] = local_A[3]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(96)]; + } + loaded_values--; + } + + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; + } + else if(warp_id < (WARPS-1)) + { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; + } + ticktock = ticktock == 0 ? 1 : 0; + + //for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) + for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) + { + idx = base_idx + threadIdx.x; + + __syncthreads(); + if(idx < K && warp_id < (WARPS-1)) + { + //local_A[0] = A[idx]; + + //#pragma unroll 32 + //for(int col = 0; col < 32; col++) + // local_B[col] = B[(col_offset+col)*ldb+idx]; + if(loaded_values == 0) + { + local_A[0] = A[idx]; + local_A[1] = A[idx+(1*val_per_iter)]; + local_A[2] = A[idx+(2*val_per_iter)]; + local_A[3] = A[idx+(3*val_per_iter)]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + { + local_B[col] = B[(col_offset+col)*ldb+idx]; + local_B[col+32] = B[(col_offset+col)*ldb+idx+(1*val_per_iter)]; + local_B[col+64] = B[(col_offset+col)*ldb+idx+(2*val_per_iter)]; + local_B[col+96] = B[(col_offset+col)*ldb+idx+(3*val_per_iter)]; + } + loaded_values = 3; + + } + else + { + + if(loaded_values == 3) + { + local_A[0] = local_A[1]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(32)]; + } + else if(loaded_values == 2) + { + local_A[0] = local_A[2]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(64)]; + } + else + { + local_A[0] = local_A[3]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(96)]; + } + loaded_values--; + } + + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; + } + else if(warp_id < (WARPS-1)) + { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; + } + ticktock = ticktock == 0 ? 1 : 0; + + if(warp_id == (WARPS-1)) + for(int k = 0; k < batch_size_warps; k++) + { + rocwmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + rocwmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + rocwmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + } + } + + __syncthreads(); + if(warp_id != (WARPS-1)){ return; } + // only warp_id == (WARPS-1) from here + int warp_lane = threadIdx.x % 32; + + ticktock = ticktock == 0 ? 1 : 0; + for(int k = 0; k < batch_size_warps; k++) + { + rocwmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + rocwmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + rocwmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + } + + // 129 mu + if(warp_id == (WARPS-1)) + rocwmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, rocwmma::mem_row_major); + + if(col_offset + warp_lane < M) + out[col_offset + warp_lane] = smem_A[warp_lane]; +#endif +} + + +template __device__ void printnonzero(T *A, int num_values, const char * strval) +{ + for(int i = 0; i < num_values; i++) + if((float)A[i] != 0.0) + printf("%s %i %f\n", strval, i, (float)A[i]); +} + +template __device__ void printnonzero(float *A, int num_values, const char*strval); +template __device__ void printnonzero(half *A, int num_values, const char*strval); + +__device__ static float nf4_data[16] = {-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0}; +template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) +{ + +#if __CUDA_ARCH__ >= 750 + using namespace nvcuda; + int col_offset = blockIdx.x *32; + const int warp_id = threadIdx.x / 32; + const int warp_idx = threadIdx.x % 32; + const int half_warp_id = threadIdx.x / 16; + const int half_warp_lane = threadIdx.x % 16; + const int batch_size_warps = (WARPS-1)*2; + + T quant_map[16]; + + #pragma unroll 16 + for(int i = 0; i < 16; i++) + quant_map[i] = nf4_data[i]; + //__shared__ T quant_map[16*160]; + + T local_A[2]; + T local_B[64]; + unsigned char local_B_4bit[32]; + + + const int a_tile_offset = 16; + const int b_tile_offset = (16*32 + 16); + + __shared__ T smem_A[8*16 + (16*(batch_size_warps-1))]; + __shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))]; + __shared__ T smem_C[8*32]; + + rocwmma::fragment a_frag; + rocwmma::fragment b_frag; + rocwmma::fragment c_frag; + rocwmma::fill_fragment(c_frag, 0.0f); + + for(int i = threadIdx.x; i < (8*32); i+=blockDim.x) + smem_C[i] = 0.0f; + + __syncthreads(); + + int ticktock = 0; + int idx = 0 + threadIdx.x; + int loaded_values = 0; + // prefetch + if(idx < K && warp_id < (WARPS-1)) + { + if(loaded_values == 0) + { + local_A[0] = A[idx]; + local_A[1] = A[idx+blockDim.x-32]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B_4bit[col] = B[(col_offset+col)*ldb+idx]; + + loaded_values = 1; + } + else + { + local_A[0] = local_A[1]; + loaded_values--; + + #pragma unroll 64 + for(int col = 0; col < 64; col+=2) + { + //local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(1.0f); + //local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(1.0f); + //local_B[col] = d2DequantizeFP4(local_B_4bit[col/2] >> 4)*(float)(17.0); + //local_B[col+1] = d2DequantizeFP4(local_B_4bit[col/2] & 0x0F)*(float)(17.0); + //local_B[col] = 127*(local_B_4bit[col/2] >> 4)*(float)(17.0); + //local_B[col+1] = 127*(local_B_4bit[col/2] & 0x0F)*(float)(17.0); + + //local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(17.0); + //local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(17.0); + local_B[col] = quant_map[160*(local_B_4bit[col/2] >> 4)+warp_idx]*T(17.0); + local_B[col+1] = quant_map[160*(local_B_4bit[col/2] & 0x0F)+warp_idx]*T(17.0); + } + } + + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; + } + else if(warp_id < (WARPS-1)) + { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; + } + ticktock = ticktock == 0 ? 1 : 0; + //if(threadIdx.x == 0) + //printf("aa %i %i\n", idx, loaded_values); + + //for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) + for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) + { + idx = base_idx + threadIdx.x; + //if(threadIdx.x == 0) + //printf("%i %i\n", idx, loaded_values); + + //__syncthreads(); + if(idx < K && warp_id < (WARPS-1)) + { + if(loaded_values == 0) + { + local_A[0] = A[idx]; + local_A[1] = A[idx+blockDim.x-32]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + { + local_B_4bit[col] = B[(col_offset+col)*ldb+idx]; + local_B_4bit[col+16] = B[(col_offset+col)*ldb+idx]; + } + + loaded_values = 1; + } + else + { + local_A[0] = local_A[1]; + loaded_values--; + + int absidx = (idx + col_offset)/blocksize; + half local_absmax = __ldg(&(absmax[absidx])); + + #pragma unroll 64 + for(int col = 0; col < 64; col+=2) + { + //local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(absidx); + //local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(absidx); + //local_B[col] = T(127)*T(local_B_4bit[col/2] >> 4)*T(absidx); + //local_B[col+1] = T(127)*T(local_B_4bit[col/2] & 0x0F)*T(absidx); + + //local_B[col] = quant_map[160*(local_B_4bit[col/2] >> 4)+warp_idx]*T(local_absmax); + //local_B[col+1] = quant_map[160*(local_B_4bit[col/2] & 0x0F)+warp_idx]*T(local_absmax); + local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(absidx); + local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(absidx); + } + //printnonzero(local_B, 128, ""); + } + + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; + } + else if(warp_id < (WARPS-1)) + { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; + } + ticktock = ticktock == 0 ? 1 : 0; + + if(warp_id == (WARPS-1)) + for(int k = 0; k < batch_size_warps; k++) + { + rocwmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + rocwmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + rocwmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + } + } + + __syncthreads(); + //if(threadIdx.x == 0) + //{ + // printnonzero(smem_A, 8*16 + (2*16*(batch_size_warps-1)), "A: "); + // printnonzero(smem_B, 2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1)), "B: "); + //} + if(warp_id != (WARPS-1)){ return; } + // only warp_id == (WARPS-1) from here + int warp_lane = threadIdx.x % 32; + + ticktock = ticktock == 0 ? 1 : 0; + for(int k = 0; k < batch_size_warps; k++) + { + //if(warp_lane == 0) + //printf("%i %i %i %i\n", (ticktock*batch_size_warps + k)*a_tile_offset, k, ticktock, threadIdx.x); + rocwmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + rocwmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + rocwmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + } + + // 129 mu + if(warp_id == (WARPS-1)) + rocwmma::store_matrix_sync(&(smem_C[0]), c_frag, 32, rocwmma::mem_row_major); + + //printnonzero(smem_C, 32, ""); + + if(col_offset + warp_lane < M) + out[col_offset + warp_lane] = smem_C[warp_lane]; +#endif +} + +#define num_values_4bit 32 +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize) +{ + + // per threadblock: + // load step-by-step in chunks of [32,warps]: 1x32 * [32,warps] -> [1,warps] + // 4 warps -> 4 loads per iter + // 1x32 * 32x4 -> 1x4 outputs per thread block + typedef hipcub::WarpReduce WarpReduce; + __shared__ typename WarpReduce::TempStorage temp_storage[THREADS/32]; + + const int warp_idx = threadIdx.x / 32; + const int warp_lane = threadIdx.x % 32; + const int row_B = (THREADS/32)*blockIdx.x + warp_idx; + const int num_values_8bit = num_values_4bit/2; + float local_C = 0.0f; + + unsigned char local_B_4bit[num_values_8bit]; + T local_B[num_values_4bit/4]; + T local_A[num_values_4bit/4]; + __shared__ T quant_map[16]; + T local_absmax = T(0.0f); + + for(int i = threadIdx.x; i < 16; i++) + quant_map[i] = T(datatype[i]); + __syncthreads(); + + // A: [1, K] + // B: [N, K] + for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += 32*num_values_4bit) + { + int inner_idx_halved = inner_idx/2; + int offset_B = ldb*row_B; + int absidx = ((2*offset_B)+inner_idx)/blocksize; + local_absmax = __ldg(&(absmax[absidx])); + + if(row_B < M) + { + if((inner_idx_halved + num_values_8bit) < (K/2)) + { + // this is the most important for performance considerations + reinterpret_cast(local_B_4bit)[0] = reinterpret_cast(B)[(offset_B+(inner_idx_halved))/(num_values_8bit)]; + } + else + { + #pragma unroll + for(int j = 0; j < (num_values_8bit); j++) + if((inner_idx_halved) + j < (K/2)) + local_B_4bit[j] = B[offset_B+inner_idx_halved + j]; + else + local_B_4bit[j] = 0b01110111; + } + } + else + { + #pragma unroll + for(int j = 0; j < (num_values_8bit); j++) + local_B_4bit[j] = 0b01110111; + } + + for(int i = 0; i < 4; i++) + { + #pragma unroll + for(int k = 0; k < num_values_8bit/4; k++) + { + #if __CUDA_ARCH__ >= 800 + local_B[k*2] = quant_map[local_B_4bit[(i*num_values_8bit/4) + k] >> 4]*local_absmax; + local_B[k*2 + 1] = quant_map[local_B_4bit[(i*num_values_8bit/4) + k] & 0x0F]*local_absmax; + #else + // bf16 multipliation not supported + local_B[k*2] = T((float)quant_map[local_B_4bit[(i*num_values_8bit/4) + k] >> 4]*(float)local_absmax); + local_B[k*2 + 1] = T((float)quant_map[local_B_4bit[(i*num_values_8bit/4) + k] & 0x0F]*(float)local_absmax); + #endif + } + + if(inner_idx+(num_values_4bit/4) + (i*num_values_4bit/4) < K) + { + // this is also relatively important for performance + if(BITS==16) + { + reinterpret_cast(local_A)[0] = reinterpret_cast(A)[inner_idx/(num_values_4bit/4) + i]; + } + else + { + reinterpret_cast(local_A)[0] = reinterpret_cast(A)[inner_idx/(num_values_4bit/8) + (2*i) + 0]; + reinterpret_cast(local_A)[1] = reinterpret_cast(A)[inner_idx/(num_values_4bit/8) + (2*i) + 1]; + } + + } + else + #pragma unroll + for(int k = 0; k < num_values_4bit/4; k++) + if(inner_idx + (i*num_values_4bit/4) + k < K) + local_A[k] = A[inner_idx + k + (i*num_values_4bit/4)]; + else + local_A[k] = T(0.0f); + + + // accumulate in float; small performance hit for Ampere, but lower error for outputs + #pragma unroll + for(int k = 0; k < num_values_4bit/4; k++) + { + #if __CUDA_ARCH__ >= 800 + local_C += (float)(local_A[k]*local_B[k]); + #else + // bf16 multipliation not supported + local_C += ((float)local_A[k]*(float)local_B[k]); + #endif + } + } + } + + local_C = WarpReduce(temp_storage[warp_idx]).Sum(local_C); + + if(row_B < M && warp_lane == 0) + out[row_B] = T(local_C); + +} + + +//#define ROWS 2 +//template __global__ void gemm_device(int M, int N, int K, T const* A, T* B, T * out, int lda, int ldb, int ldc) +//{ +//// 0. We want to fill a 8x128 tile for a thread block so we have 8x16 tile for each warp +//// 1. Load dataB into register +//// 2. Dequantize B +//// 3. Fetch data from A and multiply +// +// typedef hipcub::BlockLoad LoadA; +// //__shared__ typename LoadA::TempStorage loada; +// typedef hipcub::BlockLoad LoadB; +// //__shared__ typename LoadB::TempStorage loadb; +// typedef hipcub::BlockReduce BlockReduce; +// // Allocate shared memory for BlockReduce +// //__shared__ typename BlockReduce::TempStorage reduce; +// +// __shared__ union { +// typename BlockReduce::TempStorage reduce; +// typename LoadB::TempStorage loadb; +// typename LoadA::TempStorage loada; +// } temp_storage; +// +// +// T dataA[ITEMS]; +// T local_B[ITEMS]; +// T local_accC[ROWS]; +// int valid_items = 0; +// const int col_offset = blockIdx.x * 8; +// +// __shared__ T tileA[ROWS*THREADS*ITEMS]; +// __shared__ T accumulatorC[ROWS*8]; +// +// //#pragma unroll 8 +// //for(int i = 0; i < 8; i++) +// // tileA[threadIdx.x + (i*256)] = 0.0f; +// //__syncthreads(); +// if(threadIdx.x < 64) +// accumulatorC[threadIdx.x] = 0.0f; +// __syncthreads(); +// +// +// for(int inner_idx = 0; inner_idx < K; inner_idx+= THREADS*ITEMS) +// { +// valid_items = K - inner_idx > THREADS*ITEMS ? THREADS*ITEMS : K - inner_idx; +// int baserow = 0; +// for(int row = baserow; row < (baserow+ROWS) && row < N; row++) +// { +// LoadA(temp_storage.loada).Load(&(A[(row*K) + inner_idx]), dataA, valid_items, 0.0f); +// +// #pragma unroll ITEMS +// for(int k = 0; k < ITEMS; k++) +// tileA[row*THREADS*ITEMS + threadIdx.x + (k*THREADS)] = dataA[k]; +// +// __syncthreads(); +// } +// baserow += ROWS; +// +// // load 16 columns from B at a time. B is transposed, so its like loading rows +// // each warp loads one row +// // each thread loads 128 byte +// +// // col: inner_idx + warp_lane +// // row: ldb*(offset + warp_id) +// for(int col = 0; col < 8 && (col_offset + col) < M; col++) +// { +// int colB = col_offset + col; +// +// for(int k = 0; k < ROWS; k++) +// local_accC[k] = 0.0f; +// +// int base_idxB = ldb*colB; +// valid_items = K - inner_idx > THREADS*ITEMS ? THREADS*ITEMS : K - inner_idx; +// LoadB(temp_storage.loadb).Load(&(B[base_idxB + inner_idx]), local_B, valid_items, 0.0f); +// __syncthreads(); +// +// for(int row = 0; row < ROWS && row < N; row++) +// { +// #pragma unroll ITEMS +// for(int k = 0; k < ITEMS; k++) +// { +// int idxA = row*THREADS*ITEMS + threadIdx.x + (THREADS*k); +// local_accC[row] += tileA[idxA]*local_B[k]; +// } +// +// local_accC[row] = BlockReduce(temp_storage.reduce).Reduce(local_accC[row], hipcub::Sum()); +// if(threadIdx.x == 0) +// atomicAdd(&accumulatorC[row*8 + col], local_accC[row]); +// } +// } +// } +// +// for(int row = 0; row < ROWS && row < N; row++) +// { +// int out_idx = ldc*row + col_offset; +// +// //if(threadIdx.x < 8) +// // if(accumulatorC[row*8 + threadIdx.x] != 0.0) +// // printf("%i %i %i %i %f idx %i %i %i\n", row, col_offset, threadIdx.x, N, accumulatorC[row*8 + threadIdx.x], ldc, out_idx, blockIdx.x); +// +// if(threadIdx.x < 8 && (col_offset + threadIdx.x) < M) +// { +// //printf("%i %i %i %i %f idx %i %i\n", row, col_offset, threadIdx.x, N, accumulatorC[row*8 + threadIdx.x], ldc, out_idx); +// out[out_idx + threadIdx.x] = accumulatorC[row*8 + threadIdx.x]; +// } +// } +// +// +// +//} + + +template __global__ void kfunc(T *A, T *B, T value, long n) +{ + for(long i = (blockDim.x*blockIdx.x) + threadIdx.x; i < n; i+=(blockDim.x*gridDim.x)) + { + switch(FUNC) + { + case FILL: + A[i] = (T)value; + break; + case ARANGE: + A[i] = (T)i; + break; + case _MUL: + A[i] = A[i]*B[i]; + break; + } + } +} + + +//============================================================== +// TEMPLATE DEFINITIONS +//============================================================== + +template __global__ void kfunc(float *A, float *B, float value, long n); +template __global__ void kfunc(unsigned char *A, unsigned char *B, unsigned char value, long n); +template __global__ void kfunc(float *A, float *B, float value, long n); +template __global__ void kfunc(float *A, float *B, float value, long n); + +// these are not used and make no sense, but the compiler needs them +//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +// these are not used and make no sense, but the compiler needs them + +//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); + +template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); + +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, half * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, hip_bfloat16 * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, hip_bfloat16 * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, float * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, float * out, int lda, int ldb, int ldc, int blocksize); + +template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); +template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); + +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); + +template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); +template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); +template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_TURING>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); +template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_TURING>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); +template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); +template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); + +template __global__ void kdequant_mm_int32_fp16<4, 128, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, float* newRowStats, float* newcolStats, half * __restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n); + +template __global__ void kDoubleRowColQuant<64, 4, 16, 64*4, 0>(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols); +template __global__ void kDoubleRowColQuant<64, 4, 16, 64*4, 1>(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols); + +template __device__ unsigned char dQuantize<0>(float* smem_code, const float rand, float x); +template __device__ unsigned char dQuantize<1>(float* smem_code, const float rand, float x); + +template __global__ void kEstimateQuantiles(float *__restrict__ const A, float *code, const float offset, const float max_val, const int n); +template __global__ void kEstimateQuantiles(half *__restrict__ const A, float *code, const float offset, const half max_val, const int n); + +#define MAKE_PreconditionOptimizer32bit1State(oname, gtype) \ +template __global__ void kPreconditionOptimizer32bit1State(gtype* g, gtype* p, \ + float* state1, float *unorm, \ + const float beta1, const float beta2, const float eps, const float weight_decay, \ + const int step, const float lr, const float gnorm_scale, const int n); \ + +MAKE_PreconditionOptimizer32bit1State(MOMENTUM, half) +MAKE_PreconditionOptimizer32bit1State(MOMENTUM, float) +MAKE_PreconditionOptimizer32bit1State(RMSPROP, half) +MAKE_PreconditionOptimizer32bit1State(RMSPROP, float) +MAKE_PreconditionOptimizer32bit1State(LION, half) +MAKE_PreconditionOptimizer32bit1State(LION, float) +MAKE_PreconditionOptimizer32bit1State(LION, hip_bfloat16) +MAKE_PreconditionOptimizer32bit1State(ADAGRAD, half) +MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float) + +#define MAKE_Optimizer32bit1State(oname, gtype) \ +template __global__ void kOptimizer32bit1State(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \ + const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \ + +MAKE_Optimizer32bit1State(MOMENTUM, half) +MAKE_Optimizer32bit1State(MOMENTUM, float) +MAKE_Optimizer32bit1State(RMSPROP, half) +MAKE_Optimizer32bit1State(RMSPROP, float) +MAKE_Optimizer32bit1State(LION, half) +MAKE_Optimizer32bit1State(LION, float) +MAKE_Optimizer32bit1State(LION, hip_bfloat16) +MAKE_Optimizer32bit1State(ADAGRAD, half) +MAKE_Optimizer32bit1State(ADAGRAD, float) + +#define MAKE_PreconditionOptimizer32bit2State(oname, gtype) \ +template __global__ void kPreconditionOptimizer32bit2State(gtype* g, gtype* p, \ + float* state1, float* state2, float *unorm, \ + const float beta1, const float beta2, const float eps, const float weight_decay, \ + const int step, const float lr, const float gnorm_scale, const int n); \ + +MAKE_PreconditionOptimizer32bit2State(ADAM, float) +MAKE_PreconditionOptimizer32bit2State(ADAM, half) +MAKE_PreconditionOptimizer32bit2State(ADAM, hip_bfloat16) + +template __global__ void kOptimizer32bit2State(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); +template __global__ void kOptimizer32bit2State(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); +template __global__ void kOptimizer32bit2State(hip_bfloat16* g, hip_bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); + + +#define MAKE_PreconditionStatic8bit1State(oname, gtype) \ +template __global__ void kPreconditionOptimizerStatic8bit1State(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, \ + float *unorm, \ + const float beta1, \ + const float beta2, \ + const float eps, const int step, \ + float* __restrict__ const quantiles1, \ + float* max1, float* new_max1, \ + const float weight_decay, \ + const float gnorm_scale, \ + const int n); \ + +MAKE_PreconditionStatic8bit1State(MOMENTUM, half) +MAKE_PreconditionStatic8bit1State(MOMENTUM, float) +MAKE_PreconditionStatic8bit1State(RMSPROP, half) +MAKE_PreconditionStatic8bit1State(RMSPROP, float) +MAKE_PreconditionStatic8bit1State(LION, half) +MAKE_PreconditionStatic8bit1State(LION, float) + +#define MAKE_optimizerStatic8bit1State(oname, gtype) \ +template __global__ void kOptimizerStatic8bit1State(gtype* p, gtype* const g, unsigned char* state1, \ + const float *unorm, const float max_unorm, const float param_norm, \ + const float beta1, \ + const float beta2, \ + const float eps, const int step, const float lr, \ + float* __restrict__ const quantiles1, \ + float* max1, float* new_max1, \ + float weight_decay, \ + const float gnorm_scale, \ + const int n); \ + +MAKE_optimizerStatic8bit1State(MOMENTUM, half) +MAKE_optimizerStatic8bit1State(MOMENTUM, float) +MAKE_optimizerStatic8bit1State(RMSPROP, half) +MAKE_optimizerStatic8bit1State(RMSPROP, float) +MAKE_optimizerStatic8bit1State(LION, half) +MAKE_optimizerStatic8bit1State(LION, float) + +#define MAKE_PreconditionStatic8bit2State(oname, gtype) \ +template __global__ void kPreconditionOptimizerStatic8bit2State(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, \ + float *unorm, \ + const float beta1, const float beta2, \ + const float eps, const int step, \ + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \ + float* max1, float* max2, float* new_max1, float* new_max2, \ + const float gnorm_scale, \ + const int n); \ + +MAKE_PreconditionStatic8bit2State(ADAM, half) +MAKE_PreconditionStatic8bit2State(ADAM, float) + +#define MAKE_optimizerStatic8bit2State(oname, gtype) \ +template __global__ void kOptimizerStatic8bit2State(gtype* p, gtype* const g, unsigned char* state1, unsigned char* state2, \ + const float *unorm, const float max_unorm, const float param_norm, \ + const float beta1, const float beta2, \ + const float eps, const int step, const float lr, \ + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \ + float* max1, float* max2, float* new_max1, float* new_max2, \ + float weight_decay, \ + const float gnorm_scale, \ + const int n); \ + +MAKE_optimizerStatic8bit2State(ADAM, half) +MAKE_optimizerStatic8bit2State(ADAM, float) + +template __global__ void kPercentileClipping(float * __restrict__ g, float *gnorm_vec, int step, const int n); +template __global__ void kPercentileClipping(half * __restrict__ g, float *gnorm_vec, int step, const int n); + +#define MAKE_kQuantizeBlockwise(dtype, blocksize, num_per_thread, stochastic, data_type_name) \ +template __global__ void kQuantizeBlockwise(float * code, dtype * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); \ + +MAKE_kQuantizeBlockwise(half, 4096, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 4096, 4, 1, General8bit) +MAKE_kQuantizeBlockwise(half, 2048, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 1024, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 512, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 256, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 128, 2, 0, General8bit) +//MAKE_kQuantizeBlockwise(half, 64, 2, 0, General8bit) + +MAKE_kQuantizeBlockwise(half, 4096, 4, 0, FP4) +MAKE_kQuantizeBlockwise(half, 2048, 4, 0, FP4) +MAKE_kQuantizeBlockwise(half, 1024, 4, 0, FP4) +MAKE_kQuantizeBlockwise(half, 512, 2, 0, FP4) +MAKE_kQuantizeBlockwise(half, 256, 2, 0, FP4) +MAKE_kQuantizeBlockwise(half, 128, 2, 0, FP4) +//MAKE_kQuantizeBlockwise(half, 64, 2, 0, FP4) + +MAKE_kQuantizeBlockwise(half, 4096, 4, 0, NF4) +MAKE_kQuantizeBlockwise(half, 2048, 4, 0, NF4) +MAKE_kQuantizeBlockwise(half, 1024, 4, 0, NF4) +MAKE_kQuantizeBlockwise(half, 512, 2, 0, NF4) +MAKE_kQuantizeBlockwise(half, 256, 2, 0, NF4) +MAKE_kQuantizeBlockwise(half, 128, 2, 0, NF4) +//MAKE_kQuantizeBlockwise(half, 64, 2, 0, NF4) + +MAKE_kQuantizeBlockwise(float, 4096, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 4096, 4, 1, General8bit) +MAKE_kQuantizeBlockwise(float, 2048, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 1024, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 512, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 256, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 128, 2, 0, General8bit) +//MAKE_kQuantizeBlockwise(float, 64, 2, 0, General8bit) + +MAKE_kQuantizeBlockwise(float, 4096, 4, 0, FP4) +MAKE_kQuantizeBlockwise(float, 2048, 4, 0, FP4) +MAKE_kQuantizeBlockwise(float, 1024, 4, 0, FP4) +MAKE_kQuantizeBlockwise(float, 512, 2, 0, FP4) +MAKE_kQuantizeBlockwise(float, 256, 2, 0, FP4) +MAKE_kQuantizeBlockwise(float, 128, 2, 0, FP4) +//MAKE_kQuantizeBlockwise(float, 64, 2, 0, FP4) + +MAKE_kQuantizeBlockwise(float, 4096, 4, 0, NF4) +MAKE_kQuantizeBlockwise(float, 2048, 4, 0, NF4) +MAKE_kQuantizeBlockwise(float, 1024, 4, 0, NF4) +MAKE_kQuantizeBlockwise(float, 512, 2, 0, NF4) +MAKE_kQuantizeBlockwise(float, 256, 2, 0, NF4) +MAKE_kQuantizeBlockwise(float, 128, 2, 0, NF4) +//MAKE_kQuantizeBlockwise(float, 64, 2, 0, NF4) + +MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 1, General8bit) +MAKE_kQuantizeBlockwise(hip_bfloat16, 2048, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, General8bit) +//MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, General8bit) + +MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 0, FP4) +MAKE_kQuantizeBlockwise(hip_bfloat16, 2048, 4, 0, FP4) +MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, FP4) +MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, FP4) +MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, FP4) +MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, FP4) +//MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, FP4) + +MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 0, NF4) +MAKE_kQuantizeBlockwise(hip_bfloat16, 2048, 4, 0, NF4) +MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, NF4) +MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, NF4) +MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, NF4) +MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, NF4) +//MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, NF4) + +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, hip_bfloat16 *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, hip_bfloat16 *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, hip_bfloat16 *out, const int blocksize, const int n); + +#define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \ +template __global__ void kOptimizerStatic8bit2StateBlockwise(gtype* p, gtype* __restrict__ const g, unsigned char* state1, unsigned char* state2, \ + const float beta1, const float beta2, \ + const float eps, const int step, const float lr, \ + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \ + float* absmax1, float* absmax2, \ + float weight_decay, \ + const float gnorm_scale, const bool skip_zeros, const int n); \ + +MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, float, 2048, 8) +MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, half, 2048, 8) +MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, hip_bfloat16, 2048, 8) + + +#define MAKE_OptimizerStatic8bit1StateBlockwise(oname, gtype, block_size, num_per_thread) \ +template __global__ void kOptimizerStatic8bit1StateBlockwise( \ + gtype* p, gtype* __restrict__ const g, unsigned char* state1, \ + const float beta1, const float beta2, \ + const float eps, const int step, const float lr, \ + float* __restrict__ const quantiles1, \ + float* absmax1, \ + float weight_decay, \ + const float gnorm_scale, const bool skip_zeros, const int n); \ + +MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, half, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, half, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(LION, float, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(LION, half, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(LION, hip_bfloat16, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 2048, 8) diff --git a/csrc/kernels.hip.h b/csrc/kernels.hip.h new file mode 100644 index 000000000..b0b942cff --- /dev/null +++ b/csrc/kernels.hip.h @@ -0,0 +1,134 @@ +// !!! This is a file automatically generated by hipify!!! +#include "hip/hip_runtime.h" +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include + +#ifndef kernels +#define kernels + +//template __global__ void kMatmul_inference_4bit(INP_TYPE *A, unsigned char *B, OUT_TYPE *out, int lda, int ldb, int rowsA, int colsA, int colsB); + +template__global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n); + +__global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n); +__global__ void kDequantize(float *code, unsigned char *A, float *out, const int n); + +template __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n); + +template +__global__ void kPreconditionOptimizer32bit2State(T* g, T* p, + float* state1, float* state2, float *unorm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const int n); + +template +__global__ void kOptimizer32bit2State(T* g, T* p, + float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); + +template +__global__ void kPreconditionOptimizer32bit1State(T* g, T* p, + float* state1, float *unorm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const int n); + +template +__global__ void kOptimizer32bit1State(T* g, T* p, + float* state1, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); + +template +__global__ void +kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, + float *unorm, + const float beta1, const float beta2, + const float eps, const int step, + float* __restrict__ const quantiles1, + float* max1, float* new_max1, + const float weight_decay, + const float gnorm_scale, const int n); + + +template +__global__ void +kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, + const float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, + float* max1, float* new_max1, + float weight_decay, const float gnorm_scale, const int n); + + + +template +__global__ void +kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, + float *unorm, + const float beta1, const float beta2, + const float eps, const int step, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + const float gnorm_scale, const int n); + + +template +__global__ void +kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned char* state2, + const float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + float weight_decay, const float gnorm_scale, const int n); + +template __global__ void kOptimizerStatic8bit2StateBlockwise( + T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, + const float beta1, const float beta2, const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, + float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n); + +template __global__ void kOptimizerStatic8bit1StateBlockwise( + T* p, T* __restrict__ const g, unsigned char* state1, + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, + float* absmax1, + float weight_decay, + const float gnorm_scale, const bool skip_zeros, const int n); + + +template __global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n); + + +__global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n); + + +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); + +template __global__ void kdequant_mm_int32_fp16( + int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, + half *out, float* newRowStats, float* newcolStats, half * __restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n); + +template __global__ void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols); +template __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols); + +template __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); + +template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); + +template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc); +template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize); + +template __global__ void kfunc(T *A, T *B, T value, long n); + +#endif diff --git a/csrc/mps_kernels.metal b/csrc/mps_kernels.metal new file mode 100644 index 000000000..63b3bf78c --- /dev/null +++ b/csrc/mps_kernels.metal @@ -0,0 +1,117 @@ +#include +using namespace metal; + +#define HLF_MAX 65504 +#define TH 1024 +#define NUM 4 +#define NUM_BLOCK 4096 + +template +static unsigned char quantize_scalar( + float rand, + device float* code, + float x) +{ + int pivot = 127; + int upper_pivot = 255; + int lower_pivot = 0; + + float lower = -1.0f; + float upper = 1.0f; + + float val = code[pivot]; + // i>>=1 = {32, 16, 8, 4, 2, 1} + for(int i = 64; i > 0; i>>=1) + { + if(x > val) + { + lower_pivot = pivot; + lower = val; + pivot+=i; + } + else + { + upper_pivot = pivot; + upper = val; + pivot-=i; + } + val = code[pivot]; + } + + if(upper_pivot == 255) + upper = code[upper_pivot]; + if(lower_pivot == 0) + lower = code[lower_pivot]; + + if(!STOCHASTIC) + { + if(x > val) + { + float midpoint = (upper+val)*0.5f; + if(x > midpoint) + { + return upper_pivot; + } + else + return pivot; + } + else + { + float midpoint = (lower+val)*0.5f; + if(x < midpoint) + return lower_pivot; + else + return pivot; + } + } + else + { + if(x > val) + { + float dist_to_upper = fabs(upper-x); + float dist_full = upper-val; + if(rand >= dist_to_upper/dist_full) return upper_pivot; + else return pivot; + } + else + { + float dist_to_lower = fabs(lower-x); + float dist_full = val-lower; + if(rand >= dist_to_lower/dist_full) return lower_pivot; + else return pivot; + } + } +} + +kernel void quantize(device float* code [[buffer(0)]], + device float* A [[buffer(1)]], + device uchar* out [[buffer(2)]], + constant uint& n [[buffer(3)]], + uint id [[thread_position_in_grid]]) { + const uint n_full = (NUM_BLOCK * (n / NUM_BLOCK)) + (n % NUM_BLOCK == 0 ? 0 : NUM_BLOCK); + uint valid_items = (id / NUM_BLOCK + 1 == (n + NUM_BLOCK - 1) / NUM_BLOCK) ? n - (id / NUM_BLOCK * NUM_BLOCK) : NUM_BLOCK; + const uint base_idx = (id / NUM_BLOCK * NUM_BLOCK); + + float vals[NUM]; + uchar qvals[NUM]; + + for (uint i = base_idx; i < n_full; i += ((n + NUM_BLOCK - 1) / NUM_BLOCK) * NUM_BLOCK) { + valid_items = n - i > NUM_BLOCK ? NUM_BLOCK : n - i; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (uint j = 0; j < valid_items; j++) { + vals[j] = A[i + j]; + } + + for (uint j = 0; j < valid_items; j++) { + qvals[j] = quantize_scalar(0.0f, code, vals[j]); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (uint j = 0; j < valid_items; j++) { + out[i + j] = qvals[j]; + } + } +} diff --git a/csrc/mps_ops.h b/csrc/mps_ops.h new file mode 100644 index 000000000..e69de29bb diff --git a/csrc/mps_ops.mm b/csrc/mps_ops.mm new file mode 100644 index 000000000..d198b3552 --- /dev/null +++ b/csrc/mps_ops.mm @@ -0,0 +1,67 @@ +#import + +#define HLF_MAX 65504 +#define TH 1024 +#define NUM 4 +#define NUM_BLOCK 4096 + +static inline MPSGraph* get_graph() +{ + static MPSGraph* cur = nil; + if(!cur) { + cur = [[MPSGraph alloc] init]; + } + return cur; +} + +static inline id get_device() +{ + NSError *error = nil; + static id device = nil; + if(!device) { + device = MTLCreateSystemDefaultDevice(); + } + if(!device) { + NSLog(@"Failed to get MPS device"); + abort(); + } + return device; +} + +static inline id get_library() +{ + NSError *error = nil; + static id library = nil; + if(!library) { + library = [get_device() newLibraryWithURL:[NSURL fileURLWithPath:@"bitsandbytes.metallib"] error:&error]; + } + if(!library) { + NSLog(@"Failed to load bitsandbytes.metallib"); + abort(); + } + return library; +} + +/*MPSGraphTensor* dequantize_mps(MPSGraphTensor* code, MPSGraphTensor* A, int n) +{ + id out = [get_graph() dequantizeTensor:(MPSGraphTensor*)A scaleTensor:(MPSGraphTensor*)code zeroPoint:0.0 dataType:MPSDataTypeInt8 axis:0 name:@"out"]; + return out; +}*/ + + +// MPSGraph function for quantize +extern "C" MPSGraphTensor* quantize_mps(MPSGraph* graph, MPSGraphTensor* code, MPSGraphTensor* A, int n) +{ + id device = get_device(); + id library = get_library(); + static id kernel = nil; + if(!kernel) { + kernel = [library newFunctionWithName:@"quantize"]; + if(!kernel) { + NSLog(@"Failed to load bitsandbytes.metallib"); + abort(); + } + } + NSLog(@"Not implemented"); + return nil; +} diff --git a/csrc/ops.cu b/csrc/ops.cu index 97761216c..796211fed 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -11,6 +11,8 @@ #include #include +#define ERR_NOT_IMPLEMENTED 100 + using namespace BinSearch; using std::cout; @@ -421,14 +423,7 @@ template void transform(cublasLtHandle_t ltHandl template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { #ifdef NO_CUBLASLT - cout << "" << endl; - cout << "=============================================" << endl; - cout << "ERROR: Your GPU does not support Int8 Matmul!" << endl; - cout << "=============================================" << endl; - cout << "" << endl; - assert(false); - - return 0; + return ERR_NOT_IMPLEMENTED; #else int has_error = 0; cublasLtMatmulDesc_t matmulDesc = NULL; @@ -484,7 +479,7 @@ template int igemmlt(cublasLtHandle printf("error detected"); return has_error; -#endif +#endif // NO_CUBLASLT } int fill_up_to_nearest_multiple(int value, int multiple) diff --git a/csrc/ops.cuh b/csrc/ops.cuh index f37b3b3af..da9df6af0 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -9,7 +9,6 @@ #include #include -#include #include #include diff --git a/csrc/ops.hip.cpp b/csrc/ops.hip.cpp new file mode 100644 index 000000000..e0ab1f0dd --- /dev/null +++ b/csrc/ops.hip.cpp @@ -0,0 +1,1033 @@ +// !!! This is a file automatically generated by hipify!!! +#include "hip/hip_runtime.h" +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include +#ifndef NO_HIPBLASLT +#include +#endif +#include +#include +#include +#include + + +using namespace BinSearch; +using std::cout; +using std::endl; + +void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n) +{ + int threads = 512; + int num_blocks = n/threads; + num_blocks = n % threads == 0 ? num_blocks : num_blocks + 1; + hipLaunchKernelGGL(( kHistogramScatterAdd2D), dim3(num_blocks), dim3(512), 0, 0, histogram, index1, index2, src, maxidx1, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +template void estimateQuantiles(T *A, float *code, float offset, int n) +{ + int num_blocks = n/4096; + num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; + CUDA_CHECK_RETURN(hipMemset(code, 0, 256*sizeof(float))); + hipLaunchKernelGGL(( kEstimateQuantiles), dim3(num_blocks), dim3(512), 0, 0, A, code, offset, std::numeric_limits::max(), n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +void quantize(float *code, float *A, unsigned char *out, int n) +{ + int num_blocks = n/1024; + num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1; + hipLaunchKernelGGL(( kQuantize), dim3(num_blocks), dim3(1024), 0, 0, code, A, out, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +void dequantize(float *code, unsigned char *A, float *out, int n) +{ + int num_blocks = n/1024; + num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1; + hipLaunchKernelGGL(( kDequantize), dim3(num_blocks), dim3(1024), 0, 0, code, A, out, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n) +{ + int num_blocks = n/blocksize; + num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; + + if(blocksize == 4096) + hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(1024), 0, 0, code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 2048) + hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(512), 0, 0, code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 1024) + hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(256), 0, 0, code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 512) + hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(256), 0, 0, code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 256) + hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(128), 0, 0, code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 128) + hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(64), 0, 0, code, A, absmax, out, rand, rand_offset, n); + //else if(blocksize == 64) + // hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(32), 0, 0, code, A, absmax, out, rand, rand_offset, n); + + + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n) +{ + int num_blocks = n/blocksize; + num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; + int tile_size = (DATA_TYPE > 0) ? 1024 : 512; + + if(DATA_TYPE > 0) + hipLaunchKernelGGL(( kDequantizeBlockwise), dim3((n+tile_size-1)/tile_size), dim3(64), 0, 0, code, A, absmax, out, blocksize/2, n); + else + hipLaunchKernelGGL(( kDequantizeBlockwise), dim3((n+tile_size-1)/tile_size), dim3(64), 0, 0, code, A, absmax, out, blocksize, n); + + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + + +//void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB) +//{ +// int num_blocks = (colsB+32-1)/32; +// kMatmul_inference_4bit<<>>(A, B, out, lda, ldb, rowsA, colsA, colsB); +// CUDA_CHECK_RETURN(hipPeekAtLastError()); +//} + + +template void optimizer32bit(T* g, T* p, + float* state1, float* state2, float *unorm, float max_unorm, float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) +{ + int num_blocks = n/4096; + num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; + switch(OPTIMIZER) + { + case ADAM: + if(max_unorm > 0.0f) + { + CUDA_CHECK_RETURN(hipMemset(unorm, 0, 1*sizeof(float))); + hipLaunchKernelGGL(( kPreconditionOptimizer32bit2State), dim3(num_blocks), dim3(512), 0, 0, g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + } + hipLaunchKernelGGL(( kOptimizer32bit2State), dim3(num_blocks), dim3(1024), 0, 0, g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + break; + case MOMENTUM: + case RMSPROP: + case ADAGRAD: + if(max_unorm > 0.0f) + { + CUDA_CHECK_RETURN(hipMemset(unorm, 0, 1*sizeof(float))); + hipLaunchKernelGGL(( kPreconditionOptimizer32bit1State), dim3(num_blocks), dim3(512), 0, 0, g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + } + + hipLaunchKernelGGL(( kOptimizer32bit1State), dim3(num_blocks), dim3(1024), 0, 0, g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + break; + case LION: + // in lion, the momentum update after the parameter update + hipLaunchKernelGGL(( kOptimizer32bit1State), dim3(num_blocks), dim3(1024), 0, 0, g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + + if(max_unorm > 0.0f) + { + CUDA_CHECK_RETURN(hipMemset(unorm, 0, 1*sizeof(float))); + hipLaunchKernelGGL(( kPreconditionOptimizer32bit1State), dim3(num_blocks), dim3(512), 0, 0, g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + } + break; + } +} + +template void optimizerStatic8bit(T* p, T* g, + unsigned char* state1, unsigned char* state2, + float *unorm, float max_unorm, float param_norm, + float beta1, float beta2, + float eps, int step, float lr, + float* quantiles1, float* quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + float weight_decay, + const float gnorm_scale, int n) +{ + int num_blocks = n/4096; + num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; + + if(max_unorm > 0.0f){ CUDA_CHECK_RETURN(hipMemset(unorm, 0, 1*sizeof(float))); } + + switch(OPTIMIZER) + { + case ADAM: + CUDA_CHECK_RETURN(hipMemset(new_max1, 0, 1*sizeof(float))); + CUDA_CHECK_RETURN(hipMemset(new_max2, 0, 1*sizeof(float))); + hipLaunchKernelGGL(( kPreconditionOptimizerStatic8bit2State), dim3(num_blocks), dim3(256), 0, 0, p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + hipLaunchKernelGGL(( kOptimizerStatic8bit2State), dim3(num_blocks), dim3(1024), 0, 0, p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, + quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + break; + case MOMENTUM: + case RMSPROP: + case ADAGRAD: + CUDA_CHECK_RETURN(hipMemset(new_max1, 0, 1*sizeof(float))); + hipLaunchKernelGGL(( kPreconditionOptimizerStatic8bit1State), dim3(num_blocks), dim3(256), 0, 0, p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + hipLaunchKernelGGL(( kOptimizerStatic8bit1State), dim3(num_blocks), dim3(1024), 0, 0, p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, + quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + break; + case LION: + // in lion, the momentum update happens after the parameter update + hipLaunchKernelGGL(( kOptimizerStatic8bit1State), dim3(num_blocks), dim3(1024), 0, 0, p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, + quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + + CUDA_CHECK_RETURN(hipMemset(new_max1, 0, 1*sizeof(float))); + hipLaunchKernelGGL(( kPreconditionOptimizerStatic8bit1State), dim3(num_blocks), dim3(256), 0, 0, p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + break; + default: + break; + } +} + +#define BLOCKSIZE_2STATE 2048 +#define NUM_2STATE 8 +#define BLOCKSIZE_1STATE 2048 +#define NUM_1STATE 8 + +template void optimizerStatic8bitBlockwise(T* p, T* g, + unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, + float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) +{ + + int num_blocks = 0; + switch(OPTIMIZER) + { + case ADAM: + num_blocks = n/BLOCKSIZE_2STATE; + num_blocks = n % BLOCKSIZE_2STATE == 0 ? num_blocks : num_blocks + 1; + hipLaunchKernelGGL(( kOptimizerStatic8bit2StateBlockwise), dim3(num_blocks), dim3(BLOCKSIZE_2STATE/NUM_2STATE), 0, 0, p, g, state1, state2, beta1, beta2, eps, step, lr, + quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + break; + case MOMENTUM: + case RMSPROP: + case ADAGRAD: + case LION: + num_blocks = n/BLOCKSIZE_1STATE; + num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1; + hipLaunchKernelGGL(( kOptimizerStatic8bit1StateBlockwise), dim3(num_blocks), dim3(BLOCKSIZE_1STATE/NUM_1STATE), 0, 0, p, g, state1, beta1, beta2, eps, step, lr, + quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + break; + } +} + + + +template void percentileClipping(T * g, float *gnorm_vec, int step, const int n) +{ + int num_blocks = n/2048; + num_blocks = n % 2048 == 0 ? num_blocks : num_blocks + 1; + CUDA_CHECK_RETURN(hipMemset(&gnorm_vec[step % 100], 0, 1*sizeof(float))); + hipLaunchKernelGGL(( kPercentileClipping), dim3(num_blocks), dim3(512), 0, 0, g, gnorm_vec, step, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +void gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc) +{ + const int falpha = 1; + const int fbeta = 0; + const void * alpha = &falpha; + const void * beta = &fbeta; + hipblasStatus_t status; + + status = hipblasGemmEx(context->m_handle, + transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N, + transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N, + m, n, k, + alpha, A, HIPBLAS_R_8I, lda, B, HIPBLAS_R_8I, ldb, beta, + C, HIPBLAS_R_32I, ldc, + HIPBLAS_R_32I, HIPBLAS_GEMM_DEFAULT); + + if (status != HIPBLAS_STATUS_SUCCESS) + { + std::cout << "HIPBLAS ERROR: Status " << status << std::endl; + } + +} + +void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc, + long long int strideA, long long int strideB, long long int strideC, int batchCount) +{ + const int falpha = 1; + const int fbeta = 0; + const void * alpha = &falpha; + const void * beta = &fbeta; + hipblasStatus_t status; + + //cout << transposeA << transposeB << endl; + //printf("%i %i %i\n", m,n,k); + //printf("%i %i %i\n", lda,ldb,ldc); + //printf("%i %i %i\n", strideA, strideB, strideC); + //printf("%i\n", batchCount); + + status = hipblasGemmStridedBatchedEx(context->m_handle, + transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N, + transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N, + m, n, k, + alpha, A, HIPBLAS_R_8I, lda, (long long int)strideA, B, HIPBLAS_R_8I, ldb, (long long int)strideB, beta, + C, HIPBLAS_R_32I, ldc, (long long int)strideC, batchCount, + HIPBLAS_R_32I, HIPBLAS_GEMM_DEFAULT); + + if (status != HIPBLAS_STATUS_SUCCESS) + { + std::cout << "HIPBLAS ERROR: Status " << status << std::endl; + } + +} + +int roundoff(int v, int d) { + return (v + d - 1) / d * d; +} + + +#ifndef NO_HIPBLASLT +template hipblasLtOrder_t get_order() +{ + switch(ORDER) + { + case ROW: + return HIPBLASLT_ORDER_ROW; + break; + case COL: + return HIPBLASLT_ORDER_COL; + break; + case COL32: +//return HIPBLASLT_ORDER_COL32; + return HIPBLASLT_ORDER_COL; + break; + case COL_TURING: +//return HIPBLASLT_ORDER_COL4_4R2_8C; + return HIPBLASLT_ORDER_COL; + break; + case COL_AMPERE: +//return HIPBLASLT_ORDER_COL32_2R_4R4; + return HIPBLASLT_ORDER_COL; + break; + default: + break; + } + + return HIPBLASLT_ORDER_ROW; +} + +template hipblasLtOrder_t get_order(); +template hipblasLtOrder_t get_order(); +template hipblasLtOrder_t get_order(); +//template hipblasLtOrder_t get_order(); +//template hipblasLtOrder_t get_order(); +#endif + + +template int get_leading_dim(int dim1, int dim2) +{ + switch(ORDER) + { + case ROW: + return dim2; + break; + case COL: + return dim1; + break; + default: + return dim1; + break; + /*case COL32: + // 32*row tiles + return dim1*32; + break; + case COL_TURING: + return 32*roundoff(dim1, 8); + break; + case COL_AMPERE: + // 32*32 tiles + return 32*roundoff(dim1, 32); + break; + default: + return 0; + break; +*/ + } +} + +template int get_leading_dim(int dim1, int dim2); +template int get_leading_dim(int dim1, int dim2); +template int get_leading_dim(int dim1, int dim2); + +#ifndef NO_HIPBLASLT +template void transform(hipblasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2) +{ + hipblasLtOrder_t orderA = get_order(); + hipblasLtOrder_t orderOut = get_order(); + int ldA = get_leading_dim(dim1, dim2); + int ldOut; + if (TARGET==COL && transpose) { + ldOut = dim2; + } else { + ldOut = get_leading_dim(dim1, dim2); +} + + hipblasLtMatrixLayout_t A_desc = NULL, out_desc = NULL, B_desc = NULL; + T B = T(0); + hipblasLtMatrixTransformDesc_t A2Out_desc = NULL; + hipblasOperation_t opTranspose = HIPBLAS_OP_T; + float transformAlpha = 1.0f, transformBeta = 0.0f; + + if(DTYPE == 8) + { + checkHipblasStatus(hipblasLtMatrixLayoutCreate(&A_desc, HIP_R_8I, dim1, dim2, ldA)); +checkHipblasStatus(hipblasLtMatrixLayoutCreate(&B_desc, HIP_R_8I, 0, 0, 0)); + if (TARGET==COL && transpose) { + checkHipblasStatus(hipblasLtMatrixLayoutCreate(&out_desc, HIP_R_8I, dim2, dim1, ldOut)); + } else { + checkHipblasStatus(hipblasLtMatrixLayoutCreate(&out_desc, HIP_R_8I, dim1, dim2, ldOut)); + } +} + else if(DTYPE == 32) + { + checkHipblasStatus(hipblasLtMatrixLayoutCreate(&A_desc, HIP_R_32I, dim1, dim2, ldA)); +checkHipblasStatus(hipblasLtMatrixLayoutCreate(&B_desc, HIP_R_32I, 0, 0, 0)); + if (TARGET==COL && transpose) { + checkHipblasStatus(hipblasLtMatrixLayoutCreate(&out_desc, HIP_R_32I, dim2, dim1, ldOut)); + } else { + checkHipblasStatus(hipblasLtMatrixLayoutCreate(&out_desc, HIP_R_32I, dim1, dim2, ldOut)); + } +} + else + { + printf("ERROR WRONG TYPE FOR TRANSFORM: %i\n", DTYPE); + } + + checkHipblasStatus(hipblasLtMatrixLayoutSetAttribute(A_desc, HIPBLASLT_MATRIX_LAYOUT_ORDER, &orderA, sizeof(orderA))); + checkHipblasStatus(hipblasLtMatrixLayoutSetAttribute(out_desc, HIPBLASLT_MATRIX_LAYOUT_ORDER, &orderOut, sizeof(orderOut))); + + checkHipblasStatus(hipblasLtMatrixTransformDescCreate(&A2Out_desc, HIP_R_32F)); + + if(transpose){ checkHipblasStatus(hipblasLtMatrixTransformDescSetAttribute(A2Out_desc, HIPBLASLT_MATRIX_TRANSFORM_DESC_TRANSA, &opTranspose, sizeof(opTranspose))); } + + checkHipblasStatus(hipblasLtMatrixTransform(ltHandle, A2Out_desc, &transformAlpha, A, A_desc, &transformBeta, A, B_desc, out, out_desc, 0)); + + if (A_desc) checkHipblasStatus(hipblasLtMatrixLayoutDestroy(A_desc)); +if (B_desc) checkHipblasStatus(hipblasLtMatrixLayoutDestroy(B_desc)); + if (out_desc) checkHipblasStatus(hipblasLtMatrixLayoutDestroy(out_desc)); + if (A2Out_desc) checkHipblasStatus(hipblasLtMatrixTransformDescDestroy(A2Out_desc)); +} + +template void transform(hipblasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(hipblasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(hipblasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2); +template void transform(hipblasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2); +template void transform(hipblasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(hipblasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(hipblasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2); +template void transform(hipblasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(hipblasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(hipblasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(hipblasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2); +template void transform(hipblasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(hipblasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2); +#endif +static std::string hipError_to_string(const hipError_t ret) +{ + switch(ret) + { + case hipSuccess: + return "hipSuccess"; + case hipErrorInvalidContext: + return "hipErrorInvalidContext"; + case hipErrorInvalidKernelFile: + return "hipErrorInvalidKernelFile"; + case hipErrorMemoryAllocation: + return "hipErrorMemoryAllocation"; + case hipErrorInitializationError: + return "hipErrorInitializationError"; + case hipErrorLaunchFailure: + return "hipErrorLaunchFailure"; + case hipErrorLaunchOutOfResources: + return "hipErrorLaunchOutOfResources"; + case hipErrorInvalidDevice: + return "hipErrorInvalidDevice"; + case hipErrorInvalidValue: + return "hipErrorInvalidValue"; + case hipErrorInvalidDevicePointer: + return "hipErrorInvalidDevicePointer"; + case hipErrorInvalidMemcpyDirection: + return "hipErrorInvalidMemcpyDirection"; + case hipErrorUnknown: + return "hipErrorUnknown"; + case hipErrorInvalidResourceHandle: + return "hipErrorInvalidResourceHandle"; + case hipErrorNotReady: + return "hipErrorNotReady"; + case hipErrorNoDevice: + return "hipErrorNoDevice"; + case hipErrorPeerAccessAlreadyEnabled: + return "hipErrorPeerAccessAlreadyEnabled"; + case hipErrorPeerAccessNotEnabled: + return "hipErrorPeerAccessNotEnabled"; + case hipErrorRuntimeMemory: + return "hipErrorRuntimeMemory"; + case hipErrorRuntimeOther: + return "hipErrorRuntimeOther"; + case hipErrorHostMemoryAlreadyRegistered: + return "hipErrorHostMemoryAlreadyRegistered"; + case hipErrorHostMemoryNotRegistered: + return "hipErrorHostMemoryNotRegistered"; + case hipErrorMapBufferObjectFailed: + return "hipErrorMapBufferObjectFailed"; + case hipErrorTbd: + return "hipErrorTbd"; + default: + throw std::runtime_error("unknown hipError"); + } +} +#ifndef NO_HIPBLASLT +template int igemmlt(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) +{ +#ifdef NO_CUBLASLT + cout << "" << endl; + cout << "=============================================" << endl; + cout << "ERROR: Your GPU does not support Int8 Matmul!" << endl; + cout << "=============================================" << endl; + cout << "" << endl; + assert(false); + + return 0; +#else + int has_error = 0; + hipblasLtMatmulDesc_t matmulDesc = NULL; + hipblasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL; + hipblasOperation_t opT = HIPBLAS_OP_T; + //hipblasLtPointerMode_t alphaVec = hipblasLt_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO; + hipblasLtOrder_t col32 = HIPBLASLT_ORDER_COL; + hipblasLtOrder_t col_turing = HIPBLASLT_ORDER_COL; + hipblasLtOrder_t col_ampere = HIPBLASLT_ORDER_COL; + + has_error |= checkHipblasStatus(hipblasLtMatrixLayoutCreate(&Adesc, HIP_R_8I, m, k, lda)); + has_error |= checkHipblasStatus(hipblasLtMatrixLayoutCreate(&Bdesc, HIP_R_8I, n, k, ldb)); +has_error |= checkHipblasStatus(hipblasLtMatrixLayoutSetAttribute(Adesc, HIPBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32))); + + + if(FORMATB == COL_TURING) + has_error |= checkHipblasStatus(hipblasLtMatrixLayoutSetAttribute(Bdesc, HIPBLASLT_MATRIX_LAYOUT_ORDER, &col_turing, sizeof(col_turing))); + else + has_error |= checkHipblasStatus(hipblasLtMatrixLayoutSetAttribute(Bdesc, HIPBLASLT_MATRIX_LAYOUT_ORDER, &col_ampere, sizeof(col_ampere))); + + const int64_t max_workspace_size = 0;//set to 0 to avoid choosing GSU kernel + + if(DTYPE_OUT == 32) + { + has_error |= checkHipblasStatus(hipblasLtMatmulDescCreate(&matmulDesc, HIPBLAS_COMPUTE_32I, HIP_R_32I)); + auto opA = HIPBLAS_OP_N; + has_error |= checkHipblasStatus(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_TRANSA, &opA, sizeof(int32_t))); + has_error |= checkHipblasStatus(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_TRANSB, &opT, sizeof(int32_t))); + hipblasLtEpilogue_t epilogue = HIPBLASLT_EPILOGUE_DEFAULT; + checkHipblasStatus(hipblasLtMatmulDescSetAttribute( + matmulDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue))); + has_error |= checkHipblasStatus(hipblasLtMatrixLayoutCreate(&Cdesc, HIP_R_32I, m, n, ldc)); + has_error |= checkHipblasStatus(hipblasLtMatrixLayoutSetAttribute(Cdesc, HIPBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32))); + int alpha = 1, beta = 0; + + + /* Algo and workspace TODO: need to rework to not be duplicated */ + // Set User Preference attributes + hipblasLtMatmulPreference_t pref; + checkHipblasStatus(hipblasLtMatmulPreferenceCreate(&pref)); + checkHipblasStatus( + hipblasLtMatmulPreferenceSetAttribute(pref, + HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &max_workspace_size, + sizeof(max_workspace_size))); + + const int request_solutions = 1; + hipblasLtMatmulHeuristicResult_t heuristicResult[request_solutions]; + int returnedAlgoCount = 0; + checkHipblasStatus(hipblasLtMatmulAlgoGetHeuristic(ltHandle, + matmulDesc, + Adesc, + Bdesc, + Cdesc, + Cdesc, + pref, + request_solutions, + heuristicResult, + &returnedAlgoCount)); + + if (returnedAlgoCount == 0) + { + has_error = 1; + } + else + { + has_error |= checkHipblasStatus(hipblasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int32_t*)C, Cdesc, (int32_t*)C, Cdesc, &heuristicResult[0].algo, nullptr, 0, 0)); + } + } + else + { + has_error |= checkHipblasStatus(hipblasLtMatmulDescCreate(&matmulDesc, HIPBLAS_COMPUTE_32I, HIP_R_8I)); + hipblasOperation_t opA = HIPBLAS_OP_N; + has_error |= checkHipblasStatus(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_TRANSA, &opA, sizeof(opA))); + has_error |= checkHipblasStatus(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_TRANSB, &opT, sizeof(opT))); + has_error |= checkHipblasStatus(hipblasLtMatrixLayoutCreate(&Cdesc, HIP_R_8I, m, n, ldc)); + has_error |= checkHipblasStatus(hipblasLtMatrixLayoutSetAttribute(Cdesc, HIPBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32))); +/* Algo and workspace TODO: need to rework to not be duplicated */ + // Set User Preference attributes + hipblasLtMatmulPreference_t pref; + checkHipblasStatus(hipblasLtMatmulPreferenceCreate(&pref)); + checkHipblasStatus( + hipblasLtMatmulPreferenceSetAttribute(pref, + HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &max_workspace_size, + sizeof(max_workspace_size))); + + const int request_solutions = 1; + hipblasLtMatmulHeuristicResult_t heuristicResult[request_solutions]; + int returnedAlgoCount = 0; + checkHipblasStatus(hipblasLtMatmulAlgoGetHeuristic(ltHandle, + matmulDesc, + Adesc, + Bdesc, + Cdesc, + Cdesc, + pref, + request_solutions, + heuristicResult, + &returnedAlgoCount)); + + if(!SCALE_ROWS) + { + float alpha = 1.0f, beta = 0.0f; + + has_error |= checkHipblasStatus(hipblasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, &heuristicResult[0].algo, nullptr, 0, 0)); + } + else + { + //has_error |= checkHipblasStatus(hipblasLtMatmulDescSetAttribute(matmulDesc, hipblasLt_MATMUL_DESC_POINTER_MODE, &alphaVec, sizeof(alphaVec))); + float beta = 0.0f; + + has_error |= checkHipblasStatus(hipblasLtMatmul(ltHandle, matmulDesc, row_scale, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, &heuristicResult[0].algo, nullptr, 0, 0)); + } + } + + + if (Cdesc) has_error |= checkHipblasStatus(hipblasLtMatrixLayoutDestroy(Cdesc)); + if (Bdesc) has_error |= checkHipblasStatus(hipblasLtMatrixLayoutDestroy(Bdesc)); + if (Adesc) has_error |= checkHipblasStatus(hipblasLtMatrixLayoutDestroy(Adesc)); + if (matmulDesc) has_error |= checkHipblasStatus(hipblasLtMatmulDescDestroy(matmulDesc)); + if(has_error == 1) + printf("error detected"); + + return has_error; +#endif +} +#endif + +int fill_up_to_nearest_multiple(int value, int multiple) +{ + return value + (value % multiple == 0 ? 0 : (multiple - (value % multiple))); +} + +void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, half *bias, int numRows, int numCols) +{ + int threads = 512; + //int tileCols = fill_up_to_nearest_multiple(numCols, 32); + //int n = numRows*tileCols; + int tileCols = numCols; + int n = numRows*numCols; + //int subtile_rows = 128; + //int tilesize = 32*subtile_rows; + //int num_blocks = numRows/subtile_rows; + //num_blocks += (numRows % subtile_rows == 0) ? 0 : 1; + //num_blocks = num_blocks*(tileCols/32); + //assert(threads <= tilesize); + int num_blocks = numRows * numCols / (threads * 4); + num_blocks += (numRows * numCols) % (threads * 4) == 0 ? 0 : 1; + + hipLaunchKernelGGL(( kdequant_mm_int32_fp16<4, 128, 512>), dim3(num_blocks), dim3(threads), 0, 0, A, rowStats, colStats, out, newRowStats, newcolStats, bias, numRows, numCols, tileCols, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +#define STATS_THREADS 64 +#define STATS_ITEMS 4 +#define STATS_ROWS 16 +void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols) +{ + int tile_cols = STATS_THREADS*STATS_ITEMS; + int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols); + int tiledRows = fill_up_to_nearest_multiple(rows, STATS_ROWS); + int row_tiles = (tiledRows/STATS_ROWS); + int col_tiles = (tiledCols/tile_cols); + row_tiles = row_tiles > 0 ? row_tiles : 1; + col_tiles = col_tiles > 0 ? col_tiles : 1; + int num_blocks = row_tiles * col_tiles; + + if(nnz_threshold == 0.0) + hipLaunchKernelGGL(( kgetColRowStats), dim3(num_blocks), dim3(STATS_THREADS), 0, 0, A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols); + else if(nnz_threshold != 0.0) + hipLaunchKernelGGL(( kgetColRowStats), dim3(num_blocks), dim3(STATS_THREADS), 0, 0, A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols); +CUDA_CHECK_RETURN(hipPeekAtLastError()); + + } + +void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int *nnz_block_ptr, float threshold, int rows, int cols) +{ + int threads = 64; + int items_per_thread = 4; + int tile_cols = threads*items_per_thread; + int tile_rows = 16; + int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols); + int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows); + int row_tiles = (tiledRows/tile_rows); + int col_tiles = (tiledCols/tile_cols); + row_tiles = row_tiles > 0 ? row_tiles : 1; + col_tiles = col_tiles > 0 ? col_tiles : 1; + int num_blocks = row_tiles * col_tiles; + + + if(threshold > 0.0f) + hipLaunchKernelGGL(( kDoubleRowColQuant<64, 4, 16, 64*4, 1>), dim3(num_blocks), dim3(threads), 0, 0, A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols); + else + hipLaunchKernelGGL(( kDoubleRowColQuant<64, 4, 16, 64*4, 0>), dim3(num_blocks), dim3(threads), 0, 0, A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols); + + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +template void transformRowToFormat(char * A, char *out, int rows, int cols) +{ + int threads = 256; + int items_per_thread = 8; + // we load 128 column values per warp + int tile_cols = 32*items_per_thread; + int tile_rows = 32; + int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols); + int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows); + int row_tiles = (tiledRows/tile_rows); + int col_tiles = (tiledCols/tile_cols); + row_tiles = row_tiles > 0 ? row_tiles : 1; + col_tiles = col_tiles > 0 ? col_tiles : 1; + int num_blocks = row_tiles * col_tiles; + + int outCols = fill_up_to_nearest_multiple(cols, 32); + int outRows = fill_up_to_nearest_multiple(rows, 32); + if(FORMAT == COL_TURING) + { + if(TRANSPOSE) + outRows = fill_up_to_nearest_multiple(cols, 8); + else + outRows = fill_up_to_nearest_multiple(rows, 8); + } + else if(FORMAT == COL_AMPERE) + { + if(TRANSPOSE) + outRows = fill_up_to_nearest_multiple(cols, 32); + else + outRows = fill_up_to_nearest_multiple(rows, 32); + } + else + { + if(TRANSPOSE) + { + outCols = fill_up_to_nearest_multiple(rows, 32); + outRows = cols; + } + } + +hipLaunchKernelGGL(( kTransformRowToFormat<256, 8, 32, 32*8, TRANSPOSE, FORMAT>), dim3(num_blocks), dim3(threads), 0, 0, A, out, rows, cols, tiledCols, outRows, outCols); + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +void spmm_coo(hipsparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B) +{ + hipsparseSpMatDescr_t descA; + hipsparseDnMatDescr_t descB, descC; + + float alpha = 1.0f; + float beta = 0.0f; + void *dBuffer = NULL; + size_t bufferSize = 0; + + CHECK_HIPSPARSE( hipsparseCreateCoo(&descA, A_rows, A_cols, A_nnz, + A_rowidx, A_colidx, A_vals, + HIPSPARSE_INDEX_32I, + HIPSPARSE_INDEX_BASE_ZERO, HIP_R_16F) ); + // Create dense matrix C + CHECK_HIPSPARSE( hipsparseCreateDnMat(&descC, A_rows, B_cols, ldc, C, + HIP_R_16F, HIPSPARSE_ORDER_ROW) ); + // Create dense matrix B + if(transposed_B) + { + int tmp = A_cols; + A_cols = B_cols; + B_cols = tmp; + } + + CHECK_HIPSPARSE( hipsparseCreateDnMat(&descB, A_cols, B_cols, ldb, B, + HIP_R_16F, HIPSPARSE_ORDER_ROW) ); + // allocate an external buffer if needed + CHECK_HIPSPARSE( hipsparseSpMM_bufferSize( + handle, + HIPSPARSE_OPERATION_NON_TRANSPOSE, + transposed_B ? HIPSPARSE_OPERATION_TRANSPOSE : HIPSPARSE_OPERATION_NON_TRANSPOSE, + &alpha, descA, descB, &beta, descC, HIP_R_32F, + HIPSPARSE_SPMM_ALG_DEFAULT, &bufferSize) ); + CUDA_CHECK_RETURN( hipMalloc(&dBuffer, bufferSize) ); + + // execute SpMM + CHECK_HIPSPARSE( hipsparseSpMM(handle, + HIPSPARSE_OPERATION_NON_TRANSPOSE, + transposed_B ? HIPSPARSE_OPERATION_TRANSPOSE : HIPSPARSE_OPERATION_NON_TRANSPOSE, + &alpha, descA, descB, &beta, descC, HIP_R_32F, + HIPSPARSE_SPMM_ALG_DEFAULT, dBuffer)); + + // destroy matrix/vector descriptors + CHECK_HIPSPARSE( hipsparseDestroySpMat(descA) ); + CHECK_HIPSPARSE( hipsparseDestroyDnMat(descB) ); + CHECK_HIPSPARSE( hipsparseDestroyDnMat(descC) ); + CUDA_CHECK_RETURN( hipFree(dBuffer) ); +} + +template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) +{ + + hipLaunchKernelGGL(( kspmm_coo_very_sparse_naive), dim3(nnz_rows), dim3(256), 0, 0, max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz, rowsA, rowsB, colsB); + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + + +template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols) +{ + int threads = 256; + // we load 128 column values per warp + int tiledCols = tiledCols = fill_up_to_nearest_multiple(cols, 32); + int tiledRows = 0; + + int num_blocks = idx_size; + + /*if(FORMAT == COL_TURING) + { + tiledRows = fill_up_to_nearest_multiple(rows, 8); + } + else if(FORMAT == COL_AMPERE) + { + tiledRows = fill_up_to_nearest_multiple(rows, 32); + }*/ + +//for col format on ROCm + tiledRows = rows; + + hipLaunchKernelGGL(( kExtractOutliers), dim3(num_blocks), dim3(threads), 0, 0, A, idx, out, idx_size, rows, cols, tiledRows, tiledCols); + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + + + + +template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits) +{ + + int num_blocks = (m+31)/32; + + //cout << num_blocks << endl; + //cout << lda << endl; + //cout << ldb << endl; + //cout << ldc << endl; + + //cout << m << endl; + //cout << n << endl; + //cout << k << endl; + //if(bits == 32) + //gemm_device<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + if(bits == 16) + //gemm_device<<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + hipLaunchKernelGGL(( gemm_device), dim3(num_blocks), dim3(160), 0, 0 , m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 96, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 64, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); +} + +template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) +{ + + int num_blocks = (m+31)/32; + + //cout << num_blocks << endl; + //cout << lda << endl; + //cout << ldb << endl; + //cout << ldc << endl; + + //cout << m << endl; + //cout << n << endl; + //cout << k << endl; + hipLaunchKernelGGL(( kgemm_4bit_inference), dim3(num_blocks), dim3(96), 0, 0 , m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); + //kgemm_4bit_inference<<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); + //kgemm_4bit_inference<<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); + //kgemm_4bit_inference<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); +} + +template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize) +{ + + int num_blocks = (m+3)/4; + + hipLaunchKernelGGL(( kgemm_4bit_inference_naive), dim3(num_blocks), dim3(128), 0, 0 , m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +template void func(T *A, T *B, T value, long n) +{ + int threads = 512; + int blocks = n/threads; + blocks = n % threads == 0 ? blocks : blocks + 1; + blocks = blocks > 65535 ? 65535 : blocks; + hipLaunchKernelGGL(( kfunc), dim3(blocks), dim3(512), 0, 0, A, B, value, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +//============================================================== +// TEMPLATE DEFINITIONS +//============================================================== + +template void func(float *A, float *B, float value, long n); +template void func(unsigned char *A, unsigned char *B, unsigned char value, long n); +template void func(float *A, float *B, float value, long n); +template void func(float *A, float *B, float value, long n); + +template void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference_naive(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference_naive(int m, int n, int k, hip_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, hip_bfloat16 * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference_naive(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize); + +//template void gemm_host(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits); +template void gemm_host(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits); +template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); +template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); + +template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); +template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); + +#ifndef NO_HIPBLASLT +template int igemmlt(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); +template int igemmlt(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); +template int igemmlt(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); +template int igemmlt(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); +template int igemmlt(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); +template int igemmlt(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); +#endif + +template void transformRowToFormat(char * A, char *out, int rows, int cols); +template void transformRowToFormat(char * A, char *out, int rows, int cols); +template void transformRowToFormat(char * A, char *out, int rows, int cols); +template void transformRowToFormat(char * A, char *out, int rows, int cols); +template void transformRowToFormat(char * A, char *out, int rows, int cols); +template void transformRowToFormat(char * A, char *out, int rows, int cols); + +template void estimateQuantiles(half *A, float *code, float offset, int n); +template void estimateQuantiles(float *A, float *code, float offset, int n); + +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, hip_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, hip_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, hip_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, hip_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); + +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, hip_bfloat16 *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, hip_bfloat16 *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, hip_bfloat16 *out, int blocksize, const int n); + +#define MAKE_optimizer32bit(name, gtype) \ +template void optimizer32bit(gtype* g, gtype* p, \ + float* state1, float* state2, float* unorm, float max_unorm, float param_norm, \ + const float beta1, const float beta2, const float eps, const float weight_decay, \ + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); + +MAKE_optimizer32bit(ADAM, half) +MAKE_optimizer32bit(ADAM, float) +MAKE_optimizer32bit(ADAM, hip_bfloat16) +MAKE_optimizer32bit(MOMENTUM, half) +MAKE_optimizer32bit(MOMENTUM, float) +MAKE_optimizer32bit(RMSPROP, half) +MAKE_optimizer32bit(RMSPROP, float) +MAKE_optimizer32bit(LION, half) +MAKE_optimizer32bit(LION, float) +MAKE_optimizer32bit(LION, hip_bfloat16) +MAKE_optimizer32bit(ADAGRAD, half) +MAKE_optimizer32bit(ADAGRAD, float) + +#define MAKE_optimizerStatic8bit(name, gtype) \ +template void optimizerStatic8bit(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \ + float *unorm, float max_unorm, float param_norm, \ + float beta1, float beta2, \ + float eps, int step, float lr, \ + float* quantiles1, float* quantiles2, \ + float* max1, float* max2, float* new_max1, float* new_max2, \ + float weight_decay, \ + const float gnorm_scale, int n); \ + +MAKE_optimizerStatic8bit(ADAM, half) +MAKE_optimizerStatic8bit(ADAM, float) +MAKE_optimizerStatic8bit(MOMENTUM, half) +MAKE_optimizerStatic8bit(MOMENTUM, float) +MAKE_optimizerStatic8bit(RMSPROP, half) +MAKE_optimizerStatic8bit(RMSPROP, float) +MAKE_optimizerStatic8bit(LION, half) +MAKE_optimizerStatic8bit(LION, float) + +#define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \ +template void optimizerStatic8bitBlockwise(gtype* p, gtype* g, \ + unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \ + float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n); \ + +MAKE_optimizerStatic8bitBlockwise(half, ADAM); +MAKE_optimizerStatic8bitBlockwise(float, ADAM); +MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM); +MAKE_optimizerStatic8bitBlockwise(float, MOMENTUM); +MAKE_optimizerStatic8bitBlockwise(half, RMSPROP); +MAKE_optimizerStatic8bitBlockwise(float, RMSPROP); +MAKE_optimizerStatic8bitBlockwise(half, LION); +MAKE_optimizerStatic8bitBlockwise(float, LION); +MAKE_optimizerStatic8bitBlockwise(hip_bfloat16, LION); +MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD); +MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD); + +template void percentileClipping(float * g, float *gnorm_vec, int step, const int n); +template void percentileClipping(half * g, float *gnorm_vec, int step, const int n); + +MAKE_optimizerStatic8bitBlockwise(hip_bfloat16, ADAM); diff --git a/csrc/ops.hip.h b/csrc/ops.hip.h new file mode 100644 index 000000000..8e41f852a --- /dev/null +++ b/csrc/ops.hip.h @@ -0,0 +1,216 @@ +// !!! This is a file automatically generated by hipify!!! +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + + +#ifndef ops_H +#define ops_H + +#include +#include +#include +#include + +#include +#include +#include +//#ifndef NO_HIPBLASLT +#include +//#endif +#include +#include +#include + +/* +#include +#include +*/ + + +#define CUDA_CHECK_RETURN(value) { \ + hipError_t _m_cudaStat = value; \ + if (_m_cudaStat != hipSuccess) { \ + fprintf(stderr, "Error %s at line %d in file %s\n", \ + hipGetErrorString(_m_cudaStat), __LINE__, __FILE__); \ + exit(1); \ + } } + +#define THREADS_PER_BLOCKS (512) + +#define CHECK_HIPSPARSE(value) { \ + hipsparseStatus_t _m_hipStat = value; \ + if (_m_hipStat != HIPSPARSE_STATUS_SUCCESS) { \ + fprintf(stderr, "Error at line %d in file %s\n", \ + __LINE__, __FILE__); \ + exit(1); \ + } } + + + +#define THREADS_PER_BLOCKS (512) + + +inline void v(hipError_t status) { + if (status != hipSuccess) { + printf("hip API failed with status %d: %s\n", status, hipGetErrorString(status)); + throw std::logic_error("hip API failed"); + } +} + +inline int checkHipblasStatus(hipblasStatus_t status) { + if (status != HIPBLAS_STATUS_SUCCESS) { + printf("hipBLAS API failed with status %d\n", status); + //throw std::logic_error("cuBLAS API failed"); + return 1; + } + return 0; +} + +typedef enum Operations_t +{ + ksmul = 0, +} Operations_t; + +typedef enum Optimizer_t +{ + ADAM = 0, + MOMENTUM = 1, + RMSPROP = 2, + LARS = 3, + ADAGRAD = 4, + LION = 5, +} Optimizer_t; + +typedef enum Transform_t +{ + ROW = 0, + COL = 1, + COL32 = 2, + COL_TURING = 3, + COL_AMPERE = 4, +} Transform_t; + +typedef enum DataType_t +{ + General8bit = 0, + FP4 = 1, + NF4 = 2, +} DataType_t; + +typedef enum Funcs_t +{ + FILL = 0, + ARANGE = 1, + _MUL = 2, +} Funcs_t; + +class Context +{ + public: + rocblas_handle m_handle; + + Context() + { + rocblas_handle handle; + rocblas_create_handle(&handle); + m_handle = handle; + } + +}; + +#ifndef NO_HIPBLASLT +class ContextLt +{ + public: + hipblasLtHandle_t m_handle; + + ContextLt() + { + hipblasLtHandle_t handle; + hipblasLtCreate(&handle); + m_handle = handle; + } +}; +#endif + +class ContextHipsparse +{ + public: + hipsparseHandle_t m_handle; + + ContextHipsparse() + { + hipsparseHandle_t handle; + hipsparseCreate(&handle); + m_handle = handle; + } + +}; + + +template void estimateQuantiles(T *A, float *code, float offset, int n); + +void quantize(float *code, float *A, unsigned char *out, int n); +void dequantize(float *code, unsigned char *A, float *out, int n); +template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n); + +template void optimizer32bit(T* g, T* p, + float* state1, float* state2, float *unorm, float max_unorm, float param_norm, + float beta1, float beta2, float eps, float weight_decay, + int step, float lr, const float gnorm_scale, bool skip_zeros, int n); + +template void optimizerStatic8bit(T* p, T* g, unsigned char* state1, unsigned char* state2, + float *unorm, float max_unorm, float param_norm, + float beta1, float beta2, + float eps, int step, float lr, + float* quantiles1, float* quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + float weight_decay, + const float gnorm_scale, int n); + +template void optimizerStatic8bitBlockwise(T* p, T* g, + unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, + float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, + bool skip_zeros, int n); + +template void percentileClipping(T * g, float *gnorm_vec, int step, const int n); + +void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n); + +void gemmex(Context * context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); +void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc, + long long int strideA, long long int strideB, long long int strideC, int batchCount); + + +#ifndef NO_HIPBLASLT +template int igemmlt(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); + +template void transform(hipblasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2); +#endif + +void cutlass_igemm(bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); +void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, half* bias, int numRows, int numCols); +void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols); +void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, + int *rowidx, int *colidx, half *val, int *nnz_block_ptr, float threshold, int rows, int cols); + +template void transformRowToFormat(char * A, char *out, int rows, int cols); + +void spmm_coo(hipsparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B); + +template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); + +template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); + +void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB); + +template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits); +template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize); + +template void func(T *A, T *B, T value, long n); + +#endif diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index c74357758..1583d8215 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -7,7 +7,7 @@ #include #endif #if BUILD_HIP -#include +#include #endif #include @@ -265,7 +265,6 @@ void spmm_coo_very_sparse_naive_fp16(int *max_count, int *max_idx, int *offset_r void spmm_coo_very_sparse_naive_int8(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) { spmm_coo_very_sparse_naive(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); } - #endif extern "C" @@ -319,9 +318,9 @@ extern "C" MAKE_CFUNC32(adam, float, fp32) MAKE_CFUNC32(adam, half, fp16) - #if defined(BUILD_CUDA) +#if defined(BUILD_CUDA) MAKE_CFUNC32(adam, __nv_bfloat16, bf16) - #elif defined(BUILD_HIP) +#elif defined(BUILD_HIP) MAKE_CFUNC32(adam, hip_bfloat16, bf16) #endif MAKE_CFUNC32(momentum, float, 32) @@ -330,9 +329,9 @@ extern "C" MAKE_CFUNC32(rmsprop, half, 16) MAKE_CFUNC32(lion, float, fp32) MAKE_CFUNC32(lion, half, fp16) - #if defined(BUILD_CUDA) +#if defined(BUILD_CUDA) MAKE_CFUNC32(lion, __nv_bfloat16, bf16) - #elif defined(BUILD_HIP) +#elif defined(BUILD_HIP) MAKE_CFUNC32(lion, hip_bfloat16, bf16) #endif MAKE_CFUNC32(adagrad, float, 32) @@ -374,16 +373,16 @@ extern "C" MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, fp32) MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, fp16) MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, fp32) - #if defined(BUILD_CUDA) +#if defined(BUILD_CUDA) MAKE_CBLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16) - #elif defined(BUILD_HIP) +#elif defined(BUILD_HIP) MAKE_CBLOCKWISE8(adam, ADAM, hip_bfloat16, bf16) #endif MAKE_CBLOCKWISE8(lion, LION, half, fp16) MAKE_CBLOCKWISE8(lion, LION, float, fp32) - #if defined(BUILD_CUDA) +#if defined(BUILD_CUDA) MAKE_CBLOCKWISE8(lion, LION, __nv_bfloat16, bf16) - #elif defined(BUILD_HIP) +#elif defined(BUILD_HIP) MAKE_CBLOCKWISE8(lion, LION, hip_bfloat16, bf16) #endif @@ -434,7 +433,7 @@ extern "C" { \ transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose((cublasLtHandle_t) context->m_handle, A, out, dim1, dim2); \ } \ - + #endif #if BUILD_HIP @@ -477,7 +476,7 @@ extern "C" MAKE_FUNC_CTRANSFORM(8, col32, row, n, int8_t, COL32, ROW, false, 8) MAKE_FUNC_CTRANSFORM(32, col32, row, n, int32_t, COL32, ROW, false, 32) - #if defined(BUILD_HIP) +#if defined(BUILD_HIP) MAKE_FUNC_CTRANSFORM(8, row, col, t, int8_t, ROW, COL, true, 8) MAKE_FUNC_CTRANSFORM(32, row, col, n, int32_t, ROW, COL, false, 32) MAKE_FUNC_CTRANSFORM(32, row, col, t, int32_t, ROW, COL, true, 32) @@ -555,7 +554,7 @@ extern "C" int hasPrefetch = 0; CUDA_CHECK_RETURN(cudaDeviceGetAttribute(&hasPrefetch, cudaDevAttrConcurrentManagedAccess, device)); // 40ns overhead if (hasPrefetch == 0) return; - + CUDA_CHECK_RETURN(cudaMemPrefetchAsync(ptr, bytes, device, 0)); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } @@ -594,9 +593,9 @@ extern "C" void cgemm_4bit_inference_naive_fp16(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize) { gemm_4bit_inference_naive_fp16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); } - #if defined(BUILD_CUDA) +#if defined(BUILD_CUDA) void cgemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize) - #elif defined(BUILD_HIP) +#elif defined(BUILD_HIP) void cgemm_4bit_inference_naive_bf16(int m, int n, int k, hip_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, hip_bfloat16 * out, int lda, int ldb, int ldc, int blocksize) #endif { gemm_4bit_inference_naive_bf16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); } diff --git a/include/Algo-Direct-Common.h b/include/Algo-Direct-Common.h index c97084904..7b40edea9 100644 --- a/include/Algo-Direct-Common.h +++ b/include/Algo-Direct-Common.h @@ -190,7 +190,7 @@ struct DirectInfo xi = xws; } else { - myassert(Gap==1, "if Gap>1 then X workspace must be provided"); + myassert((Gap==1), "if Gap>1 then X workspace must be provided"); xi = x; } diff --git a/include/Algo-Direct2.h b/include/Algo-Direct2.h index 347ec9c5e..547ca9955 100644 --- a/include/Algo-Direct2.h +++ b/include/Algo-Direct2.h @@ -52,6 +52,7 @@ struct AlgoVecBase::val private: typedef AlgoScalarBase base_t; +#ifdef USE_SSE2 FORCE_INLINE //NO_INLINE void resolve(const FVec& vz, const IVec& bidx, uint32 *pr) const @@ -93,8 +94,8 @@ struct AlgoVecBase::val __m128 vxp = _mm_shuffle_ps(xp01, xp23, (1) + (3 << 2) + (1 << 4) + (3 << 6)); #endif IVec i(u.vec); - IVec vlem = operator< (vz, vxm); - IVec vlep = operator< (vz, vxp); + IVec vlem = vz < vxm; + IVec vlep = vz < vxp; i = i + vlem + vlep; i.store(pr); } @@ -123,8 +124,8 @@ struct AlgoVecBase::val __m128d vxp = _mm_shuffle_pd(vx0, vx1, 3); IVec i(b1, b0); - IVec vlem = operator< (vz, vxm); - IVec vlep = operator< (vz, vxp); + IVec vlem = (vz < vxm); + IVec vlep = (vz < vxp); i = i + vlem + vlep; union { @@ -135,6 +136,7 @@ struct AlgoVecBase::val pr[0] = u.ui32[0]; pr[1] = u.ui32[2]; } +#endif // USE_SSE2 #ifdef USE_AVX @@ -157,7 +159,7 @@ struct AlgoVecBase::val FVec vxp = _mm256_i32gather_ps(xi, idxp, sizeof(float)); IVec ip = idxm; -#else // do not use gather instrucions +#else // do not use gather instructions union U { __m256i vec; @@ -227,8 +229,8 @@ struct AlgoVecBase::val #endif - IVec vlem = operator< (vz, vxm); - IVec vlep = operator< (vz, vxp); + IVec vlem = vz < vxm; + IVec vlep = vz < vxp; ip = ip + vlem + vlep; ip.store(pr); @@ -277,8 +279,8 @@ struct AlgoVecBase::val // FVec vxp = _mm256_insertf128_pd(_mm256_castpd128_pd256(h01p), h23p, 1); IVec i(u.vec); - IVec vlem = operator< (vz, vxm); - IVec vlep = operator< (vz, vxp); + IVec vlem = vz < vxm; + IVec vlep = vz < vxp; i = i + vlem + vlep; i.extractLo32s().store(pr); } diff --git a/include/Portable.h b/include/Portable.h index 1710b0502..090a25065 100644 --- a/include/Portable.h +++ b/include/Portable.h @@ -4,10 +4,40 @@ #include #include +#if defined(__aarch64__) +#ifdef __CUDACC__ +#undef USE_NEON // Doesn't work with nvcc, undefined symbols +#else +#include +#undef USE_NEON // Not yet implemented +#endif +#undef USE_AVX // x86_64 only +#undef USE_AVX2 // x86_64 only +#undef USE_SSE2 // x86_64 only +#undef USE_SSE41 // x86_64 only +#undef USE_SSE42 // x86_64 only +#undef USE_FMA // x86_64 only +#ifdef USE_NEON +typedef float32x4_t __m128; +typedef int32x4_t __m128i; +typedef float64x2_t __m128d; +#else +typedef struct {float a; float b; float c; float d;} __m128; +typedef struct {int a; int b; int c; int d;} __m128i; +typedef struct {double a; double b;} __m128d; +#endif +#else +#undef USE_NEON // ARM64 only #ifdef __FMA__ #define USE_FMA #endif +#if !defined(__SSE2__) && !defined(_MSC_VER) +#error Compiler must support SSE2 +#endif +#define USE_SSE2 +#if defined(__aarch64__) +#else #ifdef __AVX2__ #define USE_AVX2 #endif @@ -24,7 +54,8 @@ #ifdef __SSE4_2__ #define USE_SSE42 #endif - +#endif +#endif #ifndef _MSC_VER #include @@ -147,5 +178,5 @@ inline T prev(T x) return x; } -} // namepsace Details +} // namespace Details } // namespace BinSearch diff --git a/include/SIMD.h b/include/SIMD.h index a2ac1a9ae..9d1410c73 100644 --- a/include/SIMD.h +++ b/include/SIMD.h @@ -2,6 +2,46 @@ #include "Portable.h" +#ifdef USE_SSE2 +#include +#if defined(USE_AVX) || defined(USE_AVX2) +#include +#else +#ifdef USE_SSE41 +#include +#endif +#endif +#endif + +namespace BinSearch { +namespace Details { + +template +struct FTOITraits{}; + +template +struct FVec; + +template +struct IVec; + +template +struct FVec1; + +template <> struct InstrFloatTraits +{ + typedef __m128 vec_t; +}; + +template <> struct InstrFloatTraits +{ + typedef __m128d vec_t; +}; + +} +} + +#if !defined(__aarch64__) #ifdef USE_SSE42 #ifndef _MSC_VER #include @@ -26,29 +66,11 @@ FORCE_INLINE int popcnt32(int x32) } // namespace #endif -#if defined(USE_AVX) || defined(USE_AVX2) -#include -#else -#include -#ifdef USE_SSE41 -#include -#endif -#endif - #include "Type.h" namespace BinSearch { namespace Details { -template -struct FVec; - -template -struct IVec; - -template -struct FVec1; - template <> struct InstrIntTraits { typedef __m128i vec_t; @@ -64,8 +86,8 @@ template <> struct InstrFloatTraits typedef __m128d vec_t; }; -template -struct FTOITraits +template <> +struct FTOITraits { typedef IVec vec_t; }; @@ -285,9 +307,11 @@ FORCE_INLINE FVec operator- (const FVec& a, const FVec< FORCE_INLINE FVec operator* (const FVec& a, const FVec& b) { return _mm_mul_ps( a, b ); } FORCE_INLINE FVec operator/ (const FVec& a, const FVec& b) { return _mm_div_ps( a, b ); } FORCE_INLINE IVec ftoi (const FVec& a) { return _mm_cvttps_epi32(a); } +#ifndef __clang__ // Conflicts with builtin operator FORCE_INLINE IVec operator<= (const FVec& a, const FVec& b) { return _mm_castps_si128( _mm_cmple_ps( a, b ) ); } FORCE_INLINE IVec operator>= (const FVec& a, const FVec& b) { return _mm_castps_si128( _mm_cmpge_ps( a, b ) ); } FORCE_INLINE IVec operator< (const FVec& a, const FVec& b) { return _mm_castps_si128(_mm_cmplt_ps(a, b)); } +#endif #ifdef USE_FMA FORCE_INLINE FVec mulSub(const FVec& a, const FVec& b, const FVec& c) { return _mm_fmsub_ps(a, b, c); } #endif @@ -339,9 +363,11 @@ FORCE_INLINE FVec operator- (const FVec& a, const FVec FORCE_INLINE FVec operator* (const FVec& a, const FVec& b) { return _mm_mul_pd( a, b ); } FORCE_INLINE FVec operator/ (const FVec& a, const FVec& b) { return _mm_div_pd( a, b ); } FORCE_INLINE IVec ftoi (const FVec& a) { return _mm_cvttpd_epi32(a); } +#ifndef __clang__ // Conflicts with builtin operator FORCE_INLINE IVec operator<= (const FVec& a, const FVec& b) { return _mm_castpd_si128( _mm_cmple_pd( a, b ) ); } FORCE_INLINE IVec operator< (const FVec& a, const FVec& b) { return _mm_castpd_si128(_mm_cmplt_pd(a, b)); } FORCE_INLINE IVec operator>= (const FVec& a, const FVec& b) { return _mm_castpd_si128( _mm_cmpge_pd( a, b ) ); } +#endif #ifdef USE_FMA FORCE_INLINE FVec mulSub(const FVec& a, const FVec& b, const FVec& c ) { return _mm_fmsub_pd(a, b, c); } #endif @@ -558,5 +584,6 @@ FORCE_INLINE FVec mulSub(const FVec& a, const FVec List[int]: + return [test_dims_rng.randint(min, max) for _ in range(n)] + + +def format_with_label(label: str, value: Any) -> str: + if isinstance(value, bool): + formatted = "T" if value else "F" + elif isinstance(value, (list, tuple)) and all(isinstance(v, bool) for v in value): + formatted = "".join("T" if b else "F" for b in value) + else: + formatted = str(value) + return f"{label}={formatted}" + + +def id_formatter(label: str): + """ + Return a function that formats the value given to it with the given label. + """ + return lambda value: format_with_label(label, value) + + +DTYPE_NAMES = { + torch.bfloat16: "bf16", + torch.bool: "bool", + torch.float16: "fp16", + torch.float32: "fp32", + torch.float64: "fp64", + torch.int32: "int32", + torch.int64: "int64", + torch.int8: "int8", +} + + +def describe_dtype(dtype: torch.dtype) -> str: + return DTYPE_NAMES.get(dtype) or str(dtype).rpartition(".")[2] diff --git a/tests/test_autograd.py b/tests/test_autograd.py index f045fda4c..42611e3c0 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -1,51 +1,35 @@ -from itertools import permutations, product +from typing import Tuple import pytest import torch import bitsandbytes as bnb -from bitsandbytes.cextension import HIP_ENVIRONMENT - -n = 1 -k = 25 -dim1 = torch.randint(16, 64, size=(n,)).tolist() -dim2 = torch.randint(32, 96, size=(n,)).tolist() -dim3 = torch.randint(32, 96, size=(n,)).tolist() -dim4 = torch.randint(32, 96, size=(n,)).tolist() -funcs = [(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)] -str_funcs = ["bmm", "matmul"] -req_grad = [(False, False), (True, False), (True, True), (False, True)] -req_grad_str = ["FF", "TF", "TT", "FT"] -transpose = [(False, False), (False, True), (True, True), (True, False)] -str_transpose = ["FF", "FT", "TT", "TF"] -dtype = [torch.float32, torch.float16] -values = list( - product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose) +from tests.helpers import ( + BOOLEAN_TRIPLES, + BOOLEAN_TUPLES, + TRUE_FALSE, + describe_dtype, + get_test_dims, + id_formatter, ) -str_values = list( - product( - dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose - ) -) -names = [ - "dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}".format( - *vals - ) - for vals in str_values -] - - -@pytest.mark.parametrize( - "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", - values, - ids=names, -) -def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): + +TRANSPOSE_VALS = [(False, True), (False, False)] + + +@pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", get_test_dims(32, 96, n=1), ids=id_formatter("dim2")) +@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3")) +@pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4")) +@pytest.mark.parametrize("funcs", [(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)], ids=["func=bmm", "func=matmul"]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=describe_dtype) +@pytest.mark.parametrize("req_grad", BOOLEAN_TUPLES, ids=id_formatter("req_grad")) +@pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose")) +def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool], transpose: Tuple[bool, bool]): if dim2 > 0: dim2 = dim2 - (dim2 % 16) dim3 = dim3 - (dim3 % 16) dim4 = dim4 - (dim4 % 16) - for i in range(k): + for i in range(25): # normal multiply if funcs[0] in [torch.mm, torch.matmul]: @@ -229,70 +213,17 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): assert (idx == 0).sum().item() < n * 0.02 -n = 1 -k = 3 -dim1 = torch.randint(16, 64, size=(n,)).tolist() -dim2 = torch.randint(32, 96, size=(n,)).tolist() -dim3 = torch.randint(32, 96, size=(n,)).tolist() -dim4 = torch.randint(32, 96, size=(n,)).tolist() - -dim2.append(0) - -decomp = [0.0, 6.0] -funcs = [(torch.matmul, bnb.matmul), (torch.matmul, bnb.research.switchback_bnb)] -str_funcs = ["matmullt", 'switchback_bnb'] -req_grad = [(False, False), (True, False), (True, True), (False, True)] -req_grad = list(product([True, False], repeat=3)) -req_grad_str = [] -for c in req_grad: - strval = '' - for v in c: - if v == True: strval += 'T' - else: strval += 'F' - req_grad_str.append(strval) - -transpose = [(False, True), (False, False)] -str_transpose = ["NT", "NN"] -dtype = [torch.float16, torch.bfloat16, torch.float32] -has_fp16_weights = [True, False] -has_bias = [True, False] -values = list( - product( - dim1, - dim2, - dim3, - dim4, - funcs, - dtype, - req_grad, - transpose, - decomp, - has_fp16_weights, - has_bias - ) -) -str_values = list( - product( - dim1, - dim2, - dim3, - dim4, - str_funcs, - dtype, - req_grad_str, - str_transpose, - decomp, - has_fp16_weights, - has_bias - ) -) -names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_decomp_{}_has_fp16_weights_{}_has_bias_{}".format(*vals) for vals in str_values] - -@pytest.mark.parametrize( - "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias", - values, - ids=names, -) +@pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", [*get_test_dims(32, 96, n=1), 0], ids=id_formatter("dim2")) +@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3")) +@pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4")) +@pytest.mark.parametrize("decomp", [0.0, 6.0], ids=id_formatter("decomp")) +@pytest.mark.parametrize("funcs", [(torch.matmul, bnb.matmul), (torch.matmul, bnb.research.switchback_bnb)], ids=["func=matmul", "func=switchback_bnb"]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) +@pytest.mark.parametrize("req_grad", BOOLEAN_TRIPLES, ids=id_formatter("req_grad")) +@pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose")) +@pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights")) +@pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias")) def test_matmullt( dim1, dim2, @@ -313,7 +244,7 @@ def test_matmullt( req_grad = list(req_grad) req_grad[2] = False - for i in range(k): + for i in range(3): # normal multiply if funcs[0] in [torch.mm, torch.matmul]: @@ -429,45 +360,25 @@ def test_matmullt( torch.testing.assert_close(gradBias1, gradBias2) -n = 1 -k = 3 -dim1 = torch.randint(16, 64, size=(n,)).tolist() -dim2 = torch.randint(32, 96, size=(n,)).tolist() -dim3 = torch.randint(32, 96, size=(n,)).tolist() -dim4 = torch.randint(32, 96, size=(n,)).tolist() - -dim2.append(0) - -funcs = [(torch.matmul, bnb.matmul_4bit)] -str_funcs = ["matmul"] -req_grad = list(product([True, False], repeat=3)) -req_grad_str = [] -for c in req_grad: - strval = '' - for v in c: - if v == True: strval += 'T' - else: strval += 'F' - req_grad_str.append(strval) - -transpose = [(False, True), (False, False)] -str_transpose = ["NT", "NN"] -dtype = [torch.float16, torch.float32] -compress_statistics = [False, True] -has_fp16_weights = [True, False] -has_bias = [True, False] -quant_type = ['fp4', 'nf4'] -values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type)) -str_values = list(product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose, has_bias, compress_statistics, quant_type)) -names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_has_bias_{}_compress_statistics_{}_quant_type_{}".format(*vals) for vals in str_values] -@pytest.mark.parametrize( "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type", values, ids=names) -def test_matmul_4bit( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type): +@pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", [*get_test_dims(32, 96, n=1), 0], ids=id_formatter("dim2")) +@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3")) +@pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4")) +@pytest.mark.parametrize("funcs", [(torch.matmul, bnb.matmul_4bit)], ids=["func=matmul"]) +@pytest.mark.parametrize("req_grad", BOOLEAN_TRIPLES, ids=id_formatter("req_grad")) +@pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose")) +@pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias")) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=describe_dtype) +@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) +@pytest.mark.parametrize("quant_type", ['fp4', 'nf4'], ids=id_formatter("quant_type")) +def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type): dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) if has_bias == False: req_grad = list(req_grad) req_grad[2] = False - for i in range(k): + for i in range(3): # normal multiply if funcs[0] in [torch.mm, torch.matmul]: A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype) @@ -530,32 +441,21 @@ def test_matmul_4bit( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, torch.testing.assert_close(gradBias1, gradBias2) -funcs = [(torch.matmul, bnb.research.matmul_fp8_mixed), (torch.matmul, bnb.research.matmul_fp8_global)] -str_funcs = ["matmul_fp8_mixed", 'matmul_fp8_global'] -req_grad = list(product([True, False], repeat=3)) -req_grad_str = [] -for c in req_grad: - strval = '' - for v in c: - if v == True: strval += 'T' - else: strval += 'F' - req_grad_str.append(strval) - -transpose = [(False, True), (False, False)] -str_transpose = ["NT", "NN"] -dtype = [torch.float16, torch.float32] -has_fp16_weights = [True, False] -values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose)) -str_values = list(product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose)) -names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}".format(*vals) for vals in str_values] -@pytest.mark.parametrize( "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", values, ids=names) +@pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", [*get_test_dims(32, 96, n=1), 0], ids=id_formatter("dim2")) +@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3")) +@pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4")) +@pytest.mark.parametrize("req_grad", BOOLEAN_TRIPLES, ids=id_formatter("req_grad")) +@pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose")) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=describe_dtype) +@pytest.mark.parametrize("funcs", [(torch.matmul, bnb.research.matmul_fp8_mixed), (torch.matmul, bnb.research.matmul_fp8_global)], ids=["matmul_fp8_mixed", 'matmul_fp8_global']) def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) req_grad = list(req_grad) req_grad[2] = False - for i in range(k): + for i in range(3): # normal multiply if funcs[0] in [torch.mm, torch.matmul]: A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype) @@ -566,7 +466,6 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(A.device) bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(A.device) - if not transpose[0] and transpose[1]: out_torch = funcs[0](A, B.t()) out_bnb = funcs[1](A, B.t(), fw_code, bw_code) @@ -619,4 +518,3 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): torch.testing.assert_close( gradB1, gradB2, atol=0.18, rtol=0.3 ) - diff --git a/tests/test_cuda_setup_evaluator.py b/tests/test_cuda_setup_evaluator.py index 5e9ccf590..189aa75b5 100644 --- a/tests/test_cuda_setup_evaluator.py +++ b/tests/test_cuda_setup_evaluator.py @@ -1,13 +1,11 @@ import os -import pytest -import torch from pathlib import Path -from bitsandbytes.cextension import HIP_ENVIRONMENT +import torch + # hardcoded test. Not good, but a sanity check for now # TODO: improve this -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") def test_manual_override(requires_cuda): manual_cuda_path = str(Path('/mmfs1/home/dettmers/data/local/cuda-12.2')) @@ -21,8 +19,3 @@ def test_manual_override(requires_cuda): import bitsandbytes as bnb loaded_lib = bnb.cuda_setup.main.CUDASetup.get_instance().binary_name #assert loaded_lib == 'libbitsandbytes_cuda122.so' - - - - - diff --git a/tests/test_functional.py b/tests/test_functional.py index 5dba4ef5f..2ef6f451b 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1,17 +1,24 @@ +from itertools import product import math import random import time -from itertools import product import einops +import numpy as np import pytest +from scipy.stats import norm import torch -import numpy as np import bitsandbytes as bnb from bitsandbytes import functional as F from bitsandbytes.cextension import HIP_ENVIRONMENT -from scipy.stats import norm +from tests.helpers import ( + BOOLEAN_TUPLES, + TRUE_FALSE, + describe_dtype, + get_test_dims, + id_formatter, +) torch.set_printoptions( precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000 @@ -20,12 +27,12 @@ def assert_all_approx_close(a, b, rtol=1e-3, atol=1e-3, count=0, throw=True): - idx = torch.isclose(a, b, rtol, atol) + idx = torch.isclose(a, b, rtol=rtol, atol=atol) sumval = (idx == 0).sum().item() if sumval > count: if throw: print(f"Too many values not close: assert {sumval} < {count}") - torch.testing.assert_close(a, b, rtol, atol) + torch.testing.assert_close(a, b, rtol=rtol, atol=atol) return sumval @@ -91,6 +98,7 @@ def setup(): def teardown(): pass + @pytest.mark.parametrize( "dtype", [torch.float32, torch.float16], ids=["float", "half"] ) @@ -110,6 +118,7 @@ def test_estimate_quantiles(dtype): diff = torch.abs(code - quantiles) assert (diff > 5e-02).sum().item() == 0 + def test_quantile_quantization(): for i in range(100): A1 = torch.randn(1024, 1024, device="cuda") @@ -158,10 +167,10 @@ def get_blocksizes(hip_env=False): return [4096, 2048, 1024, 512, 256, 128, 64] else: return [4096, 2048, 1024, 512, 256, 128] -@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=["fp32", "fp16", "bf16"]) -@pytest.mark.parametrize("nested", [False, True], ids=["False", "True"]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) +@pytest.mark.parametrize("nested", TRUE_FALSE, ids=id_formatter("nested")) @pytest.mark.parametrize("blocksize", get_blocksizes(HIP_ENVIRONMENT)) -@pytest.mark.parametrize("signed", [True, False], ids=['signed_True', 'signed_False']) +@pytest.mark.parametrize("signed", TRUE_FALSE, ids=id_formatter("signed")) def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed): #print('') diffs = [] @@ -284,34 +293,22 @@ def mean(xx): return sum(xx) / float(len(xx)) -# dim1 = torch.randint(1,1024*4, size=(4,)).tolist() -# dim2 = torch.randint(1,1024*4, size=(4,)).tolist() -dim1 = [1024 * 2] -dim2 = [1024 * 16] -methods = [ - ( +methods = { + "linear": ( lambda x, dim: quant(x), lambda x, dim: quant(x), dequant, dequant, mm_dequant, - ) -] -methods.append((quant_multi, quant_multi, dequant, dequant, mm_dequant)) -# methods.append((lambda x: quant_multi_chunk(x, dim=-1), lambda x: quant_multi_chunk(x, dim=0), dequant, dequant, mm_dequant)) -method_names = ["linear", "vectorwise"] -batched = [False, True] -values = list(product(dim1, dim2, methods, batched)) -values_names = list(product(dim1, dim2, method_names, batched)) -names = [ - "dim1_{}_dim2_{}_quant_{}_batched_{}".format(*vals) - for vals in values_names -] + ), + "vectorwise": (quant_multi, quant_multi, dequant, dequant, mm_dequant), +} -@pytest.mark.parametrize( - "dim1, dim2, quant_methods, batched", values, ids=names -) +@pytest.mark.parametrize("dim1", [1024 * 2], ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", [1024 * 16], ids=id_formatter("dim2")) +@pytest.mark.parametrize("quant_methods", methods.values(), ids=methods.keys()) +@pytest.mark.parametrize("batched", TRUE_FALSE, ids=id_formatter("batched")) def test_approx_igemm(dim1, dim2, quant_methods, batched): dim1 = dim1 - (dim1 % 32) dim2 = dim2 - (dim2 % 32) @@ -355,21 +352,10 @@ def test_stable_embedding(): layer.reset_parameters() -n = 2 -hidden_dim = torch.randint(32, 256, size=(n,)).tolist() -batch_dim = torch.randint(16, 256, size=(n,)).tolist() -seq_dim = torch.randint(16, 256, size=(n,)).tolist() -transpose = [(False, False), (False, True), (True, False), (True, True)] -values = list(product(hidden_dim, batch_dim, transpose, seq_dim)) -names = [ - "hidden_dim_{}_batch_dim_{},transpose_{}_seq_dim_{}".format(*vals) - for vals in values -] - - -@pytest.mark.parametrize( - "hidden_dim, batch_dim, transpose, seq_dim", values, ids=names -) +@pytest.mark.parametrize("hidden_dim", get_test_dims(32, 256, n=2), ids=id_formatter("hidden_dim")) +@pytest.mark.parametrize("batch_dim", get_test_dims(16, 256, n=2), ids=id_formatter("batch_dim")) +@pytest.mark.parametrize("seq_dim", get_test_dims(16, 256, n=2), ids=id_formatter("seq_dim")) +@pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose")) def test_igemm(hidden_dim, batch_dim, transpose, seq_dim): hidden_dim = hidden_dim - (hidden_dim % 32) batch_dim = batch_dim - (batch_dim % 16) @@ -421,17 +407,9 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim): torch.testing.assert_close(out.float(), out2) -n = 3 -seq_dim = torch.randint(32, 512, size=(n,)).tolist() -hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist() -batch_dim = torch.randint(2, 16, size=(n,)).tolist() -values = list(product(seq_dim, hidden_dim, batch_dim)) -names = [ - "seq_dim{}_hidden_dim{}_batch_dim{}".format(*vals) for vals in values -] - - -@pytest.mark.parametrize("seq_dim, hidden_dim, batch_dim", values, ids=names) +@pytest.mark.parametrize("seq_dim", get_test_dims(32, 512, n=3), ids=id_formatter("seq_dim")) +@pytest.mark.parametrize("hidden_dim", get_test_dims(32, 1024 * 4, n=3), ids=id_formatter("hidden_dim")) +@pytest.mark.parametrize("batch_dim", get_test_dims(2, 16, n=3), ids=id_formatter("batch_dim")) def test_dim3_igemm(seq_dim, hidden_dim, batch_dim): seq_dim = seq_dim - (seq_dim % 32) hidden_dim = hidden_dim - (hidden_dim % 32) @@ -452,21 +430,10 @@ def test_dim3_igemm(seq_dim, hidden_dim, batch_dim): torch.testing.assert_close(out.float(), out2) -n = 2 -seq_dim = torch.randint(32, 512, size=(n,)).tolist() -hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist() -batch_dim = torch.randint(2, 16, size=(n,)).tolist() -transpose = [False, True] -values = list(product(seq_dim, hidden_dim, batch_dim, transpose)) -names = [ - "seq_dim={}_hidden_dim={}_batch_dim={}_transpose{}".format(*vals) - for vals in values -] - - -@pytest.mark.parametrize( - "seq_dim, hidden_dim, batch_dim, transpose", values, ids=names -) +@pytest.mark.parametrize("seq_dim", get_test_dims(32, 512, n=2), ids=id_formatter("seq_dim")) +@pytest.mark.parametrize("hidden_dim", get_test_dims(32, 1024 * 4, n=2), ids=id_formatter("hidden_dim")) +@pytest.mark.parametrize("batch_dim", get_test_dims(2, 16, n=2), ids=id_formatter("batch_dim")) +@pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose")) def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose): def min_max(x): maxA = torch.amax(x, dim=2, keepdim=True) @@ -536,20 +503,11 @@ def min_max(x): assert mean(relerrs) < 0.3 -n = 2 -dim1 = torch.randint(1, 64, size=(n,)).tolist() -dim2 = torch.randint(32, 128, size=(n,)).tolist() -dim3 = torch.randint(32, 256, size=(n,)).tolist() -dim4 = torch.randint(32, 256, size=(n,)).tolist() -transpose = [(False, False), (True, False), (False, True), (True, True)] -values = list(product(dim1, dim2, dim3, dim4, transpose)) -names = [ - "dim1_{}_dim2_{}_dim3_{}_dim4_{}_transpose_{}".format(*vals) - for vals in values -] - - -@pytest.mark.parametrize("dim1, dim2, dim3, dim4, transpose", values, ids=names) +@pytest.mark.parametrize("dim1", get_test_dims(1, 64, n=2), ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", get_test_dims(32, 128, n=2), ids=id_formatter("dim2")) +@pytest.mark.parametrize("dim3", get_test_dims(32, 256, n=2), ids=id_formatter("dim3")) +@pytest.mark.parametrize("dim4", get_test_dims(32, 256, n=2), ids=id_formatter("dim4")) +@pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose")) def test_ibmm(dim1, dim2, dim3, dim4, transpose): dim2 = dim2 - (dim2 % 16) dim3 = dim3 - (dim3 % 16) @@ -576,16 +534,10 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose): out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1])) torch.testing.assert_close(out.float(), out2.float()) - -n = 1 -dim1 = torch.randint(1, 64, size=(n,)).tolist() -dim2 = torch.randint(32, 128, size=(n,)).tolist() -dim3 = torch.randint(32, 256, size=(n,)).tolist() -values = list(product(dim1, dim2, dim3)) -names = ["dim1_{}_dim2_{}_dim3_{}".format(*vals) for vals in values] - +@pytest.mark.parametrize("dim1", get_test_dims(1, 64, n=1), ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", get_test_dims(32, 128, n=1), ids=id_formatter("dim2")) +@pytest.mark.parametrize("dim3", get_test_dims(32, 256, n=1), ids=id_formatter("dim3")) @pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") -@pytest.mark.parametrize("dim1, dim2, dim3", values, ids=names) def test_vector_quant(dim1, dim2, dim3): dim2 = dim2 - (dim2 % 16) dim3 = dim3 - (dim3 % 16) @@ -597,28 +549,18 @@ def test_vector_quant(dim1, dim2, dim3): assert_all_approx_close(A1, A, atol=0.01, rtol=0.1, count=int(n*0.002)) - - -n = 2 -dim1 = torch.randint(2, 256, size=(n,)).tolist() -dim2 = torch.randint(2, 256, size=(n,)).tolist() -dim3 = torch.randint(2, 256, size=(n,)).tolist() -# dim1, dim2 = (256,), (256,) -dtype = [torch.int8, torch.int32] -a_order = ["row"] -out_order = ["col", "row"] if HIP_ENVIRONMENT else ["col", "row", "col32"] -transpose = [False] -dims = [2, 3] -values = list(product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose)) - -names = ["dim1_{}_dim2_{}_dim3_{}_dims_{}_dtype_{}_orderA_{}_orderOut_{}_transpose_{}".format(*vals)for vals in values] - - -@pytest.mark.parametrize("dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose",values,ids=names) +@pytest.mark.parametrize("dim1", get_test_dims(2, 256, n=2), ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", get_test_dims(2, 256, n=2), ids=id_formatter("dim2")) +@pytest.mark.parametrize("dim3", get_test_dims(2, 256, n=2), ids=id_formatter("dim3")) +@pytest.mark.parametrize("dtype", [torch.int8, torch.int32], ids=describe_dtype) +@pytest.mark.parametrize("orderA", ["row"], ids=id_formatter("orderA")) +@pytest.mark.parametrize("orderOut", ["col", "row", "col32"], ids=id_formatter("orderOut")) +@pytest.mark.parametrize("transpose", [False], ids=id_formatter("transpose")) +@pytest.mark.parametrize("dims", [2, 3], ids=id_formatter("dims")) def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): - if dims == 3 and out_order != "col32": + if dims == 3 and orderOut != "col32": return - if dtype == torch.int32 and out_order != "col32": + if dtype == torch.int32 and orderOut != "col32": return try: func = F.get_transform_func(dtype, orderA, orderOut, transpose) @@ -680,27 +622,12 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans torch.testing.assert_close(A, out2) -n = 1 -dim1 = torch.randint(1, 256, size=(n,)).tolist() -dim2 = torch.randint(32, 512, size=(n,)).tolist() -dim3 = torch.randint(32, 1024, size=(n,)).tolist() -dim4 = torch.randint(32, 1024, size=(n,)).tolist() - -# dim1 = [2] -# dim2 = [2] -# dim3 = [2] -# dim4 = [2] - -dims = (2, 3) -ldb = [0] -# ldb = list(range(256, 1*1024, 256)) -values = list(product(dim1, dim2, dim3, dim4, dims, ldb)) -names = [ - "dim1_{}_dim2_{}_dim3_{}_dim4_{}_dims_{}_ldb_{}".format(*vals) - for vals in values -] - -@pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims, ldb", values, ids=names) +@pytest.mark.parametrize("dim1", get_test_dims(1, 256, n=1), ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", get_test_dims(32, 512, n=1), ids=id_formatter("dim2")) +@pytest.mark.parametrize("dim3", get_test_dims(32, 1024, n=1), ids=id_formatter("dim3")) +@pytest.mark.parametrize("dim4", get_test_dims(32, 1024, n=1), ids=id_formatter("dim4")) +@pytest.mark.parametrize("dims", (2, 3), ids=id_formatter("dims")) +@pytest.mark.parametrize("ldb", (0,), ids=id_formatter("ldb")) def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb): for i in range(k): if dims == 2: @@ -722,7 +649,7 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb): C3, S = F.nvidia_transform(C2, "row", state=SC) torch.testing.assert_close(C1, C3.float()) - ## transpose + # transpose B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to( torch.int8 ) @@ -734,20 +661,11 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb): torch.testing.assert_close(C1, C3.float()) -dim1 = [32] -dim2 = [32] -dim3 = [32] -dim4 = [32] - -dims = (2,) -# ldb = list(range(256, 1*1024, 256)) -values = list(product(dim1, dim2, dim3, dim4, dims)) -names = [ - "dim1_{}_dim2_{}_dim3_{}_dim4_{}_dims_{}".format(*vals) - for vals in values -] - -@pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims", values, ids=names) +@pytest.mark.parametrize("dim1", [32], ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", [32], ids=id_formatter("dim2")) +@pytest.mark.parametrize("dim3", [32], ids=id_formatter("dim3")) +@pytest.mark.parametrize("dim4", [32], ids=id_formatter("dim4")) +@pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims")) def test_igemmlt_half(dim1, dim2, dim3, dim4, dims): formatB = F.get_special_format_str() for i in range(k): @@ -787,24 +705,15 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims): # C3, S = F.transform(C2, 'row', state=SC) # torch.testing.assert_close(C1, C3.float()) - -batch_size = 2 -seqdim = 512 -# values = [(batch_size, seqdim, 4*1024, 16*1024),(batch_size, seqdim, 5120, 4*5120),(batch_size, seqdim, 12*1024, 4*12*1024)] -values = [ - (batch_size, seqdim, 4 * 1024, 3 * 4 * 1024), - (batch_size, seqdim, 5120, 3 * 5120), - (batch_size, seqdim, 12 * 1024, 4 * 12 * 1024), -] - - -# values = list(product(batch, seq, model, hidden)) -names = [ - "batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values -] - - -@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names) +@pytest.mark.parametrize( + ("batch", "seq", "model", "hidden"), + [ + pytest.param(2, 512, 4 * 1024, 3 * 4 * 1024, id="batch=2, seq=512, model=4k, hidden=12k"), + pytest.param(2, 512, 5120, 3 * 5120, id="batch=2, seq=512, model=5k, hidden=15k"), + pytest.param(2, 512, 12 * 1024, 4 * 12 * 1024, id="batch=2, seq=512, model=12k, hidden=48k"), + ], +) +@pytest.mark.benchmark def test_bench_8bit_training(batch, seq, model, hidden): formatB = F.get_special_format_str() A = torch.randn(batch, seq, model, device="cuda").half() @@ -954,23 +863,11 @@ def test_bench_8bit_training(batch, seq, model, hidden): # print(t8) -n = 2 -dim1 = torch.randint(64, 256, size=(n,)).tolist() -dim4 = torch.randint(64, 1024, size=(n,)).tolist() - -#dim1 = [2*1024] -#dim4 = [2*1024] - -#dim1 = [4] -#dim4 = [4] - -dims = (2,) -formatB = ["col_turing", "col_ampere"] -has_bias = [True, False] -values = list(product(dim1, dim4, dims, formatB, has_bias)) -names = ["dim1_{}_dim4_{}_dims_{}_formatB_{}_has_bias_{}".format(*vals) for vals in values] - -@pytest.mark.parametrize("dim1, dim4, dims, formatB, has_bias", values, ids=names) +@pytest.mark.parametrize("dim1", get_test_dims(64, 256, n=2), ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim4", get_test_dims(64, 1024, n=2), ids=id_formatter("dim4")) +@pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims")) +@pytest.mark.parametrize("formatB", ["col_turing", "col_ampere"], ids=id_formatter("formatB")) +@pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias")) def test_dequant_mm(dim1, dim4, dims, formatB, has_bias): inner = torch.randint(1, 128, size=(1,)).item() bias = None @@ -994,33 +891,23 @@ def test_dequant_mm(dim1, dim4, dims, formatB, has_bias): if has_bias: C4 += bias # TODO: is something wrong here? If so, the problem goes deeper - #n = C1.numel() - #p = 0.06 + # n = C1.numel() + # p = 0.06 std = C1.std(0).view(1, -1) C1 /= std C4 /= std - #assert_all_approx_close(C1, C4, atol=0.02, rtol=0.1, count=int(n*0.06)) - #assert (count / n < p), f"error in more than {p} of elements: {count}/{n}={count/n}" + # assert_all_approx_close(C1, C4, atol=0.02, rtol=0.1, count=int(n*0.06)) + # assert (count / n < p), f"error in more than {p} of elements: {count}/{n}={count/n}" C5 = F.mm_dequant(C2, SC, maxA.flatten(), maxB.flatten(), bias=bias) - #torch.testing.assert_close(C5, C4, atol=0.015, rtol=0.1) + # torch.testing.assert_close(C5, C4, atol=0.015, rtol=0.1) n = C5.numel() - assert_all_approx_close(C1, C4, atol=0.015, rtol=0.1, count=int(0.01*n)) - - -n = 2 -dim1 = [1 * 1024] -dim2 = [1 * 1024] -# dim1 = torch.randint(1,4*1024, size=(n,)).tolist() -# dim2 = torch.randint(1,4*1024, size=(n,)).tolist() - -dims = (2,) -# ldb = list(range(256, 1*1024, 256)) -values = list(product(dim1, dim2, dims)) -names = ["dim1_{}_dim2_{}_dims_{}".format(*vals) for vals in values] + assert_all_approx_close(C1, C4, atol=0.015, rtol=0.1, count=int(0.01 * n)) -@pytest.mark.parametrize("dim1, dim2, dims", values, ids=names) +@pytest.mark.parametrize("dim1", [1 * 1024], ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", [1 * 1024], ids=id_formatter("dim2")) +@pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims")) def test_colrow_absmax(dim1, dim2, dims): for i in range(k): threshold = 3.0 @@ -1066,17 +953,8 @@ def test_colrow_absmax(dim1, dim2, dims): assert nnz_block_ptr2 is None -n = 2 -# dim1 = [8*1024] -# dim2 = [4*1024] -dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist() -dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist() - -values = list(product(dim1, dim2)) -names = ["dim1_{}_dim2_{}".format(*vals) for vals in values] - - -@pytest.mark.parametrize("dim1, dim2", values, ids=names) +@pytest.mark.parametrize("dim1", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim2")) def test_double_quant(dim1, dim2): for i in range(k): A = torch.randn(dim1, dim2, device="cuda").half() @@ -1114,16 +992,18 @@ def test_double_quant(dim1, dim2): torch.testing.assert_close(Scol.flatten().float(), statsAt) -n = 4 -dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist() -dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist() -inner = torch.randint(1, 4 * 1024, size=(n,)).tolist() - -values = list(zip(dim1, dim4, inner)) -names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values] - -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") -@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names) +@pytest.mark.parametrize( + ("dim1", "dim4", "inner"), + ( + pytest.param(dim1, dim4, inner, id=f"{dim1=},{dim4=},{inner=}") + for (dim1, dim4, inner) + in zip( + get_test_dims(1, 4 * 1024, n=4), + get_test_dims(1, 4 * 1024, n=4), + get_test_dims(1, 4 * 1024, n=4), + ) + ) +) def test_integrated_igemmlt(dim1, dim4, inner): for i in range(k): A = torch.randn(dim1, inner, device="cuda").half() @@ -1158,16 +1038,18 @@ def test_integrated_igemmlt(dim1, dim4, inner): assert err2 <= err1 * 1.025 -n = 6 -dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist() -dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist() -inner = torch.randint(1, 4 * 1024, size=(n,)).tolist() - -values = list(zip(dim1, dim4, inner)) -names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values] - - -@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names) +@pytest.mark.parametrize( + ("dim1", "dim4", "inner"), + ( + pytest.param(dim1, dim4, inner, id=f"{dim1=},{dim4=},{inner=}") + for (dim1, dim4, inner) + in zip( + get_test_dims(1, 4 * 1024, n=6), + get_test_dims(1, 4 * 1024, n=6), + get_test_dims(1, 4 * 1024, n=6), + ) + ) +) @pytest.mark.skip("Row scale has some bugs for ampere") def test_igemmlt_row_scale(dim1, dim4, inner): formatB = F.get_special_format_str() @@ -1234,17 +1116,17 @@ def test_igemmlt_row_scale(dim1, dim4, inner): print(sum(err3) / len(err3)) -dim1 = [1024, 2048] -inner = [12288 * 4, 4096 * 4] -dim4 = [12288, 4096] - -values = list(zip(dim1, dim4, inner)) -names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values] - - -@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names) +@pytest.mark.parametrize( + ("dim1", "dim4", "inner"), + [ + pytest.param(1024, 12288 * 4, 12288, id="1024, 12288*4, 12288"), + pytest.param(2048, 4096 * 4, 4096, id="2048, 4096*4, 4096"), + ], +) @pytest.mark.skip("Row scale has some bugs for ampere") +@pytest.mark.benchmark def test_row_scale_bench(dim1, dim4, inner): + formatB = F.get_special_format_str() err1, err2, err3 = [], [], [] relerr1, relerr2 = [], [] scale = 1 @@ -1289,33 +1171,14 @@ def test_row_scale_bench(dim1, dim4, inner): print("vector-wise", time.time() - t0) -n = 2 -dim1 = torch.randint(2, 1024, size=(n,)).tolist() -dim2 = torch.randint(2, 1024, size=(n,)).tolist() -# dim1 = [8*1024] -# dim2 = [4*1024] - -dim3 = [0] -dtype = [torch.int8] -a_order = ["row"] -out_order = ["col32", "col_turing", "col_ampere"] -transpose = [False, True] -dims = [2] -values = list( - product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose) -) -names = [ - "dim1_{}_dim2_{}_dim3_{}_dims_{}_dtype_{}_orderA_{}_orderOut_{}_{}".format( - *vals - ) - for vals in values -] - -@pytest.mark.parametrize( - "dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose", - values, - ids=names, -) +@pytest.mark.parametrize("dim1", get_test_dims(2, 1024, n=2), ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", get_test_dims(2, 1024, n=2), ids=id_formatter("dim2")) +@pytest.mark.parametrize("dim3", [0], ids=id_formatter("dim3")) +@pytest.mark.parametrize("dims", [2], ids=id_formatter("dims")) +@pytest.mark.parametrize("dtype", [torch.int8], ids=describe_dtype) +@pytest.mark.parametrize("orderA", ["row"], ids=id_formatter("orderA")) +@pytest.mark.parametrize("orderOut", ["col32", "col_turing", "col_ampere"], ids=id_formatter("orderOut")) +@pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose")) def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): for i in range(k): if dims == 2: @@ -1343,23 +1206,6 @@ def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): torch.testing.assert_close(out1, out2) -n = 2 -# dim1 = torch.randint(2,1024, size=(n,)).tolist() -# dim2 = torch.randint(2,1024, size=(n,)).tolist() -dim1 = [1] -dim2 = [33] - -dtype = [torch.int8] -# a_order = ['col_turing', 'col_ampere'] -a_order = ["col_turing"] -out_order = ["row"] -values = list(product(dim1, dim2, dtype, a_order, out_order)) -names = [ - "dim1_{}_dim2_{}_dtype_{}_orderA_{}_orderOut_{}".format(*vals) - for vals in values -] - -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") def test_overflow(): formatB = F.get_special_format_str() print(formatB) @@ -1374,17 +1220,8 @@ def test_overflow(): c2 = torch.matmul(a.float(), b.float().t()) -n = 2 -dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist() -dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist() -# dim1 = [4] -# dim2 = [5] - -values = list(product(dim1, dim2)) -names = ["dim1_{}_dim2_{}".format(*vals) for vals in values] - - -@pytest.mark.parametrize("dim1, dim2", values, ids=names) +@pytest.mark.parametrize("dim1", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim2")) def test_coo_double_quant(dim1, dim2): threshold = 3.00 for i in range(k): @@ -1411,17 +1248,9 @@ def test_coo_double_quant(dim1, dim2): ) -n = 2 -dim1 = torch.randint(1, 1 * 1024, size=(n,)).tolist() -dim2 = torch.randint(1, 1 * 1024, size=(n,)).tolist() -# dim1 = [7] -# dim2 = [11] -transposed_B = [False, True] -values = list(product(dim1, dim2, transposed_B)) -names = ["dim1_{}_dim2_{}_transposed_B_{}".format(*vals) for vals in values] - -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") -@pytest.mark.parametrize("dim1, dim2, transposed_B", values, ids=names) +@pytest.mark.parametrize("dim1", get_test_dims(1, 1 * 1024, n=2), ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", get_test_dims(1, 1 * 1024, n=2), ids=id_formatter("dim2")) +@pytest.mark.parametrize("transposed_B", TRUE_FALSE, ids=id_formatter("transposed_B")) def test_spmm_coo(dim1, dim2, transposed_B): threshold = 1.5 dim3 = torch.randint(32, 128, size=(1,)).item() @@ -1452,7 +1281,7 @@ def test_spmm_coo(dim1, dim2, transposed_B): assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=30) -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") +@pytest.mark.benchmark def test_spmm_bench(): batch = 2 model = 1024 * 1 @@ -1496,14 +1325,8 @@ def test_spmm_bench(): print(tsp / t8) -n = 2 -dim1 = torch.randint(256, 1 * 1024, size=(n,)).tolist() -dim2 = torch.randint(256, 1 * 1024, size=(n,)).tolist() -values = list(product(dim1, dim2)) -names = ["dim1_{}_dim2_{}".format(*vals) for vals in values] - -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") -@pytest.mark.parametrize("dim1, dim2", values, ids=names) +@pytest.mark.parametrize("dim1", get_test_dims(256, 1024, n=2), ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", get_test_dims(256, 1024, n=2), ids=id_formatter("dim2")) def test_integrated_sparse_decomp(dim1, dim2): threshold = 3.0 formatB = "col_turing" @@ -1553,23 +1376,10 @@ def test_matmuls(): print(err1, err2) -n = 2 -# dim1 = torch.randint(1,1*1024, size=(n,)).tolist() -# dim2 = torch.randint(1,4*1024, size=(n,)).tolist() -dim1 = [1 * 2048] -dim2 = [12288] -# dim1 = [32] -# dim2 = [32] -# dtype = [torch.float16, torch.int8] -dtype = [torch.float16] -out_function = ["zeros", "ones"] -values = list(product(dim1, dim2, dtype, out_function)) -names = [ - "dim1_{}_dim2_{}_dtype_{}_out_func_{}".format(*vals) for vals in values -] - - -@pytest.mark.parametrize("dim1, dim2, dtype, out_func", values, ids=names) +@pytest.mark.parametrize("dim1", [1 * 2048], ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", [12288], ids=id_formatter("dim2")) +@pytest.mark.parametrize("dtype", [torch.float16], ids=describe_dtype) +@pytest.mark.parametrize("out_func", ["zeros", "ones"], ids=id_formatter("out_func")) def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func): out_func = getattr(torch, out_func) @@ -1672,21 +1482,9 @@ def test_coo2csc(): torch.testing.assert_close(A2.t()[idx], cscA.values) -n = 2 -# dim1 = torch.randint(1,1*1024, size=(n,)).tolist() -# dim2 = torch.randint(1,4*1024, size=(n,)).tolist() -dim1 = [1 * 2048] -# dim2 = [12288] -dim2 = [2048] -# dim1 = [2] -# dim2 = [2] -dtype = [torch.int8] -values = list(product(dim1, dim2, dtype)) -names = ["dim1_{}_dim2_{}_dtype_{}".format(*vals) for vals in values] - - -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") -@pytest.mark.parametrize("dim1, dim2, dtype", values, ids=names) +@pytest.mark.parametrize("dim1", [1 * 2048]) +@pytest.mark.parametrize("dim2", [2048]) +@pytest.mark.parametrize("dtype", [torch.int8]) def test_spmm_coo_dequant(dim1, dim2, dtype): threshold = 6.0 # threshold = 2.8 @@ -1787,22 +1585,11 @@ def test_spmm_coo_dequant(dim1, dim2, dtype): print("partial matmul", time.time() - t0) -batch_size = 1 -seqdim = 1 -values = [] -#values.append((batch_size, seqdim, 768, 4 * 768)) -#values.append((batch_size, seqdim, 1024, 4*1024)) -#values.append((batch_size, seqdim, 1536, 4*1536)) -#values.append((batch_size, seqdim, 2048, 4*2048)) -#values.append((batch_size, seqdim, 2560, 4*2560)) -#values.append((batch_size, seqdim, 4096, 4*4096)) -#values.append((batch_size, seqdim, 5120, 4*5120)) -values.append((batch_size, seqdim, 6656, 4*6656)) -#values.append((batch_size, seqdim, 8192, 4*8192)) -#values.append((batch_size, seqdim, 5140, 4*5140)) -#values.append((batch_size, seqdim, 12288, 4*12288)) -names = ["batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values] -@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names) +@pytest.mark.parametrize( + ("batch", "seq", "model", "hidden"), + [pytest.param(1, 1, 6656, 4*6656, id="batch=1, seq=1, model=6656, hidden=26k")], +) +@pytest.mark.benchmark def test_bench_matmul(batch, seq, model, hidden): iters = 1000 formatB = F.get_special_format_str() @@ -2227,6 +2014,7 @@ def test_kbit_quantile_estimation(): assert err < 0.035 +@pytest.mark.benchmark def test_bench_dequantization(): a = torch.rand(1024, 1024, device='cuda').half() code =F.create_fp8_map(True, 3, 0, 4).cuda() @@ -2245,8 +2033,7 @@ def test_bench_dequantization(): -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") -@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=["fp32", "fp16", "bf16"]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) def test_fp4_quant(dtype): vals = list(product([0, 1], repeat=4)) @@ -2270,7 +2057,8 @@ def test_fp4_quant(dtype): code[idx] = result A1 = torch.randn(1024, 1024, device='cuda', dtype=dtype) - qa, SA = F.quantize_fp4(A1, blocksize=64) + # Can we just assign blocksize to 128 and get it to run? + qa, SA = F.quantize_fp4(A1, blocksize=128) A2 = F.dequantize_fp4(qa, SA) err = (A1 - A2).abs().float() @@ -2279,14 +2067,13 @@ def test_fp4_quant(dtype): err = err.mean() assert A2.dtype == dtype - assert err.item() < 0.1 + assert err.item() < 0.11 assert relerr.item() < 0.28 @pytest.mark.parametrize("quant_type", ['fp4', 'nf4']) def test_4bit_compressed_stats(quant_type): - blocksizes = [128, 64] if not HIP_ENVIRONMENT else [128] - for blocksize in blocksizes: + for blocksize in [256, 128]: errs1 = [] errs2 = [] for i in range(10): @@ -2305,7 +2092,7 @@ def test_4bit_compressed_stats(quant_type): assert err.item() < 0.11 - assert relerr.item() < 0.28 + assert relerr.item() < 0.30 err = (A1 - A3).abs().float() relerr = (err/(A1.abs().float()+1e-15)).mean() @@ -2314,7 +2101,7 @@ def test_4bit_compressed_stats(quant_type): errs2.append(err.item()) assert err.item() < 0.11 - assert relerr.item() < 0.28 + assert relerr.item() < 0.30 #print(sum(errs1)/len(errs1), blocksize, quant_type) #print(sum(errs2)/len(errs2), blocksize, quant_type) @@ -2324,6 +2111,7 @@ def test_4bit_compressed_stats(quant_type): #@pytest.mark.parametrize("quant_type", ['fp4', 'nf4']) @pytest.mark.parametrize("quant_type", ['nf4']) +@pytest.mark.benchmark def test_bench_4bit_dequant(quant_type): blocksize = 256 a = torch.rand(1024*12*4, 1024*12, device='cuda').half() @@ -2370,12 +2158,12 @@ def test_normal_map_tree(): #print(pivots) -@pytest.mark.parametrize("double_quant", [True, False], ids=['DQ_True', 'DQ_False']) -@pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4']) -@pytest.mark.parametrize("kind", ['fc1', 'fc2', 'attn', 'attn_packed'], ids=['fc1', 'fc2', 'attn', 'attn_packed']) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=['fp16', 'bf16', 'fp32']) -@pytest.mark.parametrize("quant_storage", [torch.uint8, torch.float16, torch.bfloat16, torch.float32], ids=['uint8', 'fp16', 'bf16', 'fp32']) -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="gemv 4bit tests are partially enabled on MI300, others being fixed for warpsize 64") +@pytest.mark.parametrize("double_quant", TRUE_FALSE, ids=lambda double_quant: f"DQ_{double_quant}") +@pytest.mark.parametrize("storage_type", ['nf4', 'fp4']) +@pytest.mark.parametrize("kind", ['fc1', 'fc2', 'attn', 'attn_packed']) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) +@pytest.mark.parametrize("quant_storage", [torch.uint8, torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) +#@pytest.mark.skipif(HIP_ENVIRONMENT, reason="gemv 4bit tests are partially enabled on MI300, others being fixed for warpsize 64") def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind): for dim in [128, 256, 512, 1024]: #for dim in [4*1024]: @@ -2541,12 +2329,12 @@ def test_managed(): @pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4']) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=['fp16', 'bf16', 'fp32']) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) @pytest.mark.parametrize("double_quant", [False], ids=['DQ_True']) def test_gemv_eye_4bit(storage_type, dtype, double_quant): dims = 10 torch.random.manual_seed(np.random.randint(0, 412424242)) - dims = torch.randint(0, 8192, size=(dims,)).tolist() + dims = get_test_dims(0, 8192, n=dims) dims = [dim + (64-(dim % 64)) for dim in dims] #for dim in [576, 5120, 3520, 5184, 1280, 4992, 5312, 2048]: for dim in dims: @@ -2564,5 +2352,3 @@ def test_gemv_eye_4bit(storage_type, dtype, double_quant): torch.testing.assert_close(A, C2) #torch.testing.assert_close(A, C1, rtol=1e-5, atol=0.00001) #torch.testing.assert_close(A, C2, rtol=1e-5, atol=0.080) - - diff --git a/tests/test_generation.py b/tests/test_generation.py index 54ec10475..576be5a21 100644 --- a/tests/test_generation.py +++ b/tests/test_generation.py @@ -1,21 +1,14 @@ -import pytest -import torch -import math - from itertools import product +import math +import pytest +import torch import transformers from transformers import ( - AutoConfig, AutoModelForCausalLM, - AutoTokenizer, BitsAndBytesConfig, - GenerationConfig, - set_seed, - ) -import bitsandbytes as bnb from bitsandbytes.cextension import HIP_ENVIRONMENT @@ -134,4 +127,3 @@ def test_pi(requires_cuda, model_and_tokenizer, inference_kernel, DQ): raise ValueError(f'Failure count: {failure_count}/{n_cases}') - diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index 478255eee..9c33ab37c 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -1,31 +1,36 @@ +import copy import os -from contextlib import nullcontext -from itertools import product +import pickle from tempfile import TemporaryDirectory import pytest import torch import bitsandbytes as bnb +from tests.helpers import TRUE_FALSE, torch_load_from_buffer, torch_save_to_buffer storage = { - 'uint8': torch.uint8, - 'float16': torch.float16, - 'bfloat16': torch.bfloat16, - 'float32': torch.float32 + "uint8": torch.uint8, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "float32": torch.float32, } -@pytest.mark.parametrize( - "quant_type, compress_statistics, bias, quant_storage", - list(product(["nf4", "fp4"], [False, True], [False, True], ['uint8', 'float16', 'bfloat16', 'float32'])), -) -def test_linear_serialization(quant_type, compress_statistics, bias, quant_storage): + +@pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"]) +@pytest.mark.parametrize("bias", TRUE_FALSE) +@pytest.mark.parametrize("compress_statistics", TRUE_FALSE) +@pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) +@pytest.mark.parametrize("save_before_forward", TRUE_FALSE) +def test_linear_serialization(quant_type, compress_statistics, bias, quant_storage, save_before_forward): original_dtype = torch.float16 compute_dtype = None device = "cuda" layer_shape = (300, 400) - linear = torch.nn.Linear(*layer_shape, dtype=original_dtype, device="cpu") # original layer + linear = torch.nn.Linear( + *layer_shape, dtype=original_dtype, device="cpu" + ) # original layer # Quantizing original layer linear_q = bnb.nn.Linear4bit( @@ -37,7 +42,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora quant_type=quant_type, device="meta", ) - new_weight = bnb.nn.Params4bit(data=linear.weight, quant_type=quant_type, requires_grad=False) + new_weight = bnb.nn.Params4bit( + data=linear.weight, quant_type=quant_type, requires_grad=False + ) linear_q.weight = new_weight if bias: linear_q.bias = torch.nn.Parameter(linear.bias) @@ -81,7 +88,12 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora quant_storage=storage[quant_storage], device="meta", ) - linear_qs.weight = bnb.nn.Params4bit(data=linear.weight, requires_grad=False, quant_type=quant_type, quant_storage=storage[quant_storage]) + linear_qs.weight = bnb.nn.Params4bit( + data=linear.weight, + requires_grad=False, + quant_type=quant_type, + quant_storage=storage[quant_storage], + ) if bias: linear_qs.bias = torch.nn.Parameter(linear.bias) linear_qs = linear_qs.to(device) @@ -92,7 +104,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora q0 = a.quant_state q1 = b.quant_state - for attr in ('code', 'dtype', 'blocksize', 'absmax'): + for attr in ("code", "dtype", "blocksize", "absmax"): c, d = getattr(q0, attr), getattr(q1, attr) if isinstance(c, torch.Tensor): assert torch.equal(c, d) @@ -100,7 +112,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora assert c == d, f"{c} != {d}" if q0.state2 is not None: - for attr in ('code', 'dtype', 'blocksize', 'absmax'): + for attr in ("code", "dtype", "blocksize", "absmax"): c, d = getattr(q0.state2, attr), getattr(q1.state2, attr) if isinstance(c, torch.Tensor): assert torch.equal(c, d) @@ -113,6 +125,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora assert a.dtype == b.dtype assert torch.equal(a, b) + if save_before_forward: + bytes_4bit = torch_save_to_buffer(linear_q) + # Forward test x = torch.rand(42, layer_shape[0], device=device) a = linear_q(x) @@ -125,14 +140,23 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora assert torch.equal(a, b) assert torch.equal(a, c) + if not save_before_forward: + bytes_4bit = torch_save_to_buffer(linear_q) + linear_q3 = torch_load_from_buffer(bytes_4bit) + # Test moving to CPU and back to GPU - linear_q2.to('cpu') + linear_q2.to("cpu") linear_q2.to(device) d = linear_qs(x) assert c.dtype == d.dtype assert c.device == d.device assert torch.equal(c, d) + d = linear_q3(x) + assert c.dtype == d.dtype + assert c.device == d.device + assert torch.equal(c, d) + # Saved size ratio test. Target set for layer_shape == (300, 400) w/ bias with TemporaryDirectory() as tmpdir: state_path_4bit = os.path.join(tmpdir, "state_4bit.pth") @@ -140,10 +164,47 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora torch.save(linear.state_dict(), state_path) torch.save(linear_q.state_dict(), state_path_4bit) - size_orig, size_4 = os.path.getsize(state_path), os.path.getsize( - state_path_4bit + size_orig, size_4 = ( + os.path.getsize(state_path), + os.path.getsize(state_path_4bit), ) size_ratio = size_4 / size_orig - target_compression = 0.143 if original_dtype == torch.float32 else 0.29 # these numbers get lower as weight shape increases + target_compression = ( + 0.143 if original_dtype == torch.float32 else 0.29 + ) # these numbers get lower as weight shape increases ratio_error_msg = f"quantized_size {size_4:,} is larger on disk than {target_compression:.2%} of original size {size_orig:,}" assert size_ratio < target_compression, ratio_error_msg + + +def test_copy_param(): + tensor = torch.tensor([1.0, 2.0, 3.0, 4.0]) + param = bnb.nn.Params4bit(data=tensor, requires_grad=False).cuda(0) + + shallow_copy_param = copy.copy(param) + assert param.quant_state is shallow_copy_param.quant_state + assert param.data.data_ptr() == shallow_copy_param.data.data_ptr() + + +def test_deepcopy_param(): + tensor = torch.tensor([1.0, 2.0, 3.0, 4.0]) + param = bnb.nn.Params4bit(data=tensor, requires_grad=False).cuda(0) + copy_param = copy.deepcopy(param) + assert param.quant_state is not copy_param.quant_state + assert param.data.data_ptr() != copy_param.data.data_ptr() + + +def test_params4bit_real_serialization(): + original_tensor = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float32) + original_param = bnb.nn.Params4bit(data=original_tensor, quant_type="fp4") + + original_param.cuda(0) # move to CUDA to trigger quantization + + serialized_param = pickle.dumps(original_param) + deserialized_param = pickle.loads(serialized_param) + + assert torch.equal(original_param.data, deserialized_param.data) + assert original_param.requires_grad == deserialized_param.requires_grad == False + assert original_param.quant_type == deserialized_param.quant_type + assert original_param.blocksize == deserialized_param.blocksize + assert original_param.compress_statistics == deserialized_param.compress_statistics + assert original_param.quant_state == deserialized_param.quant_state \ No newline at end of file diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py index 6d5fc6a82..edc3409cd 100644 --- a/tests/test_linear8bitlt.py +++ b/tests/test_linear8bitlt.py @@ -1,6 +1,5 @@ -import os from contextlib import nullcontext -from itertools import product +import os from tempfile import TemporaryDirectory import pytest @@ -10,12 +9,16 @@ from bitsandbytes import functional as F from bitsandbytes.autograd import get_inverse_transform_indices, undo_layout from bitsandbytes.nn.modules import Linear8bitLt -from bitsandbytes.cextension import HIP_ENVIRONMENT +from tests.helpers import ( + TRUE_FALSE, + id_formatter, + torch_load_from_buffer, + torch_save_to_buffer, +) # contributed by Alex Borzunov, see: # https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") @pytest.mark.skipif( not torch.cuda.is_available() or torch.cuda.get_device_capability() < (7, 5), reason="this test requires a turing-generation or newer GPU, see bitsandbytes docs", @@ -68,10 +71,13 @@ def test_linear_no_igemmlt(): assert linear_custom.state.CxB is None -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") -@pytest.mark.parametrize("has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt", - list(product([False, True], [False, True], [False, True], [False, True]))) -def test_linear_serialization(has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt): +@pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights")) +@pytest.mark.parametrize("serialize_before_forward", TRUE_FALSE, ids=id_formatter("serialize_before_forward")) +@pytest.mark.parametrize("deserialize_before_cuda", TRUE_FALSE, ids=id_formatter("deserialize_before_cuda")) +@pytest.mark.parametrize("force_no_igemmlt", TRUE_FALSE, ids=id_formatter("force_no_igemmlt")) +@pytest.mark.parametrize("save_before_forward", TRUE_FALSE, ids=id_formatter("save_before_forward")) +@pytest.mark.parametrize("load_before_cuda", TRUE_FALSE, ids=id_formatter("load_before_cuda")) +def test_linear_serialization(has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt, save_before_forward, load_before_cuda): linear = torch.nn.Linear(32, 96) x = torch.randn(3, 32, dtype=torch.half) @@ -94,6 +100,9 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri if serialize_before_forward: state_dict_8bit = linear_custom.state_dict() + if save_before_forward: + bytes_8bit = torch_save_to_buffer(linear_custom) + x_first = x.clone().cuda().requires_grad_(True) fx_first = linear_custom(x_first).float() grad_proj = torch.randn_like(fx_first) @@ -102,6 +111,9 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri if not serialize_before_forward: state_dict_8bit = linear_custom.state_dict() + if not save_before_forward: + bytes_8bit = torch_save_to_buffer(linear_custom) + with TemporaryDirectory() as tmpdir: state_path_8bit = os.path.join(tmpdir, "state_8bit.pth") state_path = os.path.join(tmpdir, "state.pth") @@ -128,16 +140,28 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri with nullcontext() if has_fp16_weights else pytest.raises(RuntimeError): new_linear_custom.load_state_dict(new_state_dict, strict=True) + if load_before_cuda: + new_linear_custom2 = torch_load_from_buffer(bytes_8bit) + new_linear_custom = new_linear_custom.cuda() if not deserialize_before_cuda: new_linear_custom.load_state_dict(new_state_dict, strict=True) + if not load_before_cuda: + new_linear_custom2 = torch_load_from_buffer(bytes_8bit) + x_second = x.clone().cuda().requires_grad_(True) fx_second = new_linear_custom(x_second).float() (fx_second * grad_proj).mean().backward() + x_third = x.clone().cuda().requires_grad_(True) + fx_third = new_linear_custom2(x_third).float() + (fx_third * grad_proj).mean().backward() + # if 8-bit weights were loaded before .cuda, state is incorrect anyway and RuntimeError was raised if has_fp16_weights or not deserialize_before_cuda: assert torch.allclose(fx_first, fx_second, atol=1e-5) assert torch.allclose(x_first.grad, x_second.grad, atol=1e-5) + assert torch.allclose(fx_first, fx_third, atol=1e-5) + assert torch.allclose(x_first.grad, x_third.grad, atol=1e-5) diff --git a/tests/test_modules.py b/tests/test_modules.py index 3e28a0f21..674620e29 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -1,11 +1,13 @@ -from itertools import product +import math +import einops import pytest import torch from torch import nn import bitsandbytes as bnb -from bitsandbytes.cextension import HIP_ENVIRONMENT +from tests.helpers import id_formatter + class MockArgs: def __init__(self, initial_data): @@ -40,11 +42,11 @@ def get_args(): def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10): - idx = torch.isclose(a, b, rtol, atol) + idx = torch.isclose(a, b, rtol=rtol, atol=atol) sumval = (idx == 0).sum().item() if sumval > count: print(f"Too many values not close: assert {sumval} < {count}") - torch.testing.assert_close(a, b, rtol, atol) + torch.testing.assert_close(a, b, rtol=rtol, atol=atol) class LinearFunction(torch.autograd.Function): @@ -310,13 +312,7 @@ def forward(self, x): return LinearFunction.apply(x, self.weight, self.bias, self.args) -threshold = [0.0, 3.0] -values = threshold -names = [f"threshold_{vals}" for vals in values] - - -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") -@pytest.mark.parametrize("threshold", values, ids=names) +@pytest.mark.parametrize("threshold", [0.0, 3.0], ids=id_formatter("threshold")) def test_linear8bitlt_inference(threshold): l1 = bnb.nn.Linear8bitLt(32, 64, threshold=threshold).cuda().half() assert l1.weight.device.type == "cuda" @@ -330,7 +326,6 @@ def test_linear8bitlt_inference(threshold): assert l1.state.CxB is not None -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") def test_linear8bitlt_accumulated_gradient(): l1 = torch.nn.Sequential(*[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)]) l2 = torch.nn.Sequential(*[torch.nn.Linear(32, 32).cuda().half() for i in range(2)]) @@ -488,7 +483,14 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): assert (idx == 0).sum().item() <= b1.numel() * 0.005 -@pytest.mark.parametrize("module", [lambda nin, nout, bias=True: bnb.nn.Linear8bitLt(nin, nout, bias=bias, has_fp16_weights=False), bnb.nn.LinearFP4], ids=['Int8Lt', 'FP4']) +@pytest.mark.parametrize( + "module", + [ + lambda n_in, n_out, bias=True: bnb.nn.Linear8bitLt(n_in, n_out, bias=bias, has_fp16_weights=False), + bnb.nn.LinearFP4, + ], + ids=['Int8Lt', 'FP4'], +) def test_linear_kbit_fp32_bias(module): # casts model to fp16 -> int8 automatically l1 = module(32, 64).cuda() @@ -511,19 +513,21 @@ def test_linear_kbit_fp32_bias(module): o1 = l1(b1) assert l1.bias is None -modules = [] -modules.append(bnb.nn.Linear8bitLt) -modules.append(bnb.nn.Linear4bit) -modules.append(bnb.nn.LinearFP4) -modules.append(bnb.nn.LinearNF4) -modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compress_statistics=True)) -modules.append(lambda d1, d2: bnb.nn.LinearNF4(d1, d2, compress_statistics=True)) -modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.float32)) -modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.float16)) -modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.bfloat16)) -names = ['Int8Lt', '4bit', 'FP4', 'NF4', 'FP4+C', 'NF4+C', 'NF4+fp32', 'NF4+fp16', 'NF4+bf16'] -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") -@pytest.mark.parametrize("module", modules, ids=names) + +module_dict = { + "Int8Lt": bnb.nn.Linear8bitLt, + "4bit": bnb.nn.Linear4bit, + "FP4": bnb.nn.LinearFP4, + "NF4": bnb.nn.LinearNF4, + "FP4+C": lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compress_statistics=True), + "NF4+C": lambda d1, d2: bnb.nn.LinearNF4(d1, d2, compress_statistics=True), + "NF4+fp32": lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.float32), + "NF4+fp16": lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.float16), + "NF4+bf16": lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.bfloat16), +} + + +@pytest.mark.parametrize("module", module_dict.values(), ids=module_dict.keys()) def test_kbit_backprop(module): b = 17 dim1 = 37 @@ -640,6 +644,3 @@ def test_4bit_warnings(): net(inp) assert len(record) == 2 - - - diff --git a/tests/test_optim.py b/tests/test_optim.py index c373a4f14..9395b8820 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -1,19 +1,16 @@ -import ctypes import os +from os.path import join import shutil import time import uuid -from itertools import product -from os.path import join -import pytest from lion_pytorch import Lion - +import pytest import torch import bitsandbytes as bnb import bitsandbytes.functional as F -from bitsandbytes.cextension import HIP_ENVIRONMENT +from tests.helpers import describe_dtype, id_formatter # import apex @@ -28,7 +25,7 @@ def assert_most_approx_close(a, b, rtol=1e-3, atol=1e-3, max_error_count=0): def get_temp_dir(): - path = f"/tmp/autoswap/{str(uuid.uuid4())}" + path = f"/tmp/autoswap/{uuid.uuid4()}" os.makedirs(path, exist_ok=True) return path @@ -104,15 +101,16 @@ def rm_path(path): str2statenames["lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")] str2statenames["paged_lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")] -dim1 = [1024] -dim2 = [32, 1024, 4097, 1] -gtype = [torch.float32, torch.float16, torch.bfloat16] -optimizer_names = ["adam", "momentum", "rmsprop", 'paged_adamw', 'paged_adam', 'lion', 'paged_lion'] -values = list(product(dim1, dim2, gtype, optimizer_names)) -names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values] -@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) +optimizer_names_32bit = ["adam", "momentum", "rmsprop", 'paged_adamw', 'paged_adam', 'lion', 'paged_lion'] + + +@pytest.mark.parametrize("optim_name", optimizer_names_32bit, ids=id_formatter("opt")) +@pytest.mark.parametrize("gtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) +@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", [32, 1024, 4097, 1], ids=id_formatter("dim2")) def test_optimizer32bit(dim1, dim2, gtype, optim_name): - if gtype == torch.bfloat16 and optim_name in ['momentum', 'rmsprop']: pytest.skip() + if gtype == torch.bfloat16 and optim_name in ['momentum', 'rmsprop']: + pytest.skip() if dim1 == 1 and dim2 == 1: return p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 @@ -137,7 +135,6 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): bnb_optimizer.step() torch_optimizer.step() - for name1, name2 in str2statenames[optim_name]: torch.testing.assert_close( torch_optimizer.state[p1][name1], @@ -148,7 +145,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): # since Lion can have pretty noisy updates where things lie at the boundary # allow up to 10 errors for Lion - assert_most_approx_close(p1, p2.float(), atol, rtol, max_error_count=10) + assert_most_approx_close(p1, p2.float(), atol=atol, rtol=rtol, max_error_count=10) if i % (k // 5) == 0 and i > 0: path = get_temp_dir() @@ -160,7 +157,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): rm_path(path) # since Lion can have pretty noisy updates where things lie at the boundary # allow up to 10 errors for Lion - assert_most_approx_close(p1, p2.float(), atol, rtol, max_error_count=10) + assert_most_approx_close(p1, p2.float(), atol=atol, rtol=rtol, max_error_count=10) for name1, name2 in str2statenames[optim_name]: # since Lion can have pretty noisy updates where things lie at the boundary # allow up to 10 errors for Lion @@ -180,14 +177,9 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): assert bnb_optimizer.state[p2]["unorm_vec"] > 0.0 -dim1 = [1024] -dim2 = [32, 1024, 4097] -gtype = [torch.float32, torch.float16] -values = list(product(dim1, dim2, gtype)) -names = ["dim1_{}_dim2_{}_gtype_{}".format(*vals) for vals in values] - - -@pytest.mark.parametrize("dim1, dim2, gtype", values, ids=names) +@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2")) +@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=describe_dtype) def test_global_config(dim1, dim2, gtype): if dim1 == 1 and dim2 == 1: return @@ -233,10 +225,7 @@ def test_global_config(dim1, dim2, gtype): assert adam2.state[p3]["state2"].dtype == torch.uint8 -dim1 = [1024] -dim2 = [32, 1024, 4097] -gtype = [torch.float32, torch.float16, torch.bfloat16] -optimizer_names = [ +optimizer_names_8bit = [ "adam8bit", "lion8bit", "momentum8bit", @@ -246,13 +235,12 @@ def test_global_config(dim1, dim2, gtype): "momentum8bit_blockwise", "rmsprop8bit_blockwise", ] -values = list(product(dim1, dim2, gtype, optimizer_names)) -names = [ - "dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values -] -@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) +@pytest.mark.parametrize("optim_name", optimizer_names_8bit, ids=id_formatter("opt")) +@pytest.mark.parametrize("gtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) +@pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2")) +@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1")) def test_optimizer8bit(dim1, dim2, gtype, optim_name): if gtype == torch.bfloat16 and optim_name not in ['adam8bit_blockwise', 'lion8bit_blockwise']: pytest.skip() if dim1 == 1 and dim2 == 1: @@ -378,18 +366,10 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): # print(sum(relerrors)/len(relerrors)) -dim1 = [1024] -dim2 = [32, 1024, 4097] -gtype = [torch.float32] -optim_bits = [32, 8] -values = list(product(dim1, dim2, gtype, optim_bits)) -names = [ - "dim1_{}_dim2_{}_gtype_{}_optim_bits_{}".format(*vals) - for vals in values -] - - -@pytest.mark.parametrize("dim1, dim2, gtype, optim_bits", values, ids=names) +@pytest.mark.parametrize("optim_bits", [32, 8], ids=id_formatter("optim_bits")) +@pytest.mark.parametrize("gtype", [torch.float32], ids=describe_dtype) +@pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2")) +@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1")) def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits): if dim1 == 1 and dim2 == 1: return @@ -477,22 +457,19 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits): adam2.load_state_dict(torch.load(join(path, "opt.pt"))) -dim1 = [4096] -dim2 = [4096] -gtype = [torch.float32, torch.float16] -# optimizer_names = ['adam8bit_blockwise', 'adam8bit', 'lamb8bit'] -# optimizer_names = ['adam8bit_blockwise', 'adam_apex', 'adam8bit', 'adam', 'adam_pytorch'] -# optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch'] -# optimizer_names = ['lamb_apex', 'lamb8bit'] -# optimizer_names = ['lars_apex', 'lars8bit'] -optimizer_names = ["adam8bit_blockwise", 'paged_adam8bit_blockwise', 'paged_adamw8bit_blockwise', 'paged_lion8bit_blockwise'] -values = list(product(dim1, dim2, gtype, optimizer_names)) -names = [ - "dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values +optimizer_names_benchmark = [ + "adam8bit_blockwise", + "paged_adam8bit_blockwise", + "paged_adamw8bit_blockwise", + "paged_lion8bit_blockwise", ] -@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) +@pytest.mark.parametrize("dim1", [4096], ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", [4096], ids=id_formatter("dim2")) +@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=describe_dtype) +@pytest.mark.parametrize("optim_name", optimizer_names_benchmark, ids=id_formatter("opt")) +@pytest.mark.benchmark def test_benchmark_blockwise(dim1, dim2, gtype, optim_name): if dim1 == 1 and dim2 == 1: return @@ -517,15 +494,12 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name): print(optim_name, gtype, s / params) # assert s < 3.9 -dim1 = [2*1024] -gtype = [torch.float16] -#mode = ['torch', 'bnb'] -mode = ['bnb'] -optimizer_names = ['paged_adamw'] -#optimizer_names = ['paged_adamw8bit_blockwise'] -values = list(product(dim1,gtype, optimizer_names, mode)) -names = ['dim1_{0}_gtype_{1}_optim_{2}_mode_{3}'.format(*vals) for vals in values] -@pytest.mark.parametrize("dim1, gtype, optim_name, mode", values, ids=names) + +@pytest.mark.parametrize("dim1", [2 * 1024], ids=id_formatter("dim1")) +@pytest.mark.parametrize("gtype", [torch.float16], ids=describe_dtype) +@pytest.mark.parametrize("optim_name", ['paged_adamw'], ids=id_formatter("optim_name")) +@pytest.mark.parametrize("mode", ['bnb'], ids=id_formatter("mode")) +@pytest.mark.benchmark def test_stream_optimizer_bench(dim1, gtype, optim_name, mode): layers1 = torch.nn.Sequential(*torch.nn.ModuleList([torch.nn.Linear(dim1, dim1) for i in range(10)])) layers1 = layers1.to(gtype) diff --git a/tests/test_triton.py b/tests/test_triton.py index 8890193fc..218a533d5 100644 --- a/tests/test_triton.py +++ b/tests/test_triton.py @@ -1,15 +1,15 @@ import pytest import torch -from bitsandbytes.triton.triton_utils import is_triton_available -from bitsandbytes.nn.triton_based_modules import SwitchBackLinear from bitsandbytes.nn import Linear8bitLt -from bitsandbytes.cextension import HIP_ENVIRONMENT +from bitsandbytes.nn.triton_based_modules import SwitchBackLinear +from bitsandbytes.triton.triton_utils import is_triton_available +from tests.helpers import TRUE_FALSE + -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") @pytest.mark.skipif(not is_triton_available() or not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] >= 8, reason="This test requires triton and a GPU with compute capability 8.0 or higher.") -@pytest.mark.parametrize("vector_wise_quantization", [False, True]) +@pytest.mark.parametrize("vector_wise_quantization", TRUE_FALSE) def test_switchback(vector_wise_quantization): for dim in [83]: for batch in [13]: @@ -58,4 +58,3 @@ def test_switchback(vector_wise_quantization): print('GX1', err_sb, err_baseline) assert err_sb < 2 * err_baseline -