diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 77fbb8ed9..0d5acb644 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -3841,9 +3841,6 @@ struct ReduceConcat final if (!llvm::is_contained(op.getDimensions(), dim)) return failure(); - if (!isEligibleForCompactPrint(op)) - return failure(); - Value prev = op.getInitValues()[0]; auto checkCommonReduce = mlir::stablehlo::CheckCommonReduceOp(op); @@ -4094,84 +4091,93 @@ bool canMergeSlicesAlongAxis(int dimension, stablehlo::SliceOp slice, otherSliceStrides); } -struct ConcatSlice final - : CheckedOpRewritePattern { - using CheckedOpRewritePattern::CheckedOpRewritePattern; +LogicalResult concatSliceSimplify(PatternRewriter &rewriter, + SmallVectorImpl &operands, int64_t dim, + SmallVectorImpl &newOperands) { + bool changed = false; + for (size_t i = 0, e = operands.size(); i < e; ++i) { + auto operand = operands[i]; + auto slice = operand.getDefiningOp(); - LogicalResult matchAndRewriteImpl(stablehlo::ConcatenateOp op, - PatternRewriter &rewriter) const { - auto dim = op.getDimension(); - - SmallVector newOperands; - - bool changed = false; - - for (int i = 0, e = op->getNumOperands(); i < e; ++i) { - auto operand = op->getOperand(i); - auto slice = operand.getDefiningOp(); + if (!slice) { + newOperands.push_back(operand); + continue; + } - if (!slice) { - newOperands.push_back(operand); - continue; + while (i + 1 < e) { + if (auto otherSlice = + operands[i + 1].getDefiningOp()) { + if (canMergeSlicesAlongAxis(dim, slice, otherSlice)) { + slice = stablehlo::SliceOp::create( + rewriter, slice->getLoc(), slice.getOperand(), + slice.getStartIndices(), otherSlice.getLimitIndices(), + slice.getStrides()); + changed = true; + i++; + continue; + } else { + break; + } } - while (i + 1 < e) { - if (auto otherSlice = - op->getOperand(i + 1).getDefiningOp()) { - if (canMergeSlicesAlongAxis(op.getDimension(), slice, otherSlice)) { - slice = stablehlo::SliceOp::create( - rewriter, slice->getLoc(), slice.getOperand(), - slice.getStartIndices(), otherSlice.getLimitIndices(), - slice.getStrides()); + if (auto otherWrap = operands[i + 1].getDefiningOp()) { + auto wrapSlice = + otherWrap.getOperand().getDefiningOp(); + if (wrapSlice && wrapSlice.getOperand() == slice.getOperand() && + otherWrap.getLhs() != 0) { + SmallVector wrapStarts = + llvm::to_vector(wrapSlice.getStartIndices()); + SmallVector wrapLimits = + llvm::to_vector(wrapSlice.getLimitIndices()); + if (wrapSlice.getStrides()[dim] == 1) { + wrapStarts[dim] = wrapLimits[dim] - otherWrap.getLhs(); + } + if (canMergeSlicesAlongAxis(dim, slice.getStartIndices(), wrapStarts, + slice.getLimitIndices(), wrapLimits, + slice.getStrides(), + wrapSlice.getStrides())) { + changed = true; + auto c2 = lowerWrap(otherWrap, rewriter, /*replace*/ false); + auto newSlice = stablehlo::SliceOp::create( + rewriter, slice->getLoc(), slice.getOperand(), + slice.getStartIndices(), wrapLimits, slice.getStrides()); + newOperands.push_back(newSlice); + for (int i = 1; i < c2.getOperands().size(); i++) { + newOperands.push_back(c2.getOperands()[i]); + } i++; - continue; - } else + slice = nullptr; + break; + } else { break; - } - if (auto otherWrap = - op->getOperand(i + 1).getDefiningOp()) { - auto wrapSlice = - otherWrap.getOperand().getDefiningOp(); - if (wrapSlice && wrapSlice.getOperand() == slice.getOperand() && - otherWrap.getLhs() != 0) { - SmallVector wrapStarts = - llvm::to_vector(wrapSlice.getStartIndices()); - SmallVector wrapLimits = - llvm::to_vector(wrapSlice.getLimitIndices()); - if (wrapSlice.getStrides()[dim] == 1) { - wrapStarts[dim] = wrapLimits[dim] - otherWrap.getLhs(); - } - if (canMergeSlicesAlongAxis( - op.getDimension(), slice.getStartIndices(), wrapStarts, - slice.getLimitIndices(), wrapLimits, slice.getStrides(), - wrapSlice.getStrides())) { - - changed = true; - auto c2 = lowerWrap(otherWrap, rewriter, /*replace*/ false); - auto newSlice = stablehlo::SliceOp::create( - rewriter, slice->getLoc(), slice.getOperand(), - slice.getStartIndices(), wrapLimits, slice.getStrides()); - newOperands.push_back(newSlice); - for (int i = 1; i < c2.getOperands().size(); i++) { - newOperands.push_back(c2.getOperands()[i]); - } - i++; - slice = nullptr; - break; - } else - break; } } - break; } + break; + } - if (slice) - newOperands.push_back(slice.getResult()); + if (slice) { + newOperands.push_back(slice.getResult()); } + } - if (!changed) - return failure(); + return changed ? success() : failure(); +} + +struct ConcatSlice final + : CheckedOpRewritePattern { + using CheckedOpRewritePattern::CheckedOpRewritePattern; + + LogicalResult matchAndRewriteImpl(stablehlo::ConcatenateOp op, + PatternRewriter &rewriter) const { + auto dim = op.getDimension(); + SmallVector newOperands; + + auto oldOperands = llvm::to_vector(op.getOperands()); + auto res = concatSliceSimplify(rewriter, oldOperands, dim, newOperands); + if (!res.succeeded()) + return res; rewriter.replaceOpWithNewOp(op, newOperands, dim); return success(); @@ -11768,7 +11774,7 @@ struct DUSSliceSimplify final dusOp, "DUS indices must be constant scalars"); dusStartIndices.push_back((*idxAttr.begin()).getSExtValue()); } - SmallVector dusEndIndices = llvm::map_to_vector( + SmallVector duslimitIndices = llvm::map_to_vector( llvm::zip(dusStartIndices, updateShape), [](auto p) { return std::get<0>(p) + std::get<1>(p); }); @@ -11786,7 +11792,7 @@ struct DUSSliceSimplify final return 0; }); auto ignoredUpdateEnd = - llvm::map_to_vector(llvm::zip(ignoredEnd, dusEndIndices, updateShape), + llvm::map_to_vector(llvm::zip(ignoredEnd, duslimitIndices, updateShape), [](auto p) -> int64_t { auto &[ignoredEnd, dusEnd, updateShape] = p; if (ignoredEnd < dusEnd) @@ -24047,49 +24053,78 @@ struct ReduceSliceFusionBase LogicalResult matchAndRewriteImpl(BinaryOpType binaryOp, PatternRewriter &rewriter) const { - SmallVector slices; + SmallVector slices; SmallVector extraValues; if (!collectSlicesInChain(binaryOp, slices, extraValues)) return failure(); - if (!areSlicesContiguous(slices)) + // ensure all slices are along the same dimension + auto sliceDims = findIntersection(slices); + if (sliceDims.size() != 1) return failure(); + int64_t sliceDim = sliceDims[0]; - return createFusedReduce(rewriter, binaryOp, slices, extraValues); + // Sort slices by start index + llvm::sort(slices, [sliceDim](const SliceInfo &a, const SliceInfo &b) { + return a.startIndices[sliceDim] < b.startIndices[sliceDim]; + }); + + // overlapping slices are optimized with a different fusion + for (int i = 0; i < slices.size() - 1; i++) { + auto info1 = slices[i]; + auto info2 = slices[i + 1]; + if (info2.startIndices[sliceDim] < info1.limitIndices[sliceDim]) { + return failure(); + } + } + + return createFusedReduce(rewriter, binaryOp, slices, extraValues, sliceDim); }; private: struct SliceInfo { stablehlo::SliceOp sliceOp; SmallVector startIndices; - SmallVector endIndices; + SmallVector limitIndices; SmallVector strides; - int64_t sliceDim; - int64_t sliceStart; - int64_t sliceEnd; + llvm::SmallSetVector sliceDims; Value initValue; + + SliceInfo() = default; + SliceInfo(stablehlo::SliceOp sliceOp) : sliceOp(sliceOp) { + startIndices = llvm::to_vector(sliceOp.getStartIndices()); + limitIndices = llvm::to_vector(sliceOp.getLimitIndices()); + strides = llvm::to_vector(sliceOp.getStrides()); + } }; + llvm::SmallSetVector + findIntersection(SmallVector &slices) const { + auto intersection = slices.front().sliceDims; + for (size_t i = 1; i < slices.size(); i++) { + llvm::set_intersect(intersection, slices[i].sliceDims); + } + return intersection; + } + std::tuple matchReduceSlice(stablehlo::ReduceOp reduceOp) const { if (reduceOp.getInputs().size() != 1 || - reduceOp.getDimensions().size() != 1) { - return std::make_tuple(false, SliceInfo()); - } - - if (!((Child *)this)->isCompatibleReduction(reduceOp)) { - return std::make_tuple(false, SliceInfo()); + reduceOp.getDimensions().size() != 1 || + !((Child *)this)->isCompatibleReduction(reduceOp)) { + return {false, SliceInfo()}; } auto slice = reduceOp.getInputs()[0].template getDefiningOp(); if (!slice) { - return std::make_tuple(false, SliceInfo()); + return {false, SliceInfo()}; } - SliceInfo info = extractSliceInfo(slice); - info.initValue = reduceOp.getInitValues()[0]; - return std::make_tuple(info.sliceDim == reduceOp.getDimensions()[0], info); + SmallVector reductionDims = {reduceOp.getDimensions()[0]}; + SliceInfo info = extractSliceInfo(slice, reductionDims); + info.initValue = reduceOp.getInitValues()[0]; + return {true, info}; } std::tuple @@ -24097,14 +24132,14 @@ struct ReduceSliceFusionBase auto reduce = reshapeOp.getOperand().template getDefiningOp(); if (!reduce) { - return std::make_tuple(false, SliceInfo()); + return {false, SliceInfo()}; } auto inputType = cast(reshapeOp.getOperand().getType()); auto outputType = cast(reshapeOp.getType()); auto reduceDim = reduce.getDimensions()[0]; if (!areValidInsertionDims(inputType, outputType, {reduceDim})) { - return std::make_tuple(false, SliceInfo()); + return {false, SliceInfo()}; } return matchReduceSlice(reduce); @@ -24118,15 +24153,15 @@ struct ReduceSliceFusionBase auto slice = reshapeOp.getOperand().template getDefiningOp(); if (!slice) { - return std::make_tuple(false, SliceInfo()); + return {false, SliceInfo()}; } - SliceInfo info = extractSliceInfo(slice); - return std::make_tuple( - areValidInsertionDims(outputType, inputType, {info.sliceDim}), info); + auto insertionDims = findReshapeInsertionDims(outputType, inputType); + SliceInfo info = extractSliceInfo(slice, insertionDims); + return {true, info}; } - bool matchingSourceOperand(SmallVector &slices, + bool matchingSourceOperand(SmallVectorImpl &slices, stablehlo::SliceOp slice) const { if (!slice) return false; @@ -24137,8 +24172,8 @@ struct ReduceSliceFusionBase // Collect all slices in the binary operation chain bool collectSlicesInChain(BinaryOpType startOp, - SmallVector &slices, - SmallVector &extraValues) const { + SmallVectorImpl &slices, + SmallVectorImpl &extraValues) const { // Use a worklist to traverse the binary operation chain SmallVector worklist; DenseSet visited; @@ -24203,98 +24238,30 @@ struct ReduceSliceFusionBase // Extract slice information SliceInfo extractSliceInfo(stablehlo::SliceOp slice) const { - SliceInfo info; - info.sliceOp = slice; - - auto startIndices = llvm::to_vector(slice.getStartIndices()); - auto limitIndices = llvm::to_vector(slice.getLimitIndices()); - auto strides = llvm::to_vector(slice.getStrides()); - - for (auto [start, end, stride] : - llvm::zip(startIndices, limitIndices, strides)) { - info.startIndices.push_back(start); - info.endIndices.push_back(end); - info.strides.push_back(stride); - } - - // Find the dimension being sliced (where start != 0 or end != full size) - auto inputType = cast(slice.getOperand().getType()); - ArrayRef inputShape = inputType.getShape(); - - for (size_t i = 0; i < info.startIndices.size(); ++i) { - if (info.startIndices[i] != 0 || info.endIndices[i] != inputShape[i]) { - info.sliceDim = i; - info.sliceStart = info.startIndices[i]; - info.sliceEnd = info.endIndices[i]; - break; + SliceInfo info(slice); + auto outputType = cast(slice.getType()); + for (size_t i = 0; i < outputType.getRank(); ++i) { + if (outputType.getDimSize(i) == 1) { + info.sliceDims.insert(i); } } - return info; } - bool areSlicesContiguous(SmallVector &slices) const { - if (slices.empty()) - return false; - - // All slices should be on the same dimension - int64_t sliceDim = slices[0].sliceDim; - for (const auto &slice : slices) { - if (slice.sliceDim != sliceDim) { - return false; - } - } - - // Sort slices by start index - llvm::sort(slices, [](const SliceInfo &a, const SliceInfo &b) { - return a.sliceStart < b.sliceStart; - }); - - // Check contiguity - int64_t expectedStart = slices[0].sliceStart; - for (const auto &slice : slices) { - if (slice.sliceStart != expectedStart) { - return false; - } - expectedStart = slice.sliceEnd; - } - - return true; + SliceInfo extractSliceInfo(stablehlo::SliceOp slice, + SmallVectorImpl &potentialDims) const { + SliceInfo info(slice); + for (auto dim : potentialDims) + info.sliceDims.insert(dim); + return info; } // Create the fused reduce operation LogicalResult createFusedReduce(PatternRewriter &rewriter, BinaryOpType binaryOp, - SmallVector &slices, - SmallVector &extraValues) const { - Value sourceOperand = slices[0].sliceOp.getOperand(); - - auto sliceInfo = slices[0]; - auto commonStartIndices = sliceInfo.sliceOp.getStartIndices(); - auto commonLimitIndices = sliceInfo.sliceOp.getLimitIndices(); - int64_t sliceDim = sliceInfo.sliceDim; - int64_t minStart = slices[0].sliceStart; - int64_t maxEnd = slices.back().sliceEnd; - - auto sourceType = cast(sourceOperand.getType()); - - // insert the slice always. if not needed we will remove it later - { - SmallVector newStartIndices = - llvm::to_vector(commonStartIndices); - newStartIndices[sliceDim] = minStart; - SmallVector newLimitIndices = - llvm::to_vector(commonLimitIndices); - newLimitIndices[sliceDim] = maxEnd; - SmallVector newStrides(sourceType.getRank(), 1); - sourceOperand = stablehlo::SliceOp::create(rewriter, binaryOp.getLoc(), - sourceOperand, newStartIndices, - newLimitIndices, newStrides); - } - - ArrayRef sourceShape = - cast(sourceOperand.getType()).getShape(); - + SmallVectorImpl &slices, + SmallVectorImpl &extraValues, + int64_t sliceDim) const { Value initValue; for (auto sliceInfo : slices) { if (sliceInfo.initValue) { @@ -24319,13 +24286,30 @@ struct ReduceSliceFusionBase } } - SmallVector newShape; - for (int64_t i = 0; i < sourceShape.size(); i++) { - if (i != sliceDim) { - newShape.push_back(sourceShape[i]); + Value sourceOperand = slices[0].sliceOp.getOperand(); + + { + SmallVector newConcatInputs; + for (auto sliceInfo : slices) { + newConcatInputs.push_back(sliceInfo.sliceOp); + } + SmallVector finalConcatInputs; + auto result = concatSliceSimplify(rewriter, newConcatInputs, sliceDim, + finalConcatInputs); + if (!result.succeeded()) + return result; + if (finalConcatInputs.size() == 1) { + sourceOperand = finalConcatInputs[0]; + } else { + sourceOperand = stablehlo::ConcatenateOp::create( + rewriter, binaryOp.getLoc(), finalConcatInputs, sliceDim); } } + SmallVector newShape = llvm::to_vector( + cast(sourceOperand.getType()).getShape()); + newShape.erase(newShape.begin() + sliceDim); + auto newReduce = stablehlo::ReduceOp::create( rewriter, binaryOp.getLoc(), TypeRange{RankedTensorType::get(newShape, elemType)}, @@ -24349,16 +24333,14 @@ struct ReduceSliceFusionBase rewriter.setInsertionPointAfter(newReduce); auto binaryResultType = cast(binaryOp.getResult().getType()); - Value result = - stablehlo::ReshapeOp::create(rewriter, binaryOp.getLoc(), - binaryResultType, newReduce.getResult(0)) - .getResult(); + Value result = stablehlo::ReshapeOp::create( + rewriter, binaryOp.getLoc(), binaryResultType, newReduce.getResult(0)); for (auto &value : extraValues) { result = rewriter.template create(binaryOp.getLoc(), result, value); } - rewriter.replaceOp(binaryOp, result); + rewriter.replaceAllUsesWith(binaryOp.getResult(), result); return success(); } }; diff --git a/test/lit_tests/elementwise_reduce_slice_fuse.mlir b/test/lit_tests/elementwise_reduce_slice_fuse.mlir index 87afc628d..bd9bcf64f 100644 --- a/test/lit_tests/elementwise_reduce_slice_fuse.mlir +++ b/test/lit_tests/elementwise_reduce_slice_fuse.mlir @@ -57,6 +57,35 @@ func.func @main_add2(%arg0: tensor<6x2xf32>, %arg1: tensor<2x4xf32>) -> tensor<4 // CHECK-NEXT: return %1 : tensor<4xf32> // CHECK-NEXT: } +func.func @main_add3(%arg0: tensor<16x2xf64>) -> tensor<2xf64> { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<16x2xf64>) -> tensor<2x16xf64> + %1 = stablehlo.slice %0 [0:2, 0:4:2] : (tensor<2x16xf64>) -> tensor<2x2xf64> + %4 = stablehlo.slice %0 [0:2, 4:6:2] : (tensor<2x16xf64>) -> tensor<2x1xf64> + %6 = stablehlo.slice %0 [0:2, 6:8:2] : (tensor<2x16xf64>) -> tensor<2x1xf64> + %8 = stablehlo.slice %0 [0:2, 8:10:2] : (tensor<2x16xf64>) -> tensor<2x1xf64> + %10 = stablehlo.slice %0 [0:2, 10:12:2] : (tensor<2x16xf64>) -> tensor<2x1xf64> + %12 = stablehlo.slice %0 [0:2, 12:14:2] : (tensor<2x16xf64>) -> tensor<2x1xf64> + %14 = stablehlo.slice %0 [0:2, 14:16:2] : (tensor<2x16xf64>) -> tensor<2x1xf64> + %2 = stablehlo.reduce(%1 init: %cst) applies stablehlo.add across dimensions = [1] : (tensor<2x2xf64>, tensor) -> tensor<2xf64> + %3 = stablehlo.reshape %2 : (tensor<2xf64>) -> tensor<2x1xf64> + %5 = stablehlo.add %3, %4 : tensor<2x1xf64> + %7 = stablehlo.add %5, %6 : tensor<2x1xf64> + %9 = stablehlo.add %8, %12 : tensor<2x1xf64> + %11 = stablehlo.add %10, %14 : tensor<2x1xf64> + %13 = stablehlo.add %11, %9 : tensor<2x1xf64> + %15 = stablehlo.add %13, %7 : tensor<2x1xf64> + %16 = stablehlo.reshape %15 : (tensor<2x1xf64>) -> tensor<2xf64> + return %16 : tensor<2xf64> +} + +// CHECK: func.func @main_add3(%arg0: tensor<16x2xf64>) -> tensor<2xf64> { +// CHECK-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor +// CHECK-NEXT: %0 = stablehlo.slice %arg0 [0:16:2, 0:2] : (tensor<16x2xf64>) -> tensor<8x2xf64> +// CHECK-NEXT: %1 = stablehlo.reduce(%0 init: %cst) applies stablehlo.add across dimensions = [0] : (tensor<8x2xf64>, tensor) -> tensor<2xf64> +// CHECK-NEXT: return %1 : tensor<2xf64> +// CHECK-NEXT: } + func.func @main_mul1(%arg0: tensor<8x2xf64>) -> tensor<2xf64> { %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<8x2xf64>) -> tensor<2x8xf64> %1 = stablehlo.slice %0 [0:2, 0:1] : (tensor<2x8xf64>) -> tensor<2x1xf64> diff --git a/test/lit_tests/elementwise_reduce_slice_fuse2.mlir b/test/lit_tests/elementwise_reduce_slice_fuse2.mlir new file mode 100644 index 000000000..29393b6cd --- /dev/null +++ b/test/lit_tests/elementwise_reduce_slice_fuse2.mlir @@ -0,0 +1,18 @@ +// RUN: enzymexlamlir-opt %s --enzyme-hlo-generate-td="patterns=add_reduce_slice_fusion" --transform-interpreter --enzyme-hlo-remove-transform | FileCheck %s + +func.func @main_partial_add(%arg0: tensor<16x4xf64>) -> tensor<1x3xf64> { + %0 = stablehlo.slice %arg0 [0:1, 1:4] : (tensor<16x4xf64>) -> tensor<1x3xf64> + %1 = stablehlo.slice %arg0 [1:2, 1:4] : (tensor<16x4xf64>) -> tensor<1x3xf64> + %2 = stablehlo.slice %arg0 [2:3, 1:4] : (tensor<16x4xf64>) -> tensor<1x3xf64> + %3 = stablehlo.add %0, %1 : tensor<1x3xf64> + %4 = stablehlo.add %3, %2 : tensor<1x3xf64> + return %4 : tensor<1x3xf64> +} + +// CHECK: func.func @main_partial_add(%arg0: tensor<16x4xf64>) -> tensor<1x3xf64> { +// CHECK-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor +// CHECK-NEXT: %0 = stablehlo.slice %arg0 [0:3, 1:4] : (tensor<16x4xf64>) -> tensor<3x3xf64> +// CHECK-NEXT: %1 = stablehlo.reduce(%0 init: %cst) applies stablehlo.add across dimensions = [0] : (tensor<3x3xf64>, tensor) -> tensor<3xf64> +// CHECK-NEXT: %2 = stablehlo.reshape %1 : (tensor<3xf64>) -> tensor<1x3xf64> +// CHECK-NEXT: return %2 : tensor<1x3xf64> +// CHECK-NEXT: }