Skip to content

Commit e3c2bf0

Browse files
committed
Fix QuantState.from_dict
Signed-off-by: cyy <[email protected]>
1 parent 58f149c commit e3c2bf0

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

bitsandbytes/functional.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -462,12 +462,13 @@ def from_dict(cls, qs_dict: dict[str, Any], device: torch.device) -> "QuantState
462462

463463
# unpacking tensor with non-tensor components
464464
qs_key = [k for k, v in qs_dict.items() if "quant_state" in k and isinstance(v, torch.Tensor)]
465-
if not len(qs_key) and "quant_type" not in qs_dict:
466-
raise ValueError("Expected packed or unpacked quant_state items, found neither")
467-
elif len(qs_key) != 1 or qs_key[0].split(".")[-1] not in cls.valid_qs_type_keys:
468-
raise ValueError(
469-
f"There should be exactly one `quant_state` item with ending from {cls.valid_qs_type_keys}.\nDetected {qs_key}.",
470-
)
465+
if "quant_type" not in qs_dict:
466+
if not qs_key:
467+
raise ValueError("Expected packed or unpacked quant_state items, found neither")
468+
elif len(qs_key) != 1 or qs_key[0].split(".")[-1] not in cls.valid_qs_type_keys:
469+
raise ValueError(
470+
f"There should be exactly one `quant_state` item with ending from {cls.valid_qs_type_keys}.\nDetected {qs_key}.",
471+
)
471472

472473
# unpacking minor and non-tensor quant state items if necessary
473474
if len(qs_key) == 1:

0 commit comments

Comments
 (0)