Skip to content

Commit 0379ea1

Browse files
[ONNX][SmoothQuant] Introduce new axes and axes_mode parameters (#3687)
### Changes - This PR introduces a new `axes` and `axes_mode` parameters for `TensorReducerBase`. These parameters have the following meaning: - `axes`: The axes along which the reduction operation should be applied. If `None`, the operation will be applied to all axes (i.e., `tuple(range(tensor.ndim))`). - `axes_mode`: Determines how the specified `axes` are treated during the operation. Use `AxesMode.REDUCTION` to reduce over the given axes, or `AxesMode.KEEP` to preserve them. These parameters are used to calculate the reduction axes (`determine_reduction_axes()` method) during statistic collection, allowing us to avoid requiring the actual tensor shape (actually only number of dimensions `ndim` is required) before inference. - Modifies the `SmoothQuant` algorithm to use the `axes` and `axes_mode` parameters for the ONNX backend instead of relying on the tensor shape from the NNCF graph, as this shape isn't always available. ### Related tickets Ref: 173880, Ref: 174334 ### Tests - Build post_training_quantization # 735 (# 739) - tests/onnx/test_nncf_graph_builder.py::test_unknown_shape
1 parent 5663afd commit 0379ea1

File tree

21 files changed

+410
-204
lines changed

21 files changed

+410
-204
lines changed

src/nncf/experimental/common/tensor_statistics/collectors.py

Lines changed: 78 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from collections import defaultdict
1818
from collections import deque
1919
from copy import deepcopy
20+
from enum import Enum
2021
from typing import Any, Optional, TypeVar, Union
2122

2223
import nncf
@@ -35,6 +36,50 @@
3536

3637
InplaceInsertionFNType = TypeVar("InplaceInsertionFNType")
3738
AggregationAxes = tuple[int, ...]
39+
Axes = tuple[int, ...]
40+
41+
42+
class AxesMode(Enum):
43+
"""
44+
Represents different strategies for handling tensor axes.
45+
46+
:param REDUCTION: Indicates that the specified axes should be reduced during an operation.
47+
:param KEEP: Indicates that the specified axes should be preserved and not reduced during
48+
an operation.
49+
"""
50+
51+
REDUCTION = "reduction"
52+
KEEP = "keep"
53+
54+
55+
def determine_reduction_axes(
56+
ndim: int, axes: Optional[Axes] = None, axes_mode: AxesMode = AxesMode.REDUCTION
57+
) -> ReductionAxes:
58+
"""
59+
Determines the set of axes along which a reduction operation should be performed
60+
based on the specified axes mode.
61+
62+
:param ndim: The number of dimensions in the input tensor.
63+
:param axes: The axes specified for the reduction operation. If `None`, all axes
64+
are considered (i.e., `tuple(range(ndim))`).
65+
66+
:param axes_mode: Defines how the specified axes are interpreted:
67+
- `AxesMode.REDUCTION`: the given axes will be reduced.
68+
- `AxesMode.KEEP`: all axes except the specified ones will be reduced.
69+
:return: The resolved set of axes along which the reduction operation should be performed.
70+
"""
71+
if axes is None:
72+
return tuple(range(ndim))
73+
74+
if axes_mode == AxesMode.REDUCTION:
75+
return axes
76+
77+
all_axes = tuple(range(ndim))
78+
if len(all_axes) > 1:
79+
# Ensure that all axes have positive values
80+
keep_axes = tuple(all_axes[i] for i in axes)
81+
return tuple(set(all_axes) - set(keep_axes))
82+
return ()
3883

3984

4085
class TensorReducerBase(ABC):
@@ -43,13 +88,21 @@ class TensorReducerBase(ABC):
4388
the specified rule. Could handle tensors inplace or out of place.
4489
"""
4590

46-
def __init__(self, reduction_axes: Optional[ReductionAxes] = None, inplace: bool = False):
91+
def __init__(
92+
self,
93+
axes: Optional[Axes] = None,
94+
axes_mode: AxesMode = AxesMode.REDUCTION,
95+
inplace: bool = False,
96+
):
4797
"""
48-
:param reduction_axes: Reduction axes for reduction calculation. Equal to list(range(len(input.shape)))
49-
if empty.
98+
:param axes: The axes along which the reduction operation should be applied.
99+
If `None`, the operation will be applied to all axes (i.e., `tuple(range(tensor.ndim))`).
100+
:param axes_mode: Determines how the specified `axes` are treated during the operation.
101+
Use `AxesMode.REDUCTION` to reduce over the given axes, or `AxesMode.KEEP` to preserve them.
50102
:param inplace: Whether should be calculated inplace or out of place.
51103
"""
52-
self._reduction_axes = reduction_axes
104+
self._axes = axes
105+
self._axes_mode = axes_mode
53106
self._inplace = inplace
54107
self._keepdims = True
55108

@@ -97,17 +150,13 @@ def __call__(self, x: list[Tensor]):
97150
def __eq__(self, __o: object) -> bool:
98151
return (
99152
isinstance(__o, self.__class__)
100-
and self._reduction_axes == __o._reduction_axes
153+
and self._axes == __o._axes
154+
and self._axes_mode == __o._axes_mode
101155
and self._inplace == __o.inplace
102156
)
103157

104158
def __hash__(self) -> int:
105-
return hash((self.__class__.__name__, self.inplace, self._reduction_axes))
106-
107-
def _get_reduction_axes(self, tensor: Tensor) -> ReductionAxes:
108-
if self._reduction_axes is not None:
109-
return self._reduction_axes
110-
return tuple(range(len(tensor.shape)))
159+
return hash((self.__class__.__name__, self.inplace, self._axes, self._axes_mode))
111160

112161

113162
class AggregatorBase:
@@ -444,92 +493,94 @@ def get_inplace_fn(self) -> Optional[InplaceInsertionFNType]:
444493
class MinReducer(TensorReducerBase):
445494
def _reduce_out_of_place(self, x: list[Tensor]) -> list[Tensor]:
446495
x = x[0]
447-
reduction_axes = self._get_reduction_axes(x)
496+
reduction_axes = determine_reduction_axes(x.ndim, self._axes, self._axes_mode)
448497
return [fns.min(x, reduction_axes, keepdims=self._keepdims)]
449498

450499

451500
class MaxReducer(TensorReducerBase):
452501
def _reduce_out_of_place(self, x: list[Tensor]) -> list[Tensor]:
453502
x = x[0]
454-
reduction_axes = self._get_reduction_axes(x)
503+
reduction_axes = determine_reduction_axes(x.ndim, self._axes, self._axes_mode)
455504
return [fns.max(x, reduction_axes, keepdims=self._keepdims)]
456505

457506

458507
class AbsMaxReducer(TensorReducerBase):
459508
def _reduce_out_of_place(self, x: list[Tensor]) -> list[Tensor]:
460509
x = fns.abs(x[0])
461-
reduction_axes = self._get_reduction_axes(x)
510+
reduction_axes = determine_reduction_axes(x.ndim, self._axes, self._axes_mode)
462511
return [fns.max(x, reduction_axes, keepdims=self._keepdims)]
463512

464513

465514
class MeanReducer(TensorReducerBase):
466515
def _reduce_out_of_place(self, x: list[Tensor]) -> list[Tensor]:
467516
x = x[0]
468-
reduction_axes = self._get_reduction_axes(x)
517+
reduction_axes = determine_reduction_axes(x.ndim, self._axes, self._axes_mode)
469518
return [fns.mean(x, reduction_axes, keepdims=self._keepdims)]
470519

471520

472521
class MeanVarianceReducer(TensorReducerBase):
473522
def _reduce_out_of_place(self, x: list[Tensor]) -> list[Tensor]:
474523
x = x[0]
475-
reduction_axes = self._get_reduction_axes(x)
524+
reduction_axes = determine_reduction_axes(x.ndim, self._axes, self._axes_mode)
476525
variance = fns.var(x, reduction_axes)
477526
return [fns.mean(variance)]
478527

479528

480529
class MaxVarianceReducer(TensorReducerBase):
481530
def _reduce_out_of_place(self, x: list[Tensor]) -> list[Tensor]:
482531
x = x[0]
483-
reduction_axes = self._get_reduction_axes(x)
532+
reduction_axes = determine_reduction_axes(x.ndim, self._axes, self._axes_mode)
484533
variance = fns.var(x, reduction_axes)
485534
return [fns.max(variance)]
486535

487536

488537
class MeanAbsMaxReducer(TensorReducerBase):
489538
def _reduce_out_of_place(self, x: list[Tensor]) -> list[Tensor]:
490539
x = fns.abs(x[0])
491-
reduction_axes = self._get_reduction_axes(x)
540+
reduction_axes = determine_reduction_axes(x.ndim, self._axes, self._axes_mode)
492541
abs_max = fns.max(x, reduction_axes, keepdims=self._keepdims)
493542
return [fns.mean(abs_max)]
494543

495544

496545
class QuantileReducerBase(TensorReducerBase):
497546
def __init__(
498547
self,
499-
reduction_axes: Optional[ReductionAxes] = None,
548+
axes: Optional[Axes] = None,
549+
axes_mode: AxesMode = AxesMode.REDUCTION,
500550
quantile: Optional[Union[float, tuple[float]]] = None,
501551
inplace: bool = False,
502552
):
503-
super().__init__(reduction_axes=reduction_axes, inplace=False)
553+
super().__init__(axes, axes_mode, False)
504554
self._quantile = (0.01, 0.99) if quantile is None else quantile
505555

506556
def __eq__(self, __o: object) -> bool:
507557
return super().__eq__(__o) and self._quantile == __o._quantile
508558

509559
def __hash__(self) -> int:
510-
return hash((self.__class__.__name__, self.inplace, self._reduction_axes, tuple(self._quantile)))
560+
return hash((self.__class__.__name__, self.inplace, self._axes, self._axes_mode, tuple(self._quantile)))
511561

512562

513563
class QuantileReducer(QuantileReducerBase):
514564
def _reduce_out_of_place(self, x: list[Tensor]) -> list[Tensor]:
515565
x = x[0]
516-
reduction_axes = self._get_reduction_axes(x)
566+
reduction_axes = determine_reduction_axes(x.ndim, self._axes, self._axes_mode)
517567
return fns.quantile(x, self._quantile, reduction_axes, keepdims=self._keepdims)
518568

519569

520570
class AbsQuantileReducer(QuantileReducerBase):
521571
def __init__(
522572
self,
523-
reduction_axes: Optional[ReductionAxes] = None,
524-
quantile: Optional[Union[float, list[float]]] = None,
573+
axes: Optional[Axes] = None,
574+
axes_mode: AxesMode = AxesMode.REDUCTION,
575+
quantile: Optional[Union[float, tuple[float]]] = None,
525576
inplace: bool = False,
526577
):
527578
quantile = (0.99,) if quantile is None else quantile
528-
super().__init__(reduction_axes=reduction_axes, quantile=quantile, inplace=False)
579+
super().__init__(axes, axes_mode, quantile)
529580

530581
def _reduce_out_of_place(self, x: list[Tensor]) -> list[Tensor]:
531582
x = fns.abs(x[0])
532-
reduction_axes = self._get_reduction_axes(x)
583+
reduction_axes = determine_reduction_axes(x.ndim, self._axes, self._axes_mode)
533584
return fns.quantile(x, self._quantile, reduction_axes, keepdims=self._keepdims)
534585

535586

@@ -553,7 +604,7 @@ def __eq__(self, __o: object) -> bool:
553604
return super().__eq__(__o) and self._channel_axis == __o._channel_axis
554605

555606
def __hash__(self) -> int:
556-
return hash((self.__class__.__name__, self.inplace, self._reduction_axes, self._channel_axis))
607+
return hash((self.__class__.__name__, self.inplace, self._axes, self._axes_mode, self._channel_axis))
557608

558609

559610
##################################################

src/nncf/openvino/statistics/collectors.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,37 +44,37 @@
4444

4545
class OVMinReducer(MinReducer):
4646
def get_inplace_fn(self):
47-
return get_inplace_min_op(self._reduction_axes)
47+
return get_inplace_min_op(self._axes)
4848

4949

5050
class OVMaxReducer(MaxReducer):
5151
def get_inplace_fn(self):
52-
return get_inplace_max_op(self._reduction_axes, False)
52+
return get_inplace_max_op(self._axes, False)
5353

5454

5555
class OVAbsMaxReducer(AbsMaxReducer):
5656
def get_inplace_fn(self):
57-
return get_inplace_max_op(self._reduction_axes, True)
57+
return get_inplace_max_op(self._axes, True)
5858

5959

6060
class OVMeanReducer(MeanReducer):
6161
def get_inplace_fn(self):
62-
return get_inplace_mean_op(self._reduction_axes)
62+
return get_inplace_mean_op(self._axes)
6363

6464

6565
class OVMeanVarianceReducer(MeanVarianceReducer):
6666
def get_inplace_fn(self):
67-
return get_inplace_mean_var_op(self._reduction_axes)
67+
return get_inplace_mean_var_op(self._axes)
6868

6969

7070
class OVMaxVarianceReducer(MaxVarianceReducer):
7171
def get_inplace_fn(self):
72-
return get_inplace_max_var_op(self._reduction_axes)
72+
return get_inplace_max_var_op(self._axes)
7373

7474

7575
class OVMeanAbsMaxReducer(MeanAbsMaxReducer):
7676
def get_inplace_fn(self):
77-
return get_inplace_mean_max_op(self._reduction_axes, True)
77+
return get_inplace_mean_max_op(self._axes, True)
7878

7979

8080
class OVShapeReducer(ShapeReducer):

src/nncf/quantization/algorithms/channel_alignment/openvino_backend.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from nncf.common.graph.layer_attributes import ConvolutionLayerAttributes
2121
from nncf.common.graph.transformations.commands import TargetType
2222
from nncf.common.tensor_statistics.collectors import TensorStatisticCollectorBase
23+
from nncf.experimental.common.tensor_statistics.collectors import AxesMode
2324
from nncf.experimental.common.tensor_statistics.collectors import MedianAggregator
2425
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
2526
from nncf.experimental.common.tensor_statistics.statistics import MinMaxTensorStatistic
@@ -81,7 +82,7 @@ def get_statistic_collector(
8182
reduction_axes, q: float, num_samples: int, inplace: bool
8283
) -> TensorStatisticCollectorBase:
8384
tensor_collector = TensorCollector(MinMaxTensorStatistic)
84-
quantile_reducer = OVQuantileReducer(reduction_axes, (q, 1 - q), inplace)
85+
quantile_reducer = OVQuantileReducer(reduction_axes, AxesMode.REDUCTION, (q, 1 - q), inplace)
8586

8687
for port_id, container_key in enumerate([MinMaxTensorStatistic.MIN_STAT, MinMaxTensorStatistic.MAX_STAT]):
8788
aggregator = MedianAggregator(num_samples=num_samples, aggregation_axes=(0, 1))

src/nncf/quantization/algorithms/min_max/algorithm.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -570,14 +570,12 @@ def _get_statistic_collector(
570570
else:
571571
quantile = 1 - params.quantile_outlier_prob
572572
reducer = self._backend_entity.reducer_map[statistic_type](
573-
reduction_axes=reduction_axes, inplace=inplace, quantile=[quantile]
573+
axes=reduction_axes, inplace=inplace, quantile=[quantile]
574574
)
575575
else:
576576
if use_abs_max and statistic_type == StatisticsType.MAX:
577577
statistic_type = StatisticsType.ABS_MAX
578-
reducer = self._backend_entity.reducer_map[statistic_type](
579-
reduction_axes=reduction_axes, inplace=inplace
580-
)
578+
reducer = self._backend_entity.reducer_map[statistic_type](axes=reduction_axes, inplace=inplace)
581579

582580
kwargs = {
583581
"num_samples": num_samples,

0 commit comments

Comments
 (0)