Skip to content

AWQ Support for ONNX Backend #3571

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions src/nncf/onnx/graph/model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from nncf.onnx.graph.onnx_helper import get_tensor
from nncf.onnx.graph.transformations.commands import ONNXInitializerUpdateCommand
from nncf.onnx.graph.transformations.commands import ONNXModelExtractionCommand
from nncf.onnx.graph.transformations.commands import ONNXMultiplyInsertionCommand
from nncf.onnx.graph.transformations.commands import ONNXOutputInsertionCommand
from nncf.onnx.graph.transformations.commands import ONNXQDQNodeRemovingCommand
from nncf.onnx.graph.transformations.commands import ONNXQuantizerInsertionCommand
Expand Down Expand Up @@ -91,6 +92,7 @@ def transform(self, transformation_layout: TransformationLayout) -> onnx.ModelPr
initializer_update_transformations = []
qdq_node_removing_transformations = []
model_extraction_transformation = None
multiply_insert_transformations = []
transformations = transformation_layout.transformations
# No transformation applied
if not transformations:
Expand All @@ -106,6 +108,8 @@ def transform(self, transformation_layout: TransformationLayout) -> onnx.ModelPr
qdq_node_removing_transformations.append(transformation)
elif isinstance(transformation, ONNXInitializerUpdateCommand):
initializer_update_transformations.append(transformation)
elif isinstance(transformation, ONNXMultiplyInsertionCommand):
multiply_insert_transformations.append(transformation)
# Inplace transformations, using deepcopy of model
if quantizer_insert_transformations or initializer_update_transformations or qdq_node_removing_transformations:
model = deepcopy(self._model)
Expand All @@ -115,6 +119,8 @@ def transform(self, transformation_layout: TransformationLayout) -> onnx.ModelPr
model = self._apply_qdq_node_removing_transformations(model, qdq_node_removing_transformations)
if initializer_update_transformations:
model = self._apply_initializer_update_transformations(model, initializer_update_transformations)
if multiply_insert_transformations:
model = self._apply_multiply_insertion_transformations(model, multiply_insert_transformations)
# Transformations that create new model
if output_insert_transformations:
model = self._apply_output_insertion_transformations(output_insert_transformations)
Expand Down Expand Up @@ -459,6 +465,48 @@ def _apply_qdq_node_removing_transformations(

return model

@staticmethod
def _apply_multiply_insertion_transformations(
model: onnx.ModelProto, transformations: list[ONNXMultiplyInsertionCommand]
) -> onnx.ModelProto:
"""
Inserts Multiply with provided value for corresponding layer.

:param transformations: List of the smooth insertion transformations.
:returns: Transformed model with Multiply nodes.
"""
node_name_to_node = get_name_to_node_map(model)

for transformation in transformations:
target_node_name = transformation.target_point.target_node_name
target_output_port = transformation.target_point.port_id
target_node = node_name_to_node[target_node_name]
output_tensor_name = target_node.output[target_output_port]

# Create a new initializer for the scale constant
scale_tensor_name = f"{transformation.multiply_node_name}_scale"
scale_tensor = onnx.numpy_helper.from_array(transformation.scale_value, name=scale_tensor_name)
model.graph.initializer.append(scale_tensor)

# Create a new Multiply node
mul_output_name = f"{transformation.multiply_node_name}_output"
mul_node = onnx.helper.make_node(
"Mul",
inputs=[output_tensor_name, scale_tensor_name],
outputs=[mul_output_name],
name=transformation.multiply_node_name,
)
target_index = get_node_index(model, target_node_name)
model.graph.insert(target_index + 1, mul_node)

for name in transformation.destination_node_names:
node = node_name_to_node[name]
for i, input_name in enumerate(node.input):
if input_name == output_tensor_name:
node.input[i] = mul_output_name

return model


def set_initializer(initializer_name: str, model: onnx.ModelProto, new_value: np.ndarray) -> None:
"""
Expand Down
13 changes: 13 additions & 0 deletions src/nncf/onnx/graph/transformations/command_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from nncf.common.graph.transformations.command_creation import CommandCreator
from nncf.common.graph.transformations.commands import TargetType
from nncf.onnx.graph.transformations.commands import ONNXInitializerUpdateCommand
from nncf.onnx.graph.transformations.commands import ONNXMultiplyInsertionCommand
from nncf.onnx.graph.transformations.commands import ONNXQDQNodeRemovingCommand
from nncf.onnx.graph.transformations.commands import ONNXTargetPoint

Expand Down Expand Up @@ -59,3 +60,15 @@ def create_command_to_update_weight(
@staticmethod
def create_command_to_insert_bias(node_without_bias, bias_value):
raise NotImplementedError

@staticmethod
def multiply_insertion_command(
source_node: NNCFNode,
destination_nodes: list[NNCFNode],
source_out_port: int,
scale_value: np.ndarray,
multiply_node_name: str,
) -> ONNXMultiplyInsertionCommand:
target_point = ONNXTargetPoint(TargetType.POST_LAYER_OPERATION, source_node.node_name, source_out_port)
destination_node_names = [d.node_name for d in destination_nodes]
return ONNXMultiplyInsertionCommand(target_point, scale_value, destination_node_names, multiply_node_name)
24 changes: 24 additions & 0 deletions src/nncf/onnx/graph/transformations/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,27 @@ def __init__(self, target_point: ONNXTargetPoint):
:param target_point: The TargetPoint instance for the layer that contains information for removing.
"""
super().__init__(TransformationType.REMOVE, target_point)


class ONNXMultiplyInsertionCommand(ONNXInsertionCommand):
"""
Inserts Multiply nodes before the corresponding nodes.
"""

def __init__(
self,
target_point: ONNXTargetPoint,
scale_value: np.ndarray,
destination_node_names: list[str],
multiply_node_name: str,
):
"""
:param target_point: The TargetPoint instance for the insertion that contains layer's information.
:param scale_value: Scale value for Multiply layer.
:param destination_node_names: New layer consumers.
:param multiply_node_name: New layer name.
"""
super().__init__(target_point)
self.scale_value = scale_value
self.destination_node_names = destination_node_names
self.multiply_node_name = multiply_node_name
4 changes: 4 additions & 0 deletions src/nncf/quantization/algorithms/weight_compression/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ def _set_backend_entity(
from nncf.quantization.algorithms.weight_compression.torch_fx_backend import FXAWQAlgoAlgoBackend

self._backend_entity = FXAWQAlgoAlgoBackend()
elif model_backend == BackendType.ONNX:
from nncf.quantization.algorithms.weight_compression.onnx_backend import ONNXAWQAlgoAlgoBackend

self._backend_entity = ONNXAWQAlgoAlgoBackend(model)
else:
msg = f"Cannot return backend-specific AWQ entity because {model_backend.value} is not supported!"
raise nncf.UnsupportedBackendError(msg)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
from nncf.experimental.common.tensor_statistics.collectors import ShapeReducer
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.experimental.common.tensor_statistics.statistics import WCTensorStatistic
from nncf.onnx.graph.metatypes import onnx_metatypes
from nncf.onnx.graph.metatypes.groups import ATOMIC_ACTIVATIONS_OPERATIONS
from nncf.onnx.graph.metatypes.groups import CONVOLUTION_METATYPES
from nncf.onnx.graph.metatypes.groups import MATMUL_METATYPES
from nncf.onnx.graph.model_transformer import remove_initializer
Expand All @@ -43,11 +45,14 @@
from nncf.onnx.graph.onnx_helper import get_tensor_value
from nncf.onnx.graph.onnx_helper import pack_4_bits
from nncf.onnx.graph.onnx_helper import pack_int4_to_uint8
from nncf.onnx.graph.transformations.command_creation import ONNXCommandCreator
from nncf.onnx.graph.transformations.commands import ONNXTargetPoint
from nncf.onnx.quantization.ignored_patterns import create_rope
from nncf.parameters import CompressionFormat
from nncf.parameters import CompressWeightsMode
from nncf.quantization.advanced_parameters import AdvancedCompressionParameters
from nncf.quantization.algorithms.weight_compression.awq_patterns import get_awq_patterns
from nncf.quantization.algorithms.weight_compression.backend import AWQAlgoBackend
from nncf.quantization.algorithms.weight_compression.backend import WeightCompressionAlgoBackend
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters
from nncf.quantization.algorithms.weight_compression.lora_correction import LoraCorrectionAlgorithm
Expand Down Expand Up @@ -181,7 +186,7 @@ def get_weight_shape(node_with_weight: NNCFNode, weight_port_id: int, graph: NNC
def set_weight(
self, node_with_weight: NNCFNode, weight_port_id: int, model: onnx.ModelProto, graph: NNCFGraph, weight: Tensor
):
node = self.name_to_node_map[node_with_weight.target_node_name]
node = self.name_to_node_map[node_with_weight.node_name]
initializer_name = node.input[weight_port_id]
set_initializer(initializer_name, model, weight.data)

Expand Down Expand Up @@ -464,3 +469,19 @@ def _replace_matmul_with_matmulnbits(
@staticmethod
def get_ignored_patterns() -> GraphPattern:
return create_rope()


class ONNXAWQAlgoAlgoBackend(AWQAlgoBackend, ONNXWeightCompressionAlgoBackend):
@staticmethod
def get_awq_patterns() -> dict[str, Callable]:
return get_awq_patterns(
onnx_metatypes.ONNXMatMulMetatype, onnx_metatypes.ONNXMulLayerMetatype, ATOMIC_ACTIVATIONS_OPERATIONS
)

@staticmethod
def scale_insertion_command(
source_node: NNCFNode, next_nodes: list[NNCFNode], source_node_output_port: int, scale: np.ndarray
):
return ONNXCommandCreator.multiply_insertion_command(
source_node, next_nodes, source_node_output_port, scale, f"{source_node.node_name}/awq_mul"
)
1 change: 0 additions & 1 deletion src/nncf/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,6 @@ def compress_weights(
raise nncf.ParameterNotSupportedError(msg)

options = {
"awq": awq,
"scale_estimation": scale_estimation,
"gptq": gptq,
"lora_correction": lora_correction,
Expand Down
4 changes: 4 additions & 0 deletions tests/post_training/data/wc_reference_data.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,7 @@ tinyllama_data_free_awq_backend_TORCH:
metric_value: 0.85466
num_int4: 94
num_int8: 124
tinyllama_data_free_awq_backend_ONNX:
metric_value: 0.82562
num_int4: 264
num_int8: 84
Comment on lines 118 to +124
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd also double check why similarity and number of int4 and int8 are different for the same compression configuration

3 changes: 2 additions & 1 deletion tests/post_training/data/wc_test_durations.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@
"tests/post_training/test_quantize_conformance.py::test_weight_compression[tinyllama_data_free_awq_backend_OV]": 164,
"tests/post_training/test_quantize_conformance.py::test_weight_compression[tinyllama_data_free_awq_backend_TORCH]": 210,
"tests/post_training/test_quantize_conformance.py::test_weight_compression[tinyllama_data_free_backend_ONNX]": 182,
"tests/post_training/test_quantize_conformance.py::test_weight_compression[tinyllama_data_free_opset19_backend_ONNX]": 512
"tests/post_training/test_quantize_conformance.py::test_weight_compression[tinyllama_data_free_opset19_backend_ONNX]": 512,
"tests/post_training/test_quantize_conformance.py::test_weight_compression[tinyllama_data_free_awq_backend_ONNX]": 154
}
2 changes: 1 addition & 1 deletion tests/post_training/model_scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@
),
},
# TODO: (andreyanufr) add torch.fx backend
"backends": [BackendType.OV, BackendType.TORCH],
"backends": [BackendType.OV, BackendType.TORCH, BackendType.ONNX],
},
]

Expand Down