Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions stablehlo/conversions/tosa/tests/unary.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,28 @@ func.func @slice_rank_seven(%arg : tensor<2x3x4x5x6x7x8xf32>) -> tensor<1x2x3x4x
return %0 : tensor<1x2x3x4x5x6x7xf32>
}

// CHECK-LABEL: @dynamic_slice_constant_start
func.func @dynamic_slice_constant_start(%arg : tensor<10x10xf32>) -> tensor<2x3xf32> {
// CHECK-DAG: %[[START:.*]] = tosa.const_shape {values = dense<[1, 2]> : tensor<2xindex>} : () -> !tosa.shape<2>
// CHECK-DAG: %[[SIZE:.*]] = tosa.const_shape {values = dense<[2, 3]> : tensor<2xindex>} : () -> !tosa.shape<2>
// CHECK: tosa.slice %arg0, %[[START]], %[[SIZE]]
%start0 = "stablehlo.constant"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
%start1 = "stablehlo.constant"() {value = dense<2> : tensor<i32>} : () -> tensor<i32>
%0 = "stablehlo.dynamic_slice"(%arg, %start0, %start1) {
slice_sizes = array<i64: 2, 3>
} : (tensor<10x10xf32>, tensor<i32>, tensor<i32>) -> tensor<2x3xf32>
return %0 : tensor<2x3xf32>
}

// CHECK-LABEL: @dynamic_slice_runtime_start
func.func @dynamic_slice_runtime_start(%arg0 : tensor<10x10xf32>, %arg1 : tensor<i32>, %arg2 : tensor<i32>) -> tensor<2x3xf32> {
// CHECK: stablehlo.dynamic_slice
%0 = "stablehlo.dynamic_slice"(%arg0, %arg1, %arg2) {
slice_sizes = array<i64: 2, 3>
} : (tensor<10x10xf32>, tensor<i32>, tensor<i32>) -> tensor<2x3xf32>
return %0 : tensor<2x3xf32>
}

// CHECK-LABEL: @tanh
func.func @tanh(%arg : tensor<10xf32>) -> tensor<10xf32> {
// CHECK: tosa.tanh
Expand Down
52 changes: 42 additions & 10 deletions stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -505,16 +505,7 @@ struct ConvertStablehloReshapeOp
auto resultShape = resultType.getShape();
SmallVector<int64_t, 8> dimensions(resultShape.begin(), resultShape.end());

RankedTensorType shapeTensorType = RankedTensorType::get(
{static_cast<int64_t>(dimensions.size())}, rewriter.getIndexType());

auto denseAttr = DenseIntElementsAttr::get(shapeTensorType, dimensions);
auto shapeType =
tosa::shapeType::get(rewriter.getContext(), dimensions.size());

auto constShapeOp =
tosa::ConstShapeOp::create(rewriter, op.getLoc(), shapeType, denseAttr);

auto constShapeOp = getTosaConstShape(rewriter, op.getLoc(), dimensions);
auto reshapeOp = tosa::ReshapeOp::create(rewriter, op.getLoc(), resultType,
op.getOperand(), constShapeOp);

Expand Down Expand Up @@ -558,6 +549,45 @@ struct ConvertStablehloFloatDivideOp
}
};

struct ConvertStablehloDynamicSliceOp
: public OpRewritePattern<stablehlo::DynamicSliceOp> {
using OpRewritePattern<stablehlo::DynamicSliceOp>::OpRewritePattern;

LogicalResult matchAndRewrite(stablehlo::DynamicSliceOp op,
PatternRewriter& rewriter) const override {
auto operandType = dyn_cast<RankedTensorType>(op.getOperand().getType());
if (!operandType) {
return rewriter.notifyMatchFailure(op, "expected ranked tensor type");
}

if (operandType.getRank() < 1) {
return rewriter.notifyMatchFailure(
op, "tosa.slice requires input tensor of at least rank 1");
}

SmallVector<int64_t> startIndices;
for (Value startIndex : op.getStartIndices()) {
DenseIntElementsAttr startAttr;
if (auto constOp = startIndex.getDefiningOp<stablehlo::ConstantOp>()) {
startAttr = dyn_cast<DenseIntElementsAttr>(constOp.getValue());
}

if (!startAttr || startAttr.getNumElements() != 1) {
return rewriter.notifyMatchFailure(
op, "tosa.slice requires constant start indices");
}

startIndices.push_back((*startAttr.value_begin<APInt>()).getSExtValue());
}

rewriter.replaceOpWithNewOp<tosa::SliceOp>(
op, op.getType(), op.getOperand(),
getTosaConstShape(rewriter, op.getLoc(), startIndices),
getTosaConstShape(rewriter, op.getLoc(), op.getSliceSizes()));
return success();
}
};

LogicalResult StablehloLegalizeToTosaPass::initialize(MLIRContext* ctx) {
RewritePatternSet patternList(ctx);
populateGeneratedPDLLPatterns(patternList);
Expand All @@ -573,6 +603,8 @@ LogicalResult StablehloLegalizeToTosaPass::initialize(MLIRContext* ctx) {
patternList.addWithLabel<ConvertStablehloReduceOp>({"StablehloReduce"}, ctx);
patternList.addWithLabel<ConvertStablehloReturnOp>({"StablehloReturn"}, ctx);
patternList.addWithLabel<ConvertStablehloSliceOp>({"StablehloSlice"}, ctx);
patternList.addWithLabel<ConvertStablehloDynamicSliceOp>(
{"StablehloDynamicSlice"}, ctx);
patternList.addWithLabel<ConvertStablehloTransposeOp>({"StablehloTranspose"},
ctx);
patternList.addWithLabel<ConvertStablehloWhileOp>({"StablehloWhile"}, ctx);
Expand Down
Loading