From d96b17725557b0102bb1c50bdbd70366416c7c41 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Mon, 9 Dec 2024 07:31:29 +0000 Subject: [PATCH 01/11] add per channel quantization for onnx.qlinearconv op --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 137 +++++++++++------- 1 file changed, 86 insertions(+), 51 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index d3251c589ac8..a5bfa62e688b 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -9,9 +9,11 @@ #include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" #include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h" +#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/raw_ostream.h" using namespace mlir; using namespace mlir::torch; @@ -332,24 +334,25 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.tensorOperands(operands, 9)) || binder.tensorResultType(resultType)) return failure(); - Value a = operands[0]; - Value aScale = operands[1]; - Value aZp = operands[2]; - Value b = operands[3]; - Value bScale = operands[4]; - Value bZp = operands[5]; - Value cScale = operands[6]; - Value cZp = operands[7]; - Value c = operands.size() == 9 ? operands[8] : nullptr; - - auto check = [](Value v) { - auto vTy = cast(v.getType()); - return llvm::all_of(vTy.getSizes(), [](int64_t d) { return d == 1; }); - }; - if (!check(aScale) || !check(aZp) || !check(bScale) || !check(bZp) || - !check(cScale) || !check(cScale)) - return rewriter.notifyMatchFailure( - binder.op, "not supported for non per-tensor quantization"); + Value input = operands[0]; + Value inputScale = operands[1]; + Value inputZp = operands[2]; + Value weight = operands[3]; + Value weightScale = operands[4]; + Value weightZp = operands[5]; + Value outputScale = operands[6]; + Value outputZp = operands[7]; + Value output = operands.size() == 9 ? operands[8] : nullptr; + + // auto check = [](Value v) { + // auto vTy = cast(v.getType()); + // return llvm::all_of(vTy.getSizes(), [](int64_t d) { return d == 1; + // }); + // }; + // if (!check(aScale) || !check(aZp) || !check(bScale) || !check(bZp) || + // !check(cScale) || !check(cScale)) + // return rewriter.notifyMatchFailure( + // binder.op, "not supported for non per-tensor quantization"); auto extract = [&rewriter, &binder](Value v) { auto vTy = cast(v.getType()); @@ -361,34 +364,64 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( v); }; - aZp = extract(aZp); - bZp = extract(bZp); - cZp = extract(cZp); - aScale = extract(aScale); - bScale = extract(bScale); - cScale = extract(cScale); - - auto make = [&rewriter, &binder](Value v, Value scale, - Value zp) -> Value { + inputZp = extract(inputZp); + outputZp = extract(outputZp); + inputScale = extract(inputScale); + outputScale = extract(outputScale); + auto makePerTensor = [&rewriter, &binder](Value v, Value scale, + Value zp) -> Value { auto ty = cast(v.getType()); auto newTy = getQTorchTypeFromTorchIntType(ty); return rewriter.create( binder.getLoc(), newTy, v, scale, zp); }; - a = make(a, aScale, aZp); - b = make(b, bScale, bZp); + auto makePerChannel = [&rewriter, &binder](Value v, Value scale, + Value zp, + Value axis) -> Value { + auto ty = cast(v.getType()); + auto newTy = getQTorchTypeFromTorchIntType(ty); + return rewriter.create( + binder.getLoc(), newTy, v, scale, zp, axis); + }; - auto cTy = rewriter.getType( + input = makePerTensor(input, inputScale, inputZp); + // The onnx's QLinearConv op expects per channel quantization only for + // the weight tensor for axis = 0. + llvm::outs() << "I'm here\n"; + auto weightTy = dyn_cast(weight.getType()); + auto weightScaleTy = + dyn_cast(weightScale.getType()); + if (!weightTy || !weightScaleTy || !weightTy.hasSizes() || + !weightScaleTy.hasSizes()) + return failure(); + llvm::outs() << "I'm here 1\n"; + auto weightShape = weightTy.getSizes(); + auto weightScaleShape = weightScaleTy.getSizes(); + Value weightScaleScalar = extract(weightScale); + if (weightScaleShape.size() == 1 && + weightScaleShape[0] != Torch::kUnknownSize && + weightScaleShape[0] == weightShape[0]) { + Value axis = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0)); + weight = makePerChannel(weight, weightScale, weightZp, axis); + } else { + weightZp = extract(weightZp); + weight = makePerTensor(weight, weightScaleScalar, weightZp); + } + weight = weightScaleScalar; + + auto outputTy = rewriter.getType( resultType.getOptionalSizes(), rewriter.getIntegerType(32, /*issigned=*/true)); + llvm::outs() << "I'm here 2\n"; // TODO(suderman): insert convolution operator. - llvm::SmallVector newOperands = {a, b}; - if (c) - newOperands.push_back(c); + llvm::SmallVector newOperands = {input, weight}; + if (output) + newOperands.push_back(output); - cTy = rewriter.getType( + outputTy = rewriter.getType( resultType.getOptionalSizes(), rewriter.getType()); @@ -402,36 +435,38 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( newAttributes.push_back(namedAttr); } - c = rewriter - .create(binder.getLoc(), cTy, newOperands, - newAttributes, - binder.op->getRegions().size()) - .getResult(0); + output = rewriter + .create(binder.getLoc(), outputTy, + newOperands, newAttributes, + binder.op->getRegions().size()) + .getResult(0); Value outScale = rewriter.create( - binder.getLoc(), rewriter.getType(), aScale, - bScale); + binder.getLoc(), rewriter.getType(), inputScale, + weightScale); Value outZp = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); - c = rewriter.create( - binder.getLoc(), cTy, c, outScale, outZp); - cTy = rewriter.getType( + output = rewriter.create( + binder.getLoc(), outputTy, output, outScale, outZp); + outputTy = rewriter.getType( resultType.getOptionalSizes(), rewriter.getF32Type()); - c = rewriter.create(binder.getLoc(), cTy, - c); - cTy = getQTorchTypeFromTorchIntType(resultType); + llvm::outs() << "I'm here 3\n"; + output = rewriter.create(binder.getLoc(), + outputTy, output); + outputTy = getQTorchTypeFromTorchIntType(resultType); Value dtyVal = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr( rewriter.getIntegerType(64), static_cast( - Torch::getScalarTypeForType(cTy.getDtype())))); - c = rewriter.create( - binder.getLoc(), cTy, c, cScale, cZp, dtyVal); + Torch::getScalarTypeForType(outputTy.getDtype())))); + output = rewriter.create( + binder.getLoc(), outputTy, output, outputScale, outputZp, dtyVal); rewriter.replaceOpWithNewOp(binder.op, resultType, - c); + output); + llvm::outs() << "I'm here 4\n"; return success(); }); patterns.onOp( From aa4bd60b377a449406e9a10d2fb8cde722f3d79e Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Thu, 12 Dec 2024 17:55:54 +0530 Subject: [PATCH 02/11] More changes --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 17 +---------------- lib/Conversion/TorchToLinalg/Linear.cpp | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index a5bfa62e688b..6176165255a4 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -344,16 +344,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Value outputZp = operands[7]; Value output = operands.size() == 9 ? operands[8] : nullptr; - // auto check = [](Value v) { - // auto vTy = cast(v.getType()); - // return llvm::all_of(vTy.getSizes(), [](int64_t d) { return d == 1; - // }); - // }; - // if (!check(aScale) || !check(aZp) || !check(bScale) || !check(bZp) || - // !check(cScale) || !check(cScale)) - // return rewriter.notifyMatchFailure( - // binder.op, "not supported for non per-tensor quantization"); - auto extract = [&rewriter, &binder](Value v) { auto vTy = cast(v.getType()); Type extractTy = rewriter.getType(); @@ -388,14 +378,12 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( input = makePerTensor(input, inputScale, inputZp); // The onnx's QLinearConv op expects per channel quantization only for // the weight tensor for axis = 0. - llvm::outs() << "I'm here\n"; auto weightTy = dyn_cast(weight.getType()); auto weightScaleTy = dyn_cast(weightScale.getType()); if (!weightTy || !weightScaleTy || !weightTy.hasSizes() || !weightScaleTy.hasSizes()) return failure(); - llvm::outs() << "I'm here 1\n"; auto weightShape = weightTy.getSizes(); auto weightScaleShape = weightScaleTy.getSizes(); Value weightScaleScalar = extract(weightScale); @@ -409,13 +397,12 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( weightZp = extract(weightZp); weight = makePerTensor(weight, weightScaleScalar, weightZp); } - weight = weightScaleScalar; + weightScale = weightScaleScalar; auto outputTy = rewriter.getType( resultType.getOptionalSizes(), rewriter.getIntegerType(32, /*issigned=*/true)); - llvm::outs() << "I'm here 2\n"; // TODO(suderman): insert convolution operator. llvm::SmallVector newOperands = {input, weight}; if (output) @@ -452,7 +439,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( outputTy = rewriter.getType( resultType.getOptionalSizes(), rewriter.getF32Type()); - llvm::outs() << "I'm here 3\n"; output = rewriter.create(binder.getLoc(), outputTy, output); outputTy = getQTorchTypeFromTorchIntType(resultType); @@ -466,7 +452,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.getLoc(), outputTy, output, outputScale, outputZp, dtyVal); rewriter.replaceOpWithNewOp(binder.op, resultType, output); - llvm::outs() << "I'm here 4\n"; return success(); }); patterns.onOp( diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 9945c52a1684..678a3f8ea20d 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -785,6 +785,21 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { weight = make.getSelf(); weightZp = make.getZeroPoint(); + weight = typeConverter->materializeTargetConversion( + rewriter, loc, typeConverter->convertType(weight.getType()), weight); + weightZp = typeConverter->materializeTargetConversion( + rewriter, loc, typeConverter->convertType(weightZp.getType()), + weightZp); + weightZp = rewriter.create(loc, rewriter.getI32Type(), + weightZp); + auto torchDtype = cast(make.getType()).getDtype(); + weightUnsigned = torch_to_linalg::isUnsignedTorchType(torchDtype); + } else if (auto make = + op.getWeight() + .getDefiningOp()) { + weight = make.getSelf(); + weightZp = make.getZeroPoint(); + weight = typeConverter->materializeTargetConversion( rewriter, loc, typeConverter->convertType(weight.getType()), weight); weightZp = typeConverter->materializeTargetConversion( From b595e84e30d6640136b181a5ff1ececf4ed221ab Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Fri, 13 Dec 2024 12:01:54 +0530 Subject: [PATCH 03/11] Remove some code --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 2 -- lib/Conversion/TorchToLinalg/Linear.cpp | 15 --------------- 2 files changed, 17 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 6176165255a4..76934994fb29 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -9,11 +9,9 @@ #include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" #include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h" -#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/Support/raw_ostream.h" using namespace mlir; using namespace mlir::torch; diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 678a3f8ea20d..9945c52a1684 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -785,21 +785,6 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { weight = make.getSelf(); weightZp = make.getZeroPoint(); - weight = typeConverter->materializeTargetConversion( - rewriter, loc, typeConverter->convertType(weight.getType()), weight); - weightZp = typeConverter->materializeTargetConversion( - rewriter, loc, typeConverter->convertType(weightZp.getType()), - weightZp); - weightZp = rewriter.create(loc, rewriter.getI32Type(), - weightZp); - auto torchDtype = cast(make.getType()).getDtype(); - weightUnsigned = torch_to_linalg::isUnsignedTorchType(torchDtype); - } else if (auto make = - op.getWeight() - .getDefiningOp()) { - weight = make.getSelf(); - weightZp = make.getZeroPoint(); - weight = typeConverter->materializeTargetConversion( rewriter, loc, typeConverter->convertType(weight.getType()), weight); weightZp = typeConverter->materializeTargetConversion( From b1319baa0938f58cd9095dcdbd5bba1ee59e23b1 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Fri, 13 Dec 2024 13:03:33 +0530 Subject: [PATCH 04/11] Update lit test --- .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 2caddff9bc3b..53147a6bfab9 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -65,15 +65,15 @@ func.func @test_quantizelinear_f8(%arg0: !torch.vtensor<[6],f32>, %arg1: !torch. // ----- // CHECK-LABEL: @test_qlinearconv_nobias -func.func @test_qlinearconv_nobias(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>, %arg3: !torch.vtensor<[1,1,1,1],ui8>, %arg4: !torch.vtensor<[1],f32>, %arg5: !torch.vtensor<[1],ui8>, %arg6: !torch.vtensor<[],f32>, %arg7: !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,1,7,7],ui8> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - %0 = torch.operator "onnx.QLinearConv"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (!torch.vtensor<[1,1,7,7],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[1,1,1,1],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,1,7,7],ui8> +func.func @test_qlinearconv_nobias(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>, %arg3: !torch.vtensor<[1,1,1,1],ui8>, %arg4: !torch.vtensor<[],f32>, %arg5: !torch.vtensor<[],ui8>, %arg6: !torch.vtensor<[],f32>, %arg7: !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,1,7,7],ui8> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %0 = torch.operator "onnx.QLinearConv"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (!torch.vtensor<[1,1,7,7],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[1,1,1,1],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,1,7,7],ui8> // CHECK: %[[aZp:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],ui8> -> !torch.int - // CHECK: %[[bZp:.+]] = torch.aten.item %arg5 : !torch.vtensor<[1],ui8> -> !torch.int // CHECK: %[[cZp:.+]] = torch.aten.item %arg7 : !torch.vtensor<[],ui8> -> !torch.int // CHECK: %[[aScale:.+]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float - // CHECK: %[[bScale:.+]] = torch.aten.item %arg4 : !torch.vtensor<[1],f32> -> !torch.float // CHECK: %[[cScale:.+]] = torch.aten.item %arg6 : !torch.vtensor<[],f32> -> !torch.float // CHECK: %[[A:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[aScale]], %[[aZp]] : !torch.vtensor<[1,1,7,7],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.quint8> + // CHECK: %[[bScale:.+]] = torch.aten.item %arg4 : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[bZp:.+]] = torch.aten.item %arg5 : !torch.vtensor<[],ui8> -> !torch.int // CHECK: %[[B:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg3, %[[bScale]], %[[bZp]] : !torch.vtensor<[1,1,1,1],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,1,1],!torch.quint8> // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 // CHECK: %[[INT0_1:.+]] = torch.constant.int 0 @@ -103,17 +103,17 @@ func.func @test_qlinearconv_nobias(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1: // ----- -// CHECK-LABEL: @test_qlinearconv_bias -func.func @test_qlinearconv_bias(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>, %arg3: !torch.vtensor<[1,1,1,1],ui8>, %arg4: !torch.vtensor<[1],f32>, %arg5: !torch.vtensor<[1],ui8>, %arg6: !torch.vtensor<[],f32>, %arg7: !torch.vtensor<[],ui8>, %arg8 : !torch.vtensor<[7],si32>) -> !torch.vtensor<[1,1,7,7],ui8> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +// CHECK-LABEL: @test_qlinearconv_bias_weight_per_channel +func.func @test_qlinearconv_bias_weight_per_channel(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>, %arg3: !torch.vtensor<[1,1,1,1],ui8>, %arg4: !torch.vtensor<[1],f32>, %arg5: !torch.vtensor<[1],ui8>, %arg6: !torch.vtensor<[],f32>, %arg7: !torch.vtensor<[],ui8>, %arg8 : !torch.vtensor<[7],si32>) -> !torch.vtensor<[1,1,7,7],ui8> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { %0 = torch.operator "onnx.QLinearConv"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.vtensor<[1,1,7,7],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[1,1,1,1],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[7],si32>) -> !torch.vtensor<[1,1,7,7],ui8> // CHECK: %[[aZp:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],ui8> -> !torch.int - // CHECK: %[[bZp:.+]] = torch.aten.item %arg5 : !torch.vtensor<[1],ui8> -> !torch.int // CHECK: %[[cZp:.+]] = torch.aten.item %arg7 : !torch.vtensor<[],ui8> -> !torch.int // CHECK: %[[aScale:.+]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float - // CHECK: %[[bScale:.+]] = torch.aten.item %arg4 : !torch.vtensor<[1],f32> -> !torch.float // CHECK: %[[cScale:.+]] = torch.aten.item %arg6 : !torch.vtensor<[],f32> -> !torch.float // CHECK: %[[A:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[aScale]], %[[aZp]] : !torch.vtensor<[1,1,7,7],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.quint8> - // CHECK: %[[B:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg3, %[[bScale]], %[[bZp]] : !torch.vtensor<[1,1,1,1],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,1,1],!torch.quint8> + // CHECK: %[[bScale:.+]] = torch.aten.item %arg4 : !torch.vtensor<[1],f32> -> !torch.float + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[B:.+]] = torch.aten._make_per_channel_quantized_tensor %arg3, %arg4, %arg5, %[[INT0]] : !torch.vtensor<[1,1,1,1],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.int -> !torch.vtensor<[1,1,1,1],!torch.quint8> // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 // CHECK: %[[INT0_1:.+]] = torch.constant.int 0 // CHECK: %[[PAD:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT0_1]] From 45d2d70e1381af2afa9ba9c05a784314bbe5b668 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Fri, 28 Feb 2025 10:37:51 +0530 Subject: [PATCH 05/11] Address PR comments --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 58 +++++++++++-------- 1 file changed, 35 insertions(+), 23 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 76934994fb29..ffc21c5b0c88 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -9,9 +9,12 @@ #include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" #include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/ErrorHandling.h" using namespace mlir; using namespace mlir::torch; @@ -376,6 +379,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( input = makePerTensor(input, inputScale, inputZp); // The onnx's QLinearConv op expects per channel quantization only for // the weight tensor for axis = 0. + bool isPerChannelQuantization = false; auto weightTy = dyn_cast(weight.getType()); auto weightScaleTy = dyn_cast(weightScale.getType()); @@ -384,32 +388,26 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( return failure(); auto weightShape = weightTy.getSizes(); auto weightScaleShape = weightScaleTy.getSizes(); - Value weightScaleScalar = extract(weightScale); - if (weightScaleShape.size() == 1 && - weightScaleShape[0] != Torch::kUnknownSize && - weightScaleShape[0] == weightShape[0]) { + if (weightScaleShape.size() == 0 || llvm::all_of(weightScaleShape, 1)) { + weightZp = extract(weightZp); + weightScale = extract(weightScale); + weight = makePerTensor(weight, weightScale, weightZp); + } else if (weightScaleShape.size() == 1 && + weightScaleShape[0] != Torch::kUnknownSize && + weightScaleShape[0] == weightShape[0]) { Value axis = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(0)); weight = makePerChannel(weight, weightScale, weightZp, axis); + isPerChannelQuantization = true; } else { - weightZp = extract(weightZp); - weight = makePerTensor(weight, weightScaleScalar, weightZp); + llvm_unreachable("Unidentified case for weight quantization"); } - weightScale = weightScaleScalar; - - auto outputTy = rewriter.getType( - resultType.getOptionalSizes(), - rewriter.getIntegerType(32, /*issigned=*/true)); // TODO(suderman): insert convolution operator. llvm::SmallVector newOperands = {input, weight}; if (output) newOperands.push_back(output); - outputTy = rewriter.getType( - resultType.getOptionalSizes(), - rewriter.getType()); - llvm::SmallVector newAttributes; newAttributes.push_back( rewriter.getNamedAttr("name", rewriter.getStringAttr("onnx.Conv"))); @@ -420,20 +418,34 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( newAttributes.push_back(namedAttr); } + auto outputTy = rewriter.getType( + resultType.getOptionalSizes(), + rewriter.getType()); + output = rewriter .create(binder.getLoc(), outputTy, newOperands, newAttributes, binder.op->getRegions().size()) .getResult(0); - Value outScale = rewriter.create( - binder.getLoc(), rewriter.getType(), inputScale, - weightScale); - Value outZp = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); - output = rewriter.create( - binder.getLoc(), outputTy, output, outScale, outZp); + Value outScale, outZp; + if (isPerChannelQuantization) { + // outZp = rewriter.create( + binder.getLoc(), weightScaleTy, weightScale, inputScale); + output = rewriter.create( + binder.getLoc(), outputTy, output, outScale, outZp); + } else { + outZp = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + outScale = rewriter.create( + binder.getLoc(), rewriter.getType(), inputScale, + weightScale); + output = rewriter.create( + binder.getLoc(), outputTy, output, outScale, outZp); + } + outputTy = rewriter.getType( resultType.getOptionalSizes(), rewriter.getF32Type()); From 7d2e31ae44604fa7e79c9f95e342bb3646bad6f0 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Mon, 3 Mar 2025 17:07:38 +0530 Subject: [PATCH 06/11] Update the lowering and add test --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 40 +++---------- .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 60 +++++++++++++------ 2 files changed, 50 insertions(+), 50 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index ffc21c5b0c88..c3a35e1cd001 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -9,12 +9,9 @@ #include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" #include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h" -#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/Support/ErrorHandling.h" using namespace mlir; using namespace mlir::torch; @@ -379,16 +376,18 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( input = makePerTensor(input, inputScale, inputZp); // The onnx's QLinearConv op expects per channel quantization only for // the weight tensor for axis = 0. - bool isPerChannelQuantization = false; auto weightTy = dyn_cast(weight.getType()); auto weightScaleTy = dyn_cast(weightScale.getType()); if (!weightTy || !weightScaleTy || !weightTy.hasSizes() || !weightScaleTy.hasSizes()) - return failure(); + return rewriter.notifyMatchFailure( + binder.op, + "Expected weight and weight_scale arguments to have sizes"); auto weightShape = weightTy.getSizes(); auto weightScaleShape = weightScaleTy.getSizes(); - if (weightScaleShape.size() == 0 || llvm::all_of(weightScaleShape, 1)) { + if (weightScaleShape.size() == 0 || + llvm::all_of(weightScaleShape, [](int64_t s) { return s == 1; })) { weightZp = extract(weightZp); weightScale = extract(weightScale); weight = makePerTensor(weight, weightScale, weightZp); @@ -398,9 +397,9 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Value axis = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(0)); weight = makePerChannel(weight, weightScale, weightZp, axis); - isPerChannelQuantization = true; } else { - llvm_unreachable("Unidentified case for weight quantization"); + llvm_unreachable("Unidentified case for weight quantization for " + "Onnx.QLinearConv op"); } // TODO(suderman): insert convolution operator. @@ -421,36 +420,12 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( auto outputTy = rewriter.getType( resultType.getOptionalSizes(), rewriter.getType()); - output = rewriter .create(binder.getLoc(), outputTy, newOperands, newAttributes, binder.op->getRegions().size()) .getResult(0); - Value outScale, outZp; - if (isPerChannelQuantization) { - // outZp = rewriter.create( - binder.getLoc(), weightScaleTy, weightScale, inputScale); - output = rewriter.create( - binder.getLoc(), outputTy, output, outScale, outZp); - } else { - outZp = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); - outScale = rewriter.create( - binder.getLoc(), rewriter.getType(), inputScale, - weightScale); - output = rewriter.create( - binder.getLoc(), outputTy, output, outScale, outZp); - } - - outputTy = rewriter.getType( - resultType.getOptionalSizes(), rewriter.getF32Type()); - - output = rewriter.create(binder.getLoc(), - outputTy, output); outputTy = getQTorchTypeFromTorchIntType(resultType); Value dtyVal = rewriter.create( binder.getLoc(), rewriter.getType(), @@ -458,6 +433,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( rewriter.getIntegerType(64), static_cast( Torch::getScalarTypeForType(outputTy.getDtype())))); + output = rewriter.create( binder.getLoc(), outputTy, output, outputScale, outputZp, dtyVal); rewriter.replaceOpWithNewOp(binder.op, resultType, diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 53147a6bfab9..cd34d49b5cdc 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -65,15 +65,15 @@ func.func @test_quantizelinear_f8(%arg0: !torch.vtensor<[6],f32>, %arg1: !torch. // ----- // CHECK-LABEL: @test_qlinearconv_nobias -func.func @test_qlinearconv_nobias(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>, %arg3: !torch.vtensor<[1,1,1,1],ui8>, %arg4: !torch.vtensor<[],f32>, %arg5: !torch.vtensor<[],ui8>, %arg6: !torch.vtensor<[],f32>, %arg7: !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,1,7,7],ui8> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - %0 = torch.operator "onnx.QLinearConv"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (!torch.vtensor<[1,1,7,7],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[1,1,1,1],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,1,7,7],ui8> +func.func @test_qlinearconv_nobias(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>, %arg3: !torch.vtensor<[1,1,1,1],ui8>, %arg4: !torch.vtensor<[1],f32>, %arg5: !torch.vtensor<[1],ui8>, %arg6: !torch.vtensor<[],f32>, %arg7: !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,1,7,7],ui8> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %0 = torch.operator "onnx.QLinearConv"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (!torch.vtensor<[1,1,7,7],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[1,1,1,1],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,1,7,7],ui8> // CHECK: %[[aZp:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],ui8> -> !torch.int // CHECK: %[[cZp:.+]] = torch.aten.item %arg7 : !torch.vtensor<[],ui8> -> !torch.int // CHECK: %[[aScale:.+]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float // CHECK: %[[cScale:.+]] = torch.aten.item %arg6 : !torch.vtensor<[],f32> -> !torch.float // CHECK: %[[A:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[aScale]], %[[aZp]] : !torch.vtensor<[1,1,7,7],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.quint8> - // CHECK: %[[bScale:.+]] = torch.aten.item %arg4 : !torch.vtensor<[],f32> -> !torch.float - // CHECK: %[[bZp:.+]] = torch.aten.item %arg5 : !torch.vtensor<[],ui8> -> !torch.int + // CHECK: %[[bZp:.+]] = torch.aten.item %arg5 : !torch.vtensor<[1],ui8> -> !torch.int + // CHECK: %[[bScale:.+]] = torch.aten.item %arg4 : !torch.vtensor<[1],f32> -> !torch.float // CHECK: %[[B:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg3, %[[bScale]], %[[bZp]] : !torch.vtensor<[1,1,1,1],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,1,1],!torch.quint8> // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 // CHECK: %[[INT0_1:.+]] = torch.constant.int 0 @@ -90,12 +90,8 @@ func.func @test_qlinearconv_nobias(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1: // CHECK: %[[NONE:.+]] = torch.constant.none // CHECK: %[[INT1_5:.+]] = torch.constant.int 1 // CHECK: %[[CONV:.+]] = torch.aten.convolution %[[A]], %[[B]], %[[NONE]], %[[DILATION]], %[[PAD]], %[[KERNEL]], %[[FALSE]], %[[STRIDE]], %[[INT1_5]] : !torch.vtensor<[1,1,7,7],!torch.quint8>, !torch.vtensor<[1,1,1,1],!torch.quint8>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.qint32> - // CHECK: %[[convScale:.+]] = torch.aten.mul.float %[[aScale]], %[[bScale]] : !torch.float, !torch.float -> !torch.float - // CHECK: %[[INT0_6:.+]] = torch.constant.int 0 - // CHECK: %[[C:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[CONV]], %[[convScale]], %[[INT0_6]] : !torch.vtensor<[1,1,7,7],!torch.qint32>, !torch.float, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.qint32> - // CHECK: %[[DEQ:.+]] = torch.aten.dequantize.self %[[C]] : !torch.vtensor<[1,1,7,7],!torch.qint32> -> !torch.vtensor<[1,1,7,7],f32> // CHECK: %[[INT13:.+]] = torch.constant.int 13 - // CHECK: %[[QUANT:.+]] = torch.aten.quantize_per_tensor %[[DEQ]], %[[cScale]], %[[cZp]], %[[INT13]] : !torch.vtensor<[1,1,7,7],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.quint8> + // CHECK: %[[QUANT:.+]] = torch.aten.quantize_per_tensor %[[CONV]], %[[cScale]], %[[cZp]], %[[INT13]] : !torch.vtensor<[1,1,7,7],!torch.qint32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.quint8> // CHECK: %[[INT:.+]] = torch.aten.int_repr %[[QUANT]] : !torch.vtensor<[1,1,7,7],!torch.quint8> -> !torch.vtensor<[1,1,7,7],ui8> // CHECK: return %[[INT]] : !torch.vtensor<[1,1,7,7],ui8> return %0 : !torch.vtensor<[1,1,7,7],ui8> @@ -103,17 +99,17 @@ func.func @test_qlinearconv_nobias(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1: // ----- -// CHECK-LABEL: @test_qlinearconv_bias_weight_per_channel -func.func @test_qlinearconv_bias_weight_per_channel(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>, %arg3: !torch.vtensor<[1,1,1,1],ui8>, %arg4: !torch.vtensor<[1],f32>, %arg5: !torch.vtensor<[1],ui8>, %arg6: !torch.vtensor<[],f32>, %arg7: !torch.vtensor<[],ui8>, %arg8 : !torch.vtensor<[7],si32>) -> !torch.vtensor<[1,1,7,7],ui8> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +// CHECK-LABEL: @test_qlinearconv_bias +func.func @test_qlinearconv_bias(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>, %arg3: !torch.vtensor<[1,1,1,1],ui8>, %arg4: !torch.vtensor<[1],f32>, %arg5: !torch.vtensor<[1],ui8>, %arg6: !torch.vtensor<[],f32>, %arg7: !torch.vtensor<[],ui8>, %arg8 : !torch.vtensor<[7],si32>) -> !torch.vtensor<[1,1,7,7],ui8> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { %0 = torch.operator "onnx.QLinearConv"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.vtensor<[1,1,7,7],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[1,1,1,1],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[7],si32>) -> !torch.vtensor<[1,1,7,7],ui8> // CHECK: %[[aZp:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],ui8> -> !torch.int // CHECK: %[[cZp:.+]] = torch.aten.item %arg7 : !torch.vtensor<[],ui8> -> !torch.int // CHECK: %[[aScale:.+]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float // CHECK: %[[cScale:.+]] = torch.aten.item %arg6 : !torch.vtensor<[],f32> -> !torch.float // CHECK: %[[A:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[aScale]], %[[aZp]] : !torch.vtensor<[1,1,7,7],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.quint8> + // CHECK: %[[bZp:.+]] = torch.aten.item %arg5 : !torch.vtensor<[1],ui8> -> !torch.int // CHECK: %[[bScale:.+]] = torch.aten.item %arg4 : !torch.vtensor<[1],f32> -> !torch.float - // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[B:.+]] = torch.aten._make_per_channel_quantized_tensor %arg3, %arg4, %arg5, %[[INT0]] : !torch.vtensor<[1,1,1,1],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.int -> !torch.vtensor<[1,1,1,1],!torch.quint8> + // CHECK: %[[B:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg3, %[[bScale]], %[[bZp]] : !torch.vtensor<[1,1,1,1],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,1,1],!torch.quint8> // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 // CHECK: %[[INT0_1:.+]] = torch.constant.int 0 // CHECK: %[[PAD:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT0_1]] @@ -128,12 +124,8 @@ func.func @test_qlinearconv_bias_weight_per_channel(%arg0: !torch.vtensor<[1,1,7 // CHECK: %[[FALSE:.+]] = torch.constant.bool false // CHECK: %[[INT1_5:.+]] = torch.constant.int 1 // CHECK: %[[CONV:.+]] = torch.aten.convolution %[[A]], %[[B]], %arg8, %[[DILATION]], %[[PAD]], %[[KERNEL]], %[[FALSE]], %[[STRIDE]], %[[INT1_5]] : !torch.vtensor<[1,1,7,7],!torch.quint8>, !torch.vtensor<[1,1,1,1],!torch.quint8>, !torch.vtensor<[7],si32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.qint32> - // CHECK: %[[convScale:.+]] = torch.aten.mul.float %[[aScale]], %[[bScale]] : !torch.float, !torch.float -> !torch.float - // CHECK: %[[INT0_6:.+]] = torch.constant.int 0 - // CHECK: %[[C:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[CONV]], %[[convScale]], %[[INT0_6]] : !torch.vtensor<[1,1,7,7],!torch.qint32>, !torch.float, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.qint32> - // CHECK: %[[DEQ:.+]] = torch.aten.dequantize.self %[[C]] : !torch.vtensor<[1,1,7,7],!torch.qint32> -> !torch.vtensor<[1,1,7,7],f32> // CHECK: %[[INT13:.+]] = torch.constant.int 13 - // CHECK: %[[QUANT:.+]] = torch.aten.quantize_per_tensor %[[DEQ]], %[[cScale]], %[[cZp]], %[[INT13]] : !torch.vtensor<[1,1,7,7],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.quint8> + // CHECK: %[[QUANT:.+]] = torch.aten.quantize_per_tensor %[[CONV]], %[[cScale]], %[[cZp]], %[[INT13]] : !torch.vtensor<[1,1,7,7],!torch.qint32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.quint8> // CHECK: %[[INT:.+]] = torch.aten.int_repr %[[QUANT]] : !torch.vtensor<[1,1,7,7],!torch.quint8> -> !torch.vtensor<[1,1,7,7],ui8> // CHECK: return %[[INT]] : !torch.vtensor<[1,1,7,7],ui8> return %0 : !torch.vtensor<[1,1,7,7],ui8> @@ -141,6 +133,38 @@ func.func @test_qlinearconv_bias_weight_per_channel(%arg0: !torch.vtensor<[1,1,7 // ----- +func.func @test_qlinearconv_weight_per_channel_quantization(%arg0: !torch.vtensor<[?,3,224,224],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>, %arg3: !torch.vtensor<[64,3,7,7],si8>, %arg4: !torch.vtensor<[64],f32>, %arg5: !torch.vtensor<[64],si8>, %arg6: !torch.vtensor<[],f32>, %arg7: !torch.vtensor<[],ui8>, %arg8 : !torch.vtensor<[64],si32>) -> !torch.vtensor<[?,64,112,112],ui8> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.opset_versions = {ai.onnx.ml = 2 : si64, ai.onnx.preview.training = 1 : si64, ai.onnx.training = 1 : si64, com.microsoft = 1 : si64, com.microsoft.experimental = 1 : si64, com.microsoft.mlfeaturizers = 1 : si64, com.microsoft.nchwc = 1 : si64}, torch.onnx_meta.producer_name = "onnx.quantize", torch.onnx_meta.producer_version = "0.1.0"} { + %0 = torch.operator "onnx.QLinearConv"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) {torch.onnx.auto_pad = "NOTSET", torch.onnx.dilations = [1 : si64, 1 : si64], torch.onnx.group = 1 : si64, torch.onnx.kernel_shape = [7 : si64, 7 : si64], torch.onnx.pads = [3 : si64, 3 : si64, 3 : si64, 3 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[?,3,224,224],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[64,3,7,7],si8>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],si8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[64],si32>) -> !torch.vtensor<[?,64,112,112],ui8> + // CHECK: %[[aZp:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],ui8> -> !torch.int + // CHECK: %[[cZp:.+]] = torch.aten.item %arg7 : !torch.vtensor<[],ui8> -> !torch.int + // CHECK: %[[aScale:.+]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[cScale:.+]] = torch.aten.item %arg6 : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[A:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[aScale]], %[[aZp]] : !torch.vtensor<[?,3,224,224],ui8>, !torch.float, !torch.int -> !torch.vtensor<[?,3,224,224],!torch.quint8> + // CHECK: %[[AXIS:.+]] = torch.constant.int 0 + // CHECK: %[[B:.+]] = torch.aten._make_per_channel_quantized_tensor %arg3, %arg4, %arg5, %[[AXIS]] : !torch.vtensor<[64,3,7,7],si8>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],si8>, !torch.int -> !torch.vtensor<[64,3,7,7],!torch.qint8> + // CHECK: %[[INT3_0:.+]] = torch.constant.int 3 + // CHECK: %[[INT3_1:.+]] = torch.constant.int 3 + // CHECK: %[[PAD:.+]] = torch.prim.ListConstruct %[[INT3_0]], %[[INT3_1]] + // CHECK: %[[INT1_0:.+]] = torch.constant.int 1 + // CHECK: %[[INT1_1:.+]] = torch.constant.int 1 + // CHECK: %[[INT2_0:.+]] = torch.constant.int 2 + // CHECK: %[[INT2_1:.+]] = torch.constant.int 2 + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[KERNEL:.+]] = torch.prim.ListConstruct %[[INT1_0]], %[[INT1_1]] + // CHECK: %[[DILATION:.+]] = torch.prim.ListConstruct %[[INT2_0]], %[[INT2_1]] + // CHECK: %[[STRIDE:.+]] = torch.prim.ListConstruct %[[INT0]], %[[INT0]] + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[INT1_2:.+]] = torch.constant.int 1 + // CHECK: %[[CONV:.+]] = torch.aten.convolution %[[A]], %[[B]], %arg8, %[[DILATION]], %[[PAD]], %[[KERNEL]], %[[FALSE]], %[[STRIDE]], %[[INT1_2]] : !torch.vtensor<[?,3,224,224],!torch.quint8>, !torch.vtensor<[64,3,7,7],!torch.qint8>, !torch.vtensor<[64],si32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[?,64,112,112],!torch.qint32> + // CHECK: %[[INT13:.+]] = torch.constant.int 13 + // CHECK: %[[QUANT:.+]] = torch.aten.quantize_per_tensor %[[CONV]], %[[cScale]], %[[cZp]], %[[INT13]] : !torch.vtensor<[?,64,112,112],!torch.qint32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[?,64,112,112],!torch.quint8> + // CHECK: %[[INT:.+]] = torch.aten.int_repr %[[QUANT]] : !torch.vtensor<[?,64,112,112],!torch.quint8> -> !torch.vtensor<[?,64,112,112],ui8> + // CHECK: return %[[INT]] : !torch.vtensor<[?,64,112,112],ui8> + return %0 : !torch.vtensor<[?,64,112,112],ui8> +} + +// ----- + // CHECK-LABEL: @test_qlinearmatmul_2D func.func @test_qlinearmatmul_2D(%arg0: !torch.vtensor<[2,4],ui8>, %arg1: !torch.vtensor<[1],f32>, %arg2: !torch.vtensor<[1],ui8>, %arg3: !torch.vtensor<[4,3],ui8>, %arg4: !torch.vtensor<[1],f32>, %arg5: !torch.vtensor<[1],ui8>, %arg6: !torch.vtensor<[1],f32>, %arg7: !torch.vtensor<[1],ui8>) -> !torch.vtensor<[2,3],ui8> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { %0 = torch.operator "onnx.QLinearMatMul"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (!torch.vtensor<[2,4],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.vtensor<[4,3],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>) -> !torch.vtensor<[2,3],ui8> From d81ef2d8047c4d0cb5b75f2ebfa810d074eb2d3d Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Wed, 5 Mar 2025 19:15:14 +0530 Subject: [PATCH 07/11] Handle non-floating point input for quantize_per_tensor --- lib/Conversion/TorchToLinalg/Uncategorized.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index e89056355785..47daea21d0d2 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -1497,6 +1497,13 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value zp = quant.getZeroPoint(); auto valueTy = value.getType(); + // The `torch.quantize_per_tensor` op accepts only float tensor as inputs. + // Hence, converting the non-float value to float type. + if (isa(valueTy)) { + valueTy = b.getF32Type(); + value = convertScalarToDtype(b, loc, value, valueTy); + } + zp = converter->materializeTargetConversion( b, loc, converter->convertType(zp.getType()), zp); zp = b.create(loc, valueTy, zp); From eb1ed60d2fb44f27c41109aa1cef17f26011e7c3 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Thu, 6 Mar 2025 19:20:53 +0530 Subject: [PATCH 08/11] Add QlinearConv as Dequant(Input) + Conv --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 158 ++++++++++++++---- .../TorchToLinalg/Uncategorized.cpp | 7 - .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 81 ++++++--- 3 files changed, 182 insertions(+), 64 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index c3a35e1cd001..affd41330154 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -326,6 +326,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( patterns.onOp( "QLinearConv", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Location loc = binder.getLoc(); Torch::ValueTensorType resultType; llvm::SmallVector operands; if ((binder.tensorOperands(operands, 8) && @@ -340,7 +341,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Value weightZp = operands[5]; Value outputScale = operands[6]; Value outputZp = operands[7]; - Value output = operands.size() == 9 ? operands[8] : nullptr; + Value bias = operands.size() == 9 ? operands[8] : nullptr; auto extract = [&rewriter, &binder](Value v) { auto vTy = cast(v.getType()); @@ -356,6 +357,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( outputZp = extract(outputZp); inputScale = extract(inputScale); outputScale = extract(outputScale); + auto makePerTensor = [&rewriter, &binder](Value v, Value scale, Value zp) -> Value { auto ty = cast(v.getType()); @@ -364,28 +366,22 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.getLoc(), newTy, v, scale, zp); }; - auto makePerChannel = [&rewriter, &binder](Value v, Value scale, - Value zp, - Value axis) -> Value { - auto ty = cast(v.getType()); - auto newTy = getQTorchTypeFromTorchIntType(ty); - return rewriter.create( - binder.getLoc(), newTy, v, scale, zp, axis); - }; - - input = makePerTensor(input, inputScale, inputZp); - // The onnx's QLinearConv op expects per channel quantization only for + // The onnx's QLinearConv op allows per channel quantization only for // the weight tensor for axis = 0. + bool isPerChannelQuantization = false; auto weightTy = dyn_cast(weight.getType()); auto weightScaleTy = dyn_cast(weightScale.getType()); - if (!weightTy || !weightScaleTy || !weightTy.hasSizes() || - !weightScaleTy.hasSizes()) + auto weightZpTy = dyn_cast(weightZp.getType()); + if (!weightTy || !weightScaleTy || !weightZpTy || + !weightTy.hasSizes() || !weightScaleTy.hasSizes() || + !weightZpTy.hasSizes()) return rewriter.notifyMatchFailure( - binder.op, - "Expected weight and weight_scale arguments to have sizes"); - auto weightShape = weightTy.getSizes(); - auto weightScaleShape = weightScaleTy.getSizes(); + binder.op, "Expected weight, weight_scale, and weight_zero_point " + "arguments to have sizes"); + ArrayRef weightShape(weightTy.getSizes()); + SmallVector weightScaleShape(weightScaleTy.getSizes()); + SmallVector weightZpShape(weightZpTy.getSizes()); if (weightScaleShape.size() == 0 || llvm::all_of(weightScaleShape, [](int64_t s) { return s == 1; })) { weightZp = extract(weightZp); @@ -394,18 +390,117 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( } else if (weightScaleShape.size() == 1 && weightScaleShape[0] != Torch::kUnknownSize && weightScaleShape[0] == weightShape[0]) { - Value axis = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(0)); - weight = makePerChannel(weight, weightScale, weightZp, axis); + // Since the convolution opertaion in the downstream pipeline + // ("Linalg") does not support the per channel quantization, hence for + // this particular case we perform the convolution over the + // dequantized input and weight instead of relying on the downstream + // pipeline to handle this. This code can be removed and made similar + // to the other paths in this lowering once the per-channel + // quantization support is added in the downstream pipeline. + isPerChannelQuantization = true; + + auto inputTy = dyn_cast(input.getType()); + if (!inputTy || !inputTy.hasSizes()) + return rewriter.notifyMatchFailure( + binder.op, "Expected input argument to have sizes"); + + // Dequantizing the input + // input = input.to(dtype=torch.float32) + // input_dequant = (input - input_zero_point) * input_scale + + // Converting the input tensor to float32 type. + Value none = rewriter.create(loc); + Value cstFalse = rewriter.create(loc, false); + Value float32Type = rewriter.create( + loc, rewriter.getI64IntegerAttr(/*float32Type*/ 6)); + Type f32InputType = rewriter.getType( + inputTy.getSizes(), rewriter.getF32Type()); + input = rewriter.create( + loc, f32InputType, input, float32Type, + /*non_blocking=*/cstFalse, + /*copy=*/cstFalse, + /*memory_format=*/none); + + Value cstOne = rewriter.create( + loc, rewriter.getF64FloatAttr(1.0)); + input = rewriter.create( + loc, f32InputType, input, inputZp, cstOne); + input = rewriter.create(loc, f32InputType, + input, inputScale); + + // Dequantizing the weight + // Shapes of the inputs are as follows: + // weight = (M x C/group x k1 x k2 x … x kn) + // weight_scale = (M) + // weight_zero_point = (M) + // + // We unsqueeze the weight_scale and weight_zero_point to match the + // rank of weight. After unsqueeze: + // weight_scale = (M, 1, 1, ..., 1) + // weight_zero_point = (M, 1, 1, ..., 1) + // + // Then, we compute the dequantized weight: + // weight = weight.to(dtype=torch.float32) + // weight_dequant = (weight - weight_zero_point) * weight_scale + int64_t diffRank = weightShape.size() - weightScaleShape.size(); + for (int i = 1; i <= diffRank; i++) { + Value cstDim = rewriter.create( + loc, rewriter.getI64IntegerAttr(i)); + + weightScaleShape.push_back(1); + Type weightScaleUnsqueezeType = weightScaleTy.getWithSizesAndDtype( + weightScaleShape, weightScaleTy.getOptionalDtype()); + weightScale = rewriter.create( + loc, weightScaleUnsqueezeType, weightScale, cstDim); + + weightZpShape.push_back(1); + Type weightZpUnsqueezeType = weightZpTy.getWithSizesAndDtype( + weightZpShape, weightZpTy.getOptionalDtype()); + weightZp = rewriter.create( + loc, weightZpUnsqueezeType, weightZp, cstDim); + } + + // Converting the weight tensor to float32 type. + Type f32WeightType = rewriter.getType( + weightShape, rewriter.getF32Type()); + weight = rewriter.create( + loc, f32WeightType, weight, float32Type, + /*non_blocking=*/cstFalse, + /*copy=*/cstFalse, + /*memory_format=*/none); + + weight = rewriter.create( + loc, f32WeightType, weight, weightZp, cstOne); + weight = rewriter.create(loc, f32WeightType, + weight, weightScale); + + // Converting the bias tensor to float32 type. + if (bias) { + auto biasTy = dyn_cast(bias.getType()); + if (!biasTy || !biasTy.hasSizes()) + return rewriter.notifyMatchFailure( + binder.op, "Expected bias argument to have sizes"); + Type f32BiasType = rewriter.getType( + biasTy.getSizes(), rewriter.getF32Type()); + bias = rewriter.create( + loc, f32BiasType, bias, float32Type, + /*non_blocking=*/cstFalse, + /*copy=*/cstFalse, + /*memory_format=*/none); + } + } else { llvm_unreachable("Unidentified case for weight quantization for " "Onnx.QLinearConv op"); } + if (!isPerChannelQuantization) + input = makePerTensor(input, inputScale, inputZp); + // TODO(suderman): insert convolution operator. llvm::SmallVector newOperands = {input, weight}; - if (output) - newOperands.push_back(output); + if (bias) + newOperands.push_back(bias); llvm::SmallVector newAttributes; newAttributes.push_back( @@ -417,14 +512,17 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( newAttributes.push_back(namedAttr); } + Type convDtype = + isPerChannelQuantization + ? cast(rewriter.getF32Type()) + : cast(rewriter.getType()); auto outputTy = rewriter.getType( - resultType.getOptionalSizes(), - rewriter.getType()); - output = rewriter - .create(binder.getLoc(), outputTy, - newOperands, newAttributes, - binder.op->getRegions().size()) - .getResult(0); + resultType.getOptionalSizes(), convDtype); + Value output = rewriter + .create( + binder.getLoc(), outputTy, newOperands, + newAttributes, binder.op->getRegions().size()) + .getResult(0); outputTy = getQTorchTypeFromTorchIntType(resultType); Value dtyVal = rewriter.create( diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 47daea21d0d2..e89056355785 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -1497,13 +1497,6 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value zp = quant.getZeroPoint(); auto valueTy = value.getType(); - // The `torch.quantize_per_tensor` op accepts only float tensor as inputs. - // Hence, converting the non-float value to float type. - if (isa(valueTy)) { - valueTy = b.getF32Type(); - value = convertScalarToDtype(b, loc, value, valueTy); - } - zp = converter->materializeTargetConversion( b, loc, converter->convertType(zp.getType()), zp); zp = b.create(loc, valueTy, zp); diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index cd34d49b5cdc..bcbb0d357e4a 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -71,10 +71,10 @@ func.func @test_qlinearconv_nobias(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1: // CHECK: %[[cZp:.+]] = torch.aten.item %arg7 : !torch.vtensor<[],ui8> -> !torch.int // CHECK: %[[aScale:.+]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float // CHECK: %[[cScale:.+]] = torch.aten.item %arg6 : !torch.vtensor<[],f32> -> !torch.float - // CHECK: %[[A:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[aScale]], %[[aZp]] : !torch.vtensor<[1,1,7,7],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.quint8> // CHECK: %[[bZp:.+]] = torch.aten.item %arg5 : !torch.vtensor<[1],ui8> -> !torch.int // CHECK: %[[bScale:.+]] = torch.aten.item %arg4 : !torch.vtensor<[1],f32> -> !torch.float // CHECK: %[[B:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg3, %[[bScale]], %[[bZp]] : !torch.vtensor<[1,1,1,1],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,1,1],!torch.quint8> + // CHECK: %[[A:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[aScale]], %[[aZp]] : !torch.vtensor<[1,1,7,7],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.quint8> // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 // CHECK: %[[INT0_1:.+]] = torch.constant.int 0 // CHECK: %[[PAD:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT0_1]] @@ -106,10 +106,10 @@ func.func @test_qlinearconv_bias(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1: !t // CHECK: %[[cZp:.+]] = torch.aten.item %arg7 : !torch.vtensor<[],ui8> -> !torch.int // CHECK: %[[aScale:.+]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float // CHECK: %[[cScale:.+]] = torch.aten.item %arg6 : !torch.vtensor<[],f32> -> !torch.float - // CHECK: %[[A:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[aScale]], %[[aZp]] : !torch.vtensor<[1,1,7,7],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.quint8> // CHECK: %[[bZp:.+]] = torch.aten.item %arg5 : !torch.vtensor<[1],ui8> -> !torch.int // CHECK: %[[bScale:.+]] = torch.aten.item %arg4 : !torch.vtensor<[1],f32> -> !torch.float // CHECK: %[[B:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg3, %[[bScale]], %[[bZp]] : !torch.vtensor<[1,1,1,1],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,1,1],!torch.quint8> + // CHECK: %[[A:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[aScale]], %[[aZp]] : !torch.vtensor<[1,1,7,7],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.quint8> // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 // CHECK: %[[INT0_1:.+]] = torch.constant.int 0 // CHECK: %[[PAD:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT0_1]] @@ -133,33 +133,60 @@ func.func @test_qlinearconv_bias(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1: !t // ----- +// CHECK-LABEL: func.func @test_qlinearconv_weight_per_channel_quantization( +// CHECK-SAME: %[[INPUT:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[?,3,224,224],ui8>, +// CHECK-SAME: %[[IN_SCALE:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[],f32>, +// CHECK-SAME: %[[IN_ZP:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[],ui8>, +// CHECK-SAME: %[[W:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[64,3,7,7],si8>, +// CHECK-SAME: %[[W_SCALE:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[64],f32>, +// CHECK-SAME: %[[W_ZP:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[64],si8>, +// CHECK-SAME: %[[OUT_SCALE:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[],f32>, +// CHECK-SAME: %[[OUT_ZP:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[],ui8>, +// CHECK-SAME: %[[BIAS:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[64],si32>) -> !torch.vtensor<[?,64,112,112],ui8> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.opset_versions = {ai.onnx.ml = 2 : si64, ai.onnx.preview.training = 1 : si64, ai.onnx.training = 1 : si64, com.microsoft = 1 : si64, com.microsoft.experimental = 1 : si64, com.microsoft.mlfeaturizers = 1 : si64, com.microsoft.nchwc = 1 : si64}, torch.onnx_meta.producer_name = "onnx.quantize", torch.onnx_meta.producer_version = "0.1.0"} { func.func @test_qlinearconv_weight_per_channel_quantization(%arg0: !torch.vtensor<[?,3,224,224],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>, %arg3: !torch.vtensor<[64,3,7,7],si8>, %arg4: !torch.vtensor<[64],f32>, %arg5: !torch.vtensor<[64],si8>, %arg6: !torch.vtensor<[],f32>, %arg7: !torch.vtensor<[],ui8>, %arg8 : !torch.vtensor<[64],si32>) -> !torch.vtensor<[?,64,112,112],ui8> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.opset_versions = {ai.onnx.ml = 2 : si64, ai.onnx.preview.training = 1 : si64, ai.onnx.training = 1 : si64, com.microsoft = 1 : si64, com.microsoft.experimental = 1 : si64, com.microsoft.mlfeaturizers = 1 : si64, com.microsoft.nchwc = 1 : si64}, torch.onnx_meta.producer_name = "onnx.quantize", torch.onnx_meta.producer_version = "0.1.0"} { %0 = torch.operator "onnx.QLinearConv"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) {torch.onnx.auto_pad = "NOTSET", torch.onnx.dilations = [1 : si64, 1 : si64], torch.onnx.group = 1 : si64, torch.onnx.kernel_shape = [7 : si64, 7 : si64], torch.onnx.pads = [3 : si64, 3 : si64, 3 : si64, 3 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[?,3,224,224],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[64,3,7,7],si8>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],si8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[64],si32>) -> !torch.vtensor<[?,64,112,112],ui8> - // CHECK: %[[aZp:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],ui8> -> !torch.int - // CHECK: %[[cZp:.+]] = torch.aten.item %arg7 : !torch.vtensor<[],ui8> -> !torch.int - // CHECK: %[[aScale:.+]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float - // CHECK: %[[cScale:.+]] = torch.aten.item %arg6 : !torch.vtensor<[],f32> -> !torch.float - // CHECK: %[[A:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[aScale]], %[[aZp]] : !torch.vtensor<[?,3,224,224],ui8>, !torch.float, !torch.int -> !torch.vtensor<[?,3,224,224],!torch.quint8> - // CHECK: %[[AXIS:.+]] = torch.constant.int 0 - // CHECK: %[[B:.+]] = torch.aten._make_per_channel_quantized_tensor %arg3, %arg4, %arg5, %[[AXIS]] : !torch.vtensor<[64,3,7,7],si8>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],si8>, !torch.int -> !torch.vtensor<[64,3,7,7],!torch.qint8> - // CHECK: %[[INT3_0:.+]] = torch.constant.int 3 - // CHECK: %[[INT3_1:.+]] = torch.constant.int 3 - // CHECK: %[[PAD:.+]] = torch.prim.ListConstruct %[[INT3_0]], %[[INT3_1]] - // CHECK: %[[INT1_0:.+]] = torch.constant.int 1 - // CHECK: %[[INT1_1:.+]] = torch.constant.int 1 - // CHECK: %[[INT2_0:.+]] = torch.constant.int 2 - // CHECK: %[[INT2_1:.+]] = torch.constant.int 2 - // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[KERNEL:.+]] = torch.prim.ListConstruct %[[INT1_0]], %[[INT1_1]] - // CHECK: %[[DILATION:.+]] = torch.prim.ListConstruct %[[INT2_0]], %[[INT2_1]] - // CHECK: %[[STRIDE:.+]] = torch.prim.ListConstruct %[[INT0]], %[[INT0]] - // CHECK: %[[FALSE:.+]] = torch.constant.bool false - // CHECK: %[[INT1_2:.+]] = torch.constant.int 1 - // CHECK: %[[CONV:.+]] = torch.aten.convolution %[[A]], %[[B]], %arg8, %[[DILATION]], %[[PAD]], %[[KERNEL]], %[[FALSE]], %[[STRIDE]], %[[INT1_2]] : !torch.vtensor<[?,3,224,224],!torch.quint8>, !torch.vtensor<[64,3,7,7],!torch.qint8>, !torch.vtensor<[64],si32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[?,64,112,112],!torch.qint32> - // CHECK: %[[INT13:.+]] = torch.constant.int 13 - // CHECK: %[[QUANT:.+]] = torch.aten.quantize_per_tensor %[[CONV]], %[[cScale]], %[[cZp]], %[[INT13]] : !torch.vtensor<[?,64,112,112],!torch.qint32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[?,64,112,112],!torch.quint8> - // CHECK: %[[INT:.+]] = torch.aten.int_repr %[[QUANT]] : !torch.vtensor<[?,64,112,112],!torch.quint8> -> !torch.vtensor<[?,64,112,112],ui8> - // CHECK: return %[[INT]] : !torch.vtensor<[?,64,112,112],ui8> + // CHECK: %[[INPUT_ZP:.*]] = torch.aten.item %[[IN_ZP]] : !torch.vtensor<[],ui8> -> !torch.int + // CHECK: %[[OUTPUT_ZP:.*]] = torch.aten.item %[[OUT_ZP]] : !torch.vtensor<[],ui8> -> !torch.int + // CHECK: %[[INPUT_SCALE:.*]] = torch.aten.item %[[IN_SCALE]] : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[OUTPUT_SCALE:.*]] = torch.aten.item %[[OUT_SCALE]] : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[F32DTYPE:.*]] = torch.constant.int 6 + // CHECK: %[[F32_INPUT:.*]] = torch.aten.to.dtype %[[INPUT]], %[[F32DTYPE]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[?,3,224,224],ui8>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,3,224,224],f32> + // CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00 + // CHECK: %[[VAL_18:.*]] = torch.aten.sub.Scalar %[[F32_INPUT]], %[[INPUT_ZP]], %[[ALPHA]] : !torch.vtensor<[?,3,224,224],f32>, !torch.int, !torch.float -> !torch.vtensor<[?,3,224,224],f32> + // CHECK: %[[DEQUANT_INPUT:.*]] = torch.aten.mul.Scalar %[[VAL_18]], %[[INPUT_SCALE]] : !torch.vtensor<[?,3,224,224],f32>, !torch.float -> !torch.vtensor<[?,3,224,224],f32> + // CHECK: %[[VAL_20:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_21:.*]] = torch.aten.unsqueeze %[[W_SCALE]], %[[VAL_20]] : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64,1],f32> + // CHECK: %[[VAL_22:.*]] = torch.aten.unsqueeze %[[W_ZP]], %[[VAL_20]] : !torch.vtensor<[64],si8>, !torch.int -> !torch.vtensor<[64,1],si8> + // CHECK: %[[VAL_23:.*]] = torch.constant.int 2 + // CHECK: %[[VAL_24:.*]] = torch.aten.unsqueeze %[[VAL_21]], %[[VAL_23]] : !torch.vtensor<[64,1],f32>, !torch.int -> !torch.vtensor<[64,1,1],f32> + // CHECK: %[[VAL_25:.*]] = torch.aten.unsqueeze %[[VAL_22]], %[[VAL_23]] : !torch.vtensor<[64,1],si8>, !torch.int -> !torch.vtensor<[64,1,1],si8> + // CHECK: %[[VAL_26:.*]] = torch.constant.int 3 + // CHECK: %[[WEIGHT_SCALE:.*]] = torch.aten.unsqueeze %[[VAL_24]], %[[VAL_26]] : !torch.vtensor<[64,1,1],f32>, !torch.int -> !torch.vtensor<[64,1,1,1],f32> + // CHECK: %[[WEIGHT_ZP:.*]] = torch.aten.unsqueeze %[[VAL_25]], %[[VAL_26]] : !torch.vtensor<[64,1,1],si8>, !torch.int -> !torch.vtensor<[64,1,1,1],si8> + // CHECK: %[[F32_WEIGHT:.*]] = torch.aten.to.dtype %[[W]], %[[F32DTYPE]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[64,3,7,7],si8>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[64,3,7,7],f32> + // CHECK: %[[VAL_30:.*]] = torch.aten.sub.Tensor %[[F32_WEIGHT]], %[[WEIGHT_ZP]], %[[ALPHA]] : !torch.vtensor<[64,3,7,7],f32>, !torch.vtensor<[64,1,1,1],si8>, !torch.float -> !torch.vtensor<[64,3,7,7],f32> + // CHECK: %[[DEQUANT_WEIGHT:.*]] = torch.aten.mul.Tensor %[[VAL_30]], %[[WEIGHT_SCALE]] : !torch.vtensor<[64,3,7,7],f32>, !torch.vtensor<[64,1,1,1],f32> -> !torch.vtensor<[64,3,7,7],f32> + // CHECK: %[[F32_BIAS:.*]] = torch.aten.to.dtype %[[BIAS]], %[[F32DTYPE]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[64],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[64],f32> + // CHECK: %[[VAL_33:.*]] = torch.constant.int 3 + // CHECK: %[[VAL_34:.*]] = torch.constant.int 3 + // CHECK: %[[PAD:.*]] = torch.prim.ListConstruct %[[VAL_33]], %[[VAL_34]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[VAL_36:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_37:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_38:.*]] = torch.constant.int 2 + // CHECK: %[[VAL_39:.*]] = torch.constant.int 2 + // CHECK: %[[VAL_40:.*]] = torch.constant.int 0 + // CHECK: %[[KERNEL:.*]] = torch.prim.ListConstruct %[[VAL_36]], %[[VAL_37]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[DILATION:.*]] = torch.prim.ListConstruct %[[VAL_38]], %[[VAL_39]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[VAL_40]], %[[VAL_40]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool false + // CHECK: %[[GROUPS:.*]] = torch.constant.int 1 + // CHECK: %[[CONV:.*]] = torch.aten.convolution %[[DEQUANT_INPUT]], %[[DEQUANT_WEIGHT]], %[[F32_BIAS]], %[[DILATION]], %[[PAD]], %[[KERNEL]], %[[TRANSPOSED]], %[[STRIDE]], %[[GROUPS]] : !torch.vtensor<[?,3,224,224],f32>, !torch.vtensor<[64,3,7,7],f32>, !torch.vtensor<[64],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[?,64,112,112],f32> + // CHECK: %[[DTYPE:.*]] = torch.constant.int 13 + // CHECK: %[[QUANT:.*]] = torch.aten.quantize_per_tensor %[[CONV]], %[[OUTPUT_SCALE]], %[[OUTPUT_ZP]], %[[DTYPE]] : !torch.vtensor<[?,64,112,112],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[?,64,112,112],!torch.quint8> + // CHECK: %[[OUTPUT:.*]] = torch.aten.int_repr %[[QUANT]] : !torch.vtensor<[?,64,112,112],!torch.quint8> -> !torch.vtensor<[?,64,112,112],ui8> + // CHECK: return %[[OUTPUT]] : !torch.vtensor<[?,64,112,112],ui8> return %0 : !torch.vtensor<[?,64,112,112],ui8> } From 077995a22f927be0b20e0de681d8acfef702909b Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Thu, 6 Mar 2025 19:28:07 +0530 Subject: [PATCH 09/11] Revert some changes to original state --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 17 ++++++++++++++++- .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 12 ++++++++++-- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index affd41330154..c3ba8fac6d99 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -497,7 +497,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( if (!isPerChannelQuantization) input = makePerTensor(input, inputScale, inputZp); - // TODO(suderman): insert convolution operator. llvm::SmallVector newOperands = {input, weight}; if (bias) newOperands.push_back(bias); @@ -524,6 +523,22 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( newAttributes, binder.op->getRegions().size()) .getResult(0); + if (!isPerChannelQuantization) { + Value outScale = rewriter.create( + binder.getLoc(), rewriter.getType(), inputScale, + weightScale); + Value outZp = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + output = rewriter.create( + binder.getLoc(), outputTy, output, outScale, outZp); + outputTy = rewriter.getType( + resultType.getOptionalSizes(), rewriter.getF32Type()); + + output = rewriter.create( + binder.getLoc(), outputTy, output); + } + outputTy = getQTorchTypeFromTorchIntType(resultType); Value dtyVal = rewriter.create( binder.getLoc(), rewriter.getType(), diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index bcbb0d357e4a..29d03c790fce 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -90,8 +90,12 @@ func.func @test_qlinearconv_nobias(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1: // CHECK: %[[NONE:.+]] = torch.constant.none // CHECK: %[[INT1_5:.+]] = torch.constant.int 1 // CHECK: %[[CONV:.+]] = torch.aten.convolution %[[A]], %[[B]], %[[NONE]], %[[DILATION]], %[[PAD]], %[[KERNEL]], %[[FALSE]], %[[STRIDE]], %[[INT1_5]] : !torch.vtensor<[1,1,7,7],!torch.quint8>, !torch.vtensor<[1,1,1,1],!torch.quint8>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.qint32> + // CHECK: %[[convScale:.+]] = torch.aten.mul.float %[[aScale]], %[[bScale]] : !torch.float, !torch.float -> !torch.float + // CHECK: %[[INT0_6:.+]] = torch.constant.int 0 + // CHECK: %[[C:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[CONV]], %[[convScale]], %[[INT0_6]] : !torch.vtensor<[1,1,7,7],!torch.qint32>, !torch.float, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.qint32> + // CHECK: %[[DEQ:.+]] = torch.aten.dequantize.self %[[C]] : !torch.vtensor<[1,1,7,7],!torch.qint32> -> !torch.vtensor<[1,1,7,7],f32> // CHECK: %[[INT13:.+]] = torch.constant.int 13 - // CHECK: %[[QUANT:.+]] = torch.aten.quantize_per_tensor %[[CONV]], %[[cScale]], %[[cZp]], %[[INT13]] : !torch.vtensor<[1,1,7,7],!torch.qint32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.quint8> + // CHECK: %[[QUANT:.+]] = torch.aten.quantize_per_tensor %[[DEQ]], %[[cScale]], %[[cZp]], %[[INT13]] : !torch.vtensor<[1,1,7,7],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.quint8> // CHECK: %[[INT:.+]] = torch.aten.int_repr %[[QUANT]] : !torch.vtensor<[1,1,7,7],!torch.quint8> -> !torch.vtensor<[1,1,7,7],ui8> // CHECK: return %[[INT]] : !torch.vtensor<[1,1,7,7],ui8> return %0 : !torch.vtensor<[1,1,7,7],ui8> @@ -124,8 +128,12 @@ func.func @test_qlinearconv_bias(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1: !t // CHECK: %[[FALSE:.+]] = torch.constant.bool false // CHECK: %[[INT1_5:.+]] = torch.constant.int 1 // CHECK: %[[CONV:.+]] = torch.aten.convolution %[[A]], %[[B]], %arg8, %[[DILATION]], %[[PAD]], %[[KERNEL]], %[[FALSE]], %[[STRIDE]], %[[INT1_5]] : !torch.vtensor<[1,1,7,7],!torch.quint8>, !torch.vtensor<[1,1,1,1],!torch.quint8>, !torch.vtensor<[7],si32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.qint32> + // CHECK: %[[convScale:.+]] = torch.aten.mul.float %[[aScale]], %[[bScale]] : !torch.float, !torch.float -> !torch.float + // CHECK: %[[INT0_6:.+]] = torch.constant.int 0 + // CHECK: %[[C:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[CONV]], %[[convScale]], %[[INT0_6]] : !torch.vtensor<[1,1,7,7],!torch.qint32>, !torch.float, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.qint32> + // CHECK: %[[DEQ:.+]] = torch.aten.dequantize.self %[[C]] : !torch.vtensor<[1,1,7,7],!torch.qint32> -> !torch.vtensor<[1,1,7,7],f32> // CHECK: %[[INT13:.+]] = torch.constant.int 13 - // CHECK: %[[QUANT:.+]] = torch.aten.quantize_per_tensor %[[CONV]], %[[cScale]], %[[cZp]], %[[INT13]] : !torch.vtensor<[1,1,7,7],!torch.qint32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.quint8> + // CHECK: %[[QUANT:.+]] = torch.aten.quantize_per_tensor %[[DEQ]], %[[cScale]], %[[cZp]], %[[INT13]] : !torch.vtensor<[1,1,7,7],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.quint8> // CHECK: %[[INT:.+]] = torch.aten.int_repr %[[QUANT]] : !torch.vtensor<[1,1,7,7],!torch.quint8> -> !torch.vtensor<[1,1,7,7],ui8> // CHECK: return %[[INT]] : !torch.vtensor<[1,1,7,7],ui8> return %0 : !torch.vtensor<[1,1,7,7],ui8> From f51e83d43ee29f5df00539066afdd54fbbd36a93 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Thu, 6 Mar 2025 19:42:17 +0530 Subject: [PATCH 10/11] Fix typo --- lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index c3ba8fac6d99..39613830f72c 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -390,8 +390,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( } else if (weightScaleShape.size() == 1 && weightScaleShape[0] != Torch::kUnknownSize && weightScaleShape[0] == weightShape[0]) { - // Since the convolution opertaion in the downstream pipeline - // ("Linalg") does not support the per channel quantization, hence for + // Since the convolution operation in the downstream pipeline + // ("Linalg") does not support the per-channel quantization, hence for // this particular case we perform the convolution over the // dequantized input and weight instead of relying on the downstream // pipeline to handle this. This code can be removed and made similar From 3d7924d43b0d22b907505014f87e6de4b5fb17e9 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Mon, 10 Mar 2025 11:33:24 +0530 Subject: [PATCH 11/11] Simplify test --- .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 45 +++++++------------ 1 file changed, 16 insertions(+), 29 deletions(-) diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 29d03c790fce..b98c10792ecc 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -141,42 +141,29 @@ func.func @test_qlinearconv_bias(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1: !t // ----- -// CHECK-LABEL: func.func @test_qlinearconv_weight_per_channel_quantization( -// CHECK-SAME: %[[INPUT:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[?,3,224,224],ui8>, -// CHECK-SAME: %[[IN_SCALE:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[],f32>, -// CHECK-SAME: %[[IN_ZP:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[],ui8>, -// CHECK-SAME: %[[W:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[64,3,7,7],si8>, -// CHECK-SAME: %[[W_SCALE:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[64],f32>, -// CHECK-SAME: %[[W_ZP:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[64],si8>, -// CHECK-SAME: %[[OUT_SCALE:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[],f32>, -// CHECK-SAME: %[[OUT_ZP:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[],ui8>, -// CHECK-SAME: %[[BIAS:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[64],si32>) -> !torch.vtensor<[?,64,112,112],ui8> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.opset_versions = {ai.onnx.ml = 2 : si64, ai.onnx.preview.training = 1 : si64, ai.onnx.training = 1 : si64, com.microsoft = 1 : si64, com.microsoft.experimental = 1 : si64, com.microsoft.mlfeaturizers = 1 : si64, com.microsoft.nchwc = 1 : si64}, torch.onnx_meta.producer_name = "onnx.quantize", torch.onnx_meta.producer_version = "0.1.0"} { +// CHECK-LABEL: func.func @test_qlinearconv_weight_per_channel_quantization func.func @test_qlinearconv_weight_per_channel_quantization(%arg0: !torch.vtensor<[?,3,224,224],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>, %arg3: !torch.vtensor<[64,3,7,7],si8>, %arg4: !torch.vtensor<[64],f32>, %arg5: !torch.vtensor<[64],si8>, %arg6: !torch.vtensor<[],f32>, %arg7: !torch.vtensor<[],ui8>, %arg8 : !torch.vtensor<[64],si32>) -> !torch.vtensor<[?,64,112,112],ui8> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.opset_versions = {ai.onnx.ml = 2 : si64, ai.onnx.preview.training = 1 : si64, ai.onnx.training = 1 : si64, com.microsoft = 1 : si64, com.microsoft.experimental = 1 : si64, com.microsoft.mlfeaturizers = 1 : si64, com.microsoft.nchwc = 1 : si64}, torch.onnx_meta.producer_name = "onnx.quantize", torch.onnx_meta.producer_version = "0.1.0"} { %0 = torch.operator "onnx.QLinearConv"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) {torch.onnx.auto_pad = "NOTSET", torch.onnx.dilations = [1 : si64, 1 : si64], torch.onnx.group = 1 : si64, torch.onnx.kernel_shape = [7 : si64, 7 : si64], torch.onnx.pads = [3 : si64, 3 : si64, 3 : si64, 3 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[?,3,224,224],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[64,3,7,7],si8>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],si8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[64],si32>) -> !torch.vtensor<[?,64,112,112],ui8> - // CHECK: %[[INPUT_ZP:.*]] = torch.aten.item %[[IN_ZP]] : !torch.vtensor<[],ui8> -> !torch.int - // CHECK: %[[OUTPUT_ZP:.*]] = torch.aten.item %[[OUT_ZP]] : !torch.vtensor<[],ui8> -> !torch.int - // CHECK: %[[INPUT_SCALE:.*]] = torch.aten.item %[[IN_SCALE]] : !torch.vtensor<[],f32> -> !torch.float - // CHECK: %[[OUTPUT_SCALE:.*]] = torch.aten.item %[[OUT_SCALE]] : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[INPUT_ZP:.*]] = torch.aten.item %arg2 : !torch.vtensor<[],ui8> -> !torch.int + // CHECK: %[[OUTPUT_ZP:.*]] = torch.aten.item %arg7 : !torch.vtensor<[],ui8> -> !torch.int + // CHECK: %[[INPUT_SCALE:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[OUTPUT_SCALE:.*]] = torch.aten.item %arg6 : !torch.vtensor<[],f32> -> !torch.float // CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[FALSE:.*]] = torch.constant.bool false // CHECK: %[[F32DTYPE:.*]] = torch.constant.int 6 - // CHECK: %[[F32_INPUT:.*]] = torch.aten.to.dtype %[[INPUT]], %[[F32DTYPE]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[?,3,224,224],ui8>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,3,224,224],f32> - // CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00 - // CHECK: %[[VAL_18:.*]] = torch.aten.sub.Scalar %[[F32_INPUT]], %[[INPUT_ZP]], %[[ALPHA]] : !torch.vtensor<[?,3,224,224],f32>, !torch.int, !torch.float -> !torch.vtensor<[?,3,224,224],f32> + // CHECK: %[[F32_INPUT:.*]] = torch.aten.to.dtype %arg0, %[[F32DTYPE]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[?,3,224,224],ui8>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,3,224,224],f32> + // CHECK: %[[VAL_18:.*]] = torch.aten.sub.Scalar %[[F32_INPUT]], %[[INPUT_ZP]], %float1.000000e00 : !torch.vtensor<[?,3,224,224],f32>, !torch.int, !torch.float -> !torch.vtensor<[?,3,224,224],f32> // CHECK: %[[DEQUANT_INPUT:.*]] = torch.aten.mul.Scalar %[[VAL_18]], %[[INPUT_SCALE]] : !torch.vtensor<[?,3,224,224],f32>, !torch.float -> !torch.vtensor<[?,3,224,224],f32> - // CHECK: %[[VAL_20:.*]] = torch.constant.int 1 - // CHECK: %[[VAL_21:.*]] = torch.aten.unsqueeze %[[W_SCALE]], %[[VAL_20]] : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64,1],f32> - // CHECK: %[[VAL_22:.*]] = torch.aten.unsqueeze %[[W_ZP]], %[[VAL_20]] : !torch.vtensor<[64],si8>, !torch.int -> !torch.vtensor<[64,1],si8> - // CHECK: %[[VAL_23:.*]] = torch.constant.int 2 - // CHECK: %[[VAL_24:.*]] = torch.aten.unsqueeze %[[VAL_21]], %[[VAL_23]] : !torch.vtensor<[64,1],f32>, !torch.int -> !torch.vtensor<[64,1,1],f32> - // CHECK: %[[VAL_25:.*]] = torch.aten.unsqueeze %[[VAL_22]], %[[VAL_23]] : !torch.vtensor<[64,1],si8>, !torch.int -> !torch.vtensor<[64,1,1],si8> - // CHECK: %[[VAL_26:.*]] = torch.constant.int 3 - // CHECK: %[[WEIGHT_SCALE:.*]] = torch.aten.unsqueeze %[[VAL_24]], %[[VAL_26]] : !torch.vtensor<[64,1,1],f32>, !torch.int -> !torch.vtensor<[64,1,1,1],f32> - // CHECK: %[[WEIGHT_ZP:.*]] = torch.aten.unsqueeze %[[VAL_25]], %[[VAL_26]] : !torch.vtensor<[64,1,1],si8>, !torch.int -> !torch.vtensor<[64,1,1,1],si8> - // CHECK: %[[F32_WEIGHT:.*]] = torch.aten.to.dtype %[[W]], %[[F32DTYPE]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[64,3,7,7],si8>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[64,3,7,7],f32> - // CHECK: %[[VAL_30:.*]] = torch.aten.sub.Tensor %[[F32_WEIGHT]], %[[WEIGHT_ZP]], %[[ALPHA]] : !torch.vtensor<[64,3,7,7],f32>, !torch.vtensor<[64,1,1,1],si8>, !torch.float -> !torch.vtensor<[64,3,7,7],f32> + // CHECK: %[[VAL_21:.*]] = torch.aten.unsqueeze %arg4, %int1 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64,1],f32> + // CHECK: %[[VAL_22:.*]] = torch.aten.unsqueeze %arg5, %int1 : !torch.vtensor<[64],si8>, !torch.int -> !torch.vtensor<[64,1],si8> + // CHECK: %[[VAL_24:.*]] = torch.aten.unsqueeze %[[VAL_21]], %int2 : !torch.vtensor<[64,1],f32>, !torch.int -> !torch.vtensor<[64,1,1],f32> + // CHECK: %[[VAL_25:.*]] = torch.aten.unsqueeze %[[VAL_22]], %int2 : !torch.vtensor<[64,1],si8>, !torch.int -> !torch.vtensor<[64,1,1],si8> + // CHECK: %[[WEIGHT_SCALE:.*]] = torch.aten.unsqueeze %[[VAL_24]], %int3 : !torch.vtensor<[64,1,1],f32>, !torch.int -> !torch.vtensor<[64,1,1,1],f32> + // CHECK: %[[WEIGHT_ZP:.*]] = torch.aten.unsqueeze %[[VAL_25]], %int3 : !torch.vtensor<[64,1,1],si8>, !torch.int -> !torch.vtensor<[64,1,1,1],si8> + // CHECK: %[[F32_WEIGHT:.*]] = torch.aten.to.dtype %arg3, %[[F32DTYPE]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[64,3,7,7],si8>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[64,3,7,7],f32> + // CHECK: %[[VAL_30:.*]] = torch.aten.sub.Tensor %[[F32_WEIGHT]], %[[WEIGHT_ZP]], %float1.000000e00 : !torch.vtensor<[64,3,7,7],f32>, !torch.vtensor<[64,1,1,1],si8>, !torch.float -> !torch.vtensor<[64,3,7,7],f32> // CHECK: %[[DEQUANT_WEIGHT:.*]] = torch.aten.mul.Tensor %[[VAL_30]], %[[WEIGHT_SCALE]] : !torch.vtensor<[64,3,7,7],f32>, !torch.vtensor<[64,1,1,1],f32> -> !torch.vtensor<[64,3,7,7],f32> - // CHECK: %[[F32_BIAS:.*]] = torch.aten.to.dtype %[[BIAS]], %[[F32DTYPE]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[64],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[64],f32> + // CHECK: %[[F32_BIAS:.*]] = torch.aten.to.dtype %arg8, %[[F32DTYPE]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[64],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[64],f32> // CHECK: %[[VAL_33:.*]] = torch.constant.int 3 // CHECK: %[[VAL_34:.*]] = torch.constant.int 3 // CHECK: %[[PAD:.*]] = torch.prim.ListConstruct %[[VAL_33]], %[[VAL_34]] : (!torch.int, !torch.int) -> !torch.list