Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18759,6 +18759,77 @@ bool isAxisFusible(int dimension, ArrayRef<Value> vals) {
return false;
}

// slice(extend x) -> extend(slice x)
// This pattern pushes a slice operation through an extend operation.
struct ExtendSlice final
: CheckedOpRewritePattern<stablehlo::SliceOp, ExtendSlice> {
using CheckedOpRewritePattern::CheckedOpRewritePattern;

LogicalResult matchAndRewriteImpl(stablehlo::SliceOp op,
PatternRewriter &rewriter) const {
auto extendOp = op.getOperand().getDefiningOp<enzymexla::ExtendOp>();
if (!extendOp)
return rewriter.notifyMatchFailure(op, "Operand is not an ExtendOp");

// This transformation is simplified if strides are 1.
if (llvm::any_of(op.getStrides(), [](int64_t s) { return s != 1; }))
return rewriter.notifyMatchFailure(op, "Requires strides of 1");

Value operand = extendOp.getOperand();
auto originalShape = cast<RankedTensorType>(operand.getType()).getShape();
int64_t d = extendOp.getDimension();
int64_t lhs = extendOp.getLhs();
int64_t rhs = extendOp.getRhs();

auto starts = op.getStartIndices();
auto limits = op.getLimitIndices();

SmallVector<int64_t> new_starts = llvm::to_vector(starts);
SmallVector<int64_t> new_limits = llvm::to_vector(limits);
SmallVector<int64_t> new_strides = llvm::to_vector(op.getStrides());

int64_t start_d = starts[d];
int64_t limit_d = limits[d];
int64_t size_d = originalShape[d];

// Calculate the parameters for the new slice operation on the original
// operand. The new slice covers the part of the original tensor that is
// visible in the final output.
new_starts[d] = std::max((int64_t)0, start_d - lhs);
new_limits[d] = std::min(size_d, limit_d - lhs);

// Calculate the new padding amounts for the extend operation.
// new_lhs is the size of the overlap between the slice and the prepended
// padding.
int64_t new_lhs = std::max((int64_t)0, std::min(limit_d, lhs) - start_d);
// new_rhs is the size of the overlap between the slice and the appended
// padding.
int64_t new_rhs =
std::max((int64_t)0, limit_d - std::max(start_d, lhs + size_d));

if (new_lhs == 0 && new_rhs == 0) {
auto newSlice = rewriter.replaceOpWithNewOp<stablehlo::SliceOp>(
op, op.getType(), operand, new_starts, new_limits, new_strides);
return success();
}

if (extendOp.getResult().getNumUses() > 1) {
return rewriter.notifyMatchFailure(
op, "ExtendOp result is used multiple times");
}

// Create the new slice on the original tensor.
auto newSlice = rewriter.create<stablehlo::SliceOp>(
op.getLoc(), operand, new_starts, new_limits, new_strides);

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if new_lhs and new_rhs are 0 [even if multiple users], we can just use the sliced operand directly instead of the extend

// Create the new extend on the newly sliced tensor.
rewriter.replaceOpWithNewOp<enzymexla::ExtendOp>(op, op.getType(), newSlice,
new_lhs, new_rhs, d);

return success();
}
};

struct SliceExtend final
: CheckedOpRewritePattern<enzymexla::ExtendOp, SliceExtend> {
using CheckedOpRewritePattern::CheckedOpRewritePattern;
Expand Down Expand Up @@ -22285,6 +22356,7 @@ struct EnzymeHLOOptPass
mlir::enzyme::populateWithGenerated(patterns);

patterns.add<SliceExtend>(context);
patterns.add<ExtendSlice>(context);
patterns.add<SliceRotate>(context);
patterns.add<SliceWrap>(context);
patterns.add<ReshapeWrap>(context);
Expand Down
5 changes: 5 additions & 0 deletions src/enzyme_ad/jax/TransformOps/TransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1868,6 +1868,11 @@ def SliceExtend : EnzymeHLOPatternOp<
let patterns = ["SliceExtend"];
}

def ExtendSlice : EnzymeHLOPatternOp<
"extend_slice"> {
let patterns = ["ExtendSlice"];
}

def SliceRotate : EnzymeHLOPatternOp<
"slice_rotate"> {
let patterns = ["SliceRotate"];
Expand Down
39 changes: 39 additions & 0 deletions test/lit_tests/extendslice.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// RUN: enzymexlamlir-opt --enzyme-hlo-generate-td="patterns=extend_slice" --transform-interpreter --enzyme-hlo-remove-transform %s | FileCheck %s

// CHECK: func.func @f_single_use(%arg0: tensor<4x1520x3056xf64>) -> tensor<3x1520x3056xf64> {
// CHECK-NEXT: %0 = stablehlo.slice %arg0 [0:2, 0:1520, 0:3056] : (tensor<4x1520x3056xf64>) -> tensor<2x1520x3056xf64>
// CHECK-NEXT: %1 = "enzymexla.extend"(%0) <{dimension = 0 : i64, lhs = 1 : i64, rhs = 0 : i64}> : (tensor<2x1520x3056xf64>) -> tensor<3x1520x3056xf64>
// CHECK-NEXT: return %1 : tensor<3x1520x3056xf64>
// CHECK-NEXT: }
func.func @f_single_use(%a: tensor<4x1520x3056xf64>) -> (tensor<3x1520x3056xf64>) {
%b = "enzymexla.extend"(%a) <{dimension = 0 : i64, lhs = 1 : i64, rhs = 0 : i64}> : (tensor<4x1520x3056xf64>) -> tensor<5x1520x3056xf64>
%c = stablehlo.slice %b [0:3, 0:1520, 0:3056] : (tensor<5x1520x3056xf64>) -> tensor<3x1520x3056xf64>
return %c : tensor<3x1520x3056xf64>
}


// CHECK: func.func @f_multiple_uses(%arg0: tensor<4x1520x3056xf64>) -> (tensor<3x1520x3056xf64>, tensor<4x1520x3056xf64>) {
// CHECK-NEXT: %0 = "enzymexla.extend"(%arg0) <{dimension = 0 : i64, lhs = 1 : i64, rhs = 0 : i64}> : (tensor<4x1520x3056xf64>) -> tensor<5x1520x3056xf64>
// CHECK-NEXT: %1 = stablehlo.slice %0 [0:3, 0:1520, 0:3056] : (tensor<5x1520x3056xf64>) -> tensor<3x1520x3056xf64>
// CHECK-NEXT: %2 = stablehlo.slice %0 [0:4, 0:1520, 0:3056] : (tensor<5x1520x3056xf64>) -> tensor<4x1520x3056xf64>
// CHECK-NEXT: return %1, %2 : tensor<3x1520x3056xf64>, tensor<4x1520x3056xf64>
// CHECK-NEXT: }
func.func @f_multiple_uses(%a: tensor<4x1520x3056xf64>) -> (tensor<3x1520x3056xf64>, tensor<4x1520x3056xf64>) {
%b = "enzymexla.extend"(%a) <{dimension = 0 : i64, lhs = 1 : i64, rhs = 0 : i64}> : (tensor<4x1520x3056xf64>) -> tensor<5x1520x3056xf64>
%c = stablehlo.slice %b [0:3, 0:1520, 0:3056] : (tensor<5x1520x3056xf64>) -> tensor<3x1520x3056xf64>
%d = stablehlo.slice %b [0:4, 0:1520, 0:3056] : (tensor<5x1520x3056xf64>) -> tensor<4x1520x3056xf64>
return %c, %d : tensor<3x1520x3056xf64>, tensor<4x1520x3056xf64>
}

// CHECK: func.func @f_multiple_uses_superfluous_extend(%arg0: tensor<4x1520x3056xf64>) -> (tensor<3x1520x3056xf64>, tensor<4x1520x3056xf64>) {
// CHECK-NEXT: %0 = stablehlo.slice %arg0 [0:3, 0:1520, 0:3056] : (tensor<4x1520x3056xf64>) -> tensor<3x1520x3056xf64>
// CHECK-NEXT: %1 = "enzymexla.extend"(%0) <{dimension = 0 : i64, lhs = 1 : i64, rhs = 0 : i64}> : (tensor<3x1520x3056xf64>) -> tensor<4x1520x3056xf64>
// CHECK-NEXT: %2 = stablehlo.slice %arg0 [0:3, 0:1520, 0:3056] : (tensor<4x1520x3056xf64>) -> tensor<3x1520x3056xf64>
// CHECK-NEXT: return %2, %1 : tensor<3x1520x3056xf64>, tensor<4x1520x3056xf64>
// CHECK-NEXT: }
func.func @f_multiple_uses_superfluous_extend(%a: tensor<4x1520x3056xf64>) -> (tensor<3x1520x3056xf64>, tensor<4x1520x3056xf64>) {
%b = "enzymexla.extend"(%a) <{dimension = 0 : i64, lhs = 1 : i64, rhs = 0 : i64}> : (tensor<4x1520x3056xf64>) -> tensor<5x1520x3056xf64>
%c = stablehlo.slice %b [0:4, 0:1520, 0:3056] : (tensor<5x1520x3056xf64>) -> tensor<4x1520x3056xf64>
%d = stablehlo.slice %b [1:4, 0:1520, 0:3056] : (tensor<5x1520x3056xf64>) -> tensor<3x1520x3056xf64>
return %d, %c : tensor<3x1520x3056xf64>, tensor<4x1520x3056xf64>
}
Loading