-
Notifications
You must be signed in to change notification settings - Fork 518
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ONNX] Add support for asymmetric padding for Onnx.AveragePool op #3923
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -456,107 +456,47 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( | |
patterns.onOp( | ||
"AveragePool", 11, | ||
[](OpBinder binder, ConversionPatternRewriter &rewriter) { | ||
std::string autoPad; | ||
SmallVector<int64_t> dilations; | ||
if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET")) | ||
return failure(); | ||
if (autoPad != "NOTSET") { | ||
// TODO: Add support for `auto_pad` != "NOTSET" | ||
return rewriter.notifyMatchFailure( | ||
binder.op, "unsupported conversion: auto_pad != NOTSET"); | ||
} | ||
|
||
Torch::ValueTensorType resultType; | ||
Value operand; | ||
bool ceilMode, countIncludePad; | ||
int64_t ceilMode, countIncludePad; | ||
std::string autoPad; | ||
if (binder.tensorOperand(operand) || | ||
binder.s64BoolAttr(ceilMode, "ceil_mode", false) || | ||
binder.s64BoolAttr(countIncludePad, "count_include_pad", false) || | ||
binder.s64IntegerAttr(ceilMode, "ceil_mode", 0) || | ||
binder.s64IntegerAttr(countIncludePad, "count_include_pad", 0) || | ||
binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET") || | ||
binder.tensorResultType(resultType)) | ||
return failure(); | ||
return rewriter.notifyMatchFailure( | ||
binder.op, "operand/ceil_mode/count_include_pad/auto_pad/" | ||
"resultType bind failure"); | ||
|
||
// Determine the rank of input tensor. | ||
std::optional<unsigned> maybeRank = Torch::getTensorRank(operand); | ||
if (!maybeRank) | ||
return rewriter.notifyMatchFailure(binder.op, | ||
"Unimplemented: unranked tensor"); | ||
unsigned rank = *maybeRank; | ||
|
||
SmallVector<int64_t> kernel, padding, strides; | ||
if (binder.s64IntegerArrayAttr(kernel, "kernel_shape", {})) { | ||
return failure(); | ||
} | ||
if (kernel.size() != rank - 2) { | ||
return rewriter.notifyMatchFailure( | ||
binder.op, "kernel list size does not match the number of axes"); | ||
} | ||
SmallVector<int64_t> defaultPadding(2 * (rank - 2), 0); | ||
if (binder.s64IntegerArrayAttr(padding, "pads", defaultPadding)) { | ||
return failure(); | ||
} | ||
if (padding.size() != 2 * (rank - 2)) { | ||
return rewriter.notifyMatchFailure( | ||
binder.op, | ||
"padding list size does not match twice the number of axes"); | ||
} | ||
if (binder.s64IntegerArrayAttr( | ||
strides, "strides", llvm::SmallVector<int64_t>(rank - 2, 1))) { | ||
return failure(); | ||
} | ||
if (strides.size() != 1 && strides.size() != rank - 2) { | ||
return rewriter.notifyMatchFailure( | ||
binder.op, "strides list size does not match the number of axes"); | ||
} | ||
|
||
SmallVector<Value> cstKernel, cstPadding, cstStridesDilations; | ||
for (int64_t i : kernel) { | ||
cstKernel.push_back(rewriter.create<Torch::ConstantIntOp>( | ||
binder.getLoc(), rewriter.getI64IntegerAttr(i))); | ||
} | ||
// Onnx pads format: [x1_begin, x2_begin…x1_end, x2_end,…] | ||
// Pytorch pads format: [x1, x2,...] or [x], assume begin==end for all | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add e2e tests in shark-testsuite if the change work. I don't think torch to linalg support this pattern. |
||
// axes x. | ||
int64_t paddingSizeHalf = padding.size() / 2; | ||
for (int64_t i = 0; i < paddingSizeHalf; ++i) { | ||
// Check if onnx padding attribute is symmetric. | ||
if (padding[i] != padding[i + paddingSizeHalf]) | ||
return rewriter.notifyMatchFailure( | ||
binder.op, "onnx padding attribute is not symmetric"); | ||
cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>( | ||
binder.getLoc(), rewriter.getI64IntegerAttr(padding[i]))); | ||
} | ||
for (int64_t i : strides) { | ||
cstStridesDilations.push_back(rewriter.create<Torch::ConstantIntOp>( | ||
binder.getLoc(), rewriter.getI64IntegerAttr(i))); | ||
} | ||
SmallVector<int64_t> kernel, padding, strides, dilations, | ||
stridesDilations; | ||
if (failed(checkAndGetOnnxPoolingOpParameters( | ||
binder, rewriter, resultType.getDtype(), autoPad, | ||
/*spatialRank=*/rank - 2, | ||
/*input=*/operand, kernel, strides, padding, dilations))) | ||
return rewriter.notifyMatchFailure(binder.op, | ||
"invalid pooling parameters"); | ||
|
||
// No dilations attribute in pytorch avgpool op, so use this trick to | ||
// encode dilation into strides. Then in the following torchtolinalg | ||
// lowering, decode strides into strides + dilation. | ||
// Since the PyTorch AvgPool op does not contain the `dilation` arg, | ||
// hence we use the trick of encoding dilation into strides. Then, | ||
// during the torch->linalg lowering of the `AvgPool` op we decode the | ||
// `strides` arg into strides values followed by dilation like: | ||
// [strideDim1,strideDim2,...,dilationDim1,dilationDim2,...] | ||
if (binder.s64IntegerArrayAttr( | ||
dilations, "dilations", | ||
llvm::SmallVector<int64_t>(rank - 2, 1))) { | ||
return failure(); | ||
} | ||
for (auto dilation : dilations) { | ||
cstStridesDilations.push_back(rewriter.create<Torch::ConstantIntOp>( | ||
binder.getLoc(), rewriter.getI64IntegerAttr(dilation))); | ||
} | ||
stridesDilations = strides; | ||
stridesDilations.append(dilations); | ||
|
||
Value kernelSizeList = rewriter.create<Torch::PrimListConstructOp>( | ||
binder.getLoc(), | ||
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), | ||
cstKernel); | ||
Value paddingList = rewriter.create<Torch::PrimListConstructOp>( | ||
binder.getLoc(), | ||
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), | ||
cstPadding); | ||
Value kernelSizeList = createConstantIntList(binder, rewriter, kernel); | ||
Value paddingList = createConstantIntList(binder, rewriter, padding); | ||
Value stridesDilationsList = | ||
rewriter.create<Torch::PrimListConstructOp>( | ||
binder.getLoc(), | ||
Torch::ListType::get( | ||
Torch::IntType::get(binder.op->getContext())), | ||
cstStridesDilations); | ||
createConstantIntList(binder, rewriter, stridesDilations); | ||
Value cstCeilMode = | ||
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), ceilMode); | ||
Value cstCountIncludePad = rewriter.create<Torch::ConstantBoolOp>( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1124,138 +1124,38 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( | |
}); | ||
patterns.onOp( | ||
"MaxPool", 12, [](OpBinder binder, ConversionPatternRewriter &rewriter) { | ||
std::string autoPad; | ||
if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET")) | ||
return rewriter.notifyMatchFailure(binder.op, | ||
"auto_pad bind failure"); | ||
|
||
Torch::ValueTensorType resultTypeOut; | ||
Value operand; | ||
int64_t ceilMode, storageOrder; | ||
// TODO: Add support for indices output and storage_order | ||
std::string autoPad; | ||
if (binder.tensorOperand(operand) || | ||
binder.s64IntegerAttr(ceilMode, "ceil_mode", 0) || | ||
binder.s64IntegerAttr(storageOrder, "storage_order", 0) || | ||
binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET") || | ||
binder.tensorResultTypeAtIndex(resultTypeOut, 0)) | ||
return rewriter.notifyMatchFailure( | ||
binder.op, | ||
"operand/ceil_mode/storage_order/resultType bind failure"); | ||
binder.op, "operand/ceil_mode/storage_order/auto_pad/resultType " | ||
"bind failure"); | ||
// TODO: Add support for storage_order | ||
if (storageOrder != 0) | ||
return rewriter.notifyMatchFailure( | ||
binder.op, "storage_order setting is not supported."); | ||
|
||
// Determine the rank of input tensor. | ||
std::optional<unsigned> maybeRank = Torch::getTensorRank(operand); | ||
if (!maybeRank) | ||
return rewriter.notifyMatchFailure(binder.op, | ||
"Unimplemented: unranked tensor"); | ||
int64_t rank = *maybeRank; | ||
int64_t spatial = rank - 2; | ||
unsigned rank = *maybeRank; | ||
|
||
SmallVector<int64_t> kernel, padding, strides, dilations; | ||
if (binder.s64IntegerArrayAttr(kernel, "kernel_shape", {})) | ||
SmallVector<int64_t> kernel, padding, strides, dilations, | ||
stridesDilations; | ||
if (failed(checkAndGetOnnxPoolingOpParameters( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about |
||
binder, rewriter, resultTypeOut.getDtype(), autoPad, | ||
/*spatialRank=*/rank - 2, | ||
/*input=*/operand, kernel, strides, padding, dilations))) | ||
return rewriter.notifyMatchFailure(binder.op, | ||
"kernel_shape bind failure"); | ||
if (kernel.size() != static_cast<size_t>(spatial)) | ||
return rewriter.notifyMatchFailure( | ||
binder.op, "kernel list size does not match the number of axes"); | ||
if (binder.s64IntegerArrayAttr(padding, "pads", {})) | ||
return rewriter.notifyMatchFailure(binder.op, "pads bind failure"); | ||
if (!padding.empty() && | ||
padding.size() != static_cast<size_t>(2 * spatial)) | ||
return rewriter.notifyMatchFailure( | ||
binder.op, "padding list must contain (begin,end) pair for each " | ||
"spatial axis"); | ||
if (binder.s64IntegerArrayAttr(strides, "strides", {})) | ||
return rewriter.notifyMatchFailure(binder.op, "strides bind failure"); | ||
if (!strides.empty() && strides.size() != static_cast<size_t>(spatial)) | ||
return rewriter.notifyMatchFailure( | ||
binder.op, "strides list size does not match the number of axes"); | ||
if (binder.s64IntegerArrayAttr(dilations, "dilations", {})) | ||
return rewriter.notifyMatchFailure(binder.op, | ||
"dilations bind failure"); | ||
|
||
// set default padding | ||
if (padding.empty()) | ||
padding.resize(spatial, 0); | ||
if (strides.empty()) | ||
strides.resize(spatial, 1); | ||
if (dilations.empty()) | ||
dilations.resize(spatial, 1); | ||
|
||
auto inputTensorType = cast<Torch::ValueTensorType>(operand.getType()); | ||
|
||
// Padding for the beginning and ending along each spatial axis, it can | ||
// take any value greater than or equal to 0. The value represent the | ||
// number of pixels added to the beginning and end part of the | ||
// corresponding axis. pads format should be as follow [x1_begin, | ||
// x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added | ||
// at the beginning of axis i and xi_end, the number of pixels added at | ||
// the end of axis i. | ||
if (autoPad != "NOTSET" && autoPad != "VALID") { | ||
const bool isSameLower = autoPad == "SAME_LOWER"; | ||
ArrayRef<int64_t> inputShape = inputTensorType.getSizes(); | ||
padding.resize_for_overwrite(2 * spatial); | ||
for (unsigned dimIdx = 0; dimIdx < spatial; dimIdx++) { | ||
const int64_t dilatedKernelSize = | ||
dilations[dimIdx] * (kernel[dimIdx] - 1) + 1; | ||
int64_t totalPad = ((inputShape[dimIdx + 2] + strides[dimIdx] - 1) / | ||
strides[dimIdx] - | ||
1) * | ||
strides[dimIdx] + | ||
dilatedKernelSize - inputShape[dimIdx + 2]; | ||
totalPad = totalPad >= 0 ? totalPad : 0; | ||
padding[dimIdx] = | ||
isSameLower ? ((totalPad + 1) / 2) : (totalPad / 2); | ||
padding[spatial + dimIdx] = totalPad - padding[dimIdx]; | ||
} | ||
} | ||
|
||
// If the padding is symmetric we can push the padding operation to the | ||
// torch operator. | ||
if (padding.size() == static_cast<size_t>(2 * spatial)) { | ||
bool equal = true; | ||
for (int i = 0; i < spatial; ++i) { | ||
equal = equal && (padding[i] == padding[i + spatial]); | ||
} | ||
if (equal) | ||
padding.resize(spatial); | ||
} | ||
|
||
// Torch pool operators require equal padding on each size of each | ||
// dimension so we materialize the padding behavior explicitly and set | ||
// the padding to 0. | ||
if (padding.size() == static_cast<size_t>(2 * spatial)) { | ||
auto operandTy = cast<Torch::ValueTensorType>(operand.getType()); | ||
llvm::SmallVector<int64_t> shuffledPadding(spatial * 2); | ||
llvm::SmallVector<int64_t> paddedShape(operandTy.getSizes()); | ||
for (int i = 0; i < spatial; ++i) { | ||
paddedShape[i + 2] += padding[i] + padding[i + spatial]; | ||
shuffledPadding[2 * i] = padding[spatial - i - 1]; | ||
shuffledPadding[2 * i + 1] = padding[2 * spatial - i - 1]; | ||
} | ||
|
||
Value shuffledPaddingList = | ||
createConstantIntList(binder, rewriter, shuffledPadding); | ||
Value zero; | ||
if (isa<FloatType>(resultTypeOut.getDtype())) { | ||
zero = rewriter.create<Torch::ConstantFloatOp>( | ||
binder.getLoc(), rewriter.getType<Torch::FloatType>(), | ||
rewriter.getF64FloatAttr( | ||
std::numeric_limits<double>::lowest())); | ||
} else if (isa<IntegerType>(resultTypeOut.getDtype())) { | ||
zero = rewriter.create<Torch::ConstantIntOp>( | ||
binder.getLoc(), rewriter.getI64IntegerAttr( | ||
std::numeric_limits<int64_t>::lowest())); | ||
} | ||
|
||
auto paddedInputTy = rewriter.getType<Torch::ValueTensorType>( | ||
paddedShape, operandTy.getDtype()); | ||
operand = rewriter.create<Torch::AtenConstantPadNdOp>( | ||
binder.getLoc(), paddedInputTy, operand, shuffledPaddingList, | ||
zero); | ||
padding.clear(); | ||
padding.resize(spatial, 0); | ||
} | ||
"invalid pooling parameters"); | ||
|
||
Value kernelSizeList = createConstantIntList(binder, rewriter, kernel); | ||
Value paddingList = createConstantIntList(binder, rewriter, padding); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why change
ceilMode
andcountIncludePad
from bool to int64_t?