Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .github/workflows/pull.yml
Original file line number Diff line number Diff line change
Expand Up @@ -852,8 +852,7 @@ jobs:
strategy:
matrix:
dtype: [fp32]
# TODO(T12345): re-enable qnn_16a16w once OOM on linux.2xlarge is resolved
pt2e_quantize: [qnn_8a8w]
pt2e_quantize: [qnn_16a16w, qnn_8a8w]
mode: [qnn]
fail-fast: false
with:
Expand Down
3 changes: 1 addition & 2 deletions .github/workflows/trunk.yml
Original file line number Diff line number Diff line change
Expand Up @@ -951,8 +951,7 @@ jobs:
strategy:
matrix:
dtype: [fp32]
# TODO(T12345): re-enable qnn_16a16w once OOM on linux.2xlarge is resolved
pt2e_quantize: [qnn_8a8w]
pt2e_quantize: [qnn_16a16w, qnn_8a8w]
mode: [qnn]
fail-fast: false
with:
Expand Down
26 changes: 25 additions & 1 deletion backends/qualcomm/quantizer/quantizer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
Expand Down Expand Up @@ -418,6 +418,27 @@
)
return quant_range

def _get_input_quant_range(self, user_node, input_node):
"""Return the quant range of the spec assigned to `input_node` in
`user_node.meta[quantization_annotation].input_qspec_map`. Falls back
to None if no concrete spec is registered for this input — needed
when the user's output_qspec is a SharedQuantizationSpec that hides
the dtype/qmin/qmax."""
quant_info = user_node.meta.get(QCOM_QUANT_ANNOTATION_KEY, None)
if quant_info is None:
return
qspec = getattr(quant_info, "input_qspec_map", {}).get(input_node)
if qspec is None:
return
try:
dtype_info = torch.iinfo(qspec.dtype)
except:
return
return (
(dtype_info.max if qspec.quant_max is None else qspec.quant_max)
- (dtype_info.min if qspec.quant_min is None else qspec.quant_min)
)

def _get_candidates_with_infinity_args(self, graph_module: GraphModule):
binary_op_sources = [
operator.add,
Expand All @@ -441,7 +462,7 @@
torch.ops.aten.scalar_tensor.default,
}

def _replace_inf(self, graph_module: GraphModule) -> GraphModule:

Check warning on line 465 in backends/qualcomm/quantizer/quantizer.py

View workflow job for this annotation

GitHub Actions / lintrunner

FLAKE8 C901

'QnnQuantizer._replace_inf' is too complex (13) See https://www.flake8rules.com/rules/C901.html.
candidates = self._get_candidates_with_infinity_args(graph_module)
for node in graph_module.graph.nodes:
if all(
Expand Down Expand Up @@ -473,7 +494,10 @@

quant_min, quant_max = float("inf"), float("-inf")
for source_node in node.users:
if quant_range := self._get_quant_range(source_node):
if quant_range := self._get_input_quant_range(source_node, node):
quant_min = min(quant_min, -quant_range)
quant_max = max(quant_max, quant_range)
elif quant_range := self._get_quant_range(source_node):
quant_min = min(quant_min, -quant_range)
quant_max = max(quant_max, quant_range)

Expand Down
Loading