diff --git a/backends/qualcomm/_passes/__init__.py b/backends/qualcomm/_passes/__init__.py index fc59ce3d262..6be7afa048b 100644 --- a/backends/qualcomm/_passes/__init__.py +++ b/backends/qualcomm/_passes/__init__.py @@ -42,6 +42,7 @@ from .fuse_consecutive_cast import FuseConsecutiveCast from .fuse_consecutive_transpose import FuseConsecutiveTranspose from .i64_to_i32 import I64toI32 +from .insert_cast_for_fp_act_quantized_weight import InsertCastForFpActQuantizedWeight from .insert_io_qdq import InsertIOQDQ from .insert_requantize import InsertRequantize from .insert_reshape_for_reduce_ops import InsertReshapeForReduceOps @@ -98,6 +99,7 @@ FuseConsecutiveCast, FuseConsecutiveTranspose, I64toI32, + InsertCastForFpActQuantizedWeight, InsertIOQDQ, InsertReshapeForReduceOps, InsertRequantize, diff --git a/backends/qualcomm/_passes/insert_cast_for_fp_act_quantized_weight.py b/backends/qualcomm/_passes/insert_cast_for_fp_act_quantized_weight.py new file mode 100644 index 00000000000..57b7253f242 --- /dev/null +++ b/backends/qualcomm/_passes/insert_cast_for_fp_act_quantized_weight.py @@ -0,0 +1,141 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.qualcomm.builders.node_visitor import dq_ops +from executorch.backends.qualcomm.builders.utils import is_parameter +from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.passes import dead_code_elimination_pass + +from .utils import copy_meta + +TARGET_OPS = { + exir_ops.edge.aten.convolution.default, + exir_ops.edge.aten.linear.default, +} + + +class InsertCastForFpActQuantizedWeight(ExportPass): + """ + Insert fp32↔fp16 casts around conv/linear nodes that have a quantized + weight but a floating-point activation. + + Background — QNN vs PyTorch dtype contract: + In PyTorch, a conv/linear with fp32 activation and int8 weight (e.g. + produced by fp16a8w quantization) is valid: the weight is stored as int8 + but dequantized to fp32 before the multiply-accumulate. QNN HTP, however, + requires that when the weight is quantized (int8/int4) the activation must + also be fp16, not fp32. Passing an fp32 activation to such an op causes a + QNN compilation error. + + Fix: + Wrap the offending node with an fp32→fp16 cast on the input activation and + an fp16→fp32 cast on the output, so the node itself operates in fp16 while + the surrounding graph continues to see fp32 tensors. + + Before: [fp32 act] → conv/linear(w=int8) → [fp32 out] + After: [fp32 act] → cast(fp16) → conv/linear(w=int8) → cast(fp32) → [fp32 out] + + Pattern matched: + - Node target is in TARGET_OPS (convolution, linear) + - Node has no QCOM_QUANT_ATTRS (activation is not quantized, i.e. fp32) + - Weight arg (args[1]) is a parameter with QCOM_QUANT_ATTRS, + optionally wrapped in a dequantize op + - Input activation dtype is fp32 + + The bias meta["val"] is also updated to fp16 to stay consistent with the + fp16 compute domain of the node. + """ + + def __init__(self, edge_program: torch.export.ExportedProgram): + super().__init__() + self.edge_program = edge_program + + def _get_weight_param_node(self, weight: torch.fx.Node): + """Return the underlying parameter node for a weight, unwrapping a DQ op if present.""" + if is_parameter(weight, self.edge_program): + return weight + if weight.target in dq_ops: + param_node = weight.args[0] + if isinstance(param_node, torch.fx.Node) and is_parameter( + param_node, self.edge_program + ): + return param_node + return None + + def _has_quantized_weight(self, node: torch.fx.Node) -> bool: + if node.target not in TARGET_OPS or len(node.args) < 2: + return False + weight = node.args[1] + if not isinstance(weight, torch.fx.Node): + return False + param_node = self._get_weight_param_node(weight) + return param_node is not None and bool(param_node.meta.get(QCOM_QUANT_ATTRS)) + + def _insert_fp32_fp16_casts( + self, graph_module: torch.fx.GraphModule, node: torch.fx.Node + ): + """Wrap node with cast(fp32→fp16) on input and cast(fp16→fp32) on output.""" + input_act = node.args[0] + + with graph_module.graph.inserting_before(node): + cast_in = graph_module.graph.create_node( + "call_function", + exir_ops.edge.aten._to_copy.default, + (input_act,), + {"dtype": torch.float16}, + ) + cast_in.meta = copy_meta( + node.meta, + lambda m: {**m, "val": input_act.meta["val"].to(torch.float16)}, + ) + node.replace_input_with(input_act, cast_in) + + # Update bias meta["val"] to fp16 if present. + if len(node.args) > 2 and node.args[2] is not None: + bias_node = node.args[2] + if isinstance(bias_node, torch.fx.Node) and "val" in bias_node.meta: + if bias_node.meta["val"].dtype == torch.float32: + bias_node.meta["val"] = bias_node.meta["val"].to(torch.float16) + + users = list(node.users.keys()) + orig_output_val = node.meta["val"] + node.meta["val"] = orig_output_val.to(torch.float16) + + with graph_module.graph.inserting_after(node): + cast_out = graph_module.graph.create_node( + "call_function", + exir_ops.edge.aten._to_copy.default, + (node,), + {"dtype": torch.float32}, + ) + cast_out.meta = copy_meta( + node.meta, + lambda m: {**m, "val": orig_output_val.to(torch.float32)}, + ) + + for user in users: + user.replace_input_with(node, cast_out) + + def call(self, graph_module: torch.fx.GraphModule): + for node in list(graph_module.graph.nodes): + if node.meta.get(QCOM_QUANT_ATTRS): + continue + if not self._has_quantized_weight(node): + continue + input_act = node.args[0] + if not isinstance(input_act, torch.fx.Node): + continue + input_val = input_act.meta.get("val") + if input_val is not None and input_val.dtype == torch.float32: + self._insert_fp32_fp16_casts(graph_module, node) + + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + dead_code_elimination_pass(graph_module) + return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/qnn_pass_manager.py b/backends/qualcomm/_passes/qnn_pass_manager.py index 57354af11de..0a6a909344b 100644 --- a/backends/qualcomm/_passes/qnn_pass_manager.py +++ b/backends/qualcomm/_passes/qnn_pass_manager.py @@ -47,6 +47,7 @@ FuseConsecutiveCast, FuseConsecutiveTranspose, I64toI32, + InsertCastForFpActQuantizedWeight, InsertIOQDQ, InsertRequantize, InsertReshapeForReduceOps, @@ -117,6 +118,7 @@ def get_capture_program_passes(): (FixedLinearKeepDim, True), (FoldQDQ, True), (I64toI32, True), + (InsertCastForFpActQuantizedWeight, True), (LayoutTransform, True), (RecomposePadMaxPool2d, True), (RecomposePixelUnshuffle, True), diff --git a/backends/qualcomm/_passes/utils.py b/backends/qualcomm/_passes/utils.py index 04371d61e1c..5b86e4fbf33 100755 --- a/backends/qualcomm/_passes/utils.py +++ b/backends/qualcomm/_passes/utils.py @@ -79,6 +79,7 @@ def get_passes_dependency_for_capture_program(): FixedLinearKeepDim, FoldQDQ, I64toI32, + InsertCastForFpActQuantizedWeight, LayoutTransform, RecomposePadMaxPool2d, RecomposePixelUnshuffle, @@ -112,6 +113,7 @@ def get_passes_dependency_for_capture_program(): FixedLinearKeepDim: [FoldQDQ], FoldQDQ: [AnnotateQuantAttrs, AnnotateStack, AnnotateUnbind], I64toI32: [RemoveRedundancy], + InsertCastForFpActQuantizedWeight: [FoldQDQ, LayoutTransform], LayoutTransform: [ AnnotateQuantAttrs, ExpandBroadcastTensorShape, diff --git a/backends/qualcomm/quantizer/annotators/htp_rules.py b/backends/qualcomm/quantizer/annotators/htp_rules.py index cd65d02c752..9604e2ad6f1 100644 --- a/backends/qualcomm/quantizer/annotators/htp_rules.py +++ b/backends/qualcomm/quantizer/annotators/htp_rules.py @@ -234,32 +234,33 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None: if _is_annotated([node]) or not _is_float_tensor(node): return - input_qspec_map, input_nodes = {}, node.args[0] - for input in input_nodes: - input_qspec = input.meta.get(Q_ANNOTATION_KEY, None) - qspec = getattr(input_qspec, "output_qspec", None) - # keep shared qspec here for propagation the data range - # without introducing extra requantizations - if isinstance(qspec, SharedQuantizationSpec): - input_qspec_map[input] = SharedQuantizationSpec(input) - else: - input_qspec_map[input] = quantization_config.input_activation - - output_qspec = QuantizationSpec( - dtype=quantization_config.output_activation.dtype, - qscheme=quantization_config.output_activation.qscheme, - quant_max=quantization_config.output_activation.quant_max, - quant_min=quantization_config.output_activation.quant_min, - observer_or_fake_quant_ctr=ConcatObserver.with_args( - # we need to know the concat node in order to hack all the input observers' data range - # since deep copy of fake tensor (node.meta["val"]) is inhibited - # we could only ship grap & node name and perform postprocess inside observer currently - **{ - "node_name": node.name, - "graph": node.graph, - } - ), - ) + input_qspec_map, input_nodes, output_qspec = {}, node.args[0], None + if quantization_config.input_activation is not None: + for input in input_nodes: + input_qspec = input.meta.get(Q_ANNOTATION_KEY, None) + qspec = getattr(input_qspec, "output_qspec", None) + # keep shared qspec here for propagation the data range + # without introducing extra requantizations + if isinstance(qspec, SharedQuantizationSpec): + input_qspec_map[input] = SharedQuantizationSpec(input) + else: + input_qspec_map[input] = quantization_config.input_activation + + output_qspec = QuantizationSpec( + dtype=quantization_config.output_activation.dtype, + qscheme=quantization_config.output_activation.qscheme, + quant_max=quantization_config.output_activation.quant_max, + quant_min=quantization_config.output_activation.quant_min, + observer_or_fake_quant_ctr=ConcatObserver.with_args( + # we need to know the concat node in order to hack all the input observers' data range + # since deep copy of fake tensor (node.meta["val"]) is inhibited + # we could only ship grap & node name and perform postprocess inside observer currently + **{ + "node_name": node.name, + "graph": node.graph, + } + ), + ) node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( input_qspec_map=input_qspec_map, output_qspec=output_qspec, @@ -309,8 +310,12 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None: input_qspec_map = {} input_act = node.args[0] assert isinstance(input_act, Node) - input_qspec_map[input_act] = quantization_config.input_activation - share_qparams_with_input_node_qspec = SharedQuantizationSpec((input_act, node)) + share_qparams_with_input_node_qspec = None + if quantization_config.input_activation is not None: + input_qspec_map[input_act] = quantization_config.input_activation + share_qparams_with_input_node_qspec = SharedQuantizationSpec( + (input_act, node) + ) node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( input_qspec_map=input_qspec_map, @@ -522,12 +527,14 @@ def _derive_div_qparams_fn( return input_act_qspec = quantization_config.input_activation - output_act_qspec = _derived_inp1_const_div_quant_spec( - node, quantization_config.output_activation - ) + output_act_qspec = None + if input_act_qspec is not None: + output_act_qspec = _derived_inp1_const_div_quant_spec( + node, quantization_config.output_activation + ) input_qspec_map = {} input_act0 = node.args[0] - if _is_float_tensor(input_act0): + if _is_float_tensor(input_act0) and input_act_qspec is not None: input_qspec_map[input_act0] = input_act_qspec node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( @@ -722,38 +729,28 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None: input_qspec_map = {} input_act = node.args[0] - input_qspec_map[input_act] = quantization_config.input_activation + input_qspec = quantization_config.input_activation + out_act_quantization_spec = None + if input_qspec is not None: + input_qspec_map[input_act] = input_qspec - assert isinstance(input_act, Node) - out_qconf = quantization_config.output_activation + assert isinstance(input_act, Node) + out_qconf = quantization_config.output_activation - q_max = ( - torch.iinfo(out_qconf.dtype).max - if out_qconf.quant_max is None - else out_qconf.quant_max - ) - q_min = ( - torch.iinfo(out_qconf.dtype).min - if out_qconf.quant_min is None - else out_qconf.quant_min - ) + q_max = ( + torch.iinfo(out_qconf.dtype).max + if out_qconf.quant_max is None + else out_qconf.quant_max + ) + q_min = ( + torch.iinfo(out_qconf.dtype).min + if out_qconf.quant_min is None + else out_qconf.quant_min + ) - scale = 1 / (q_max - q_min + 1) + scale = 1 / (q_max - q_min + 1) - output_obs_ctr = observer = FixedQParamsObserver.with_args( - scale=scale, - zero_point=0, - dtype=quantization_config.output_activation.dtype, - qscheme=torch.torch.per_tensor_affine, - quant_max=q_max, - quant_min=q_min, - ) - if quantization_config in ( - get_8a8w_qnn_qat_config(), - get_16a4w_qnn_qat_config(), - ): - output_obs_ctr = FixedQParamsFakeQuantize.with_args( - observer=observer, + output_obs_ctr = observer = FixedQParamsObserver.with_args( scale=scale, zero_point=0, dtype=quantization_config.output_activation.dtype, @@ -761,15 +758,28 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None: quant_max=q_max, quant_min=q_min, ) + if quantization_config in ( + get_8a8w_qnn_qat_config(), + get_16a4w_qnn_qat_config(), + ): + output_obs_ctr = FixedQParamsFakeQuantize.with_args( + observer=observer, + scale=scale, + zero_point=0, + dtype=quantization_config.output_activation.dtype, + qscheme=torch.torch.per_tensor_affine, + quant_max=q_max, + quant_min=q_min, + ) - # make sigmoid map to the range between 0~1 - out_act_quantization_spec = QuantizationSpec( - dtype=quantization_config.output_activation.dtype, - quant_max=q_max, - quant_min=q_min, - observer_or_fake_quant_ctr=output_obs_ctr, - qscheme=torch.torch.per_tensor_affine, - ) + # make sigmoid map to the range between 0~1 + out_act_quantization_spec = QuantizationSpec( + dtype=quantization_config.output_activation.dtype, + quant_max=q_max, + quant_min=q_min, + observer_or_fake_quant_ctr=output_obs_ctr, + qscheme=torch.torch.per_tensor_affine, + ) if _is_float_tensor(node): node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( @@ -798,11 +808,15 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None: value = node.args[3] input_qspec_map = {} - input_qspec_map[value] = quantization_config.input_activation + input_qspec = quantization_config.input_activation + output_qspec = None + if input_qspec is not None: + input_qspec_map[value] = input_qspec + output_qspec = SharedQuantizationSpec((value, node)) node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( input_qspec_map=input_qspec_map, - output_qspec=SharedQuantizationSpec((value, node)), + output_qspec=output_qspec, _annotated=True, ) @@ -818,11 +832,15 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None: value = node.args[2] input_qspec_map = {} - input_qspec_map[value] = quantization_config.input_activation + input_qspec = quantization_config.input_activation + output_qspec = None + if input_qspec is not None: + input_qspec_map[value] = input_qspec + output_qspec = SharedQuantizationSpec((value, node)) node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( input_qspec_map=input_qspec_map, - output_qspec=SharedQuantizationSpec((value, node)), + output_qspec=output_qspec, _annotated=True, ) @@ -942,7 +960,8 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None: act_node = node.args[0] assert isinstance(act_node, Node) input_spec = quantization_config.input_activation - input_qspec_map[act_node] = input_spec + if input_spec is not None: + input_qspec_map[act_node] = input_spec weight_node = node.args[1] assert isinstance(weight_node, Node) @@ -1027,18 +1046,22 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None: return input_qspec_map = {} - for input_node in node.args: - assert isinstance(input_node, Node) - if _is_float_tensor(input_node): - input_qspec_map[input_node] = quantization_config.input_activation - - node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( - input_qspec_map=input_qspec_map, - output_qspec=( + input_spec = quantization_config.input_activation + output_spec = None + if input_spec is not None: + for input_node in node.args: + assert isinstance(input_node, Node) + if _is_float_tensor(input_node): + input_qspec_map[input_node] = input_spec + output_spec = ( quantization_config.output_activation if _is_float_tensor(node) else None - ), + ) + + node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=output_spec, _annotated=True, ) @@ -1058,16 +1081,16 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None: input_qspec_map = {} input_act0 = node.args[0] - if isinstance(input_act0, Node): + if isinstance(input_act0, Node) and input_act_qspec is not None: input_qspec_map[input_act0] = input_act_qspec input_act1 = node.args[1] if isinstance(input_act1, Node): # In matmul, QNN_DATATYPE_SFIXED_POINT_16 Input1 must have QNN_DATATYPE_UFIXED_POINT_16 Input0 and must be symmetric quantized. - if input_act_qspec.dtype == torch.int32: + if input_act_qspec is not None and input_act_qspec.dtype == torch.int32: # we should use int16 for mm / bmm instead of int4 input_qspec_map[input_act1] = get_16a16w_qnn_ptq_config().weight - else: + elif input_act_qspec is not None: input_qspec_map[input_act1] = input_act_qspec node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( @@ -1391,38 +1414,28 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None: input_qspec_map = {} input_act = node.args[0] - input_qspec_map[input_act] = quantization_config.input_activation + input_qspec = quantization_config.input_activation + out_act_quantization_spec = None + if input_qspec is not None: + input_qspec_map[input_act] = input_qspec - assert isinstance(input_act, Node) - out_qconf = quantization_config.output_activation + assert isinstance(input_act, Node) + out_qconf = quantization_config.output_activation - q_max = ( - torch.iinfo(out_qconf.dtype).max - if out_qconf.quant_max is None - else out_qconf.quant_max - ) - q_min = ( - torch.iinfo(out_qconf.dtype).min - if out_qconf.quant_min is None - else out_qconf.quant_min - ) + q_max = ( + torch.iinfo(out_qconf.dtype).max + if out_qconf.quant_max is None + else out_qconf.quant_max + ) + q_min = ( + torch.iinfo(out_qconf.dtype).min + if out_qconf.quant_min is None + else out_qconf.quant_min + ) - scale = 1 / (q_max - q_min + 1) + scale = 1 / (q_max - q_min + 1) - output_obs_ctr = observer = FixedQParamsObserver.with_args( - scale=scale, - zero_point=0, - dtype=quantization_config.output_activation.dtype, - qscheme=torch.torch.per_tensor_affine, - quant_max=q_max, - quant_min=q_min, - ) - if quantization_config in ( - get_8a8w_qnn_qat_config(), - get_16a4w_qnn_qat_config(), - ): - output_obs_ctr = FixedQParamsFakeQuantize.with_args( - observer=observer, + output_obs_ctr = observer = FixedQParamsObserver.with_args( scale=scale, zero_point=0, dtype=quantization_config.output_activation.dtype, @@ -1430,15 +1443,28 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None: quant_max=q_max, quant_min=q_min, ) + if quantization_config in ( + get_8a8w_qnn_qat_config(), + get_16a4w_qnn_qat_config(), + ): + output_obs_ctr = FixedQParamsFakeQuantize.with_args( + observer=observer, + scale=scale, + zero_point=0, + dtype=quantization_config.output_activation.dtype, + qscheme=torch.torch.per_tensor_affine, + quant_max=q_max, + quant_min=q_min, + ) - # make sigmoid map to the range between 0~1 - out_act_quantization_spec = QuantizationSpec( - dtype=quantization_config.output_activation.dtype, - quant_max=q_max, - quant_min=q_min, - observer_or_fake_quant_ctr=output_obs_ctr, - qscheme=torch.torch.per_tensor_affine, - ) + # make sigmoid map to the range between 0~1 + out_act_quantization_spec = QuantizationSpec( + dtype=quantization_config.output_activation.dtype, + quant_max=q_max, + quant_min=q_min, + observer_or_fake_quant_ctr=output_obs_ctr, + qscheme=torch.torch.per_tensor_affine, + ) if _is_float_tensor(node): node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( @@ -1472,12 +1498,16 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None: value = node.args[1] input_qspec_map = {} - input_qspec_map[input] = quantization_config.input_activation - input_qspec_map[value] = SharedQuantizationSpec((input, node)) + input_act_qspec = quantization_config.input_activation + output_qspec = None + if input_act_qspec is not None: + input_qspec_map[input] = input_act_qspec + input_qspec_map[value] = SharedQuantizationSpec((input, node)) + output_qspec = SharedQuantizationSpec((input, node)) node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( input_qspec_map=input_qspec_map, - output_qspec=SharedQuantizationSpec((input, node)), + output_qspec=output_qspec, _annotated=True, ) @@ -1513,16 +1543,19 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None: first_input_node = input_nodes[0] input_qspec_map = {} - assert isinstance(first_input_node, Node) - input_qspec_map[first_input_node] = quantization_config.input_activation - share_qparams_with_input_act0_qspec = SharedQuantizationSpec( - (first_input_node, node) - ) + input_act_qspec = quantization_config.input_activation + share_qparams_with_input_act0_qspec = None + if input_act_qspec is not None: + assert isinstance(first_input_node, Node) + input_qspec_map[first_input_node] = input_act_qspec + share_qparams_with_input_act0_qspec = SharedQuantizationSpec( + (first_input_node, node) + ) - for input_node in input_nodes[1:]: - if input_node not in input_qspec_map: - assert isinstance(input_node, Node) - input_qspec_map[input_node] = share_qparams_with_input_act0_qspec + for input_node in input_nodes[1:]: + if input_node not in input_qspec_map: + assert isinstance(input_node, Node) + input_qspec_map[input_node] = share_qparams_with_input_act0_qspec node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( input_qspec_map=input_qspec_map, @@ -1562,29 +1595,19 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None: return input_qspec_map = {} - input_act = node.args[0] - assert isinstance(input_act, Node) - input_qspec_map[input_act] = quantization_config.input_activation - - out_act_quantization_spec = quantization_config.output_activation - # Based on quantization constraints in QNN document, for the uint16 data type, the scale should be set to 1/32768.0 and the zero_point should be 32768. - if out_act_quantization_spec.dtype == torch.int32: - scale = 1 / 32768.0 - zero_point = 32768 - output_obs_ctr = observer = FixedQParamsObserver.with_args( - scale=scale, - zero_point=zero_point, - dtype=quantization_config.output_activation.dtype, - qscheme=torch.torch.per_tensor_affine, - quant_max=quantization_config.output_activation.quant_max, - quant_min=quantization_config.output_activation.quant_min, - ) - if isinstance( - quantization_config.output_activation.observer_or_fake_quant_ctr, - torch.ao.quantization.fake_quantize.FakeQuantizeBase, - ): - output_obs_ctr = FixedQParamsFakeQuantize.with_args( - observer=observer, + input_act_qspec = quantization_config.input_activation + out_act_quantization_spec = None + if input_act_qspec is not None: + input_act = node.args[0] + assert isinstance(input_act, Node) + input_qspec_map[input_act] = input_act_qspec + + out_act_quantization_spec = quantization_config.output_activation + # Based on quantization constraints in QNN document, for the uint16 data type, the scale should be set to 1/32768.0 and the zero_point should be 32768. + if out_act_quantization_spec.dtype == torch.int32: + scale = 1 / 32768.0 + zero_point = 32768 + output_obs_ctr = observer = FixedQParamsObserver.with_args( scale=scale, zero_point=zero_point, dtype=quantization_config.output_activation.dtype, @@ -1592,14 +1615,27 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None: quant_max=quantization_config.output_activation.quant_max, quant_min=quantization_config.output_activation.quant_min, ) - - out_act_quantization_spec = QuantizationSpec( - dtype=quantization_config.output_activation.dtype, - quant_max=quantization_config.output_activation.quant_max, - quant_min=quantization_config.output_activation.quant_min, - observer_or_fake_quant_ctr=output_obs_ctr, - qscheme=torch.torch.per_tensor_affine, - ) + if isinstance( + quantization_config.output_activation.observer_or_fake_quant_ctr, + torch.ao.quantization.fake_quantize.FakeQuantizeBase, + ): + output_obs_ctr = FixedQParamsFakeQuantize.with_args( + observer=observer, + scale=scale, + zero_point=zero_point, + dtype=quantization_config.output_activation.dtype, + qscheme=torch.torch.per_tensor_affine, + quant_max=quantization_config.output_activation.quant_max, + quant_min=quantization_config.output_activation.quant_min, + ) + + out_act_quantization_spec = QuantizationSpec( + dtype=quantization_config.output_activation.dtype, + quant_max=quantization_config.output_activation.quant_max, + quant_min=quantization_config.output_activation.quant_min, + observer_or_fake_quant_ctr=output_obs_ctr, + qscheme=torch.torch.per_tensor_affine, + ) if _is_float_tensor(node): node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( @@ -1617,14 +1653,18 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None: return input_qspec_map = {} - if _is_float_tensor(node.args[0]): - input_act = node.args[0] - assert isinstance(input_act, Node) - input_qspec_map[input_act] = quantization_config.input_activation + input_act_qspec = quantization_config.input_activation + out_act_quantization_spec = None + if input_act_qspec is not None: + if _is_float_tensor(node.args[0]): + input_act = node.args[0] + assert isinstance(input_act, Node) + input_qspec_map[input_act] = input_act_qspec + out_act_quantization_spec = SharedQuantizationSpec((input_act, node)) node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( input_qspec_map=input_qspec_map, - output_qspec=SharedQuantizationSpec((input_act, node)), + output_qspec=out_act_quantization_spec, _annotated=True, ) @@ -1693,10 +1733,14 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None: if _is_annotated([node]) or not _is_float_tensor(input_act): return input_qspec_map = {} - - assert isinstance(input_act, Node) - share_qparams_with_out_node0_qspec = SharedQuantizationSpec((input_act, node)) - input_qspec_map[input_act] = quantization_config.input_activation + input_act_qspec = quantization_config.input_activation + share_qparams_with_out_node0_qspec = None + if input_act_qspec is not None: + assert isinstance(input_act, Node) + share_qparams_with_out_node0_qspec = SharedQuantizationSpec( + (input_act, node) + ) + input_qspec_map[input_act] = input_act_qspec node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( input_qspec_map=input_qspec_map, @@ -1744,17 +1788,21 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None: return input_qspec_map = {} - for input_node in node.args: - assert isinstance(input_node, Node) - if _is_float_tensor(input_node): - input_qspec_map[input_node] = quantization_config.input_activation - node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( - input_qspec_map=input_qspec_map, - output_qspec=( + input_act_qspec = quantization_config.input_activation + output_qspec = None + if input_act_qspec is not None: + for input_node in node.args: + assert isinstance(input_node, Node) + if _is_float_tensor(input_node): + input_qspec_map[input_node] = input_act_qspec + output_qspec = ( quantization_config.output_activation if _is_float_tensor(node) else None - ), + ) + node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=output_qspec, _annotated=True, ) diff --git a/backends/qualcomm/quantizer/qconfig.py b/backends/qualcomm/quantizer/qconfig.py index b3c5edf9910..2ea2b866ee0 100644 --- a/backends/qualcomm/quantizer/qconfig.py +++ b/backends/qualcomm/quantizer/qconfig.py @@ -110,6 +110,144 @@ def _derive_bias_qparams_fn( ) +def get_fp16a8w_qnn_ptq_config( + act_symmetric: bool = False, + act_observer=MovingAverageMinMaxObserver, + eps: float = None, +) -> QuantizationConfig: + extra_args: Dict[str, Any] = {"eps": eps if eps else DEFAULT_EPS_8BIT} + + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=torch.iinfo(torch.int8).min + 1, + quant_max=torch.iinfo(torch.int8).max, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + bias_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + qscheme=torch.per_tensor_symmetric, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + # input_activation=None, output_activation=None means FP activation (no quantization) + return QuantizationConfig( + input_activation=None, + output_activation=None, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + +def get_fp16a8w_per_channel_quant_config( + act_observer=MovingAverageMinMaxObserver, + act_symmetric: bool = False, + ch_axis: int = 0, + eps: float = None, +) -> QuantizationConfig: + extra_args: Dict[str, Any] = {"eps": eps if eps else DEFAULT_EPS_8BIT} + + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=torch.iinfo(torch.int8).min + 1, + quant_max=torch.iinfo(torch.int8).max, + qscheme=torch.per_channel_symmetric, + ch_axis=ch_axis, + observer_or_fake_quant_ctr=PerChannelParamObserver.with_args(**extra_args), + ) + + return QuantizationConfig( + input_activation=None, + output_activation=None, + weight=weight_quantization_spec, + bias=None, + ) + + +# TODO merge qat and ptq to a function, and use a bool flag to control it +def get_fp16a8w_qnn_qat_config( + act_symmetric: bool = False, + act_observer=MovingAverageMinMaxObserver, + eps: float = None, +) -> QuantizationConfig: + extra_args: Dict[str, Any] = {"eps": eps if eps else DEFAULT_EPS_8BIT} + + weight_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args( + dtype=torch.int8, + quant_min=torch.iinfo(torch.int8).min + 1, + quant_max=torch.iinfo(torch.int8).max, + qscheme=torch.per_tensor_symmetric, + observer=MovingAverageMinMaxObserver.with_args(**extra_args), + ) + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=torch.iinfo(torch.int8).min + 1, + quant_max=torch.iinfo(torch.int8).max, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=weight_fake_quant_ctr, + ) + + bias_fake_quant_ctr = FakeQuantize.with_args( + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + qscheme=torch.per_tensor_symmetric, + observer=MovingAverageMinMaxObserver.with_args(**extra_args), + ) + bias_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + qscheme=torch.per_tensor_symmetric, + observer_or_fake_quant_ctr=bias_fake_quant_ctr, + ) + + # input_activation=None, output_activation=None means FP activation (no quantization) + return QuantizationConfig( + input_activation=None, + output_activation=None, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + +def get_fp16a8w_qat_per_channel_quant_config( + act_observer=MovingAverageMinMaxObserver, + act_symmetric: bool = False, + ch_axis: int = 0, + eps: float = None, +) -> QuantizationConfig: + extra_args: Dict[str, Any] = {"eps": eps if eps else DEFAULT_EPS_8BIT} + + weight_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args( + dtype=torch.int8, + quant_min=torch.iinfo(torch.int8).min + 1, + quant_max=torch.iinfo(torch.int8).max, + qscheme=torch.per_channel_symmetric, + observer=MovingAveragePerChannelMinMaxObserver.with_args(**extra_args), + ) + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=torch.iinfo(torch.int8).min + 1, + quant_max=torch.iinfo(torch.int8).max, + qscheme=torch.per_channel_symmetric, + ch_axis=ch_axis, + observer_or_fake_quant_ctr=weight_fake_quant_ctr, + ) + + return QuantizationConfig( + input_activation=None, + output_activation=None, + weight=weight_quantization_spec, + bias=None, + ) + + def get_8a8w_qnn_ptq_config( act_symmetric: bool = False, act_observer=MovingAverageMinMaxObserver, diff --git a/backends/qualcomm/quantizer/quantizer.py b/backends/qualcomm/quantizer/quantizer.py index 5d297ef14c4..7512ddb93d6 100644 --- a/backends/qualcomm/quantizer/quantizer.py +++ b/backends/qualcomm/quantizer/quantizer.py @@ -51,6 +51,10 @@ get_8a4w_qnn_ptq_config, get_8a8w_qnn_ptq_config, get_8a8w_qnn_qat_config, + get_fp16a8w_per_channel_quant_config, + get_fp16a8w_qat_per_channel_quant_config, + get_fp16a8w_qnn_ptq_config, + get_fp16a8w_qnn_qat_config, get_ptq_per_block_quant_config, get_ptq_per_channel_quant_config, get_qat_per_block_quant_config, @@ -89,6 +93,7 @@ class QuantDtype(IntEnum): use_16a4w_block = 3 use_8a8w = 4 use_8a4w = 5 + use_fp16a8w = 6 QUANT_CONFIG_DICT = { @@ -147,6 +152,16 @@ class QuantDtype(IntEnum): ), None, ), + (QuantDtype.use_fp16a8w, False): ( + get_fp16a8w_qnn_ptq_config, + get_fp16a8w_per_channel_quant_config, + None, + ), + (QuantDtype.use_fp16a8w, True): ( + get_fp16a8w_qnn_qat_config, + get_fp16a8w_qat_per_channel_quant_config, + None, + ), # QAT, (QuantDtype.use_16a4w, True): ( get_16a4w_qnn_qat_config, diff --git a/backends/qualcomm/quantizer/rules.py b/backends/qualcomm/quantizer/rules.py index 878acfea422..f3c33d544f3 100644 --- a/backends/qualcomm/quantizer/rules.py +++ b/backends/qualcomm/quantizer/rules.py @@ -97,13 +97,16 @@ def annotate_single_in_share_out( return input_qspec_map = {} - if _is_float_tensor(node.args[0]): - input_act = node.args[0] + input_act_qspec = quantization_config.input_activation + input_act = node.args[0] + if _is_float_tensor(input_act) and input_act_qspec is not None: assert isinstance(input_act, Node) - input_qspec_map[input_act] = quantization_config.input_activation + input_qspec_map[input_act] = input_act_qspec output_act_qspec = ( - SharedQuantizationSpec((input_act, node)) if _is_float_tensor(node) else None + SharedQuantizationSpec((input_act, node)) + if _is_float_tensor(node) and input_act_qspec is not None + else None ) if len(input_qspec_map) > 0 or output_act_qspec is not None: node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( @@ -118,9 +121,11 @@ def annotate_single_in(node: Node, quantization_config: QuantizationConfig) -> N return input_qspec_map = {} + input_act_qspec = quantization_config.input_activation input_act = node.args[0] assert isinstance(input_act, Node) - input_qspec_map[input_act] = quantization_config.input_activation + if input_act_qspec is not None: + input_qspec_map[input_act] = input_act_qspec if len(input_qspec_map) > 0: node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( @@ -136,10 +141,11 @@ def annotate_single_in_single_out( return input_qspec_map = {} - if _is_float_tensor(node.args[0]): + input_act_qspec = quantization_config.input_activation + if _is_float_tensor(node.args[0]) and input_act_qspec is not None: input_act = node.args[0] assert isinstance(input_act, Node) - input_qspec_map[input_act] = quantization_config.input_activation + input_qspec_map[input_act] = input_act_qspec output_act_qspec = ( quantization_config.output_activation if _is_float_tensor(node) else None @@ -164,11 +170,11 @@ def annotate_binary(node: Node, quantization_config: QuantizationConfig) -> None input_qspec_map = {} input_act0 = node.args[0] - if _is_float_tensor(input_act0): + if _is_float_tensor(input_act0) and input_act_qspec is not None: input_qspec_map[input_act0] = input_act_qspec input_act1 = node.args[1] - if _is_float_tensor(input_act1): + if _is_float_tensor(input_act1) and input_act_qspec is not None: input_qspec_map[input_act1] = input_act_qspec if len(input_qspec_map) > 0 or output_act_qspec is not None: @@ -190,10 +196,11 @@ def annotate_conv(node: Node, quantization_config: QuantizationConfig) -> None: ) input_qspec_map = {} + input_act_qspec = quantization_config.input_activation input_act = node.args[0] assert isinstance(input_act, Node) - input_spec = quantization_config.input_activation - input_qspec_map[input_act] = input_spec + if input_act_qspec is not None: + input_qspec_map[input_act] = input_act_qspec weight = node.args[1] assert isinstance(weight, Node) diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 12d5e0902db..db863051d73 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -2208,13 +2208,21 @@ def forward(self, x): class SimpleModel(torch.nn.Module): - def __init__(self): + def __init__(self, kernel_size=3): super().__init__() kernel_sz = 32 - self.conv1 = torch.nn.Conv2d(kernel_sz, kernel_sz, 3, padding=1, bias=True) - self.conv2 = torch.nn.Conv2d(kernel_sz, kernel_sz, 3, padding=1, bias=True) - self.conv3 = torch.nn.Conv2d(kernel_sz, kernel_sz, 3, padding=1, bias=False) - self.conv4 = torch.nn.Conv2d(kernel_sz, kernel_sz, 3, padding=1, bias=False) + self.conv1 = torch.nn.Conv2d( + kernel_sz, kernel_sz, kernel_size, padding=1, bias=True + ) + self.conv2 = torch.nn.Conv2d( + kernel_sz, kernel_sz, kernel_size, padding=1, bias=True + ) + self.conv3 = torch.nn.Conv2d( + kernel_sz, kernel_sz, kernel_size, padding=1, bias=False + ) + self.conv4 = torch.nn.Conv2d( + kernel_sz, kernel_sz, kernel_size, padding=1, bias=False + ) self.hardtanh = torch.nn.Hardtanh(min_val=0, max_val=6) self.relu = torch.nn.ReLU() self.batch_norm = torch.nn.BatchNorm2d(kernel_sz) diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index d76e3ea1df7..0497974137c 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -882,6 +882,86 @@ def test_qnn_backend_expm1(self): module = ExpM1() # noqa: F405 self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_fp16a8w_conv2d(self): + # fp16a8w: FP16 activation + INT8 weight; weight kernel must be [1,1] + modules = [ + Conv2dSingle( # noqa: F405 + in_channel=2, out_channel=4, kernel_size=1, padding=0 + ), + Conv2dSingle( # noqa: F405 + in_channel=2, out_channel=4, kernel_size=1, padding=0, bias=False + ), + ] + sample_input = (torch.randn([1, 2, 3, 3]),) + for i, module in enumerate(modules): + with self.subTest(i=i): + module = self.get_qdq_module( + module, sample_input, quant_dtype=QuantDtype.use_fp16a8w + ) + self.lower_module_and_test_output(module, sample_input) + + def test_qnn_backend_fp16a8w_conv2d_qat(self): + # fp16a8w QAT: FP16 activation + INT8 weight; weight kernel must be [1,1] + # QAT fake quantize (FusedMovingAvgObsFakeQuantize) requires float32 tensors, + modules = [ + Conv2dSingle( # noqa: F405 + in_channel=2, out_channel=4, kernel_size=1, padding=0 + ), + Conv2dSingle( # noqa: F405 + in_channel=2, out_channel=4, kernel_size=1, padding=0, bias=False + ), + ] + sample_input = (torch.randn([1, 2, 3, 3]),) + for i, module in enumerate(modules): + with self.subTest(i=i): + # QAT in float32 + prepared = self.get_prepared_qat_module( + module, sample_input, quant_dtype=QuantDtype.use_fp16a8w + ) + module = self.get_converted_sgd_trained_module( + module, prepared, sample_input + ) + self.lower_module_and_test_output(module, sample_input) + + def test_qnn_backend_fp16a8w_linear(self): + # fp16a8w: FP16 activation + INT8 weight for linear (per-channel weight quantization) + modules = [Linear(), Linear(use_bias=False)] # noqa: F405 + sample_input = (torch.randn([1, 512]),) + for i, module in enumerate(modules): + with self.subTest(i=i): + module = self.get_qdq_module( + module, + sample_input, + quant_dtype=QuantDtype.use_fp16a8w, + is_linear_per_channel=True, + ) + self.lower_module_and_test_output(module, sample_input) + + def test_qnn_backend_fp16a8w_simple_model(self): + module = SimpleModel(kernel_size=1) # noqa: F405 + sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) + module = self.get_qdq_module( + module, + sample_input, + quant_dtype=QuantDtype.use_fp16a8w, + is_linear_per_channel=True, + ) + self.lower_module_and_test_output(module, sample_input) + + def test_qnn_backend_fp16a8w_fp16_simple_model(self): + module = SimpleModel(kernel_size=1).to(torch.float16) # noqa: F405 + sample_input = ( + torch.ones(1, 32, 28, 28, dtype=torch.float16), + torch.ones(1, 32, 28, 28, dtype=torch.float16), + ) + module = self.get_qdq_module( + module, + sample_input, + quant_dtype=QuantDtype.use_fp16a8w, + is_linear_per_channel=True, + ) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_flip(self): sample_input = (torch.randn(3, 4, 5, 6),) module = Flip() # noqa: F405