-
Notifications
You must be signed in to change notification settings - Fork 262
[TorchFX] Use torchao for quantize_pt2e API when possible #3588
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
[TorchFX] Use torchao for quantize_pt2e API when possible #3588
Conversation
8695761 to
3432700
Compare
| return PassResult(graph_module, True) | ||
|
|
||
|
|
||
| def get_device(module: torch.nn.Module) -> torch.device: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please reuse
Line 416 in cc935e4
| def get_model_device(model: torch.nn.Module) -> torch.device: |
| :param quant_min: Minimum quant value. | ||
| :type quant_min: int | ||
| :param quant_max: Maximum quant value. | ||
| :type quant_max: int | ||
| :param scale: Defines the scale factor used for quantization. | ||
| :type scale: torch.Tensor | ||
| :param zero_point: Specifies the quantized value to which 0 in floating point maps to. | ||
| :type zero_point: torch.Tensor | ||
| :param is_per_channel: Whether quantization is applied per channel. | ||
| :type is_per_channel: bool | ||
| :param ch_axis: Channel axis used for per-channel quantization. | ||
| :type ch_axis: int |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| :param quant_min: Minimum quant value. | |
| :type quant_min: int | |
| :param quant_max: Maximum quant value. | |
| :type quant_max: int | |
| :param scale: Defines the scale factor used for quantization. | |
| :type scale: torch.Tensor | |
| :param zero_point: Specifies the quantized value to which 0 in floating point maps to. | |
| :type zero_point: torch.Tensor | |
| :param is_per_channel: Whether quantization is applied per channel. | |
| :type is_per_channel: bool | |
| :param ch_axis: Channel axis used for per-channel quantization. | |
| :type ch_axis: int | |
| :param quant_min: Minimum quant value. | |
| :param quant_max: Maximum quant value. | |
| :param scale: Defines the scale factor used for quantization. | |
| :param zero_point: Specifies the quantized value to which 0 in floating point maps to. | |
| :param is_per_channel: Whether quantization is applied per channel. | |
| :param ch_axis: Channel axis used for per-channel quantization. |
:type in docstring used only for API objects
| return named_param.device | ||
|
|
||
|
|
||
| def create_getattr_from_value(module: torch.nn.Module, graph: torch.fx.Graph, prefix: str, value: Any) -> torch.fx.Node: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not found where value is not a torch.Tensor, is it really need to use Any?
| """ | ||
|
|
||
| def get_new_attr_name(module: torch.nn.Module, prefix: str): | ||
| def get_attr_name(i: int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Context
torch.aodirectory is being moved to a separate repo,torchao, and the legacytorch.aoimplementation was deprecated in the latest release of PyTorch (see details here)The solution in our side is to
OpenVINOQuantizerin nncf leaving only the ExecuTorch implementationnncf.quantuzeTorchFXbackendtorchaodependency for thequantize_pt2eAPI or remove all dependencies on torch.ao from thequantize_pt2e,torch_ao_adapteras wellThis PR does not achieve the goal, but makes necessary first steps to achieve the goal
Changes
OpenVINOQuantizer,TorchAOQuantizerAdapterandquantize_pt2eare usingtorchaoclasses whenever it possible (using the conditional import)torch_fx_MinMaxBackendand TorchFX transformations don't use the torch.aoFakeQuantizeclass anymore. Instead, a structureTorchQDQParametersis introduced insrc/nncf/experimental/torch/fx/quantization/qdq_parameters.pytransformations.pydependency ontorch.aois resolved (by moving_fuse_conv_bn_import to other files and movingcreate_getattr_from_valuefunction to the nncftransformations.pyfile)Reason for changes
OpenVINOQuantizerfrom ExecuTorch inquantize_pt2etorch.aofrom thetransformations.pyRelated tickets
170072
Tests
test_openvino_quantizer_with_torch_ao_convert_pt2eis enable only for the torchao implementation