Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions src/nncf/experimental/common/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,25 @@ class TensorReducerBase(ABC):
the specified rule. Could handle tensors inplace or out of place.
"""

def __init__(self, reduction_axes: Optional[ReductionAxes] = None, inplace: bool = False):
def __init__(
self,
reduction_axes: Optional[ReductionAxes] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we forward this parameter in the children of the TensorReducerBase?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

keep_axes: Optional[tuple[int, ...]] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
keep_axes: Optional[tuple[int, ...]] = None,
keep_axes: Optional[Axes] = None,

Perhaps we could rename ReductionAxes and reuse them there?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

inplace: bool = False,
):
"""
:param reduction_axes: Reduction axes for reduction calculation. Equal to list(range(len(input.shape)))
if empty.
:param keep_axes: Axes to preserve during the reduction operation. These will be used in
`_reduce_out_of_place()` to calculate the reduction axes once the tensor shape is known.
:param inplace: Whether should be calculated inplace or out of place.
"""
if reduction_axes is not None and keep_axes is not None:
msg = "Only one of `reduction_axes` or `keep_axes` should be specified, not both."
raise nncf.ValidationError(msg)

self._reduction_axes = reduction_axes
self._keep_axes = keep_axes
self._inplace = inplace
self._keepdims = True

Expand Down Expand Up @@ -99,14 +111,24 @@ def __eq__(self, __o: object) -> bool:
isinstance(__o, self.__class__)
and self._reduction_axes == __o._reduction_axes
and self._inplace == __o.inplace
and self._keep_axes == __o._keep_axes
)

def __hash__(self) -> int:
return hash((self.__class__.__name__, self.inplace, self._reduction_axes))
return hash((self.__class__.__name__, self.inplace, self._reduction_axes, self._keep_axes))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we should update __hash__ methods for some of the TensorReducerBase as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


def _get_reduction_axes(self, tensor: Tensor) -> ReductionAxes:
if self._reduction_axes is not None:
return self._reduction_axes

if self._keep_axes is not None:
axes = list(range(tensor.ndim))
if len(axes) > 1:
# Ensure that all axes have positive values
keep_axes = tuple(axes[i] for i in self._keep_axes)
return tuple(set(axes) - set(keep_axes))
return ()

return tuple(range(len(tensor.shape)))


Expand Down
168 changes: 126 additions & 42 deletions src/nncf/quantization/algorithms/smooth_quant/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,17 @@
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
from nncf.common.utils.backend import BackendType
from nncf.common.utils.backend import get_backend
from nncf.experimental.common.tensor_statistics.collectors import MaxAggregator
from nncf.experimental.common.tensor_statistics.collectors import NoopAggregator
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.quantization.algorithms.algorithm import Algorithm
from nncf.tensor import Tensor
from nncf.tensor import functions as fns

TModel = TypeVar("TModel")
TTensor = TypeVar("TTensor")
STATISTIC_BRANCH_KEY = "abs_max"
SHAPE_BRANCH_KEY = "shape"
ALPHA_MAP = {"convolution": 0.05, "matmul": 0.95}


Expand Down Expand Up @@ -98,6 +102,10 @@ def _set_backend_entity(self, model: TModel) -> None:
msg = f"Cannot return backend-specific entity because {model_backend.value} is not supported!"
raise nncf.UnsupportedBackendError(msg)

# Only the OpenVINO backend supports in-place statistics, so we should set this variable here.
if model_backend != BackendType.OPENVINO:
self._inplace_statistics = False

def apply(
self,
model: TModel,
Expand All @@ -108,18 +116,19 @@ def apply(
self._set_backend_entity(model)
alpha_map = self._get_alpha_map()

nodes_to_smooth_data = self._get_nodes_to_smooth_data(graph, alpha_map.keys())
model_transformer = ModelTransformerFactory.create(model)
transformation_layout = TransformationLayout()

node_groups = self._group_nodes_by_source(nodes_to_smooth_data, graph)
nodes = self._get_nodes_to_smooth_data(graph, alpha_map.keys())
nodes = self._retrieve_shape(nodes, statistic_points)
node_groups = self._group_nodes_by_source(nodes, graph)

transformation_layout = TransformationLayout()
for group_id, nodes in track(node_groups.items(), description="Applying Smooth Quant"):
best_scale = None
best_ratio = 0.0
empty_statistic = False

source_node, input_port_id, source_output_port_id, shape = group_id

for node_to_smooth in nodes:
source_node, input_port_id, source_output_port_id, _ = group_id
activations_value = self._get_statistics_for_node(
statistic_points, node_to_smooth.node_name, input_port_id
)
Expand Down Expand Up @@ -168,16 +177,15 @@ def apply(
)
transformation_layout.register(weight_update_command)

activations_by_output_id = {e.output_port_id: e for e in graph.get_output_edges(source_node)}
activations_shape = activations_by_output_id[source_output_port_id].tensor_shape
activation_scale = self._calculate_activation_scale(best_scale, activations_shape, nodes, graph)
activation_scale = self._calculate_activation_scale(best_scale, shape, nodes, graph)

scale_node_name = self._create_scale_node_name(source_node.node_name, source_output_port_id)
scale_insertion_command = self._backend_entity.scale_insertion_command(
source_node, activation_scale.data, source_output_port_id, nodes, scale_node_name
)
transformation_layout.register(scale_insertion_command)

model_transformer = ModelTransformerFactory.create(model)
transformed_model = model_transformer.transform(transformation_layout)
return transformed_model

Expand All @@ -204,27 +212,59 @@ def _calculate_scale_and_ratio(
ratio = scales.min() / (scales.max() + eps)
return scales, ratio

def _group_nodes_by_source(self, nodes_to_smooth: list[dict], nncf_graph: NNCFGraph) -> dict[tuple, list]:
def _group_nodes_by_source(
self, nodes_to_smooth: list[tuple[NNCFNode, int, tuple[int, ...]]], nncf_graph: NNCFGraph
) -> dict[tuple, list]:
"""
Groups nodes that will be smoothed by source (parent node).

:param nodes_to_smooth: List of the nodes that will be smoothed.
:param nodes_to_smooth: A list of tuples where each tuple consists of a node, an input port, and the
shape of the tensor associated with that node and input port.
:param nncf_graph: NNCFGraph instance.
:return: Dictionary with the source info as key and grouped nodes as value.
"""
groups = defaultdict(list)
for node_data in nodes_to_smooth:
node_to_smooth = node_data["node_to_smooth"]
input_act_port = node_data["input_act_port"]
for node_to_smooth, input_act_port, shape in nodes_to_smooth:
source_node = nncf_graph.get_input_edge_by_port_id(node_to_smooth, input_act_port).from_node
edge = nncf_graph.get_edge(source_node, node_to_smooth)
# Such group_id (with node, ports, and shape as a hash) allows us to be confident
# that all sensitive parameters are equal for successor nodes are equal.
group_id = (source_node, input_act_port, edge.output_port_id, hash(str(edge.tensor_shape)))
group_id = (source_node, input_act_port, edge.output_port_id, shape)
groups[group_id].append(node_to_smooth)

return groups

def _retrieve_shape(
self, nodes: list[tuple[NNCFNode, int]], statistic_points: StatisticPointsContainer
) -> list[tuple[NNCFNode, int, tuple[int, ...]]]:
"""
Retrieves the shapes of tensors associated with specific nodes and input ports
from the given statistic points container.

:param nodes: A list of tuples, each containing a node and its corresponding input port index.
:param statistic_points: Container holding statistics, used to retrieve tensor shapes.
:return: A list of tuples where each tuple consists of a node, an input port, and the
shape of the tensor associated with that node and input port. If shape information is
not available, an empty tuple is returned for the shape.
"""
items = []
for node, input_port in nodes:
for tensor_collector in statistic_points.get_algo_statistics_for_node(
node.node_name,
self._backend_entity.get_filter_fn_for_statistics(input_port, self._algorithm_key),
self._algorithm_key,
):
stats = tensor_collector.get_statistics()
shape = stats[SHAPE_BRANCH_KEY]
if shape is not None:
shape = tuple(shape.tolist())
else:
shape = tuple()

items.append((node, input_port, shape))

return items

def _get_statistics_for_node(
self, statistic_points: StatisticPointsContainer, node_name: str, act_port: int
) -> list[TTensor]:
Expand All @@ -247,42 +287,90 @@ def _get_statistics_for_node(
return statistics_for_node

def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPointsContainer:
statistic_container = StatisticPointsContainer()

model_backend = get_backend(model)
self._set_backend_entity(model)
alpha_map = self._get_alpha_map()

alpha_map = self._get_alpha_map()
nodes_to_smooth_data = self._get_nodes_to_smooth_data(graph, alpha_map.keys())

for node_data in nodes_to_smooth_data:
node_to_smooth = node_data["node_to_smooth"]
container = StatisticPointsContainer()
for node_to_smooth, input_act_port in nodes_to_smooth_data:
target_point = self._backend_entity.target_point(
target_type=self._backend_entity.pre_layer_target_type(),
target_node_name=node_to_smooth.node_name,
port_id=node_data["input_act_port"],
)
input_reduction_axes = self._calculate_input_reduction_axes(
graph, node_to_smooth, node_data["input_act_port"]
port_id=input_act_port,
)
stat_collector = self._backend_entity.get_abs_max_channel_collector(
self._subset_size, input_reduction_axes, self._inplace_statistics, STATISTIC_BRANCH_KEY
)
statistic_container.add_statistic_point(
StatisticPoint(
target_point=target_point,
tensor_collector=stat_collector,
algorithm=self._algorithm_key,

# NOTE:The OpenVINO backend performs in-place statistic calculations.
# To insert reduction operations into the model graph, the reduction axes must be known before inference.
# However, when using `keep_axes`, the reduction axes are determined during statistics collection.
# Therefore, `keep_axes` and `inplace` cannot be used together with the OpenVINO backend.
# For the ONNX backend, we can't calculate reduction axes before inference because the tensor shape
# (actually, only the number of dimensions (ndim) is required) is unknown for some operations.
if model_backend == BackendType.ONNX:
keep_axes = (self._backend_entity.get_activation_channel_axis(node_to_smooth, input_act_port),)
collector = self._create_tensor_collector(
self._subset_size,
self._inplace_statistics,
keep_axes=keep_axes,
)
else:
reduction_axes = self._calculate_input_reduction_axes(graph, node_to_smooth, input_act_port)
collector = self._create_tensor_collector(
self._subset_size,
self._inplace_statistics,
reduction_axes=reduction_axes,
)
)
return statistic_container

def _get_nodes_to_smooth_data(self, nncf_graph: NNCFGraph, node_metatypes: list[OperatorMetatype]) -> list[dict]:
container.add_statistic_point(StatisticPoint(target_point, collector, self._algorithm_key))

return container

def _create_tensor_collector(
self,
num_samples: int,
inplace: bool,
keep_axes: Optional[tuple[int, ...]] = None,
reduction_axes: Optional[tuple[int, ...]] = None,
) -> TensorCollector:
"""
Initializes and returns a configured tensor collector for the `SmoothQuant` algorithm.

:param num_samples: Maximum number of samples to collect for the aggregator.
:param inplace: If True, statistics will be computed in-place.
:param keep_axes: Axes to preserve during the reduction operation.
:param reduction_axes: Axes over which the reduction operation is applied.
:return: A tensor collector configured with the specified reduction and aggregation logic.
"""
if reduction_axes is not None and keep_axes is not None:
msg = "Only one of `reduction_axes` or `keep_axes` should be specified, not both."
raise nncf.ValidationError(msg)

collector = TensorCollector()

abs_max_reducer_cls = self._backend_entity.get_abs_max_reducer_cls()
collector.register_statistic_branch(
STATISTIC_BRANCH_KEY,
abs_max_reducer_cls(reduction_axes, keep_axes, inplace),
MaxAggregator(num_samples=num_samples),
)
shape_reducer_cls = self._backend_entity.get_shape_reducer_cls()
collector.register_statistic_branch(
SHAPE_BRANCH_KEY, shape_reducer_cls(inplace), NoopAggregator(num_samples=1, return_first=True)
)

return collector

def _get_nodes_to_smooth_data(
self, nncf_graph: NNCFGraph, node_metatypes: list[OperatorMetatype]
) -> list[tuple[NNCFNode, int]]:
"""
Collects layers whose activations will be smoothed.

:param nncf_graph: NNCFGraph instance.
:param node_metatypes: Metatypes for nodes to search for.
:return: List with the data for each layer.
:return: A list of pairs, where each pair consists of a node and its corresponding
input activation port.
"""
nodes_with_weights = nncf_graph.get_nodes_by_metatypes(node_metatypes)
nodes_to_smooth_data = []
Expand All @@ -306,12 +394,8 @@ def _get_nodes_to_smooth_data(self, nncf_graph: NNCFGraph, node_metatypes: list[
if self._backend_entity.is_node_with_shared_weight(node_with_weight, nncf_graph):
continue

nodes_to_smooth_data.append(
{
"node_to_smooth": node_with_weight,
"input_act_port": activation_port_id,
}
)
nodes_to_smooth_data.append((node_with_weight, activation_port_id))

return nodes_to_smooth_data

def _calculate_activation_scale(
Expand Down
35 changes: 20 additions & 15 deletions src/nncf/quantization/algorithms/smooth_quant/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.commands import TransformationCommand
from nncf.common.tensor_statistics.statistic_point import StatisticPoint
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.experimental.common.tensor_statistics.collectors import AbsMaxReducer
from nncf.experimental.common.tensor_statistics.collectors import ShapeReducer
from nncf.tensor import Tensor

TModel = TypeVar("TModel")
Expand Down Expand Up @@ -97,20 +98,6 @@ def get_activations_port_id(node: NNCFNode, nncf_graph: NNCFGraph) -> int:
:return: Map with the activation & weighted ports.
"""

@staticmethod
@abstractmethod
def get_abs_max_channel_collector(
num_samples: int, stats_reduction_axes: tuple[int], inplace: bool, branch_key: str
) -> TensorCollector:
"""
Returns TensorCollector with MaxAggregator and AbsMaxReducer.

:param stats_reduction_axes: Calculated reduction axes.
:param inplace: Whether to calculate statistic inplace or not.
:param branch_key: Specific string for branch key.
:return: TensorCollector instance.
"""

@staticmethod
@abstractmethod
def get_weight_value(node_with_weight: NNCFNode, model: TModel, nncf_graph: NNCFGraph) -> Tensor:
Expand Down Expand Up @@ -199,3 +186,21 @@ def get_filter_fn_for_statistics(activation_port_id: int, algorithm_key: str) ->
:param algorithm_key: Current algorithm key.
:return: Backend-specific callable to filter statistic containers according to its statistic point.
"""

@staticmethod
def get_abs_max_reducer_cls() -> type[AbsMaxReducer]:
"""
Returns the backend-specific `AbsMaxReducer` class.

:return: The `AbsMaxReducer` class.
"""
return AbsMaxReducer

@staticmethod
def get_shape_reducer_cls() -> type[ShapeReducer]:
"""
Returns the backend-specific `ShapeReducer` class.

:return: The `ShapeReducer` class.
"""
return ShapeReducer
13 changes: 0 additions & 13 deletions src/nncf/quantization/algorithms/smooth_quant/onnx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@
from nncf.common.graph.operator_metatypes import OperatorMetatype
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.tensor_statistics.statistic_point import StatisticPoint
from nncf.experimental.common.tensor_statistics.collectors import AbsMaxReducer
from nncf.experimental.common.tensor_statistics.collectors import MaxAggregator
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.onnx.graph.metatypes.groups import MATMUL_METATYPES
from nncf.onnx.graph.metatypes.groups import OPERATIONS_WITH_WEIGHTS
from nncf.onnx.graph.metatypes.groups import QUANTIZE_AGNOSTIC_OPERATIONS
Expand Down Expand Up @@ -76,16 +73,6 @@ def get_activations_port_id(node: NNCFNode, nncf_graph: NNCFGraph) -> int:

return activation_port

@staticmethod
def get_abs_max_channel_collector(
num_samples: int, stats_reduction_axes: tuple[int], inplace: bool, branch_key: str
) -> TensorCollector:
collector = TensorCollector()
reducer = AbsMaxReducer(reduction_axes=stats_reduction_axes)
aggregator = MaxAggregator(num_samples=num_samples)
collector.register_statistic_branch(branch_key, reducer, aggregator)
return collector

@staticmethod
def _get_weight_tensor_port_id(node: NNCFNode) -> int:
weight_ports = list(node.layer_attributes.weight_attrs)
Expand Down
Loading