Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Add per channel quantization support for Onnx.QLinearConv op #3917

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 69 additions & 51 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,24 +318,15 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
binder.tensorOperands(operands, 9)) ||
binder.tensorResultType(resultType))
return failure();
Value a = operands[0];
Value aScale = operands[1];
Value aZp = operands[2];
Value b = operands[3];
Value bScale = operands[4];
Value bZp = operands[5];
Value cScale = operands[6];
Value cZp = operands[7];
Value c = operands.size() == 9 ? operands[8] : nullptr;

auto check = [](Value v) {
auto vTy = cast<Torch::ValueTensorType>(v.getType());
return llvm::all_of(vTy.getSizes(), [](int64_t d) { return d == 1; });
};
if (!check(aScale) || !check(aZp) || !check(bScale) || !check(bZp) ||
!check(cScale) || !check(cScale))
return rewriter.notifyMatchFailure(
binder.op, "not supported for non per-tensor quantization");
Value input = operands[0];
Value inputScale = operands[1];
Value inputZp = operands[2];
Value weight = operands[3];
Value weightScale = operands[4];
Value weightZp = operands[5];
Value outputScale = operands[6];
Value outputZp = operands[7];
Value output = operands.size() == 9 ? operands[8] : nullptr;
zjgarvey marked this conversation as resolved.
Show resolved Hide resolved

auto extract = [&rewriter, &binder](Value v) {
auto vTy = cast<Torch::ValueTensorType>(v.getType());
Expand All @@ -347,34 +338,61 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
v);
};

aZp = extract(aZp);
bZp = extract(bZp);
cZp = extract(cZp);
aScale = extract(aScale);
bScale = extract(bScale);
cScale = extract(cScale);

auto make = [&rewriter, &binder](Value v, Value scale,
Value zp) -> Value {
inputZp = extract(inputZp);
outputZp = extract(outputZp);
inputScale = extract(inputScale);
outputScale = extract(outputScale);
auto makePerTensor = [&rewriter, &binder](Value v, Value scale,
Value zp) -> Value {
auto ty = cast<Torch::ValueTensorType>(v.getType());
auto newTy = getQTorchTypeFromTorchIntType(ty);
return rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
binder.getLoc(), newTy, v, scale, zp);
};

a = make(a, aScale, aZp);
b = make(b, bScale, bZp);
auto makePerChannel = [&rewriter, &binder](Value v, Value scale,
Value zp,
Value axis) -> Value {
auto ty = cast<Torch::ValueTensorType>(v.getType());
auto newTy = getQTorchTypeFromTorchIntType(ty);
return rewriter.create<Torch::Aten_MakePerChannelQuantizedTensorOp>(
binder.getLoc(), newTy, v, scale, zp, axis);
};

auto cTy = rewriter.getType<Torch::ValueTensorType>(
input = makePerTensor(input, inputScale, inputZp);
// The onnx's QLinearConv op expects per channel quantization only for
// the weight tensor for axis = 0.
auto weightTy = dyn_cast<Torch::ValueTensorType>(weight.getType());
auto weightScaleTy =
dyn_cast<Torch::ValueTensorType>(weightScale.getType());
if (!weightTy || !weightScaleTy || !weightTy.hasSizes() ||
!weightScaleTy.hasSizes())
return failure();
auto weightShape = weightTy.getSizes();
auto weightScaleShape = weightScaleTy.getSizes();
Value weightScaleScalar = extract(weightScale);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extract won't work if the weight scale isn't a single element. I'd put this in the else block below.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see you use this below to handle the quantization of the output, but this must also be per-channel if the weight is per-channel.

if (weightScaleShape.size() == 1 &&
weightScaleShape[0] != Torch::kUnknownSize &&
weightScaleShape[0] == weightShape[0]) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additionally check that weightShape[0] != 1 since we don't want to lower to per-channel when there is only one channel.

Value axis = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(0));
weight = makePerChannel(weight, weightScale, weightZp, axis);
} else {
weightZp = extract(weightZp);
weight = makePerTensor(weight, weightScaleScalar, weightZp);
}
Comment on lines +380 to +383
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit of a nit, but I'd prefer an else if here with the conditions for makePerTensor, and then an else branch with an unreachable, just to be very clear about what assumptions are being made in each case.

weightScale = weightScaleScalar;

auto outputTy = rewriter.getType<Torch::ValueTensorType>(
resultType.getOptionalSizes(),
rewriter.getIntegerType(32, /*issigned=*/true));

// TODO(suderman): insert convolution operator.
llvm::SmallVector<Value> newOperands = {a, b};
if (c)
newOperands.push_back(c);
llvm::SmallVector<Value> newOperands = {input, weight};
if (output)
newOperands.push_back(output);

cTy = rewriter.getType<Torch::ValueTensorType>(
outputTy = rewriter.getType<Torch::ValueTensorType>(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, this is a bit subtle. The last optional input for this op is the int32 bias, assumed to be quantized via the product of input and weight scales. This implies that the quantization of the bias (and also the output of the convolution) is also per-channel if the weight was per-channel quantized. This part is fine, but we will need to case out the logic below.

resultType.getOptionalSizes(),
rewriter.getType<Torch::QInt32Type>());

Expand All @@ -388,36 +406,36 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
newAttributes.push_back(namedAttr);
}

c = rewriter
.create<Torch::OperatorOp>(binder.getLoc(), cTy, newOperands,
newAttributes,
binder.op->getRegions().size())
.getResult(0);
output = rewriter
.create<Torch::OperatorOp>(binder.getLoc(), outputTy,
newOperands, newAttributes,
binder.op->getRegions().size())
.getResult(0);

Value outScale = rewriter.create<Torch::AtenMulFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(), aScale,
bScale);
binder.getLoc(), rewriter.getType<Torch::FloatType>(), inputScale,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will need to possibly be float x tensor mul.

weightScale);
Value outZp = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
c = rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
binder.getLoc(), cTy, c, outScale, outZp);
cTy = rewriter.getType<Torch::ValueTensorType>(
output = rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
binder.getLoc(), outputTy, output, outScale, outZp);
outputTy = rewriter.getType<Torch::ValueTensorType>(
resultType.getOptionalSizes(), rewriter.getF32Type());

c = rewriter.create<Torch::AtenDequantizeSelfOp>(binder.getLoc(), cTy,
c);
cTy = getQTorchTypeFromTorchIntType(resultType);
output = rewriter.create<Torch::AtenDequantizeSelfOp>(binder.getLoc(),
outputTy, output);
outputTy = getQTorchTypeFromTorchIntType(resultType);
Value dtyVal = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(
rewriter.getIntegerType(64),
static_cast<int64_t>(
Torch::getScalarTypeForType(cTy.getDtype()))));
c = rewriter.create<Torch::AtenQuantizePerTensorOp>(
binder.getLoc(), cTy, c, cScale, cZp, dtyVal);
Torch::getScalarTypeForType(outputTy.getDtype()))));
output = rewriter.create<Torch::AtenQuantizePerTensorOp>(
binder.getLoc(), outputTy, output, outputScale, outputZp, dtyVal);
rewriter.replaceOpWithNewOp<Torch::AtenIntReprOp>(binder.op, resultType,
c);
output);
return success();
});
patterns.onOp(
Expand Down
18 changes: 9 additions & 9 deletions test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,15 @@ func.func @test_quantizelinear_f8(%arg0: !torch.vtensor<[6],f32>, %arg1: !torch.
// -----

// CHECK-LABEL: @test_qlinearconv_nobias
func.func @test_qlinearconv_nobias(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>, %arg3: !torch.vtensor<[1,1,1,1],ui8>, %arg4: !torch.vtensor<[1],f32>, %arg5: !torch.vtensor<[1],ui8>, %arg6: !torch.vtensor<[],f32>, %arg7: !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,1,7,7],ui8> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%0 = torch.operator "onnx.QLinearConv"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (!torch.vtensor<[1,1,7,7],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[1,1,1,1],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,1,7,7],ui8>
func.func @test_qlinearconv_nobias(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>, %arg3: !torch.vtensor<[1,1,1,1],ui8>, %arg4: !torch.vtensor<[],f32>, %arg5: !torch.vtensor<[],ui8>, %arg6: !torch.vtensor<[],f32>, %arg7: !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,1,7,7],ui8> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%0 = torch.operator "onnx.QLinearConv"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (!torch.vtensor<[1,1,7,7],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[1,1,1,1],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,1,7,7],ui8>
// CHECK: %[[aZp:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],ui8> -> !torch.int
// CHECK: %[[bZp:.+]] = torch.aten.item %arg5 : !torch.vtensor<[1],ui8> -> !torch.int
// CHECK: %[[cZp:.+]] = torch.aten.item %arg7 : !torch.vtensor<[],ui8> -> !torch.int
// CHECK: %[[aScale:.+]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float
// CHECK: %[[bScale:.+]] = torch.aten.item %arg4 : !torch.vtensor<[1],f32> -> !torch.float
// CHECK: %[[cScale:.+]] = torch.aten.item %arg6 : !torch.vtensor<[],f32> -> !torch.float
// CHECK: %[[A:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[aScale]], %[[aZp]] : !torch.vtensor<[1,1,7,7],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.quint8>
// CHECK: %[[bScale:.+]] = torch.aten.item %arg4 : !torch.vtensor<[],f32> -> !torch.float
// CHECK: %[[bZp:.+]] = torch.aten.item %arg5 : !torch.vtensor<[],ui8> -> !torch.int
// CHECK: %[[B:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg3, %[[bScale]], %[[bZp]] : !torch.vtensor<[1,1,1,1],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,1,1],!torch.quint8>
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
// CHECK: %[[INT0_1:.+]] = torch.constant.int 0
Expand Down Expand Up @@ -103,17 +103,17 @@ func.func @test_qlinearconv_nobias(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1:

// -----

// CHECK-LABEL: @test_qlinearconv_bias
func.func @test_qlinearconv_bias(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>, %arg3: !torch.vtensor<[1,1,1,1],ui8>, %arg4: !torch.vtensor<[1],f32>, %arg5: !torch.vtensor<[1],ui8>, %arg6: !torch.vtensor<[],f32>, %arg7: !torch.vtensor<[],ui8>, %arg8 : !torch.vtensor<[7],si32>) -> !torch.vtensor<[1,1,7,7],ui8> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-LABEL: @test_qlinearconv_bias_weight_per_channel
func.func @test_qlinearconv_bias_weight_per_channel(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>, %arg3: !torch.vtensor<[1,1,1,1],ui8>, %arg4: !torch.vtensor<[1],f32>, %arg5: !torch.vtensor<[1],ui8>, %arg6: !torch.vtensor<[],f32>, %arg7: !torch.vtensor<[],ui8>, %arg8 : !torch.vtensor<[7],si32>) -> !torch.vtensor<[1,1,7,7],ui8> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%0 = torch.operator "onnx.QLinearConv"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.vtensor<[1,1,7,7],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[1,1,1,1],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[7],si32>) -> !torch.vtensor<[1,1,7,7],ui8>
// CHECK: %[[aZp:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],ui8> -> !torch.int
// CHECK: %[[bZp:.+]] = torch.aten.item %arg5 : !torch.vtensor<[1],ui8> -> !torch.int
// CHECK: %[[cZp:.+]] = torch.aten.item %arg7 : !torch.vtensor<[],ui8> -> !torch.int
// CHECK: %[[aScale:.+]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float
// CHECK: %[[bScale:.+]] = torch.aten.item %arg4 : !torch.vtensor<[1],f32> -> !torch.float
// CHECK: %[[cScale:.+]] = torch.aten.item %arg6 : !torch.vtensor<[],f32> -> !torch.float
// CHECK: %[[A:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[aScale]], %[[aZp]] : !torch.vtensor<[1,1,7,7],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.quint8>
// CHECK: %[[B:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg3, %[[bScale]], %[[bZp]] : !torch.vtensor<[1,1,1,1],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,1,1],!torch.quint8>
// CHECK: %[[bScale:.+]] = torch.aten.item %arg4 : !torch.vtensor<[1],f32> -> !torch.float
// CHECK: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[B:.+]] = torch.aten._make_per_channel_quantized_tensor %arg3, %arg4, %arg5, %[[INT0]] : !torch.vtensor<[1,1,1,1],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.int -> !torch.vtensor<[1,1,1,1],!torch.quint8>
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
// CHECK: %[[INT0_1:.+]] = torch.constant.int 0
// CHECK: %[[PAD:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT0_1]]
Expand Down
Loading