diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index 30cad3e34..abc1fd223 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -8,7 +8,7 @@ from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr from ..._ops import register_kernel -from ...cextension import HIP_ENVIRONMENT, lib +from ...cextension import ROCM_WARP_SIZE_64, lib @register_kernel("bitsandbytes::int8_linear_matmul", "cuda") @@ -211,7 +211,7 @@ def _get_col_absmax( def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: torch._check_is_size(blocksize) - if HIP_ENVIRONMENT: + if ROCM_WARP_SIZE_64: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) else: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) @@ -269,7 +269,7 @@ def _( def _dequantize_blockwise_impl( A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor ) -> None: - if HIP_ENVIRONMENT: + if ROCM_WARP_SIZE_64: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) else: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) @@ -303,7 +303,7 @@ def _dequantize_blockwise_impl( def _( A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype ) -> tuple[torch.Tensor, torch.Tensor]: - if HIP_ENVIRONMENT: + if ROCM_WARP_SIZE_64: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) else: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) @@ -385,7 +385,7 @@ def _dequantize_4bit_impl( dtype: torch.dtype, out: torch.Tensor, ) -> None: - if HIP_ENVIRONMENT: + if ROCM_WARP_SIZE_64: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) else: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 899a83314..42fdd3164 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -9,7 +9,13 @@ import torch from bitsandbytes.consts import DYNAMIC_LIBRARY_SUFFIX, PACKAGE_DIR -from bitsandbytes.cuda_specs import CUDASpecs, get_cuda_specs, get_cuda_version_tuple, get_rocm_gpu_arch +from bitsandbytes.cuda_specs import ( + CUDASpecs, + get_cuda_specs, + get_cuda_version_tuple, + get_rocm_gpu_arch, + get_rocm_warpsize, +) logger = logging.getLogger(__name__) @@ -298,6 +304,7 @@ def get_native_library() -> BNBNativeLibrary: ROCM_GPU_ARCH = get_rocm_gpu_arch() +ROCM_WARP_SIZE_64 = True if get_rocm_warpsize() == 64 else False try: # to support Intel CPU/GPU (XPU) backend diff --git a/bitsandbytes/cuda_specs.py b/bitsandbytes/cuda_specs.py index 32563a159..71e7568a9 100644 --- a/bitsandbytes/cuda_specs.py +++ b/bitsandbytes/cuda_specs.py @@ -100,3 +100,29 @@ def get_rocm_gpu_arch() -> str: """, ) return "unknown" + + +def get_rocm_warpsize() -> int: + """Get ROCm warp size.""" + logger = logging.getLogger(__name__) + try: + if torch.version.hip: + result = subprocess.run(["rocminfo"], capture_output=True, text=True) + match = re.search(r"Wavefront Size:\s+([0-9]{2})\(0x[0-9]{2}\)", result.stdout) + if match: + return int(match.group(1)) + else: + # default to 64 to be safe + return 64 + else: + # nvidia cards always use 32 warp size + return 32 + except Exception as e: + logger.error(f"Could not detect ROCm warp size: {e}. Defaulting to 64. (some 4-bit functions may not work!)") + if torch.cuda.is_available(): + logger.warning( + """ +ROCm warp size detection failed despite ROCm being available. + """, + ) + return 64 diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index c9f5ece60..f670dfe7c 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -15,7 +15,7 @@ from bitsandbytes.utils import _reverse_4bit_compress_format, pack_dict_to_tensor, unpack_tensor_to_dict -from .cextension import HIP_ENVIRONMENT, ipex_cpu, ipex_xpu, lib +from .cextension import ROCM_WARP_SIZE_64, ipex_cpu, ipex_xpu, lib name2qmap = {} @@ -804,7 +804,7 @@ def quantize_fp4( quant_storage=torch.uint8, ): if blocksize is None: - blocksize = 64 if not HIP_ENVIRONMENT else 128 + blocksize = 64 if not ROCM_WARP_SIZE_64 else 128 return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "fp4", quant_storage) @@ -817,7 +817,7 @@ def quantize_nf4( quant_storage=torch.uint8, ): if blocksize is None: - blocksize = 64 if not HIP_ENVIRONMENT else 128 + blocksize = 64 if not ROCM_WARP_SIZE_64 else 128 return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "nf4", quant_storage) @@ -855,7 +855,7 @@ def quantize_4bit( """ if blocksize is None: - blocksize = 64 if not HIP_ENVIRONMENT else 128 + blocksize = 64 if not ROCM_WARP_SIZE_64 else 128 input_shape = A.shape @@ -910,7 +910,7 @@ def dequantize_fp4( blocksize: Optional[int] = None, ) -> torch.Tensor: if blocksize is None: - blocksize = 64 if not HIP_ENVIRONMENT else 128 + blocksize = 64 if not ROCM_WARP_SIZE_64 else 128 return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4") @@ -922,7 +922,7 @@ def dequantize_nf4( blocksize: Optional[int] = None, ) -> torch.Tensor: if blocksize is None: - blocksize = 64 if not HIP_ENVIRONMENT else 128 + blocksize = 64 if not ROCM_WARP_SIZE_64 else 128 return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4") @@ -962,7 +962,7 @@ def dequantize_4bit( """ if blocksize is None: - blocksize = 64 if not HIP_ENVIRONMENT else 128 + blocksize = 64 if not ROCM_WARP_SIZE_64 else 128 if quant_state is None: assert absmax is not None and out is not None diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 1cef1f5e9..a43194e6a 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -11,7 +11,7 @@ import torch.nn.functional as F import bitsandbytes as bnb -from bitsandbytes.cextension import HIP_ENVIRONMENT +from bitsandbytes.cextension import ROCM_WARP_SIZE_64 from bitsandbytes.functional import QuantState, _enable_ipex_fusion, ipex_cpu, ipex_xpu from bitsandbytes.optim import GlobalOptimManager from bitsandbytes.utils import ( @@ -225,7 +225,7 @@ def __new__( data = torch.empty(0) if blocksize is None: - blocksize = 64 if not HIP_ENVIRONMENT else 128 + blocksize = 64 if not ROCM_WARP_SIZE_64 else 128 self = torch.Tensor._make_subclass(cls, data, requires_grad) self.blocksize = blocksize diff --git a/csrc/kernels.hip b/csrc/kernels.hip index ec3f7f025..6956ebac4 100644 --- a/csrc/kernels.hip +++ b/csrc/kernels.hip @@ -3044,7 +3044,9 @@ MAKE_kQuantizeBlockwise(half, 1024, 4, 0, General8bit) MAKE_kQuantizeBlockwise(half, 512, 2, 0, General8bit) MAKE_kQuantizeBlockwise(half, 256, 2, 0, General8bit) MAKE_kQuantizeBlockwise(half, 128, 2, 0, General8bit) -//MAKE_kQuantizeBlockwise(half, 64, 2, 0, General8bit) +#if WARP_SIZE == 32 + MAKE_kQuantizeBlockwise(half, 64, 2, 0, General8bit) +#endif MAKE_kQuantizeBlockwise(half, 4096, 4, 0, FP4) MAKE_kQuantizeBlockwise(half, 2048, 4, 0, FP4) @@ -3052,7 +3054,9 @@ MAKE_kQuantizeBlockwise(half, 1024, 4, 0, FP4) MAKE_kQuantizeBlockwise(half, 512, 2, 0, FP4) MAKE_kQuantizeBlockwise(half, 256, 2, 0, FP4) MAKE_kQuantizeBlockwise(half, 128, 2, 0, FP4) -//MAKE_kQuantizeBlockwise(half, 64, 2, 0, FP4) +#if WARP_SIZE == 32 + MAKE_kQuantizeBlockwise(half, 64, 2, 0, FP4) +#endif MAKE_kQuantizeBlockwise(half, 4096, 4, 0, NF4) MAKE_kQuantizeBlockwise(half, 2048, 4, 0, NF4) @@ -3060,7 +3064,9 @@ MAKE_kQuantizeBlockwise(half, 1024, 4, 0, NF4) MAKE_kQuantizeBlockwise(half, 512, 2, 0, NF4) MAKE_kQuantizeBlockwise(half, 256, 2, 0, NF4) MAKE_kQuantizeBlockwise(half, 128, 2, 0, NF4) -//MAKE_kQuantizeBlockwise(half, 64, 2, 0, NF4) +#if WARP_SIZE == 32 + MAKE_kQuantizeBlockwise(half, 64, 2, 0, NF4) +#endif MAKE_kQuantizeBlockwise(float, 4096, 4, 0, General8bit) MAKE_kQuantizeBlockwise(float, 4096, 4, 1, General8bit) @@ -3069,7 +3075,9 @@ MAKE_kQuantizeBlockwise(float, 1024, 4, 0, General8bit) MAKE_kQuantizeBlockwise(float, 512, 2, 0, General8bit) MAKE_kQuantizeBlockwise(float, 256, 2, 0, General8bit) MAKE_kQuantizeBlockwise(float, 128, 2, 0, General8bit) -//MAKE_kQuantizeBlockwise(float, 64, 2, 0, General8bit) +#if WARP_SIZE == 32 + MAKE_kQuantizeBlockwise(float, 64, 2, 0, General8bit) +#endif MAKE_kQuantizeBlockwise(float, 4096, 4, 0, FP4) MAKE_kQuantizeBlockwise(float, 2048, 4, 0, FP4) @@ -3077,7 +3085,9 @@ MAKE_kQuantizeBlockwise(float, 1024, 4, 0, FP4) MAKE_kQuantizeBlockwise(float, 512, 2, 0, FP4) MAKE_kQuantizeBlockwise(float, 256, 2, 0, FP4) MAKE_kQuantizeBlockwise(float, 128, 2, 0, FP4) -//MAKE_kQuantizeBlockwise(float, 64, 2, 0, FP4) +#if WARP_SIZE == 32 + MAKE_kQuantizeBlockwise(float, 64, 2, 0, FP4) +#endif MAKE_kQuantizeBlockwise(float, 4096, 4, 0, NF4) MAKE_kQuantizeBlockwise(float, 2048, 4, 0, NF4) @@ -3085,7 +3095,9 @@ MAKE_kQuantizeBlockwise(float, 1024, 4, 0, NF4) MAKE_kQuantizeBlockwise(float, 512, 2, 0, NF4) MAKE_kQuantizeBlockwise(float, 256, 2, 0, NF4) MAKE_kQuantizeBlockwise(float, 128, 2, 0, NF4) -//MAKE_kQuantizeBlockwise(float, 64, 2, 0, NF4) +#if WARP_SIZE == 32 + MAKE_kQuantizeBlockwise(float, 64, 2, 0, NF4) +#endif MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 0, General8bit) MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 1, General8bit) @@ -3094,7 +3106,9 @@ MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, General8bit) MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, General8bit) MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, General8bit) MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, General8bit) -//MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, General8bit) +#if WARP_SIZE == 32 + MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, General8bit) +#endif MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 0, FP4) MAKE_kQuantizeBlockwise(hip_bfloat16, 2048, 4, 0, FP4) @@ -3102,7 +3116,9 @@ MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, FP4) MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, FP4) MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, FP4) MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, FP4) -//MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, FP4) +#if WARP_SIZE == 32 + MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, FP4) +#endif MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 0, NF4) MAKE_kQuantizeBlockwise(hip_bfloat16, 2048, 4, 0, NF4) @@ -3110,7 +3126,9 @@ MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, NF4) MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, NF4) MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, NF4) MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, NF4) -//MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, NF4) +#if WARP_SIZE == 32 + MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, NF4) +#endif template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); diff --git a/csrc/ops.hip b/csrc/ops.hip index 260b74b30..4f7ce9abb 100644 --- a/csrc/ops.hip +++ b/csrc/ops.hip @@ -5,6 +5,7 @@ // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. +#include #include #include #include @@ -57,8 +58,8 @@ template void quantizeBlockwise(floa hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(128), 0, 0, code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 128) hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(64), 0, 0, code, A, absmax, out, rand, rand_offset, n); - //else if(blocksize == 64) - // hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(32), 0, 0, code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 64 && BNB_WARP_SIZE == 32) + hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(32), 0, 0, code, A, absmax, out, rand, rand_offset, n); CUDA_CHECK_RETURN(hipPeekAtLastError()); diff --git a/tests/test_functional.py b/tests/test_functional.py index fc37cb4c3..e9bcbc267 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -9,7 +9,7 @@ import bitsandbytes as bnb from bitsandbytes import functional as F -from bitsandbytes.cextension import HIP_ENVIRONMENT, ROCM_GPU_ARCH +from bitsandbytes.cextension import HIP_ENVIRONMENT, ROCM_GPU_ARCH, ROCM_WARP_SIZE_64 from tests.helpers import ( BOOLEAN_TUPLES, TRUE_FALSE, @@ -95,7 +95,7 @@ class Test8BitBlockwiseQuantizeFunctional: @pytest.mark.parametrize("nested", TRUE_FALSE, ids=id_formatter("nested")) @pytest.mark.parametrize( "blocksize", - [4096, 2048, 1024, 512, 256, 128, 64] if not HIP_ENVIRONMENT else [4096, 2048, 1024, 512, 256, 128], + [4096, 2048, 1024, 512, 256, 128, 64] if not ROCM_WARP_SIZE_64 else [4096, 2048, 1024, 512, 256, 128], ) @pytest.mark.parametrize("signed", TRUE_FALSE, ids=id_formatter("signed")) def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, signed): @@ -1107,7 +1107,7 @@ class TestQuantize4BitFunctional: @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize( "blocksize", - [64, 128, 256, 512, 1024, 2048, 4096] if not HIP_ENVIRONMENT else [128, 256, 512, 1024, 2048, 4096], + [64, 128, 256, 512, 1024, 2048, 4096] if not ROCM_WARP_SIZE_64 else [128, 256, 512, 1024, 2048, 4096], ) def test_4bit_quant(self, device, dtype, quant_type, blocksize): if device == "hpu" and not is_supported_on_hpu(quant_type, dtype): @@ -1174,7 +1174,7 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize): @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) - @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128], ids=id_formatter("blocksize")) + @pytest.mark.parametrize("blocksize", [64, 128] if not ROCM_WARP_SIZE_64 else [128], ids=id_formatter("blocksize")) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=describe_dtype) def test_4bit_compressed_stats(self, device, quant_type, blocksize, dtype): if device == "hpu" and not is_supported_on_hpu(quant_type, dtype): @@ -1241,7 +1241,7 @@ def test_bench_4bit_dequant(self, quant_type): # print((time.time()-t0)/iters*1e6) @pytest.mark.skipif( - HIP_ENVIRONMENT, reason="gemv 4bit tests are partially enabled on MI300, others being fixed for warpsize 64" + ROCM_WARP_SIZE_64, reason="gemv 4bit tests are partially enabled on MI300, others being fixed for warpsize 64" ) @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("double_quant", TRUE_FALSE, ids=lambda double_quant: f"DQ_{double_quant}") diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index 1c5e77a32..398fb83d3 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -8,7 +8,7 @@ import torch import bitsandbytes as bnb -from bitsandbytes.cextension import HIP_ENVIRONMENT +from bitsandbytes.cextension import ROCM_WARP_SIZE_64 from tests.helpers import ( TRUE_FALSE, describe_dtype, @@ -192,7 +192,7 @@ def test_linear_serialization( @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) -@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) +@pytest.mark.parametrize("blocksize", [64, 128] if not ROCM_WARP_SIZE_64 else [128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) def test_copy_param(device, quant_type, blocksize, compress_statistics): if device == "hpu" and not is_supported_on_hpu(quant_type): @@ -249,7 +249,7 @@ def test_params4bit_torch_chunk_split(device, quant_type): @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) -@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) +@pytest.mark.parametrize("blocksize", [64, 128] if not ROCM_WARP_SIZE_64 else [128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) def test_deepcopy_param(device, quant_type, blocksize, compress_statistics): if device == "hpu" and not is_supported_on_hpu(quant_type): @@ -278,7 +278,7 @@ def test_deepcopy_param(device, quant_type, blocksize, compress_statistics): @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) -@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) +@pytest.mark.parametrize("blocksize", [64, 128] if not ROCM_WARP_SIZE_64 else [128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) def test_params4bit_real_serialization(device, quant_type, blocksize, compress_statistics): if device == "hpu" and not is_supported_on_hpu(quant_type): diff --git a/tests/test_ops.py b/tests/test_ops.py index 8aa0560fd..17961fed1 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -4,7 +4,7 @@ import torch import bitsandbytes -from bitsandbytes.cextension import HIP_ENVIRONMENT +from bitsandbytes.cextension import ROCM_WARP_SIZE_64 from bitsandbytes.functional import ipex_xpu from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter, is_supported_on_hpu @@ -103,7 +103,7 @@ def test_int8_scaled_mm(self, device, dtype, has_bias): class TestInt8BlockwiseQuantOps: @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) - @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) + @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [128, 256, 512]) def test_quantize_blockwise(self, device, dtype, blocksize): if device == "cpu": if dtype != torch.float32: @@ -127,7 +127,7 @@ def test_quantize_blockwise(self, device, dtype, blocksize): @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) - @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) + @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [128, 256, 512]) def test_dequantize_blockwise(self, device, dtype, blocksize): if device == "cpu" and dtype != torch.float32: pytest.skip("CPU implementation is only available for float32") @@ -157,7 +157,7 @@ class Test4bitBlockwiseQuantOps: @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) - @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) + @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [128, 256, 512]) def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype): pytest.skip("This configuration is not supported on HPU.") @@ -181,7 +181,7 @@ def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) - @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) + @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [128, 256, 512]) def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype): pytest.skip("This configuration is not supported on HPU.") @@ -215,7 +215,7 @@ def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksi @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) - @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) + @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [128, 256, 512]) def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype): pytest.skip("This configuration is not supported on HPU.")