Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions bitsandbytes/backends/cuda/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr

from ..._ops import register_kernel
from ...cextension import HIP_ENVIRONMENT, lib
from ...cextension import ROCM_WARP_SIZE_64, lib


@register_kernel("bitsandbytes::int8_linear_matmul", "cuda")
Expand Down Expand Up @@ -211,7 +211,7 @@ def _get_col_absmax(
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)

if HIP_ENVIRONMENT:
if ROCM_WARP_SIZE_64:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
else:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
Expand Down Expand Up @@ -269,7 +269,7 @@ def _(
def _dequantize_blockwise_impl(
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor
) -> None:
if HIP_ENVIRONMENT:
if ROCM_WARP_SIZE_64:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
else:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
Expand Down Expand Up @@ -303,7 +303,7 @@ def _dequantize_blockwise_impl(
def _(
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
) -> tuple[torch.Tensor, torch.Tensor]:
if HIP_ENVIRONMENT:
if ROCM_WARP_SIZE_64:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
else:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
Expand Down Expand Up @@ -385,7 +385,7 @@ def _dequantize_4bit_impl(
dtype: torch.dtype,
out: torch.Tensor,
) -> None:
if HIP_ENVIRONMENT:
if ROCM_WARP_SIZE_64:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
else:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
Expand Down
9 changes: 8 additions & 1 deletion bitsandbytes/cextension.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@
import torch

from bitsandbytes.consts import DYNAMIC_LIBRARY_SUFFIX, PACKAGE_DIR
from bitsandbytes.cuda_specs import CUDASpecs, get_cuda_specs, get_cuda_version_tuple, get_rocm_gpu_arch
from bitsandbytes.cuda_specs import (
CUDASpecs,
get_cuda_specs,
get_cuda_version_tuple,
get_rocm_gpu_arch,
get_rocm_warpsize,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -298,6 +304,7 @@ def get_native_library() -> BNBNativeLibrary:


ROCM_GPU_ARCH = get_rocm_gpu_arch()
ROCM_WARP_SIZE_64 = True if get_rocm_warpsize() == 64 else False

try:
# to support Intel CPU/GPU (XPU) backend
Expand Down
26 changes: 26 additions & 0 deletions bitsandbytes/cuda_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,29 @@ def get_rocm_gpu_arch() -> str:
""",
)
return "unknown"


def get_rocm_warpsize() -> int:
"""Get ROCm warp size."""
logger = logging.getLogger(__name__)
try:
if torch.version.hip:
result = subprocess.run(["rocminfo"], capture_output=True, text=True)
match = re.search(r"Wavefront Size:\s+([0-9]{2})\(0x[0-9]{2}\)", result.stdout)
if match:
return int(match.group(1))
else:
# default to 64 to be safe
return 64
else:
# nvidia cards always use 32 warp size
return 32
except Exception as e:
logger.error(f"Could not detect ROCm warp size: {e}. Defaulting to 64. (some 4-bit functions may not work!)")
if torch.cuda.is_available():
logger.warning(
"""
ROCm warp size detection failed despite ROCm being available.
""",
)
return 64
14 changes: 7 additions & 7 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from bitsandbytes.utils import _reverse_4bit_compress_format, pack_dict_to_tensor, unpack_tensor_to_dict

from .cextension import HIP_ENVIRONMENT, ipex_cpu, ipex_xpu, lib
from .cextension import ROCM_WARP_SIZE_64, ipex_cpu, ipex_xpu, lib

name2qmap = {}

Expand Down Expand Up @@ -804,7 +804,7 @@ def quantize_fp4(
quant_storage=torch.uint8,
):
if blocksize is None:
blocksize = 64 if not HIP_ENVIRONMENT else 128
blocksize = 64 if not ROCM_WARP_SIZE_64 else 128
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "fp4", quant_storage)


Expand All @@ -817,7 +817,7 @@ def quantize_nf4(
quant_storage=torch.uint8,
):
if blocksize is None:
blocksize = 64 if not HIP_ENVIRONMENT else 128
blocksize = 64 if not ROCM_WARP_SIZE_64 else 128
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "nf4", quant_storage)


Expand Down Expand Up @@ -855,7 +855,7 @@ def quantize_4bit(
"""

if blocksize is None:
blocksize = 64 if not HIP_ENVIRONMENT else 128
blocksize = 64 if not ROCM_WARP_SIZE_64 else 128

input_shape = A.shape

Expand Down Expand Up @@ -910,7 +910,7 @@ def dequantize_fp4(
blocksize: Optional[int] = None,
) -> torch.Tensor:
if blocksize is None:
blocksize = 64 if not HIP_ENVIRONMENT else 128
blocksize = 64 if not ROCM_WARP_SIZE_64 else 128
return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4")


Expand All @@ -922,7 +922,7 @@ def dequantize_nf4(
blocksize: Optional[int] = None,
) -> torch.Tensor:
if blocksize is None:
blocksize = 64 if not HIP_ENVIRONMENT else 128
blocksize = 64 if not ROCM_WARP_SIZE_64 else 128
return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4")


Expand Down Expand Up @@ -962,7 +962,7 @@ def dequantize_4bit(
"""

if blocksize is None:
blocksize = 64 if not HIP_ENVIRONMENT else 128
blocksize = 64 if not ROCM_WARP_SIZE_64 else 128

if quant_state is None:
assert absmax is not None and out is not None
Expand Down
4 changes: 2 additions & 2 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch.nn.functional as F

import bitsandbytes as bnb
from bitsandbytes.cextension import HIP_ENVIRONMENT
from bitsandbytes.cextension import ROCM_WARP_SIZE_64
from bitsandbytes.functional import QuantState, _enable_ipex_fusion, ipex_cpu, ipex_xpu
from bitsandbytes.optim import GlobalOptimManager
from bitsandbytes.utils import (
Expand Down Expand Up @@ -225,7 +225,7 @@ def __new__(
data = torch.empty(0)

if blocksize is None:
blocksize = 64 if not HIP_ENVIRONMENT else 128
blocksize = 64 if not ROCM_WARP_SIZE_64 else 128

self = torch.Tensor._make_subclass(cls, data, requires_grad)
self.blocksize = blocksize
Expand Down
36 changes: 27 additions & 9 deletions csrc/kernels.hip
Original file line number Diff line number Diff line change
Expand Up @@ -3044,23 +3044,29 @@ MAKE_kQuantizeBlockwise(half, 1024, 4, 0, General8bit)
MAKE_kQuantizeBlockwise(half, 512, 2, 0, General8bit)
MAKE_kQuantizeBlockwise(half, 256, 2, 0, General8bit)
MAKE_kQuantizeBlockwise(half, 128, 2, 0, General8bit)
//MAKE_kQuantizeBlockwise(half, 64, 2, 0, General8bit)
#if WARP_SIZE == 32
MAKE_kQuantizeBlockwise(half, 64, 2, 0, General8bit)
#endif

MAKE_kQuantizeBlockwise(half, 4096, 4, 0, FP4)
MAKE_kQuantizeBlockwise(half, 2048, 4, 0, FP4)
MAKE_kQuantizeBlockwise(half, 1024, 4, 0, FP4)
MAKE_kQuantizeBlockwise(half, 512, 2, 0, FP4)
MAKE_kQuantizeBlockwise(half, 256, 2, 0, FP4)
MAKE_kQuantizeBlockwise(half, 128, 2, 0, FP4)
//MAKE_kQuantizeBlockwise(half, 64, 2, 0, FP4)
#if WARP_SIZE == 32
MAKE_kQuantizeBlockwise(half, 64, 2, 0, FP4)
#endif

MAKE_kQuantizeBlockwise(half, 4096, 4, 0, NF4)
MAKE_kQuantizeBlockwise(half, 2048, 4, 0, NF4)
MAKE_kQuantizeBlockwise(half, 1024, 4, 0, NF4)
MAKE_kQuantizeBlockwise(half, 512, 2, 0, NF4)
MAKE_kQuantizeBlockwise(half, 256, 2, 0, NF4)
MAKE_kQuantizeBlockwise(half, 128, 2, 0, NF4)
//MAKE_kQuantizeBlockwise(half, 64, 2, 0, NF4)
#if WARP_SIZE == 32
MAKE_kQuantizeBlockwise(half, 64, 2, 0, NF4)
#endif

MAKE_kQuantizeBlockwise(float, 4096, 4, 0, General8bit)
MAKE_kQuantizeBlockwise(float, 4096, 4, 1, General8bit)
Expand All @@ -3069,23 +3075,29 @@ MAKE_kQuantizeBlockwise(float, 1024, 4, 0, General8bit)
MAKE_kQuantizeBlockwise(float, 512, 2, 0, General8bit)
MAKE_kQuantizeBlockwise(float, 256, 2, 0, General8bit)
MAKE_kQuantizeBlockwise(float, 128, 2, 0, General8bit)
//MAKE_kQuantizeBlockwise(float, 64, 2, 0, General8bit)
#if WARP_SIZE == 32
MAKE_kQuantizeBlockwise(float, 64, 2, 0, General8bit)
#endif

MAKE_kQuantizeBlockwise(float, 4096, 4, 0, FP4)
MAKE_kQuantizeBlockwise(float, 2048, 4, 0, FP4)
MAKE_kQuantizeBlockwise(float, 1024, 4, 0, FP4)
MAKE_kQuantizeBlockwise(float, 512, 2, 0, FP4)
MAKE_kQuantizeBlockwise(float, 256, 2, 0, FP4)
MAKE_kQuantizeBlockwise(float, 128, 2, 0, FP4)
//MAKE_kQuantizeBlockwise(float, 64, 2, 0, FP4)
#if WARP_SIZE == 32
MAKE_kQuantizeBlockwise(float, 64, 2, 0, FP4)
#endif

MAKE_kQuantizeBlockwise(float, 4096, 4, 0, NF4)
MAKE_kQuantizeBlockwise(float, 2048, 4, 0, NF4)
MAKE_kQuantizeBlockwise(float, 1024, 4, 0, NF4)
MAKE_kQuantizeBlockwise(float, 512, 2, 0, NF4)
MAKE_kQuantizeBlockwise(float, 256, 2, 0, NF4)
MAKE_kQuantizeBlockwise(float, 128, 2, 0, NF4)
//MAKE_kQuantizeBlockwise(float, 64, 2, 0, NF4)
#if WARP_SIZE == 32
MAKE_kQuantizeBlockwise(float, 64, 2, 0, NF4)
#endif

MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 0, General8bit)
MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 1, General8bit)
Expand All @@ -3094,23 +3106,29 @@ MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, General8bit)
MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, General8bit)
MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, General8bit)
MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, General8bit)
//MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, General8bit)
#if WARP_SIZE == 32
MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, General8bit)
#endif

MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 0, FP4)
MAKE_kQuantizeBlockwise(hip_bfloat16, 2048, 4, 0, FP4)
MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, FP4)
MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, FP4)
MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, FP4)
MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, FP4)
//MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, FP4)
#if WARP_SIZE == 32
MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, FP4)
#endif

MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 0, NF4)
MAKE_kQuantizeBlockwise(hip_bfloat16, 2048, 4, 0, NF4)
MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, NF4)
MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, NF4)
MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, NF4)
MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, NF4)
//MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, NF4)
#if WARP_SIZE == 32
MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, NF4)
#endif

template __global__ void kDequantizeBlockwise<half, 512, 64, 8, FP4>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, General8bit>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
Expand Down
4 changes: 2 additions & 2 deletions csrc/ops.hip
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ template <typename T, int STOCHASTIC, int DATA_TYPE> void quantizeBlockwise(floa
hipLaunchKernelGGL(( kQuantizeBlockwise<T, 256, 2, 0, DATA_TYPE>), dim3(num_blocks), dim3(128), 0, 0, code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 128)
hipLaunchKernelGGL(( kQuantizeBlockwise<T, 128, 2, 0, DATA_TYPE>), dim3(num_blocks), dim3(64), 0, 0, code, A, absmax, out, rand, rand_offset, n);
//else if(blocksize == 64)
// hipLaunchKernelGGL(( kQuantizeBlockwise<T, 64, 2, 0, DATA_TYPE>), dim3(num_blocks), dim3(32), 0, 0, code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 64 && warpSize == 32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warpSize will be deprecated in 7.0, we just added a WARP_SIZE macro, please use it instead.

hipLaunchKernelGGL(( kQuantizeBlockwise<T, 64, 2, 0, DATA_TYPE>), dim3(num_blocks), dim3(32), 0, 0, code, A, absmax, out, rand, rand_offset, n);


CUDA_CHECK_RETURN(hipPeekAtLastError());
Expand Down
10 changes: 5 additions & 5 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import bitsandbytes as bnb
from bitsandbytes import functional as F
from bitsandbytes.cextension import HIP_ENVIRONMENT, ROCM_GPU_ARCH
from bitsandbytes.cextension import HIP_ENVIRONMENT, ROCM_GPU_ARCH, ROCM_WARP_SIZE_64
from tests.helpers import (
BOOLEAN_TUPLES,
TRUE_FALSE,
Expand Down Expand Up @@ -95,7 +95,7 @@ class Test8BitBlockwiseQuantizeFunctional:
@pytest.mark.parametrize("nested", TRUE_FALSE, ids=id_formatter("nested"))
@pytest.mark.parametrize(
"blocksize",
[4096, 2048, 1024, 512, 256, 128, 64] if not HIP_ENVIRONMENT else [4096, 2048, 1024, 512, 256, 128],
[4096, 2048, 1024, 512, 256, 128, 64] if not ROCM_WARP_SIZE_64 else [4096, 2048, 1024, 512, 256, 128],
)
@pytest.mark.parametrize("signed", TRUE_FALSE, ids=id_formatter("signed"))
def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, signed):
Expand Down Expand Up @@ -1107,7 +1107,7 @@ class TestQuantize4BitFunctional:
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize(
"blocksize",
[64, 128, 256, 512, 1024, 2048, 4096] if not HIP_ENVIRONMENT else [128, 256, 512, 1024, 2048, 4096],
[64, 128, 256, 512, 1024, 2048, 4096] if not ROCM_WARP_SIZE_64 else [128, 256, 512, 1024, 2048, 4096],
)
def test_4bit_quant(self, device, dtype, quant_type, blocksize):
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype):
Expand Down Expand Up @@ -1174,7 +1174,7 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize):

@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128], ids=id_formatter("blocksize"))
@pytest.mark.parametrize("blocksize", [64, 128] if not ROCM_WARP_SIZE_64 else [128], ids=id_formatter("blocksize"))
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=describe_dtype)
def test_4bit_compressed_stats(self, device, quant_type, blocksize, dtype):
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype):
Expand Down Expand Up @@ -1241,7 +1241,7 @@ def test_bench_4bit_dequant(self, quant_type):
# print((time.time()-t0)/iters*1e6)

@pytest.mark.skipif(
HIP_ENVIRONMENT, reason="gemv 4bit tests are partially enabled on MI300, others being fixed for warpsize 64"
ROCM_WARP_SIZE_64, reason="gemv 4bit tests are partially enabled on MI300, others being fixed for warpsize 64"
)
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("double_quant", TRUE_FALSE, ids=lambda double_quant: f"DQ_{double_quant}")
Expand Down
8 changes: 4 additions & 4 deletions tests/test_linear4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch

import bitsandbytes as bnb
from bitsandbytes.cextension import HIP_ENVIRONMENT
from bitsandbytes.cextension import ROCM_WARP_SIZE_64
from tests.helpers import (
TRUE_FALSE,
describe_dtype,
Expand Down Expand Up @@ -192,7 +192,7 @@ def test_linear_serialization(

@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128])
@pytest.mark.parametrize("blocksize", [64, 128] if not ROCM_WARP_SIZE_64 else [128])
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
def test_copy_param(device, quant_type, blocksize, compress_statistics):
if device == "hpu" and not is_supported_on_hpu(quant_type):
Expand Down Expand Up @@ -249,7 +249,7 @@ def test_params4bit_torch_chunk_split(device, quant_type):

@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128])
@pytest.mark.parametrize("blocksize", [64, 128] if not ROCM_WARP_SIZE_64 else [128])
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
def test_deepcopy_param(device, quant_type, blocksize, compress_statistics):
if device == "hpu" and not is_supported_on_hpu(quant_type):
Expand Down Expand Up @@ -278,7 +278,7 @@ def test_deepcopy_param(device, quant_type, blocksize, compress_statistics):

@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128])
@pytest.mark.parametrize("blocksize", [64, 128] if not ROCM_WARP_SIZE_64 else [128])
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
def test_params4bit_real_serialization(device, quant_type, blocksize, compress_statistics):
if device == "hpu" and not is_supported_on_hpu(quant_type):
Expand Down
Loading
Loading