From 5941c847d5633dc2ed3f432dabc14ea6645b42ec Mon Sep 17 00:00:00 2001 From: Paul Stark Date: Tue, 2 Dec 2025 14:44:44 -0800 Subject: [PATCH 1/2] Avoid assert with convolutions with no strides or dilates. Issue #2829 --- .../StablehloToLinalgConvolution.cpp | 42 +++++++++++++------ 1 file changed, 30 insertions(+), 12 deletions(-) diff --git a/stablehlo/conversions/linalg/transforms/StablehloToLinalgConvolution.cpp b/stablehlo/conversions/linalg/transforms/StablehloToLinalgConvolution.cpp index 4967d522fc..5e843d879f 100644 --- a/stablehlo/conversions/linalg/transforms/StablehloToLinalgConvolution.cpp +++ b/stablehlo/conversions/linalg/transforms/StablehloToLinalgConvolution.cpp @@ -268,24 +268,42 @@ struct NormalConvolutionOpConversion final break; } case 3: { - res = linalg::Conv1DNwcWcfOp::create( - rewriter, loc, resultType, ValueRange{input, filter}, - ValueRange{zeroTensor}, strides, dilations, - linalg::getPrunedAttributeList(op)); + if (strides && dilations) { + res = linalg::Conv1DNwcWcfOp::create( + rewriter, loc, resultType, ValueRange{input, filter}, + ValueRange{zeroTensor}, strides, dilations, + linalg::getPrunedAttributeList(op)); + } else { + res = linalg::Conv1DNwcWcfOp::create( + rewriter, loc, resultType, ValueRange{input, filter}, + ValueRange{zeroTensor}, linalg::getPrunedAttributeList(op)); + } break; } case 4: { - res = linalg::Conv2DNhwcHwcfOp::create( - rewriter, loc, resultType, ValueRange{input, filter}, - ValueRange{zeroTensor}, strides, dilations, - linalg::getPrunedAttributeList(op)); + if (strides && dilations) { + res = linalg::Conv2DNhwcHwcfOp::create( + rewriter, loc, resultType, ValueRange{input, filter}, + ValueRange{zeroTensor}, strides, dilations, + linalg::getPrunedAttributeList(op)); + } else { + res = linalg::Conv2DNhwcHwcfOp::create( + rewriter, loc, resultType, ValueRange{input, filter}, + ValueRange{zeroTensor}, linalg::getPrunedAttributeList(op)); + } break; } case 5: { - res = linalg::Conv3DNdhwcDhwcfOp::create( - rewriter, loc, resultType, ValueRange{input, filter}, - ValueRange{zeroTensor}, strides, dilations, - linalg::getPrunedAttributeList(op)); + if (strides && dilations) { + res = linalg::Conv3DNdhwcDhwcfOp::create( + rewriter, loc, resultType, ValueRange{input, filter}, + ValueRange{zeroTensor}, strides, dilations, + linalg::getPrunedAttributeList(op)); + } else { + res = linalg::Conv3DNdhwcDhwcfOp::create( + rewriter, loc, resultType, ValueRange{input, filter}, + ValueRange{zeroTensor}, linalg::getPrunedAttributeList(op)); + } break; } default: { From 8c6ca7363a418be8bd5e9cca7a376d686e81826d Mon Sep 17 00:00:00 2001 From: Paul Stark Date: Fri, 5 Dec 2025 13:16:48 -0800 Subject: [PATCH 2/2] Change the way null convolution dilates and strides are handled to deal with only one of them being null. Signed-off-by: Paul Stark --- .../StablehloToLinalgConvolution.cpp | 52 +++++++------------ 1 file changed, 18 insertions(+), 34 deletions(-) diff --git a/stablehlo/conversions/linalg/transforms/StablehloToLinalgConvolution.cpp b/stablehlo/conversions/linalg/transforms/StablehloToLinalgConvolution.cpp index 5e843d879f..ccbf43fb2d 100644 --- a/stablehlo/conversions/linalg/transforms/StablehloToLinalgConvolution.cpp +++ b/stablehlo/conversions/linalg/transforms/StablehloToLinalgConvolution.cpp @@ -248,10 +248,12 @@ struct NormalConvolutionOpConversion final resultType.getElementType(), dynSizes); Value zeroTensor = fillTensorWithZeros(rewriter, loc, emptyTensor); linalg::LinalgOp res; - Attribute strides; - if (auto s = op.getWindowStrides()) strides = rewriter.getI64TensorAttr(*s); - Attribute dilations; - if (auto d = op.getRhsDilation()) dilations = rewriter.getI64TensorAttr(*d); + auto s = op.getWindowStrides(); + Attribute strides = + s ? rewriter.getI64TensorAttr(*s) : rewriter.getI64TensorAttr({1}); + auto d = op.getRhsDilation(); + Attribute dilations = + d ? rewriter.getI64TensorAttr(*d) : rewriter.getI64TensorAttr({1}); // Apply padding and input dilation. llvm::SmallVector spatialDimMapping(rank - 2); @@ -268,42 +270,24 @@ struct NormalConvolutionOpConversion final break; } case 3: { - if (strides && dilations) { - res = linalg::Conv1DNwcWcfOp::create( - rewriter, loc, resultType, ValueRange{input, filter}, - ValueRange{zeroTensor}, strides, dilations, - linalg::getPrunedAttributeList(op)); - } else { - res = linalg::Conv1DNwcWcfOp::create( - rewriter, loc, resultType, ValueRange{input, filter}, - ValueRange{zeroTensor}, linalg::getPrunedAttributeList(op)); - } + res = linalg::Conv1DNwcWcfOp::create( + rewriter, loc, resultType, ValueRange{input, filter}, + ValueRange{zeroTensor}, strides, dilations, + linalg::getPrunedAttributeList(op)); break; } case 4: { - if (strides && dilations) { - res = linalg::Conv2DNhwcHwcfOp::create( - rewriter, loc, resultType, ValueRange{input, filter}, - ValueRange{zeroTensor}, strides, dilations, - linalg::getPrunedAttributeList(op)); - } else { - res = linalg::Conv2DNhwcHwcfOp::create( - rewriter, loc, resultType, ValueRange{input, filter}, - ValueRange{zeroTensor}, linalg::getPrunedAttributeList(op)); - } + res = linalg::Conv2DNhwcHwcfOp::create( + rewriter, loc, resultType, ValueRange{input, filter}, + ValueRange{zeroTensor}, strides, dilations, + linalg::getPrunedAttributeList(op)); break; } case 5: { - if (strides && dilations) { - res = linalg::Conv3DNdhwcDhwcfOp::create( - rewriter, loc, resultType, ValueRange{input, filter}, - ValueRange{zeroTensor}, strides, dilations, - linalg::getPrunedAttributeList(op)); - } else { - res = linalg::Conv3DNdhwcDhwcfOp::create( - rewriter, loc, resultType, ValueRange{input, filter}, - ValueRange{zeroTensor}, linalg::getPrunedAttributeList(op)); - } + res = linalg::Conv3DNdhwcDhwcfOp::create( + rewriter, loc, resultType, ValueRange{input, filter}, + ValueRange{zeroTensor}, strides, dilations, + linalg::getPrunedAttributeList(op)); break; } default: {