1717from collections import defaultdict
1818from collections import deque
1919from copy import deepcopy
20+ from enum import Enum
2021from typing import Any , Optional , TypeVar , Union
2122
2223import nncf
3536
3637InplaceInsertionFNType = TypeVar ("InplaceInsertionFNType" )
3738AggregationAxes = tuple [int , ...]
39+ Axes = tuple [int , ...]
40+
41+
42+ class AxesMode (Enum ):
43+ """
44+ Represents different strategies for handling tensor axes.
45+
46+ :param REDUCTION: Indicates that the specified axes should be reduced during an operation.
47+ :param KEEP: Indicates that the specified axes should be preserved and not reduced during
48+ an operation.
49+ """
50+
51+ REDUCTION = "reduction"
52+ KEEP = "keep"
53+
54+
55+ def determine_reduction_axes (
56+ ndim : int , axes : Optional [Axes ] = None , axes_mode : AxesMode = AxesMode .REDUCTION
57+ ) -> ReductionAxes :
58+ """
59+ Determines the set of axes along which a reduction operation should be performed
60+ based on the specified axes mode.
61+
62+ :param ndim: The number of dimensions in the input tensor.
63+ :param axes: The axes specified for the reduction operation. If `None`, all axes
64+ are considered (i.e., `tuple(range(ndim))`).
65+
66+ :param axes_mode: Defines how the specified axes are interpreted:
67+ - `AxesMode.REDUCTION`: the given axes will be reduced.
68+ - `AxesMode.KEEP`: all axes except the specified ones will be reduced.
69+ :return: The resolved set of axes along which the reduction operation should be performed.
70+ """
71+ if axes is None :
72+ return tuple (range (ndim ))
73+
74+ if axes_mode == AxesMode .REDUCTION :
75+ return axes
76+
77+ all_axes = tuple (range (ndim ))
78+ if len (all_axes ) > 1 :
79+ # Ensure that all axes have positive values
80+ keep_axes = tuple (all_axes [i ] for i in axes )
81+ return tuple (set (all_axes ) - set (keep_axes ))
82+ return ()
3883
3984
4085class TensorReducerBase (ABC ):
@@ -43,13 +88,21 @@ class TensorReducerBase(ABC):
4388 the specified rule. Could handle tensors inplace or out of place.
4489 """
4590
46- def __init__ (self , reduction_axes : Optional [ReductionAxes ] = None , inplace : bool = False ):
91+ def __init__ (
92+ self ,
93+ axes : Optional [Axes ] = None ,
94+ axes_mode : AxesMode = AxesMode .REDUCTION ,
95+ inplace : bool = False ,
96+ ):
4797 """
48- :param reduction_axes: Reduction axes for reduction calculation. Equal to list(range(len(input.shape)))
49- if empty.
98+ :param axes: The axes along which the reduction operation should be applied.
99+ If `None`, the operation will be applied to all axes (i.e., `tuple(range(tensor.ndim))`).
100+ :param axes_mode: Determines how the specified `axes` are treated during the operation.
101+ Use `AxesMode.REDUCTION` to reduce over the given axes, or `AxesMode.KEEP` to preserve them.
50102 :param inplace: Whether should be calculated inplace or out of place.
51103 """
52- self ._reduction_axes = reduction_axes
104+ self ._axes = axes
105+ self ._axes_mode = axes_mode
53106 self ._inplace = inplace
54107 self ._keepdims = True
55108
@@ -97,17 +150,13 @@ def __call__(self, x: list[Tensor]):
97150 def __eq__ (self , __o : object ) -> bool :
98151 return (
99152 isinstance (__o , self .__class__ )
100- and self ._reduction_axes == __o ._reduction_axes
153+ and self ._axes == __o ._axes
154+ and self ._axes_mode == __o ._axes_mode
101155 and self ._inplace == __o .inplace
102156 )
103157
104158 def __hash__ (self ) -> int :
105- return hash ((self .__class__ .__name__ , self .inplace , self ._reduction_axes ))
106-
107- def _get_reduction_axes (self , tensor : Tensor ) -> ReductionAxes :
108- if self ._reduction_axes is not None :
109- return self ._reduction_axes
110- return tuple (range (len (tensor .shape )))
159+ return hash ((self .__class__ .__name__ , self .inplace , self ._axes , self ._axes_mode ))
111160
112161
113162class AggregatorBase :
@@ -444,92 +493,94 @@ def get_inplace_fn(self) -> Optional[InplaceInsertionFNType]:
444493class MinReducer (TensorReducerBase ):
445494 def _reduce_out_of_place (self , x : list [Tensor ]) -> list [Tensor ]:
446495 x = x [0 ]
447- reduction_axes = self . _get_reduction_axes ( x )
496+ reduction_axes = determine_reduction_axes ( x . ndim , self . _axes , self . _axes_mode )
448497 return [fns .min (x , reduction_axes , keepdims = self ._keepdims )]
449498
450499
451500class MaxReducer (TensorReducerBase ):
452501 def _reduce_out_of_place (self , x : list [Tensor ]) -> list [Tensor ]:
453502 x = x [0 ]
454- reduction_axes = self . _get_reduction_axes ( x )
503+ reduction_axes = determine_reduction_axes ( x . ndim , self . _axes , self . _axes_mode )
455504 return [fns .max (x , reduction_axes , keepdims = self ._keepdims )]
456505
457506
458507class AbsMaxReducer (TensorReducerBase ):
459508 def _reduce_out_of_place (self , x : list [Tensor ]) -> list [Tensor ]:
460509 x = fns .abs (x [0 ])
461- reduction_axes = self . _get_reduction_axes ( x )
510+ reduction_axes = determine_reduction_axes ( x . ndim , self . _axes , self . _axes_mode )
462511 return [fns .max (x , reduction_axes , keepdims = self ._keepdims )]
463512
464513
465514class MeanReducer (TensorReducerBase ):
466515 def _reduce_out_of_place (self , x : list [Tensor ]) -> list [Tensor ]:
467516 x = x [0 ]
468- reduction_axes = self . _get_reduction_axes ( x )
517+ reduction_axes = determine_reduction_axes ( x . ndim , self . _axes , self . _axes_mode )
469518 return [fns .mean (x , reduction_axes , keepdims = self ._keepdims )]
470519
471520
472521class MeanVarianceReducer (TensorReducerBase ):
473522 def _reduce_out_of_place (self , x : list [Tensor ]) -> list [Tensor ]:
474523 x = x [0 ]
475- reduction_axes = self . _get_reduction_axes ( x )
524+ reduction_axes = determine_reduction_axes ( x . ndim , self . _axes , self . _axes_mode )
476525 variance = fns .var (x , reduction_axes )
477526 return [fns .mean (variance )]
478527
479528
480529class MaxVarianceReducer (TensorReducerBase ):
481530 def _reduce_out_of_place (self , x : list [Tensor ]) -> list [Tensor ]:
482531 x = x [0 ]
483- reduction_axes = self . _get_reduction_axes ( x )
532+ reduction_axes = determine_reduction_axes ( x . ndim , self . _axes , self . _axes_mode )
484533 variance = fns .var (x , reduction_axes )
485534 return [fns .max (variance )]
486535
487536
488537class MeanAbsMaxReducer (TensorReducerBase ):
489538 def _reduce_out_of_place (self , x : list [Tensor ]) -> list [Tensor ]:
490539 x = fns .abs (x [0 ])
491- reduction_axes = self . _get_reduction_axes ( x )
540+ reduction_axes = determine_reduction_axes ( x . ndim , self . _axes , self . _axes_mode )
492541 abs_max = fns .max (x , reduction_axes , keepdims = self ._keepdims )
493542 return [fns .mean (abs_max )]
494543
495544
496545class QuantileReducerBase (TensorReducerBase ):
497546 def __init__ (
498547 self ,
499- reduction_axes : Optional [ReductionAxes ] = None ,
548+ axes : Optional [Axes ] = None ,
549+ axes_mode : AxesMode = AxesMode .REDUCTION ,
500550 quantile : Optional [Union [float , tuple [float ]]] = None ,
501551 inplace : bool = False ,
502552 ):
503- super ().__init__ (reduction_axes = reduction_axes , inplace = False )
553+ super ().__init__ (axes , axes_mode , False )
504554 self ._quantile = (0.01 , 0.99 ) if quantile is None else quantile
505555
506556 def __eq__ (self , __o : object ) -> bool :
507557 return super ().__eq__ (__o ) and self ._quantile == __o ._quantile
508558
509559 def __hash__ (self ) -> int :
510- return hash ((self .__class__ .__name__ , self .inplace , self ._reduction_axes , tuple (self ._quantile )))
560+ return hash ((self .__class__ .__name__ , self .inplace , self ._axes , self . _axes_mode , tuple (self ._quantile )))
511561
512562
513563class QuantileReducer (QuantileReducerBase ):
514564 def _reduce_out_of_place (self , x : list [Tensor ]) -> list [Tensor ]:
515565 x = x [0 ]
516- reduction_axes = self . _get_reduction_axes ( x )
566+ reduction_axes = determine_reduction_axes ( x . ndim , self . _axes , self . _axes_mode )
517567 return fns .quantile (x , self ._quantile , reduction_axes , keepdims = self ._keepdims )
518568
519569
520570class AbsQuantileReducer (QuantileReducerBase ):
521571 def __init__ (
522572 self ,
523- reduction_axes : Optional [ReductionAxes ] = None ,
524- quantile : Optional [Union [float , list [float ]]] = None ,
573+ axes : Optional [Axes ] = None ,
574+ axes_mode : AxesMode = AxesMode .REDUCTION ,
575+ quantile : Optional [Union [float , tuple [float ]]] = None ,
525576 inplace : bool = False ,
526577 ):
527578 quantile = (0.99 ,) if quantile is None else quantile
528- super ().__init__ (reduction_axes = reduction_axes , quantile = quantile , inplace = False )
579+ super ().__init__ (axes , axes_mode , quantile )
529580
530581 def _reduce_out_of_place (self , x : list [Tensor ]) -> list [Tensor ]:
531582 x = fns .abs (x [0 ])
532- reduction_axes = self . _get_reduction_axes ( x )
583+ reduction_axes = determine_reduction_axes ( x . ndim , self . _axes , self . _axes_mode )
533584 return fns .quantile (x , self ._quantile , reduction_axes , keepdims = self ._keepdims )
534585
535586
@@ -553,7 +604,7 @@ def __eq__(self, __o: object) -> bool:
553604 return super ().__eq__ (__o ) and self ._channel_axis == __o ._channel_axis
554605
555606 def __hash__ (self ) -> int :
556- return hash ((self .__class__ .__name__ , self .inplace , self ._reduction_axes , self ._channel_axis ))
607+ return hash ((self .__class__ .__name__ , self .inplace , self ._axes , self . _axes_mode , self ._channel_axis ))
557608
558609
559610##################################################
0 commit comments