Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions backends/qualcomm/builders/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,16 +293,15 @@ def make_qnn_per_tensor_config(self, quant_attrs: Dict):
}
# check Qnn_ScaleOffset_t in QNN/include/QnnTypes.h
quant_config[QCOM_OFFSET] = -quant_attrs[QCOM_ZERO_POINT]
range_ = quant_config[QCOM_QUANT_MAX] - quant_config[QCOM_QUANT_MIN]
assert range_ > 3, (
f"2-bit quantization (range={range_}) does not support per-tensor encoding. "
"Use per-channel quantization instead."
)
# special case for 4 bits
if (
quant_config[QCOM_DTYPE] == torch.int8
and quant_config[QCOM_QUANT_MAX] - quant_config[QCOM_QUANT_MIN] <= 15
):
# special case for 4-bit / 2-bit integer weights.
quant_range = quant_config[QCOM_QUANT_MAX] - quant_config[QCOM_QUANT_MIN]
if quant_config[QCOM_DTYPE] == torch.int8 and quant_range <= 15:
if quant_range <= 3:
raise ValueError(
f"2-bit quantization (range={quant_range}) "
"does not support per-tensor encoding. Use per-channel quantization instead."
)
# special case for 4 bits
quant_config[QCOM_BITWIDTH] = 4
return (
PyQnnManager.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET,
Expand Down
25 changes: 15 additions & 10 deletions backends/qualcomm/quantizer/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,12 +283,21 @@ def _qspec_port_encoding_type(node: Node, qspec: QuantizationSpecBase):
qscheme = qspec.qscheme

if qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine]:
range_ = qspec.quant_max - qspec.quant_min
assert range_ > 3, (
f"2-bit quantization (range={range_}) does not support per-tensor encoding. "
"Use per-channel quantization instead."
)
if qspec.dtype == torch.int8 and range_ <= 15:
# quant_max/quant_min are None for non-integer activations (e.g. uint16 in
# 16a2w) whose range is not expressed as a fixed integer bound; skip the
# 4-bit BW_SCALE_OFFSET special-casing for those tensors.
if (
qspec.dtype == torch.int8
and qspec.quant_max is not None
and qspec.quant_min is not None
and (quant_range := qspec.quant_max - qspec.quant_min) <= 15
):
if quant_range <= 3:
raise ValueError(
f"2-bit quantization (range={quant_range}) "
"does not support per-tensor encoding. "
"Use per-channel quantization instead."
)
encoding_type = (
PyQnnManager.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET
)
Expand All @@ -303,10 +312,6 @@ def _qspec_port_encoding_type(node: Node, qspec: QuantizationSpecBase):
encoding_type = (
PyQnnManager.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_BLOCKWISE_EXPANSION
)
elif qspec.dtype == torch.int8 and qspec.quant_max - qspec.quant_min <= 3:
encoding_type = (
PyQnnManager.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET
)
elif qspec.dtype == torch.int8 and qspec.quant_max - qspec.quant_min <= 15:
encoding_type = (
PyQnnManager.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET
Expand Down
Loading