From a47a40e76e8f6ab4f374adcb75f57bfadafdd06f Mon Sep 17 00:00:00 2001 From: cyy Date: Mon, 18 Aug 2025 14:32:51 +0800 Subject: [PATCH 1/5] Fix QuantState.as_dict Signed-off-by: cyy --- bitsandbytes/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 7cca33dcf..c5ecba140 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -513,7 +513,7 @@ def as_dict(self, packed=False): "blocksize": self.blocksize, "quant_map": self.code, "dtype": str(self.dtype).strip("torch."), - "shape": tuple(self.shape), + "shape": tuple(self.shape) if self.shape is not None else None, } if self.nested: qs_dict.update( From 9b7b37305616d8cd7f46685e0960eb1056095121 Mon Sep 17 00:00:00 2001 From: cyy Date: Mon, 18 Aug 2025 14:36:27 +0800 Subject: [PATCH 2/5] Fix QuantState.from_dict Signed-off-by: cyy --- bitsandbytes/functional.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index c5ecba140..b05096a37 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -464,12 +464,13 @@ def from_dict(cls, qs_dict: dict[str, Any], device: torch.device) -> "QuantState # unpacking tensor with non-tensor components qs_key = [k for k, v in qs_dict.items() if "quant_state" in k and isinstance(v, torch.Tensor)] - if not len(qs_key) and "quant_type" not in qs_dict: - raise ValueError("Expected packed or unpacked quant_state items, found neither") - elif len(qs_key) != 1 or qs_key[0].split(".")[-1] not in cls.valid_qs_type_keys: - raise ValueError( - f"There should be exactly one `quant_state` item with ending from {cls.valid_qs_type_keys}.\nDetected {qs_key}.", - ) + if "quant_type" not in qs_dict: + if not qs_key: + raise ValueError("Expected packed or unpacked quant_state items, found neither") + elif len(qs_key) != 1 or qs_key[0].split(".")[-1] not in cls.valid_qs_type_keys: + raise ValueError( + f"There should be exactly one `quant_state` item with ending from {cls.valid_qs_type_keys}.\nDetected {qs_key}.", + ) # unpacking minor and non-tensor quant state items if necessary if len(qs_key) == 1: From d53c8f6e36a04dc29c5ade25f5e4acbaacf0aa92 Mon Sep 17 00:00:00 2001 From: cyy Date: Mon, 18 Aug 2025 14:46:22 +0800 Subject: [PATCH 3/5] Fix QuantState.as_dict Signed-off-by: cyy --- bitsandbytes/functional.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index b05096a37..0e3c800d5 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -503,7 +503,7 @@ def from_dict(cls, qs_dict: dict[str, Any], device: torch.device) -> "QuantState ) return quant_state - def as_dict(self, packed=False): + def as_dict(self, packed: bool = False) -> dict[str, Any]: """ returns dict of tensors and strings to use in serialization via _save_to_state_dict() param: packed -- returns dict[str, torch.Tensor] for state_dict fit for safetensors saving @@ -532,7 +532,10 @@ def as_dict(self, packed=False): # packed format allows serialization of non-tensor components, critical for saving in safetensors format qs_packed_dict = {k: v for k, v in qs_dict.items() if isinstance(v, torch.Tensor)} non_tensor_dict = {k: v for k, v in qs_dict.items() if not isinstance(v, torch.Tensor)} - qs_packed_dict["quant_state." + "bitsandbytes__" + self.quant_type] = pack_dict_to_tensor(non_tensor_dict) + key = "quant_state.bitsandbytes__" + if self.quant_type is not None: + key += self.quant_type + qs_packed_dict[key] = pack_dict_to_tensor(non_tensor_dict) return qs_packed_dict def to(self, device): From b169b038496d5749ca65960f85ca6c9718b41620 Mon Sep 17 00:00:00 2001 From: cyy Date: Mon, 18 Aug 2025 16:13:03 +0800 Subject: [PATCH 4/5] Add test Signed-off-by: cyy --- tests/test_functional.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/test_functional.py b/tests/test_functional.py index fb67430ae..8eb009e77 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -118,6 +118,9 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, for i in range(iters): A1 = torch.randn(1024, 1024, device=device, dtype=dtype) C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested) + if i == 0: + d = S.as_dict() + S = F.QuantState.from_dict(d, device=torch.device(device)) A2 = F.dequantize_blockwise(C, S) diff = torch.abs(A1 - A2).float() reldiff = diff / torch.abs(A1.float() + 1e-8) @@ -134,6 +137,9 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, for i in range(iters): A1 = torch.rand(1024, 1024, device=device, dtype=dtype) C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested, code=code) + if i == 0: + d = S.as_dict() + S = F.QuantState.from_dict(d, device=torch.device(device)) A2 = F.dequantize_blockwise(C, S) diff = torch.abs(A1 - A2).float() reldiff = diff / torch.abs(A1.float() + 1e-8) @@ -243,6 +249,9 @@ def test_fp8_quant(self, device): for i in range(10): A1 = torch.randn(1024, 1024, device=device) C, SC = F.quantize_blockwise(A1, code=code) + if i == 0: + d = SC.as_dict() + SC = F.QuantState.from_dict(d, device=torch.device(device)) A2 = F.dequantize_blockwise(C, SC) diff = torch.abs(A1 - A2) reldiff = diff / torch.abs(A1 + 1e-8) @@ -1116,6 +1125,8 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize): A1 = torch.randn(1024, 1024, device=device, dtype=dtype) qa, SA = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type) + d = SA.as_dict() + SA = F.QuantState.from_dict(d, device=torch.device(device)) A2 = F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type) err = (A1 - A2).abs().float() From 525628709c5e23ca6821be7f83aa49314dd43f95 Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Wed, 17 Sep 2025 13:12:32 +0800 Subject: [PATCH 5/5] Fix comment Signed-off-by: Yuanyuan Chen --- bitsandbytes/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 0e3c800d5..58716ec5d 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -943,7 +943,7 @@ def dequantize_4bit( """Dequantizes a packed 4-bit quantized tensor. The input tensor is dequantized by dividing it into blocks of `blocksize` values. - The the absolute maximum value within these blocks is used for scaling + The absolute maximum value within these blocks is used for scaling the non-linear dequantization. Args: