diff --git a/src/nncf/experimental/common/tensor_statistics/collectors.py b/src/nncf/experimental/common/tensor_statistics/collectors.py index b11c021a0a6..f783c8e5267 100644 --- a/src/nncf/experimental/common/tensor_statistics/collectors.py +++ b/src/nncf/experimental/common/tensor_statistics/collectors.py @@ -17,6 +17,7 @@ from collections import defaultdict from collections import deque from copy import deepcopy +from enum import Enum from typing import Any, Optional, TypeVar, Union import nncf @@ -35,6 +36,50 @@ InplaceInsertionFNType = TypeVar("InplaceInsertionFNType") AggregationAxes = tuple[int, ...] +Axes = tuple[int, ...] + + +class AxesMode(Enum): + """ + Represents different strategies for handling tensor axes. + + :param REDUCTION: Indicates that the specified axes should be reduced during an operation. + :param KEEP: Indicates that the specified axes should be preserved and not reduced during + an operation. + """ + + REDUCTION = "reduction" + KEEP = "keep" + + +def determine_reduction_axes( + ndim: int, axes: Optional[Axes] = None, axes_mode: AxesMode = AxesMode.REDUCTION +) -> ReductionAxes: + """ + Determines the set of axes along which a reduction operation should be performed + based on the specified axes mode. + + :param ndim: The number of dimensions in the input tensor. + :param axes: The axes specified for the reduction operation. If `None`, all axes + are considered (i.e., `tuple(range(ndim))`). + + :param axes_mode: Defines how the specified axes are interpreted: + - `AxesMode.REDUCTION`: the given axes will be reduced. + - `AxesMode.KEEP`: all axes except the specified ones will be reduced. + :return: The resolved set of axes along which the reduction operation should be performed. + """ + if axes is None: + return tuple(range(ndim)) + + if axes_mode == AxesMode.REDUCTION: + return axes + + all_axes = tuple(range(ndim)) + if len(all_axes) > 1: + # Ensure that all axes have positive values + keep_axes = tuple(all_axes[i] for i in axes) + return tuple(set(all_axes) - set(keep_axes)) + return () class TensorReducerBase(ABC): @@ -43,13 +88,21 @@ class TensorReducerBase(ABC): the specified rule. Could handle tensors inplace or out of place. """ - def __init__(self, reduction_axes: Optional[ReductionAxes] = None, inplace: bool = False): + def __init__( + self, + axes: Optional[Axes] = None, + axes_mode: AxesMode = AxesMode.REDUCTION, + inplace: bool = False, + ): """ - :param reduction_axes: Reduction axes for reduction calculation. Equal to list(range(len(input.shape))) - if empty. + :param axes: The axes along which the reduction operation should be applied. + If `None`, the operation will be applied to all axes (i.e., `tuple(range(tensor.ndim))`). + :param axes_mode: Determines how the specified `axes` are treated during the operation. + Use `AxesMode.REDUCTION` to reduce over the given axes, or `AxesMode.KEEP` to preserve them. :param inplace: Whether should be calculated inplace or out of place. """ - self._reduction_axes = reduction_axes + self._axes = axes + self._axes_mode = axes_mode self._inplace = inplace self._keepdims = True @@ -97,17 +150,13 @@ def __call__(self, x: list[Tensor]): def __eq__(self, __o: object) -> bool: return ( isinstance(__o, self.__class__) - and self._reduction_axes == __o._reduction_axes + and self._axes == __o._axes + and self._axes_mode == __o._axes_mode and self._inplace == __o.inplace ) def __hash__(self) -> int: - return hash((self.__class__.__name__, self.inplace, self._reduction_axes)) - - def _get_reduction_axes(self, tensor: Tensor) -> ReductionAxes: - if self._reduction_axes is not None: - return self._reduction_axes - return tuple(range(len(tensor.shape))) + return hash((self.__class__.__name__, self.inplace, self._axes, self._axes_mode)) class AggregatorBase: @@ -444,35 +493,35 @@ def get_inplace_fn(self) -> Optional[InplaceInsertionFNType]: class MinReducer(TensorReducerBase): def _reduce_out_of_place(self, x: list[Tensor]) -> list[Tensor]: x = x[0] - reduction_axes = self._get_reduction_axes(x) + reduction_axes = determine_reduction_axes(x.ndim, self._axes, self._axes_mode) return [fns.min(x, reduction_axes, keepdims=self._keepdims)] class MaxReducer(TensorReducerBase): def _reduce_out_of_place(self, x: list[Tensor]) -> list[Tensor]: x = x[0] - reduction_axes = self._get_reduction_axes(x) + reduction_axes = determine_reduction_axes(x.ndim, self._axes, self._axes_mode) return [fns.max(x, reduction_axes, keepdims=self._keepdims)] class AbsMaxReducer(TensorReducerBase): def _reduce_out_of_place(self, x: list[Tensor]) -> list[Tensor]: x = fns.abs(x[0]) - reduction_axes = self._get_reduction_axes(x) + reduction_axes = determine_reduction_axes(x.ndim, self._axes, self._axes_mode) return [fns.max(x, reduction_axes, keepdims=self._keepdims)] class MeanReducer(TensorReducerBase): def _reduce_out_of_place(self, x: list[Tensor]) -> list[Tensor]: x = x[0] - reduction_axes = self._get_reduction_axes(x) + reduction_axes = determine_reduction_axes(x.ndim, self._axes, self._axes_mode) return [fns.mean(x, reduction_axes, keepdims=self._keepdims)] class MeanVarianceReducer(TensorReducerBase): def _reduce_out_of_place(self, x: list[Tensor]) -> list[Tensor]: x = x[0] - reduction_axes = self._get_reduction_axes(x) + reduction_axes = determine_reduction_axes(x.ndim, self._axes, self._axes_mode) variance = fns.var(x, reduction_axes) return [fns.mean(variance)] @@ -480,7 +529,7 @@ def _reduce_out_of_place(self, x: list[Tensor]) -> list[Tensor]: class MaxVarianceReducer(TensorReducerBase): def _reduce_out_of_place(self, x: list[Tensor]) -> list[Tensor]: x = x[0] - reduction_axes = self._get_reduction_axes(x) + reduction_axes = determine_reduction_axes(x.ndim, self._axes, self._axes_mode) variance = fns.var(x, reduction_axes) return [fns.max(variance)] @@ -488,7 +537,7 @@ def _reduce_out_of_place(self, x: list[Tensor]) -> list[Tensor]: class MeanAbsMaxReducer(TensorReducerBase): def _reduce_out_of_place(self, x: list[Tensor]) -> list[Tensor]: x = fns.abs(x[0]) - reduction_axes = self._get_reduction_axes(x) + reduction_axes = determine_reduction_axes(x.ndim, self._axes, self._axes_mode) abs_max = fns.max(x, reduction_axes, keepdims=self._keepdims) return [fns.mean(abs_max)] @@ -496,40 +545,42 @@ def _reduce_out_of_place(self, x: list[Tensor]) -> list[Tensor]: class QuantileReducerBase(TensorReducerBase): def __init__( self, - reduction_axes: Optional[ReductionAxes] = None, + axes: Optional[Axes] = None, + axes_mode: AxesMode = AxesMode.REDUCTION, quantile: Optional[Union[float, tuple[float]]] = None, inplace: bool = False, ): - super().__init__(reduction_axes=reduction_axes, inplace=False) + super().__init__(axes, axes_mode, False) self._quantile = (0.01, 0.99) if quantile is None else quantile def __eq__(self, __o: object) -> bool: return super().__eq__(__o) and self._quantile == __o._quantile def __hash__(self) -> int: - return hash((self.__class__.__name__, self.inplace, self._reduction_axes, tuple(self._quantile))) + return hash((self.__class__.__name__, self.inplace, self._axes, self._axes_mode, tuple(self._quantile))) class QuantileReducer(QuantileReducerBase): def _reduce_out_of_place(self, x: list[Tensor]) -> list[Tensor]: x = x[0] - reduction_axes = self._get_reduction_axes(x) + reduction_axes = determine_reduction_axes(x.ndim, self._axes, self._axes_mode) return fns.quantile(x, self._quantile, reduction_axes, keepdims=self._keepdims) class AbsQuantileReducer(QuantileReducerBase): def __init__( self, - reduction_axes: Optional[ReductionAxes] = None, - quantile: Optional[Union[float, list[float]]] = None, + axes: Optional[Axes] = None, + axes_mode: AxesMode = AxesMode.REDUCTION, + quantile: Optional[Union[float, tuple[float]]] = None, inplace: bool = False, ): quantile = (0.99,) if quantile is None else quantile - super().__init__(reduction_axes=reduction_axes, quantile=quantile, inplace=False) + super().__init__(axes, axes_mode, quantile) def _reduce_out_of_place(self, x: list[Tensor]) -> list[Tensor]: x = fns.abs(x[0]) - reduction_axes = self._get_reduction_axes(x) + reduction_axes = determine_reduction_axes(x.ndim, self._axes, self._axes_mode) return fns.quantile(x, self._quantile, reduction_axes, keepdims=self._keepdims) @@ -553,7 +604,7 @@ def __eq__(self, __o: object) -> bool: return super().__eq__(__o) and self._channel_axis == __o._channel_axis def __hash__(self) -> int: - return hash((self.__class__.__name__, self.inplace, self._reduction_axes, self._channel_axis)) + return hash((self.__class__.__name__, self.inplace, self._axes, self._axes_mode, self._channel_axis)) ################################################## diff --git a/src/nncf/openvino/statistics/collectors.py b/src/nncf/openvino/statistics/collectors.py index e081015c138..35e01a55e99 100644 --- a/src/nncf/openvino/statistics/collectors.py +++ b/src/nncf/openvino/statistics/collectors.py @@ -44,37 +44,37 @@ class OVMinReducer(MinReducer): def get_inplace_fn(self): - return get_inplace_min_op(self._reduction_axes) + return get_inplace_min_op(self._axes) class OVMaxReducer(MaxReducer): def get_inplace_fn(self): - return get_inplace_max_op(self._reduction_axes, False) + return get_inplace_max_op(self._axes, False) class OVAbsMaxReducer(AbsMaxReducer): def get_inplace_fn(self): - return get_inplace_max_op(self._reduction_axes, True) + return get_inplace_max_op(self._axes, True) class OVMeanReducer(MeanReducer): def get_inplace_fn(self): - return get_inplace_mean_op(self._reduction_axes) + return get_inplace_mean_op(self._axes) class OVMeanVarianceReducer(MeanVarianceReducer): def get_inplace_fn(self): - return get_inplace_mean_var_op(self._reduction_axes) + return get_inplace_mean_var_op(self._axes) class OVMaxVarianceReducer(MaxVarianceReducer): def get_inplace_fn(self): - return get_inplace_max_var_op(self._reduction_axes) + return get_inplace_max_var_op(self._axes) class OVMeanAbsMaxReducer(MeanAbsMaxReducer): def get_inplace_fn(self): - return get_inplace_mean_max_op(self._reduction_axes, True) + return get_inplace_mean_max_op(self._axes, True) class OVShapeReducer(ShapeReducer): diff --git a/src/nncf/quantization/algorithms/channel_alignment/openvino_backend.py b/src/nncf/quantization/algorithms/channel_alignment/openvino_backend.py index af111c72cb4..fcfc57c12bb 100644 --- a/src/nncf/quantization/algorithms/channel_alignment/openvino_backend.py +++ b/src/nncf/quantization/algorithms/channel_alignment/openvino_backend.py @@ -20,6 +20,7 @@ from nncf.common.graph.layer_attributes import ConvolutionLayerAttributes from nncf.common.graph.transformations.commands import TargetType from nncf.common.tensor_statistics.collectors import TensorStatisticCollectorBase +from nncf.experimental.common.tensor_statistics.collectors import AxesMode from nncf.experimental.common.tensor_statistics.collectors import MedianAggregator from nncf.experimental.common.tensor_statistics.collectors import TensorCollector from nncf.experimental.common.tensor_statistics.statistics import MinMaxTensorStatistic @@ -81,7 +82,7 @@ def get_statistic_collector( reduction_axes, q: float, num_samples: int, inplace: bool ) -> TensorStatisticCollectorBase: tensor_collector = TensorCollector(MinMaxTensorStatistic) - quantile_reducer = OVQuantileReducer(reduction_axes, (q, 1 - q), inplace) + quantile_reducer = OVQuantileReducer(reduction_axes, AxesMode.REDUCTION, (q, 1 - q), inplace) for port_id, container_key in enumerate([MinMaxTensorStatistic.MIN_STAT, MinMaxTensorStatistic.MAX_STAT]): aggregator = MedianAggregator(num_samples=num_samples, aggregation_axes=(0, 1)) diff --git a/src/nncf/quantization/algorithms/min_max/algorithm.py b/src/nncf/quantization/algorithms/min_max/algorithm.py index 643d46b61b5..5ff7c06f3b3 100644 --- a/src/nncf/quantization/algorithms/min_max/algorithm.py +++ b/src/nncf/quantization/algorithms/min_max/algorithm.py @@ -570,14 +570,12 @@ def _get_statistic_collector( else: quantile = 1 - params.quantile_outlier_prob reducer = self._backend_entity.reducer_map[statistic_type]( - reduction_axes=reduction_axes, inplace=inplace, quantile=[quantile] + axes=reduction_axes, inplace=inplace, quantile=[quantile] ) else: if use_abs_max and statistic_type == StatisticsType.MAX: statistic_type = StatisticsType.ABS_MAX - reducer = self._backend_entity.reducer_map[statistic_type]( - reduction_axes=reduction_axes, inplace=inplace - ) + reducer = self._backend_entity.reducer_map[statistic_type](axes=reduction_axes, inplace=inplace) kwargs = { "num_samples": num_samples, diff --git a/src/nncf/quantization/algorithms/smooth_quant/algorithm.py b/src/nncf/quantization/algorithms/smooth_quant/algorithm.py index 701073a218b..4205aae4f5d 100644 --- a/src/nncf/quantization/algorithms/smooth_quant/algorithm.py +++ b/src/nncf/quantization/algorithms/smooth_quant/algorithm.py @@ -21,13 +21,16 @@ from nncf.common.graph.graph import NNCFNode from nncf.common.graph.operator_metatypes import OperatorMetatype from nncf.common.graph.transformations.layout import TransformationLayout -from nncf.common.graph.utils import get_reduction_axes from nncf.common.logging import nncf_logger from nncf.common.logging.track_progress import track from nncf.common.tensor_statistics.statistic_point import StatisticPoint from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer from nncf.common.utils.backend import BackendType from nncf.common.utils.backend import get_backend +from nncf.experimental.common.tensor_statistics.collectors import AxesMode +from nncf.experimental.common.tensor_statistics.collectors import MaxAggregator +from nncf.experimental.common.tensor_statistics.collectors import NoopAggregator +from nncf.experimental.common.tensor_statistics.collectors import TensorCollector from nncf.quantization.algorithms.algorithm import Algorithm from nncf.tensor import Tensor from nncf.tensor import functions as fns @@ -35,6 +38,7 @@ TModel = TypeVar("TModel") TTensor = TypeVar("TTensor") STATISTIC_BRANCH_KEY = "abs_max" +SHAPE_BRANCH_KEY = "shape" ALPHA_MAP = {"convolution": 0.05, "matmul": 0.95} @@ -98,6 +102,10 @@ def _set_backend_entity(self, model: TModel) -> None: msg = f"Cannot return backend-specific entity because {model_backend.value} is not supported!" raise nncf.UnsupportedBackendError(msg) + # Only the OpenVINO backend supports in-place statistics, so we should set this variable here. + if model_backend != BackendType.OPENVINO: + self._inplace_statistics = False + def apply( self, model: TModel, @@ -108,18 +116,19 @@ def apply( self._set_backend_entity(model) alpha_map = self._get_alpha_map() - nodes_to_smooth_data = self._get_nodes_to_smooth_data(graph, alpha_map.keys()) - model_transformer = ModelTransformerFactory.create(model) - transformation_layout = TransformationLayout() - - node_groups = self._group_nodes_by_source(nodes_to_smooth_data, graph) + nodes = self._get_nodes_to_smooth_data(graph, alpha_map.keys()) + nodes = self._retrieve_shape(nodes, statistic_points) + node_groups = self._group_nodes_by_source(nodes, graph) + transformation_layout = TransformationLayout() for group_id, nodes in track(node_groups.items(), description="Applying Smooth Quant"): best_scale = None best_ratio = 0.0 empty_statistic = False + + source_node, input_port_id, source_output_port_id, shape = group_id + for node_to_smooth in nodes: - source_node, input_port_id, source_output_port_id, _ = group_id activations_value = self._get_statistics_for_node( statistic_points, node_to_smooth.node_name, input_port_id ) @@ -168,9 +177,7 @@ def apply( ) transformation_layout.register(weight_update_command) - activations_by_output_id = {e.output_port_id: e for e in graph.get_output_edges(source_node)} - activations_shape = activations_by_output_id[source_output_port_id].tensor_shape - activation_scale = self._calculate_activation_scale(best_scale, activations_shape, nodes, graph) + activation_scale = self._calculate_activation_scale(best_scale, shape, nodes, graph) scale_node_name = self._create_scale_node_name(source_node.node_name, source_output_port_id) scale_insertion_command = self._backend_entity.scale_insertion_command( @@ -178,6 +185,7 @@ def apply( ) transformation_layout.register(scale_insertion_command) + model_transformer = ModelTransformerFactory.create(model) transformed_model = model_transformer.transform(transformation_layout) return transformed_model @@ -204,27 +212,56 @@ def _calculate_scale_and_ratio( ratio = scales.min() / (scales.max() + eps) return scales, ratio - def _group_nodes_by_source(self, nodes_to_smooth: list[dict], nncf_graph: NNCFGraph) -> dict[tuple, list]: + def _group_nodes_by_source( + self, nodes_to_smooth: list[tuple[NNCFNode, int, tuple[int, ...]]], nncf_graph: NNCFGraph + ) -> dict[tuple, list]: """ Groups nodes that will be smoothed by source (parent node). - :param nodes_to_smooth: List of the nodes that will be smoothed. + :param nodes_to_smooth: A list of tuples where each tuple consists of a node, an input port, and the + shape of the tensor associated with that node and input port. :param nncf_graph: NNCFGraph instance. :return: Dictionary with the source info as key and grouped nodes as value. """ groups = defaultdict(list) - for node_data in nodes_to_smooth: - node_to_smooth = node_data["node_to_smooth"] - input_act_port = node_data["input_act_port"] + for node_to_smooth, input_act_port, shape in nodes_to_smooth: source_node = nncf_graph.get_input_edge_by_port_id(node_to_smooth, input_act_port).from_node edge = nncf_graph.get_edge(source_node, node_to_smooth) # Such group_id (with node, ports, and shape as a hash) allows us to be confident # that all sensitive parameters are equal for successor nodes are equal. - group_id = (source_node, input_act_port, edge.output_port_id, hash(str(edge.tensor_shape))) + group_id = (source_node, input_act_port, edge.output_port_id, shape) groups[group_id].append(node_to_smooth) return groups + def _retrieve_shape( + self, nodes: list[tuple[NNCFNode, int]], statistic_points: StatisticPointsContainer + ) -> list[tuple[NNCFNode, int, tuple[int, ...]]]: + """ + Retrieves the shapes of tensors associated with specific nodes and input ports + from the given statistic points container. + + :param nodes: A list of tuples, each containing a node and its corresponding input port index. + :param statistic_points: Container holding statistics, used to retrieve tensor shapes. + :return: A list of tuples where each tuple consists of a node, an input port, and the + shape of the tensor associated with that node and input port. If shape information is + not available, an empty tuple is returned for the shape. + """ + items = [] + for node, input_port in nodes: + for tensor_collector in statistic_points.get_algo_statistics_for_node( + node.node_name, + self._backend_entity.get_filter_fn_for_statistics(input_port, self._algorithm_key), + self._algorithm_key, + ): + stats = tensor_collector.get_statistics() + shape = stats[SHAPE_BRANCH_KEY] + shape = tuple(shape.tolist()) + + items.append((node, input_port, shape)) + + return items + def _get_statistics_for_node( self, statistic_points: StatisticPointsContainer, node_name: str, act_port: int ) -> list[TTensor]: @@ -247,42 +284,76 @@ def _get_statistics_for_node( return statistics_for_node def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPointsContainer: - statistic_container = StatisticPointsContainer() - self._set_backend_entity(model) - alpha_map = self._get_alpha_map() + alpha_map = self._get_alpha_map() nodes_to_smooth_data = self._get_nodes_to_smooth_data(graph, alpha_map.keys()) - for node_data in nodes_to_smooth_data: - node_to_smooth = node_data["node_to_smooth"] + container = StatisticPointsContainer() + for node_to_smooth, input_act_port in nodes_to_smooth_data: target_point = self._backend_entity.target_point( target_type=self._backend_entity.pre_layer_target_type(), target_node_name=node_to_smooth.node_name, - port_id=node_data["input_act_port"], - ) - input_reduction_axes = self._calculate_input_reduction_axes( - graph, node_to_smooth, node_data["input_act_port"] + port_id=input_act_port, ) - stat_collector = self._backend_entity.get_abs_max_channel_collector( - self._subset_size, input_reduction_axes, self._inplace_statistics, STATISTIC_BRANCH_KEY - ) - statistic_container.add_statistic_point( - StatisticPoint( - target_point=target_point, - tensor_collector=stat_collector, - algorithm=self._algorithm_key, - ) - ) - return statistic_container - def _get_nodes_to_smooth_data(self, nncf_graph: NNCFGraph, node_metatypes: list[OperatorMetatype]) -> list[dict]: + # NOTE:The OpenVINO backend performs in-place statistic calculations. + # To insert reduction operations into the model graph, the reduction axes must be known before inference. + # However, when using `keep_axes`, the reduction axes are determined during statistics collection. + # Therefore, `keep_axes` and `inplace` cannot be used together with the OpenVINO backend. + # For the ONNX backend, we can't calculate reduction axes before inference because the tensor shape + # (actually, only the number of dimensions (ndim) is required) is unknown for some operations. + axes_mode, axes = self._backend_entity.get_tensor_collector_axes(graph, node_to_smooth, input_act_port) + + collector = self._create_tensor_collector(self._subset_size, axes, axes_mode) + + container.add_statistic_point(StatisticPoint(target_point, collector, self._algorithm_key)) + + return container + + def _create_tensor_collector( + self, + num_samples: int, + axes: Optional[tuple[int, ...]], + axes_mode: AxesMode, + ) -> TensorCollector: + """ + Initializes and returns a configured tensor collector for the `SmoothQuant` algorithm. + + :param num_samples: Maximum number of samples to collect for the aggregator. + :param axes: The axes specified for the reduction operation. + :param axes_mode: Defines how the specified axes are interpreted: + - `AxesMode.REDUCTION`: the given axes will be reduced. + - `AxesMode.KEEP`: all axes except the specified ones will be reduced. + :return: A tensor collector configured with the specified reduction and aggregation logic. + """ + collector = TensorCollector() + + abs_max_reducer_cls = self._backend_entity.get_abs_max_reducer_cls() + collector.register_statistic_branch( + STATISTIC_BRANCH_KEY, + abs_max_reducer_cls(axes, axes_mode, self._inplace_statistics), + MaxAggregator(num_samples=num_samples), + ) + shape_reducer_cls = self._backend_entity.get_shape_reducer_cls() + collector.register_statistic_branch( + SHAPE_BRANCH_KEY, + shape_reducer_cls(self._inplace_statistics), + NoopAggregator(num_samples=1, return_first=True), + ) + + return collector + + def _get_nodes_to_smooth_data( + self, nncf_graph: NNCFGraph, node_metatypes: list[OperatorMetatype] + ) -> list[tuple[NNCFNode, int]]: """ Collects layers whose activations will be smoothed. :param nncf_graph: NNCFGraph instance. :param node_metatypes: Metatypes for nodes to search for. - :return: List with the data for each layer. + :return: A list of pairs, where each pair consists of a node and its corresponding + input activation port. """ nodes_with_weights = nncf_graph.get_nodes_by_metatypes(node_metatypes) nodes_to_smooth_data = [] @@ -306,12 +377,8 @@ def _get_nodes_to_smooth_data(self, nncf_graph: NNCFGraph, node_metatypes: list[ if self._backend_entity.is_node_with_shared_weight(node_with_weight, nncf_graph): continue - nodes_to_smooth_data.append( - { - "node_to_smooth": node_with_weight, - "input_act_port": activation_port_id, - } - ) + nodes_to_smooth_data.append((node_with_weight, activation_port_id)) + return nodes_to_smooth_data def _calculate_activation_scale( @@ -362,22 +429,6 @@ def _calculate_weight_scale(self, scale_value: Tensor, node: NNCFNode, weights_v return weight_scale return scale_value - def _calculate_input_reduction_axes(self, nncf_graph: NNCFGraph, node: NNCFNode, input_port: int) -> tuple[int]: - """ - Returns reduction axes for specified input. - - :param nncf_graph: NNCFGraph instance. - :param node: NNCFNode to check. - :param input_port: Specified input port id. - :return: Calculated reduction axes. - """ - shape = nncf_graph.get_input_edge_by_port_id(node, input_port).tensor_shape - reduction_axes = tuple([]) - if len(shape) > 1: - channel_axis = self._backend_entity.get_activation_channel_axis(node, input_port) - reduction_axes = get_reduction_axes((channel_axis,), shape) - return reduction_axes - def _process_weight_statistics(self, node: NNCFNode, weights: Tensor) -> Tensor: """ Returns processed weight statistics for node. diff --git a/src/nncf/quantization/algorithms/smooth_quant/backend.py b/src/nncf/quantization/algorithms/smooth_quant/backend.py index 8b62d5bbb51..0555e1e8747 100644 --- a/src/nncf/quantization/algorithms/smooth_quant/backend.py +++ b/src/nncf/quantization/algorithms/smooth_quant/backend.py @@ -19,8 +19,11 @@ from nncf.common.graph.transformations.commands import TargetPoint from nncf.common.graph.transformations.commands import TargetType from nncf.common.graph.transformations.commands import TransformationCommand +from nncf.common.graph.utils import get_reduction_axes from nncf.common.tensor_statistics.statistic_point import StatisticPoint -from nncf.experimental.common.tensor_statistics.collectors import TensorCollector +from nncf.experimental.common.tensor_statistics.collectors import AbsMaxReducer +from nncf.experimental.common.tensor_statistics.collectors import AxesMode +from nncf.experimental.common.tensor_statistics.collectors import ShapeReducer from nncf.tensor import Tensor TModel = TypeVar("TModel") @@ -97,20 +100,6 @@ def get_activations_port_id(node: NNCFNode, nncf_graph: NNCFGraph) -> int: :return: Map with the activation & weighted ports. """ - @staticmethod - @abstractmethod - def get_abs_max_channel_collector( - num_samples: int, stats_reduction_axes: tuple[int], inplace: bool, branch_key: str - ) -> TensorCollector: - """ - Returns TensorCollector with MaxAggregator and AbsMaxReducer. - - :param stats_reduction_axes: Calculated reduction axes. - :param inplace: Whether to calculate statistic inplace or not. - :param branch_key: Specific string for branch key. - :return: TensorCollector instance. - """ - @staticmethod @abstractmethod def get_weight_value(node_with_weight: NNCFNode, model: TModel, nncf_graph: NNCFGraph) -> Tensor: @@ -199,3 +188,50 @@ def get_filter_fn_for_statistics(activation_port_id: int, algorithm_key: str) -> :param algorithm_key: Current algorithm key. :return: Backend-specific callable to filter statistic containers according to its statistic point. """ + + @staticmethod + def get_abs_max_reducer_cls() -> type[AbsMaxReducer]: + """ + Returns the backend-specific `AbsMaxReducer` class. + + :return: The `AbsMaxReducer` class. + """ + return AbsMaxReducer + + @staticmethod + def get_shape_reducer_cls() -> type[ShapeReducer]: + """ + Returns the backend-specific `ShapeReducer` class. + + :return: The `ShapeReducer` class. + """ + return ShapeReducer + + def calculate_input_reduction_axes(self, nncf_graph: NNCFGraph, node: NNCFNode, input_port: int) -> tuple[int]: + """ + Returns reduction axes for specified input. + + :param nncf_graph: NNCFGraph instance. + :param node: NNCFNode to check. + :param input_port: Specified input port id. + :return: Calculated reduction axes. + """ + shape = nncf_graph.get_input_edge_by_port_id(node, input_port).tensor_shape + reduction_axes = tuple([]) + if len(shape) > 1: + channel_axis = self.get_activation_channel_axis(node, input_port) + reduction_axes = get_reduction_axes((channel_axis,), shape) + return reduction_axes + + def get_tensor_collector_axes(self, nncf_graph: NNCFGraph, node_to_smooth: NNCFNode, input_port: int): + """ + Returns axes and axes mode required for tensor collector. + + :param nncf_graph: NNCFGraph instance. + :param node: NNCFNode to smooth. + :param input_port: Specified input port id. + :return: Axes and axes mode required for tensor collector. + """ + axes_mode = AxesMode.REDUCTION + axes = self.calculate_input_reduction_axes(nncf_graph, node_to_smooth, input_port) + return axes_mode, axes diff --git a/src/nncf/quantization/algorithms/smooth_quant/onnx_backend.py b/src/nncf/quantization/algorithms/smooth_quant/onnx_backend.py index be3c3edfa7a..b88bd694a3c 100644 --- a/src/nncf/quantization/algorithms/smooth_quant/onnx_backend.py +++ b/src/nncf/quantization/algorithms/smooth_quant/onnx_backend.py @@ -21,9 +21,7 @@ from nncf.common.graph.operator_metatypes import OperatorMetatype from nncf.common.graph.transformations.commands import TargetType from nncf.common.tensor_statistics.statistic_point import StatisticPoint -from nncf.experimental.common.tensor_statistics.collectors import AbsMaxReducer -from nncf.experimental.common.tensor_statistics.collectors import MaxAggregator -from nncf.experimental.common.tensor_statistics.collectors import TensorCollector +from nncf.experimental.common.tensor_statistics.collectors import AxesMode from nncf.onnx.graph.metatypes.groups import MATMUL_METATYPES from nncf.onnx.graph.metatypes.groups import OPERATIONS_WITH_WEIGHTS from nncf.onnx.graph.metatypes.groups import QUANTIZE_AGNOSTIC_OPERATIONS @@ -76,16 +74,6 @@ def get_activations_port_id(node: NNCFNode, nncf_graph: NNCFGraph) -> int: return activation_port - @staticmethod - def get_abs_max_channel_collector( - num_samples: int, stats_reduction_axes: tuple[int], inplace: bool, branch_key: str - ) -> TensorCollector: - collector = TensorCollector() - reducer = AbsMaxReducer(reduction_axes=stats_reduction_axes) - aggregator = MaxAggregator(num_samples=num_samples) - collector.register_statistic_branch(branch_key, reducer, aggregator) - return collector - @staticmethod def _get_weight_tensor_port_id(node: NNCFNode) -> int: weight_ports = list(node.layer_attributes.weight_attrs) @@ -242,3 +230,8 @@ def filter_func(point: StatisticPoint) -> bool: ) return filter_func + + def get_tensor_collector_axes(self, nncf_graph: NNCFGraph, node_to_smooth: NNCFNode, input_port: int): + axes_mode = AxesMode.KEEP + axes = (self.get_activation_channel_axis(node_to_smooth, input_port),) + return axes_mode, axes diff --git a/src/nncf/quantization/algorithms/smooth_quant/openvino_backend.py b/src/nncf/quantization/algorithms/smooth_quant/openvino_backend.py index 212242b44fe..9aeb202707c 100644 --- a/src/nncf/quantization/algorithms/smooth_quant/openvino_backend.py +++ b/src/nncf/quantization/algorithms/smooth_quant/openvino_backend.py @@ -20,8 +20,6 @@ from nncf.common.graph.operator_metatypes import OperatorMetatype from nncf.common.graph.transformations.commands import TargetType from nncf.common.tensor_statistics.statistic_point import StatisticPoint -from nncf.experimental.common.tensor_statistics.collectors import MaxAggregator -from nncf.experimental.common.tensor_statistics.collectors import TensorCollector from nncf.openvino.graph.layout import OVLayoutElem from nncf.openvino.graph.layout import get_linear_weights_layout_from_node from nncf.openvino.graph.metatypes.groups import QUANTIZE_AGNOSTIC_OPERATIONS @@ -33,6 +31,7 @@ from nncf.openvino.graph.transformations.commands import OVTargetPoint from nncf.openvino.graph.transformations.commands import OVWeightUpdateCommand from nncf.openvino.statistics.collectors import OVAbsMaxReducer +from nncf.openvino.statistics.collectors import OVShapeReducer from nncf.quantization.algorithms.smooth_quant.backend import SmoothQuantAlgoBackend from nncf.tensor import Tensor @@ -76,16 +75,6 @@ def get_activations_port_id(node: NNCFNode, nncf_graph: NNCFGraph) -> int: raise nncf.InternalError(msg) return activation_ports[0] - @staticmethod - def get_abs_max_channel_collector( - num_samples: int, stats_reduction_axes: tuple[int], inplace: bool, branch_key: str - ) -> TensorCollector: - collector = TensorCollector() - reducer = OVAbsMaxReducer(reduction_axes=stats_reduction_axes, inplace=inplace) - aggregator = MaxAggregator(num_samples=num_samples) - collector.register_statistic_branch(branch_key, reducer, aggregator) - return collector - @staticmethod def get_weight_value(node_with_weight: NNCFNode, model: ov.Model, nncf_graph: NNCFGraph) -> Tensor: port_id = OVSmoothQuantAlgoBackend.get_weight_tensor_port_id(node_with_weight) @@ -165,3 +154,11 @@ def filter_func(point: StatisticPoint) -> bool: ) return filter_func + + @staticmethod + def get_abs_max_reducer_cls() -> type[OVAbsMaxReducer]: + return OVAbsMaxReducer + + @staticmethod + def get_shape_reducer_cls() -> type[OVShapeReducer]: + return OVShapeReducer diff --git a/src/nncf/quantization/algorithms/smooth_quant/torch_backend.py b/src/nncf/quantization/algorithms/smooth_quant/torch_backend.py index 356f342f903..2da75ce2433 100644 --- a/src/nncf/quantization/algorithms/smooth_quant/torch_backend.py +++ b/src/nncf/quantization/algorithms/smooth_quant/torch_backend.py @@ -20,9 +20,6 @@ from nncf.common.graph.transformations.commands import TargetType from nncf.common.quantization.quantizer_propagation.structs import QuantizationTrait from nncf.common.tensor_statistics.statistic_point import StatisticPoint -from nncf.experimental.common.tensor_statistics.collectors import AbsMaxReducer -from nncf.experimental.common.tensor_statistics.collectors import MaxAggregator -from nncf.experimental.common.tensor_statistics.collectors import TensorCollector from nncf.quantization.algorithms.smooth_quant.backend import SmoothQuantAlgoBackend from nncf.tensor import Tensor from nncf.torch.function_hook.commands import PT2ConstUpdateCommand @@ -78,16 +75,6 @@ def get_activations_port_id(node: NNCFNode, nncf_graph: NNCFGraph) -> int: # all nodes with the metatypes have 0 activation port id. return 0 - @staticmethod - def get_abs_max_channel_collector( - num_samples: int, stats_reduction_axes: tuple[int], inplace: bool, branch_key: str - ) -> TensorCollector: - collector = TensorCollector() - reducer = AbsMaxReducer(reduction_axes=stats_reduction_axes) - aggregator = MaxAggregator(num_samples=num_samples) - collector.register_statistic_branch(branch_key, reducer, aggregator) - return collector - @staticmethod def get_weight_value(node_with_weight: NNCFNode, model: NNCFNetwork, nncf_graph: NNCFGraph) -> Tensor: if isinstance(model, GraphModelWrapper): diff --git a/src/nncf/quantization/algorithms/smooth_quant/torch_fx_backend.py b/src/nncf/quantization/algorithms/smooth_quant/torch_fx_backend.py index 883314cd3b5..e17dfef3fbd 100644 --- a/src/nncf/quantization/algorithms/smooth_quant/torch_fx_backend.py +++ b/src/nncf/quantization/algorithms/smooth_quant/torch_fx_backend.py @@ -20,9 +20,6 @@ from nncf.common.graph.transformations.commands import TargetType from nncf.common.quantization.quantizer_propagation.structs import QuantizationTrait from nncf.common.tensor_statistics.statistic_point import StatisticPoint -from nncf.experimental.common.tensor_statistics.collectors import AbsMaxReducer -from nncf.experimental.common.tensor_statistics.collectors import MaxAggregator -from nncf.experimental.common.tensor_statistics.collectors import TensorCollector from nncf.experimental.torch.fx.commands import FXApplyTransformationCommand from nncf.experimental.torch.fx.node_utils import get_tensor_constant_from_node from nncf.experimental.torch.fx.transformations import constant_update_transformation_builder @@ -83,16 +80,6 @@ def is_node_with_weights(node: NNCFNode) -> bool: def get_activations_port_id(node: NNCFNode, nncf_graph: NNCFGraph) -> int: return 0 - @staticmethod - def get_abs_max_channel_collector( - num_samples: int, stats_reduction_axes: tuple[int], inplace: bool, branch_key: str - ) -> TensorCollector: - collector = TensorCollector() - reducer = AbsMaxReducer(reduction_axes=stats_reduction_axes) - aggregator = MaxAggregator(num_samples=num_samples) - collector.register_statistic_branch(branch_key, reducer, aggregator) - return collector - @staticmethod def get_weight_value(node_with_weight: NNCFNode, model: torch.fx.GraphModule, nncf_graph: NNCFGraph) -> Tensor: weight_node = get_const_node(node_with_weight, node_with_weight.metatype.weight_port_ids[0], nncf_graph) diff --git a/src/nncf/torch/tensor_statistics/collectors.py b/src/nncf/torch/tensor_statistics/collectors.py index be2b8a0e319..9fe723f8339 100644 --- a/src/nncf/torch/tensor_statistics/collectors.py +++ b/src/nncf/torch/tensor_statistics/collectors.py @@ -277,7 +277,7 @@ def get_mean_percentile_statistic_collector( """ tensor_collector = TensorCollector(_get_wrapped_percentile_tensor_statistic(target_shape=scale_shape)) quantiles_to_collect = np.true_divide(percentiles_to_collect, 100) - reducer = QuantileReducer(reduction_axes=reduction_axes, quantile=quantiles_to_collect) + reducer = QuantileReducer(axes=reduction_axes, quantile=quantiles_to_collect) for output_port_id, p in enumerate(percentiles_to_collect): aggregator = MeanAggregator( aggregation_axes=aggregation_axes, diff --git a/tests/common/experimental/test_reducers_and_aggregators.py b/tests/common/experimental/test_reducers_and_aggregators.py index a9c4df66846..c6027432c5b 100644 --- a/tests/common/experimental/test_reducers_and_aggregators.py +++ b/tests/common/experimental/test_reducers_and_aggregators.py @@ -23,6 +23,7 @@ from nncf.common.graph.layer_attributes import Dtype from nncf.common.tensor import NNCFTensor from nncf.experimental.common.tensor_statistics.collectors import AggregationAxes +from nncf.experimental.common.tensor_statistics.collectors import AxesMode from nncf.experimental.common.tensor_statistics.collectors import HAWQAggregator from nncf.experimental.common.tensor_statistics.collectors import HistogramAggregator from nncf.experimental.common.tensor_statistics.collectors import MaxAggregator @@ -39,6 +40,7 @@ from nncf.experimental.common.tensor_statistics.collectors import PercentileAggregator from nncf.experimental.common.tensor_statistics.collectors import RawReducer from nncf.experimental.common.tensor_statistics.collectors import ShapeReducer +from nncf.experimental.common.tensor_statistics.collectors import determine_reduction_axes from nncf.experimental.common.tensor_statistics.statistics import MinMaxTensorStatistic from nncf.tensor import Tensor from nncf.tensor import functions as fns @@ -220,7 +222,7 @@ def test_min_max_mean_reducers(self, reducer_name, ref, reducers): reduction_axes = (1, 2) input_ = np.arange(-26, 10).reshape((4, 3, 3)) for i, reduction_axes_ in enumerate([reduction_axes, None]): - reducer = reducers[reducer_name](reduction_axes=reduction_axes_, inplace=False) + reducer = reducers[reducer_name](axes=reduction_axes_, inplace=False) val = reducer([self.get_nncf_tensor(input_, Dtype.FLOAT)]) assert len(val) == 1 assert fns.allclose(val[0], self.get_nncf_tensor(ref[i])) @@ -233,7 +235,7 @@ def test_quantile_reducers(self, reducer_name, ref, reducers): input_ = np.arange(-26, 10).reshape((1, 4, 3, 3)) input_[0][0][0] = -20000 input_[0][0][1] = 10000 - reducer = reducers[reducer_name](reduction_axes=reduction_axes, inplace=False) + reducer = reducers[reducer_name](axes=reduction_axes, inplace=False) val = reducer([self.get_nncf_tensor(input_, dtype=Dtype.FLOAT)]) assert val.shape[0] == len(ref) for i, ref_ in enumerate(ref): @@ -244,7 +246,7 @@ def test_quantile_reducers(self, reducer_name, ref, reducers): [[None, 16.1666], [(0,), 14.25], [(0, 1), 15.875], [(0, 1, 2), 16.1666]], ) def test_mean_variance_reducer(self, axes, reference): - reducer = MeanVarianceReducer(reduction_axes=axes) + reducer = MeanVarianceReducer(axes) nncf_data = self.get_nncf_tensor(np.array(WEIGHT_COMPRESSION_REDUCERS_DATA), dtype=Dtype.FLOAT) result = reducer._reduce_out_of_place([nncf_data]) assert len(result) == 1 @@ -255,7 +257,7 @@ def test_mean_variance_reducer(self, axes, reference): [[None, 10.0], [(0,), 4.16666], [(0, 1), 6.33333], [(0, 1, 2), 10.0]], ) def test_mean_abs_max_reducer(self, axes, reference): - reducer = MeanAbsMaxReducer(reduction_axes=axes) + reducer = MeanAbsMaxReducer(axes) nncf_data = self.get_nncf_tensor(np.array(WEIGHT_COMPRESSION_REDUCERS_DATA), dtype=Dtype.FLOAT) result = reducer._reduce_out_of_place([nncf_data]) assert len(result) == 1 @@ -266,7 +268,7 @@ def test_mean_abs_max_reducer(self, axes, reference): [[None, 16.1666], [(0,), 64.0], [(0, 1), 36.1875], [(0, 1, 2), 16.1666]], ) def test_max_variance_reducer(self, axes, reference): - reducer = MaxVarianceReducer(reduction_axes=axes) + reducer = MaxVarianceReducer(axes) nncf_data = self.get_nncf_tensor(np.array(WEIGHT_COMPRESSION_REDUCERS_DATA), dtype=Dtype.FLOAT) result = reducer._reduce_out_of_place([nncf_data]) assert len(result) == 1 @@ -566,10 +568,10 @@ def test_mad_percentile_aggregators_not_implemented_aggregation_axes(self, MAD_p def test_reducers_name_hash_equal(self, reducer_name, reducers): params = {} if reducer_name in ["min", "max", "abs_max", "mean"]: - params["reduction_axes"] = [None, (0, 1, 3), (1, 2, 3)] + params["axes"] = [None, (0, 1, 3), (1, 2, 3)] params["inplace"] = [False, True] elif reducer_name in ["quantile", "abs_quantile"]: - params["reduction_axes"] = [None, (0, 1, 3), (1, 2, 3)] + params["axes"] = [None, (0, 1, 3), (1, 2, 3)] params["quantile"] = [[0.01, 0.99], [0.001, 0.999]] elif reducer_name == "batch_mean": params["inplace"] = [False, True] @@ -716,3 +718,23 @@ def test_histogramm_aggregator(self, ref_hist, ref_min, ref_max, ref_aggr_min, r assert all(isinstance(val, Tensor) for val in aggr.values()) assert fns.allclose(aggr[MinMaxTensorStatistic.MIN_STAT], ref_aggr_min) assert fns.allclose(aggr[MinMaxTensorStatistic.MAX_STAT], ref_aggr_max) + + +@pytest.mark.parametrize( + "ndim, axes, axes_mode, expected_reduction_axes", + [ + [3, (0, 1), AxesMode.REDUCTION, (0, 1)], + [3, None, AxesMode.REDUCTION, (0, 1, 2)], + [3, None, AxesMode.KEEP, (0, 1, 2)], + [2, (-1,), AxesMode.KEEP, (0,)], + [2, (-2,), AxesMode.KEEP, (1,)], + [2, (0,), AxesMode.KEEP, (1,)], + [2, (1,), AxesMode.KEEP, (0,)], + [0, (), AxesMode.KEEP, ()], + ], +) +def test_determine_reduction_axes( + ndim: int, axes: tuple[int, ...], axes_mode: AxesMode, expected_reduction_axes: tuple[int, ...] +): + actual_reduction_axes = determine_reduction_axes(ndim, axes, axes_mode) + assert actual_reduction_axes == expected_reduction_axes diff --git a/tests/common/experimental/test_tensor_collector_batch_size.py b/tests/common/experimental/test_tensor_collector_batch_size.py index 13d2559cfaa..f91b876a8a5 100644 --- a/tests/common/experimental/test_tensor_collector_batch_size.py +++ b/tests/common/experimental/test_tensor_collector_batch_size.py @@ -73,7 +73,7 @@ def _create_tensor_collector(self, shape, inplace, reducer, aggregator) -> Tenso collector = TensorCollector(MinMaxTensorStatistic) reduction_axes = get_reduction_axes([batch_axis], shape) aggregation_axes = (0, 1) - kwargs = {"reduction_axes": reduction_axes, "inplace": inplace} + kwargs = {"axes": reduction_axes, "inplace": inplace} reducer = reducer(**kwargs) aggregator = aggregator( aggregation_axes=aggregation_axes, diff --git a/tests/common/test_statistics_aggregator.py b/tests/common/test_statistics_aggregator.py index 64ee950c323..91fe12e8e0b 100644 --- a/tests/common/test_statistics_aggregator.py +++ b/tests/common/test_statistics_aggregator.py @@ -839,10 +839,10 @@ def test_same_collectors_different_attrs_dont_merge(self, statistics_type, test_ model = params["model"](dataset_samples) params = {} if statistics_type in [StatisticsType.MIN, StatisticsType.MAX, StatisticsType.ABS_MAX, StatisticsType.MEAN]: - params["reduction_axes"] = [None, (0, 1, 3), (1, 2, 3)] + params["axes"] = [None, (0, 1, 3), (1, 2, 3)] params["inplace"] = [False, True] elif statistics_type in [StatisticsType.QUANTILE, StatisticsType.ABS_QUANTILE]: - params["reduction_axes"] = [None, (0, 1, 3), (1, 2, 3)] + params["axes"] = [None, (0, 1, 3), (1, 2, 3)] params["quantile"] = [[0.01, 0.99], [0.001, 0.999]] elif statistics_type == "batch_mean": params["inplace"] = [False, True] diff --git a/tests/cross_fw/test_templates/test_channel_alignment.py b/tests/cross_fw/test_templates/test_channel_alignment.py index 697ee2c3505..bc5ca86daae 100644 --- a/tests/cross_fw/test_templates/test_channel_alignment.py +++ b/tests/cross_fw/test_templates/test_channel_alignment.py @@ -547,7 +547,7 @@ def test_statistic_collectors(self, inplace_ref, q_ref): assert len(statistic_collector.reducers) == 1 reducer = statistic_collector.reducers.pop() assert isinstance(reducer, QuantileReducer) - assert reducer._reduction_axes == reduction_axes_ref + assert reducer._axes == reduction_axes_ref assert np.allclose(reducer._quantile, (q_ref, 1 - q_ref)) assert len(statistic_collector.aggregators) == 2 diff --git a/tests/cross_fw/test_templates/test_quantizer_config.py b/tests/cross_fw/test_templates/test_quantizer_config.py index f9c58b4530c..71c5867560a 100644 --- a/tests/cross_fw/test_templates/test_quantizer_config.py +++ b/tests/cross_fw/test_templates/test_quantizer_config.py @@ -69,7 +69,7 @@ def check_is_mean_min_max_statistic_collector(self, tensor_collector: TensorColl assert aggrs[0].__class__ == aggrs[1].__class__ def get_reduction_axes(self, reducer: TensorReducerBase) -> ReductionAxes: - return reducer._reduction_axes + return reducer._axes @staticmethod def _transform_to_inference_graph(nncf_graph: NNCFGraph, min_max_algo: MinMaxQuantization): diff --git a/tests/cross_fw/test_templates/test_smooth_quant.py b/tests/cross_fw/test_templates/test_smooth_quant.py index 0cefaa9e791..b3b54351372 100644 --- a/tests/cross_fw/test_templates/test_smooth_quant.py +++ b/tests/cross_fw/test_templates/test_smooth_quant.py @@ -20,7 +20,7 @@ from nncf.common.factory import StatisticsAggregatorFactory from nncf.common.graph.graph import NNCFNode from nncf.experimental.common.tensor_statistics.collectors import AbsMaxReducer -from nncf.experimental.common.tensor_statistics.collectors import MaxAggregator +from nncf.experimental.common.tensor_statistics.collectors import ShapeReducer from nncf.parameters import ModelType from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters from nncf.quantization.advanced_parameters import AdvancedSmoothQuantParameters @@ -170,28 +170,6 @@ def test_smooth_quant_algo(self, model_cls, reference_values, tmpdir): self.check_scales(quantized_model, reference_values, model_cls) - def test_get_abs_max_channel_collector(self, inplace_statistics: bool): - backend = self.get_backend() - reduction_axes = (3, 2, 1) - samples = 1 - - backend_tensor_collector = backend.get_abs_max_channel_collector( - num_samples=samples, - stats_reduction_axes=reduction_axes, - inplace=inplace_statistics, - branch_key="test_branch", - ) - - assert len(backend_tensor_collector.aggregators) == 1 - for aggregator in backend_tensor_collector.aggregators.values(): - assert isinstance(aggregator, MaxAggregator) - - assert len(backend_tensor_collector.reducers) == 1 - for reducer in backend_tensor_collector.reducers: - assert isinstance(reducer, AbsMaxReducer) - assert reducer.inplace == inplace_statistics - assert reducer._reduction_axes == reduction_axes - @pytest.mark.parametrize( "model_cls, references", ( @@ -227,7 +205,7 @@ def test__get_nodes_to_smooth_data(self, model_cls, references, tmpdir): algo._set_backend_entity(model) alpha_map = algo._get_alpha_map() smooth_data = algo._get_nodes_to_smooth_data(nncf_graph, alpha_map.keys()) - smooth_data = {d["node_to_smooth"].node_name: d["input_act_port"] for d in smooth_data} + smooth_data = {node.node_name: input_act_port for node, input_act_port in smooth_data} name_map = self.get_node_name_map(model_cls) assert len(name_map) == len(smooth_data) @@ -259,7 +237,13 @@ def test_empty_stats(self, mocker, tmpdir): algo._set_backend_entity = lambda model: backend_entity mocked_transformer = mocker.MagicMock() + empty_shapes = [ + (node, port, ()) for node, port in algo._get_nodes_to_smooth_data(graph, algo._get_alpha_map().keys()) + ] mocker.patch("nncf.common.factory.ModelTransformerFactory.create", return_value=mocked_transformer) + mocker.patch( + "nncf.quantization.algorithms.smooth_quant.algorithm.SmoothQuant._retrieve_shape", return_value=empty_shapes + ) algo.apply(model, graph, algo_statistic_points) mocked_transformer.transform.assert_called_once() @@ -316,3 +300,8 @@ def test_get_weight_channel_axis(self, node_metatype, layer_attributes, referenc pytest.xfail("Expected exception") assert activation_channel_axis == reference_value + + def test_reducers_cls(self): + backend = self.get_backend() + assert backend.get_abs_max_reducer_cls() is AbsMaxReducer + assert backend.get_shape_reducer_cls() is ShapeReducer diff --git a/tests/onnx/common.py b/tests/onnx/common.py index 50cd89176db..f2d47eefebd 100644 --- a/tests/onnx/common.py +++ b/tests/onnx/common.py @@ -34,6 +34,31 @@ def __init__(self): self._outputs = [] self._graph_name = "onnx-graph" + def add_shape(self, data: str, output: Optional[str] = None) -> str: + i = len(self._nodes) + + output = f"Shape_{i}_output" if output is None else output + self._nodes.append(onnx.helper.make_node(op_type="Shape", inputs=[data], outputs=[output], name=f"Shape_{i}")) + return output + + def add_gather(self, data: str, indices: str, axis: int = 0, output: Optional[str] = None) -> str: + i = len(self._nodes) + + output = f"Gather_{i}_output" if output is None else output + self._nodes.append( + onnx.helper.make_node( + op_type="Gather", inputs=[data, indices], outputs=[output], axis=axis, name=f"Gather_{i}" + ) + ) + return output + + def add_reshape(self, data: str, shape: str, output: Optional[str] = None) -> str: + i = len(self._nodes) + + output = f"Reshape_{i}_output" if output is None else output + self._nodes.append(onnx.helper.make_node("Reshape", inputs=[data, shape], outputs=[output])) + return output + def add_input(self, name: str, shape: tuple[int]) -> str: self._inputs.append(onnx.helper.make_tensor_value_info(name, onnx.TensorProto.FLOAT, shape)) return name @@ -63,6 +88,17 @@ def add_matmul( ) return output + def add_initializer(self, data: np.ndarray) -> str: + i = len(self._nodes) + + name = f"Initializer_{i}" + tensor_dtype = onnx.helper.np_dtype_to_tensor_dtype(data.dtype) + initializer = onnx.helper.make_tensor( + name=name, data_type=tensor_dtype, dims=data.shape, vals=data.tobytes(), raw=True + ) + self._initializers.append(initializer) + return name + def add_gemm( self, input: str, @@ -133,10 +169,33 @@ def add_selu(self, input: str, output: Optional[str] = None) -> str: self._nodes.append(onnx.helper.make_node(op_type="Selu", inputs=[input], outputs=[output], name=f"Selu_{i}")) return output + def add_constant(self, data: np.ndarray, output: Optional[str] = None) -> str: + i = len(self._nodes) + + output = f"Constant_{i}_output" if output is None else output + + tensor_dtype = onnx.helper.np_dtype_to_tensor_dtype(data.dtype) + + self._nodes.append( + onnx.helper.make_node( + "Constant", + inputs=[], + outputs=[output], + value=onnx.helper.make_tensor( + name=f"Constant_{i}", + data_type=tensor_dtype, + dims=data.shape, + vals=data.flatten(), + ), + ) + ) + + return output + def add_unsqueeze(self, input: str, axes: tuple[int, ...], output: Optional[str] = None) -> str: i = len(self._nodes) - axes_name = "Unsqueeze_{i}_axes" + axes_name = f"Unsqueeze_{i}_axes" axes_data = np.array(axes, dtype=np.int64) axes_initializer = onnx.helper.make_tensor( name=axes_name, diff --git a/tests/onnx/test_nncf_graph_builder.py b/tests/onnx/test_nncf_graph_builder.py index d81719f73f3..24314e56fe8 100644 --- a/tests/onnx/test_nncf_graph_builder.py +++ b/tests/onnx/test_nncf_graph_builder.py @@ -11,14 +11,17 @@ import os +import numpy as np import onnx import pytest import torch +from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXMatMulMetatype from nncf.onnx.graph.model_transformer import ONNXModelTransformer from nncf.onnx.graph.nncf_graph_builder import GraphConverter from tests.cross_fw.shared.nx_graph import compare_nx_graph_with_reference from tests.cross_fw.shared.paths import TEST_ROOT +from tests.onnx.common import ModelBuilder from tests.onnx.conftest import ONNX_TEST_ROOT from tests.onnx.models import ALL_SYNTHETIC_MODELS from tests.onnx.models import OneConvolutionalModel @@ -112,3 +115,28 @@ def test_add_output_nodes_with_no_parents_node(): nx_graph = nncf_graph.get_graph_for_structure_analysis(extended=True) path_to_dot = REFERENCE_GRAPHS_DIR / "synthetic" / "output_with_no_parents_model.dot" compare_nx_graph_with_reference(nx_graph, path_to_dot, check_edge_attrs=True) + + +@pytest.mark.parametrize("opset_version, ref_shape", [[13, ()], [19, (-1, -1, -1)]]) +def test_unknown_shape(opset_version: int, ref_shape: tuple[int, ...]): + mb = ModelBuilder() + + x = mb.add_input("x", ("batch", 3, 4, 5)) + + y = mb.add_shape(x) + y = mb.add_gather(y, mb.add_initializer(np.array(0, dtype=np.int64))) + y = mb.add_unsqueeze(y, axes=[0]) + y = mb.add_concat([y, mb.add_initializer(np.array([-1, 60], dtype=np.int64))], axis=0) + + x = mb.add_reshape(x, y) + x = mb.add_matmul(x, (60, 10)) + + mb.add_output(x, ("batch", 1, 10)) + + model = mb.build(opset_version, ir_version=9) + + graph = GraphConverter.create_nncf_graph(model) + matmul = graph.get_nodes_by_metatypes([ONNXMatMulMetatype])[0] # only 1 matmul + + for e in graph.get_input_edges(matmul): + assert e.tensor_shape == ref_shape diff --git a/tests/openvino/native/quantization/test_reducers_and_aggregators.py b/tests/openvino/native/quantization/test_reducers_and_aggregators.py index 1f0d5a65e9d..a47773ff82f 100644 --- a/tests/openvino/native/quantization/test_reducers_and_aggregators.py +++ b/tests/openvino/native/quantization/test_reducers_and_aggregators.py @@ -81,7 +81,7 @@ def test_mixed_precision_reducers(self, reducer_cls, reduction_axes, ref_value, input_ = np.arange(2 * 4 * 8).reshape(2, 4, 8) input_[:, :2] *= 2 - reducer = reducer_cls(reduction_axes=reduction_axes, inplace=inplace) + reducer = reducer_cls(axes=reduction_axes, inplace=inplace) inplace_fn = reducer.get_inplace_fn() ov_model_input = opset.parameter(input_.shape) diff --git a/tests/openvino/native/test_smooth_quant.py b/tests/openvino/native/test_smooth_quant.py index 60780122084..a5c5290965e 100644 --- a/tests/openvino/native/test_smooth_quant.py +++ b/tests/openvino/native/test_smooth_quant.py @@ -21,6 +21,8 @@ from nncf.openvino.graph.layout import OVLayoutElem from nncf.openvino.graph.metatypes.openvino_metatypes import OVConvolutionMetatype from nncf.openvino.graph.metatypes.openvino_metatypes import OVMatMulMetatype +from nncf.openvino.statistics.collectors import OVAbsMaxReducer +from nncf.openvino.statistics.collectors import OVShapeReducer from nncf.quantization.algorithms.smooth_quant.openvino_backend import OVSmoothQuantAlgoBackend from tests.cross_fw.test_templates.helpers import ConvTestModel from tests.cross_fw.test_templates.helpers import LinearMultiShapeModel @@ -182,3 +184,8 @@ def test_get_weight_channel_axis(self, node_metatype, weights_layout, reference_ @staticmethod def get_matmul_metatype(): return [OVMatMulMetatype] + + def test_reducers_cls(self): + backend = self.get_backend() + assert backend.get_abs_max_reducer_cls() is OVAbsMaxReducer + assert backend.get_shape_reducer_cls() is OVShapeReducer