Skip to content

Commit b20fbf3

Browse files
committed
Fix QuantState.from_dict
Signed-off-by: cyy <[email protected]>
1 parent bcfd75c commit b20fbf3

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

bitsandbytes/functional.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -463,12 +463,13 @@ def from_dict(cls, qs_dict: dict[str, Any], device: torch.device) -> "QuantState
463463

464464
# unpacking tensor with non-tensor components
465465
qs_key = [k for k, v in qs_dict.items() if "quant_state" in k and isinstance(v, torch.Tensor)]
466-
if not len(qs_key) and "quant_type" not in qs_dict:
467-
raise ValueError("Expected packed or unpacked quant_state items, found neither")
468-
elif len(qs_key) != 1 or qs_key[0].split(".")[-1] not in cls.valid_qs_type_keys:
469-
raise ValueError(
470-
f"There should be exactly one `quant_state` item with ending from {cls.valid_qs_type_keys}.\nDetected {qs_key}.",
471-
)
466+
if "quant_type" not in qs_dict:
467+
if not qs_key:
468+
raise ValueError("Expected packed or unpacked quant_state items, found neither")
469+
elif len(qs_key) != 1 or qs_key[0].split(".")[-1] not in cls.valid_qs_type_keys:
470+
raise ValueError(
471+
f"There should be exactly one `quant_state` item with ending from {cls.valid_qs_type_keys}.\nDetected {qs_key}.",
472+
)
472473

473474
# unpacking minor and non-tensor quant state items if necessary
474475
if len(qs_key) == 1:

0 commit comments

Comments
 (0)