Skip to content

[FX][AWQ][Scale Estimation][Mixed Precision] Add Data Aware Algorithm Support for FX Backend #3409

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: develop
Choose a base branch
from
Open
1 change: 1 addition & 0 deletions .ci/cspell_dict.txt
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ frobenius
fsolve
funcs
fval
fxawq
fxsq
gacts
gelsy
Expand Down
3 changes: 3 additions & 0 deletions nncf/quantization/algorithms/weight_compression/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,10 @@ def _set_backend_entity(
from nncf.quantization.algorithms.weight_compression.torch_backend import PTAWQAlgoAlgoBackend

self._backend_entity = PTAWQAlgoAlgoBackend()
elif model_backend == BackendType.TORCH_FX:
from nncf.quantization.algorithms.weight_compression.torch_fx_backend import FXAWQAlgoAlgoBackend

self._backend_entity = FXAWQAlgoAlgoBackend()
else:
msg = f"Cannot return backend-specific AWQ entity because {model_backend.value} is not supported!"
raise nncf.UnsupportedBackendError(msg)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,10 @@ def _set_backend_entity(self, model: TModel) -> None:
from nncf.quantization.algorithms.weight_compression.torch_backend import PTMixedPrecisionAlgoBackend

self._backend_entity = PTMixedPrecisionAlgoBackend()
elif model_backend == BackendType.TORCH_FX:
from nncf.quantization.algorithms.weight_compression.torch_fx_backend import FXMixedPrecisionAlgoBackend

self._backend_entity = FXMixedPrecisionAlgoBackend()
else:
msg = f"Cannot return backend-specific entity because {model_backend.value} is not supported!"
raise nncf.UnsupportedBackendError(msg)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ def _set_backend_entity(self, model: TModel) -> None:
from nncf.quantization.algorithms.weight_compression.torch_backend import PTWeightCompressionAlgoBackend

self._backend_entity = PTWeightCompressionAlgoBackend()
elif model_backend == BackendType.TORCH_FX:
from nncf.quantization.algorithms.weight_compression.torch_fx_backend import FXWeightCompressionAlgoBackend

self._backend_entity = FXWeightCompressionAlgoBackend()
else:
msg = (
"Cannot return backend-specific Scale Estimation entity because"
Expand Down
104 changes: 89 additions & 15 deletions nncf/quantization/algorithms/weight_compression/torch_fx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,14 @@
from nncf.parameters import CompressionFormat
from nncf.parameters import CompressWeightsMode
from nncf.quantization.advanced_parameters import AdvancedCompressionParameters
from nncf.quantization.algorithms.weight_compression.backend import AWQAlgoBackend
from nncf.quantization.algorithms.weight_compression.backend import MixedPrecisionAlgoBackend
from nncf.quantization.algorithms.weight_compression.backend import WeightCompressionAlgoBackend
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters
from nncf.quantization.algorithms.weight_compression.handle_errors import handle_invalid_group_size_error
from nncf.quantization.algorithms.weight_compression.lora_correction import LoraCorrectionAlgorithm
from nncf.quantization.algorithms.weight_compression.torch_backend import PTAWQAlgoAlgoBackend
from nncf.quantization.algorithms.weight_compression.torch_backend import PTMixedPrecisionAlgoBackend
from nncf.quantization.algorithms.weight_compression.torch_backend import PTWeightCompressionAlgoBackend
from nncf.quantization.algorithms.weight_compression.weight_lowering import compress_weight
from nncf.tensor import Tensor
Expand Down Expand Up @@ -133,12 +138,11 @@ def get_activation_port_id(node: NNCFNode, graph: NNCFGraph) -> int:
def get_weight(
self, node_with_weight: NNCFNode, weight_port_id: int, model: torch.fx.GraphModule, graph: NNCFGraph
) -> Tensor:
weight_edge = graph.get_input_edge_by_port_id(node_with_weight, weight_port_id)
weight_node = weight_edge.from_node
graph_weight_node = get_graph_node_by_name(model.graph, weight_node.node_name)
graph_node_with_weight = get_graph_node_by_name(model.graph, node_with_weight.node_name)
graph_weight_node = graph_node_with_weight.all_input_nodes[weight_port_id]
weight = get_tensor_constant_from_node(graph_weight_node, model).data
if weight is None:
msg = f"Could not find a node in the model by name {weight_node}."
msg = f"Could not find a node in the model by name {graph_weight_node}."
raise nncf.InternalError(msg)

return Tensor(weight)
Expand Down Expand Up @@ -192,7 +196,8 @@ def transform_model(
advanced_parameters: AdvancedCompressionParameters = AdvancedCompressionParameters(),
) -> torch.fx.GraphModule:
transformation_layout = TransformationLayout()

invalid_node_names = []
first_caught_error = None
for wc_params in weight_compression_parameters:
compression_config = wc_params.compression_config
if compression_config.mode in [
Expand All @@ -207,15 +212,19 @@ def transform_model(
if weight is None or not isinstance(weight, Tensor):
msg = f"Could not find a nncf.tensor in the model by name {weight_name}."
raise nncf.InternalError(msg)

# calculates compressed weights and decompression parameters
compressed_weight = compress_weight(
weight,
wc_params.reduction_axes,
compression_config,
None if precomputed_scales is None else precomputed_scales.get(wc_params.weight_name),
None if precomputed_zero_points is None else precomputed_zero_points.get(wc_params.weight_name),
)
try:
# calculates compressed weights and decompression parameters
compressed_weight = compress_weight(
weight,
wc_params.reduction_axes,
compression_config,
None if precomputed_scales is None else precomputed_scales.get(wc_params.weight_name),
None if precomputed_zero_points is None else precomputed_zero_points.get(wc_params.weight_name),
)
except nncf.InvalidGroupSizeError as error:
first_caught_error = error
invalid_node_names.append(wc_params.node_with_weight.node_name)
continue

# creates weight decompressor
if compression_config.mode == CompressWeightsMode.INT8_SYM:
Expand Down Expand Up @@ -265,8 +274,73 @@ def transform_model(
)
)
)

if first_caught_error:
handle_invalid_group_size_error(first_caught_error, invalid_node_names)
# apply transformations
transformed_model = FXModelTransformer(model).transform(transformation_layout)

return transformed_model


class FXMixedPrecisionAlgoBackend(MixedPrecisionAlgoBackend, FXWeightCompressionAlgoBackend):
@staticmethod
Comment on lines +285 to +286
Copy link
Collaborator

@daniil-lyakhov daniil-lyakhov Apr 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have a plan on how to reuse the same classes across all the PyTorch backends? If not, could you please create a ticket for that?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can set a flag in a common backend class (such as FX=False) when initializing it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anzr299, It would be great if you would open a ticket and implement a make a proposal which we could discuss with the team

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, I will do that

def mean_variance_statistic_collector(
reduction_axes: Tuple[int], subset_size: Optional[int] = None
) -> TensorCollector:
return PTMixedPrecisionAlgoBackend.mean_variance_statistic_collector(
reduction_axes=reduction_axes, subset_size=subset_size
)

@staticmethod
def max_variance_statistic_collector(
reduction_axes: Tuple[int], subset_size: Optional[int] = None
) -> TensorCollector:
return PTMixedPrecisionAlgoBackend.max_variance_statistic_collector(
reduction_axes=reduction_axes, subset_size=subset_size
)

@staticmethod
def mean_abs_max_statistic_collector(
reduction_axes: Tuple[int], subset_size: Optional[int] = None
) -> TensorCollector:
return PTMixedPrecisionAlgoBackend.mean_abs_max_statistic_collector(
reduction_axes=reduction_axes, subset_size=subset_size
)


class FXAWQMultiply(torch.nn.Module):
def __init__(self, scale: torch.Tensor):
super().__init__()
self.register_buffer("_scale_value", scale)
self._scale_value: torch.Tensor

def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.mul(x, self._scale_value)


class FXAWQAlgoAlgoBackend(AWQAlgoBackend, FXWeightCompressionAlgoBackend):
@staticmethod
def get_awq_patterns():
return PTAWQAlgoAlgoBackend.get_awq_patterns()

@staticmethod
def scale_insertion_command(source_node, next_nodes, source_node_output_port, scale):
input_port_id = 0
target_points = []
for node in next_nodes:
target_points.append(
PTTargetPoint(
TargetType.OPERATOR_PRE_HOOK,
node.node_name,
input_port_id=input_port_id,
)
)
awq_multiply = FXAWQMultiply(scale)
awq_node_name = f"{source_node.node_name}/awq_mul"
return FXApplyTransformationCommand(
module_insertion_transformation_builder(
awq_multiply,
target_points,
awq_node_name,
)
)
12 changes: 0 additions & 12 deletions nncf/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,8 +568,6 @@ def compress_weights(
raise nncf.ParameterNotSupportedError(msg)

options = {
"awq": awq,
"scale_estimation": scale_estimation,
"gptq": gptq,
"lora_correction": lora_correction,
}
Expand All @@ -578,16 +576,6 @@ def compress_weights(
msg = f"TorchFX backend does not support {', '.join(unsupported_options)} option(s). Set them to None."
raise nncf.ParameterNotSupportedError(msg)

if sensitivity_metric not in [None, SensitivityMetric.WEIGHT_QUANTIZATION_ERROR]:
msg = (
"TorchFX backend only supports data-free sensitivity metric. "
"Set None or SensitivityMetric.WEIGHT_QUANTIZATION_ERROR."
)
raise nncf.ParameterNotSupportedError(msg)

if dataset:
msg = "TorchFX only supports data-free weights compression. Set the 'dataset' option to None"
raise nncf.ParameterNotSupportedError(msg)
if advanced_parameters and advanced_parameters.statistics_path:
msg = "TorchFX does not supports statistics caching."
raise nncf.ParameterNotSupportedError(msg)
Expand Down
Loading
Loading