Skip to content

Commit 7363da7

Browse files
skywfacebook-github-bot
authored andcommitted
onnx export of per channel fake quantize functions (pytorch#42835)
Summary: Fixes pytorch#39502 This PR adds support for exporting **fake_quantize_per_channel_affine** to a pair of QuantizeLinear and DequantizeLinear. Per tensor support was added by PR pytorch#39738. `axis` attribute of QuantizeLinear and DequantizeLinear, which is required for per channel support, is added in opset13 added by onnx/onnx#2772. [update 1/20/2021]: opset13 is being supported on master, the added function is now properly tested. Code also rebased to new master. The function is also tested offline with the following code ```python import torch from torch import quantization from torchvision import models qat_resnet18 = models.resnet18(pretrained=True).eval().cuda() qat_resnet18.qconfig = quantization.QConfig( activation=quantization.default_fake_quant, weight=quantization.default_per_channel_weight_fake_quant) quantization.prepare_qat(qat_resnet18, inplace=True) qat_resnet18.apply(quantization.enable_observer) qat_resnet18.apply(quantization.enable_fake_quant) dummy_input = torch.randn(16, 3, 224, 224).cuda() _ = qat_resnet18(dummy_input) for module in qat_resnet18.modules(): if isinstance(module, quantization.FakeQuantize): module.calculate_qparams() qat_resnet18.apply(quantization.disable_observer) qat_resnet18.cuda() input_names = [ "actual_input_1" ] output_names = [ "output1" ] torch.onnx.export(qat_resnet18, dummy_input, "quant_model.onnx", verbose=True, opset_version=13) ``` It can generate the desired graph. Pull Request resolved: pytorch#42835 Reviewed By: houseroad Differential Revision: D26293823 Pulled By: SplitInfinity fbshipit-source-id: 300498a2e24b7731b12fa2fbdea4e73dde80e7ea
1 parent 159c48b commit 7363da7

File tree

3 files changed

+51
-1
lines changed

3 files changed

+51
-1
lines changed

Diff for: test/onnx/test_models.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def test_fake_quant(self):
182182
self.exportTest(toC(FakeQuantNet()), toC(x))
183183

184184
@skipIfUnsupportedMinOpsetVersion(10)
185-
def test_qat_resnet(self):
185+
def test_qat_resnet_pertensor(self):
186186
# Quantize ResNet50 model
187187
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
188188
qat_resnet50 = resnet50()
@@ -202,6 +202,27 @@ def test_qat_resnet(self):
202202

203203
self.exportTest(toC(qat_resnet50), toC(x))
204204

205+
@skipIfUnsupportedMinOpsetVersion(13)
206+
def test_qat_resnet_per_channel(self):
207+
# Quantize ResNet50 model
208+
x = torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)
209+
qat_resnet50 = resnet50()
210+
211+
qat_resnet50.qconfig = quantization.QConfig(
212+
activation=quantization.default_fake_quant,
213+
weight=quantization.default_per_channel_weight_fake_quant)
214+
quantization.prepare_qat(qat_resnet50, inplace=True)
215+
qat_resnet50.apply(torch.quantization.enable_observer)
216+
qat_resnet50.apply(torch.quantization.enable_fake_quant)
217+
218+
_ = qat_resnet50(x)
219+
for module in qat_resnet50.modules():
220+
if isinstance(module, quantization.FakeQuantize):
221+
module.calculate_qparams()
222+
qat_resnet50.apply(torch.quantization.disable_observer)
223+
224+
self.exportTest(toC(qat_resnet50), toC(x))
225+
205226
@disableScriptTest() # None type in outputs
206227
def test_googlenet(self):
207228
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))

Diff for: test/onnx/test_pytorch_onnx_onnxruntime.py

+14
Original file line numberDiff line numberDiff line change
@@ -5998,6 +5998,20 @@ def forward(self, input):
59985998
x = torch.randn(6, 4, 3, 3)
59995999
self.run_test(FakeQuantizePerTensorModel(), (x))
60006000

6001+
@skipIfUnsupportedMinOpsetVersion(13)
6002+
def test_fake_quantize_per_channel(self):
6003+
class FakeQuantizePerChannelModel(torch.nn.Module):
6004+
def forward(self, input):
6005+
amax = torch.ones(4)
6006+
scale = amax / 127.
6007+
zero_point = torch.zeros_like(amax, dtype=torch.long)
6008+
# Quantize twice to test differnet branches
6009+
y = torch.fake_quantize_per_channel_affine(input, scale, zero_point, 1, 0, 255)
6010+
return torch.fake_quantize_per_channel_affine(y, scale, zero_point, 1, -128, 127)
6011+
6012+
x = torch.randn(6, 4, 3, 3)
6013+
self.run_test(FakeQuantizePerChannelModel(), (x))
6014+
60016015
def test_batchnorm_training(self):
60026016
class MyModule(torch.nn.Module):
60036017
def __init__(self):

Diff for: torch/onnx/symbolic_opset13.py

+15
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,21 @@ def where(g, condition, self=None, other=None, _outputs=None):
121121
return sym_help._unbind_helper(g, condition, g.op("Constant", value_t=torch.tensor(1)), _outputs)
122122
return g.op("Where", condition, self, other)
123123

124+
@parse_args('v', 'v', 'v', 'i', 'i', 'i')
125+
def fake_quantize_per_channel_affine(g, inputs, scale, zero_point, axis, quant_min=-128, quant_max=127):
126+
if quant_min not in [0, -128] or quant_max not in [127, 255]:
127+
raise RuntimeError(
128+
"ONNX defines [0, 255] for quint8 and [-128, 127] for qint8, got [{}, {}]".format(quant_min, quant_max))
129+
130+
# ONNX defines zero_point to be int8 or uint8
131+
if quant_min == 0:
132+
zero_point = g.op("Cast", zero_point, to_i=sym_help.cast_pytorch_to_onnx['Byte'])
133+
else:
134+
zero_point = g.op("Cast", zero_point, to_i=sym_help.cast_pytorch_to_onnx['Char'])
135+
return g.op(
136+
"DequantizeLinear",
137+
g.op("QuantizeLinear", inputs, scale, zero_point, axis_i=axis),
138+
scale, zero_point, axis_i=axis)
124139

125140
def _reduce_op_symbolic(onnx_op_name):
126141
def symbolic(g, self, dim=None, keepdim=None):

0 commit comments

Comments
 (0)