Skip to content

Commit 8417590

Browse files
[mlir][bufferization] Return BufferLikeType in BufferizableOpInterface (#144867)
Support custom types (2/N): allow value-owning operations (e.g. allocation ops) to bufferize custom tensors into custom buffers. This requires BufferizableOpInterface::getBufferType() to return BufferLikeType instead of BaseMemRefType. Affected implementors of the interface are updated accordingly. Relates to ee070d0.
1 parent 90e46a4 commit 8417590

File tree

14 files changed

+173
-99
lines changed

14 files changed

+173
-99
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -695,7 +695,7 @@ AliasingOpOperandList defaultGetAliasingOpOperands(Value value,
695695
/// This is the default implementation of
696696
/// BufferizableOpInterface::getBufferType. Should not be called from other
697697
/// places.
698-
FailureOr<BaseMemRefType>
698+
FailureOr<BufferLikeType>
699699
defaultGetBufferType(Value value, const BufferizationOptions &options,
700700
SmallVector<Value> &invocationStack);
701701

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -518,7 +518,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
518518
Note: This interface method should never be called directly from user
519519
code. Always use `bufferization::getBufferType`.
520520
}],
521-
/*retType=*/"::mlir::FailureOr<::mlir::BaseMemRefType>",
521+
/*retType=*/"::mlir::FailureOr<::mlir::bufferization::BufferLikeType>",
522522
/*methodName=*/"getBufferType",
523523
/*args=*/(ins "::mlir::Value":$value,
524524
"const ::mlir::bufferization::BufferizationOptions &":$options,

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
110110
AliasingValueList getAliasingValues(
111111
OpOperand &opOperand, const AnalysisState &state);
112112

113-
FailureOr<BaseMemRefType> getBufferType(
113+
FailureOr<BufferLikeType> getBufferType(
114114
Value value, const BufferizationOptions &options,
115115
SmallVector<Value> &invocationStack);
116116

@@ -474,10 +474,10 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
474474

475475
bool isWritable(Value value, const AnalysisState &state);
476476

477-
FailureOr<BaseMemRefType> getBufferType(
477+
FailureOr<BufferLikeType> getBufferType(
478478
Value value, const BufferizationOptions &options,
479479
SmallVector<Value> &invocationStack) {
480-
return ::llvm::cast<BaseMemRefType>(getBuffer().getType());
480+
return getBuffer().getType();
481481
}
482482
}];
483483

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
namespace mlir::bufferization {
2020
struct BufferizationOptions;
21-
class BufferizationState;
2221
class BufferLikeType;
2322
} // namespace mlir::bufferization
2423

mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ template <typename ConcreteModel, typename ConcreteOp>
3232
struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
3333
: public BufferizableOpInterface::ExternalModel<ConcreteModel, ConcreteOp> {
3434

35-
FailureOr<BaseMemRefType>
35+
FailureOr<BufferLikeType>
3636
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
3737
SmallVector<Value> &invocationStack) const {
3838
// Note: The user may want to override this function for OpResults in
@@ -110,7 +110,7 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
110110
if (!bufferType)
111111
return op->emitOpError("could not infer buffer type of block argument");
112112

113-
return bufferType;
113+
return cast<BufferLikeType>(bufferType);
114114
}
115115

116116
protected:

mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ struct SelectOpInterface
176176
return success();
177177
}
178178

179-
FailureOr<BaseMemRefType>
179+
FailureOr<BufferLikeType>
180180
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
181181
SmallVector<Value> &invocationStack) const {
182182
auto selectOp = cast<arith::SelectOp>(op);
@@ -190,17 +190,17 @@ struct SelectOpInterface
190190
if (failed(trueType) || failed(falseType))
191191
return failure();
192192
if (*trueType == *falseType)
193-
return *trueType;
193+
return cast<BufferLikeType>(*trueType);
194194
if (trueType->getMemorySpace() != falseType->getMemorySpace())
195195
return op->emitError("inconsistent memory space on true/false operands");
196196

197197
// If the buffers have different types, they differ only in their layout
198198
// map.
199199
auto memrefType = llvm::cast<MemRefType>(*trueType);
200-
return getMemRefTypeWithFullyDynamicLayout(
200+
return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
201201
RankedTensorType::get(memrefType.getShape(),
202202
memrefType.getElementType()),
203-
memrefType.getMemorySpace());
203+
memrefType.getMemorySpace()));
204204
}
205205
};
206206

mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -917,15 +917,17 @@ AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands(
917917
return AliasingOpOperandList(std::move(result));
918918
}
919919

920-
FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
920+
FailureOr<BufferLikeType> bufferization::detail::defaultGetBufferType(
921921
Value value, const BufferizationOptions &options,
922922
SmallVector<Value> &invocationStack) {
923923
assert(llvm::isa<TensorType>(value.getType()) && "expected tensor type");
924924
auto tensorType = cast<TensorType>(value.getType());
925925

926926
// No further analysis is possible for a block argument.
927-
if (llvm::isa<BlockArgument>(value))
928-
return bufferization::getMemRefType(tensorType, options);
927+
if (llvm::isa<BlockArgument>(value)) {
928+
return cast<BufferLikeType>(
929+
bufferization::getMemRefType(tensorType, options));
930+
}
929931

930932
// Value is an OpResult.
931933
Operation *op = getOwnerOfValue(value);
@@ -937,8 +939,7 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
937939
// If the OpResult has an equivalent OpOperand, both OpResult and
938940
// OpOperand bufferize to the exact same buffer type.
939941
Value equivalentOperand = aliases.getAliases().front().opOperand->get();
940-
return asMemRefType(
941-
getBufferType(equivalentOperand, options, invocationStack));
942+
return getBufferType(equivalentOperand, options, invocationStack);
942943
}
943944

944945
// If we do not know the memory space and there is no default memory space,
@@ -948,7 +949,8 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
948949
if (!memSpace.has_value())
949950
return op->emitError("could not infer memory space");
950951

951-
return getMemRefType(tensorType, options, /*layout=*/{}, *memSpace);
952+
return cast<BufferLikeType>(
953+
getMemRefType(tensorType, options, /*layout=*/{}, *memSpace));
952954
}
953955

954956
bool bufferization::detail::defaultIsRepetitiveRegion(

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ AliasingValueList AllocTensorOp::getAliasingValues(OpOperand &opOperand,
222222
return {};
223223
}
224224

225-
FailureOr<BaseMemRefType>
225+
FailureOr<BufferLikeType>
226226
AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
227227
SmallVector<Value> &invocationStack) {
228228
assert(value == getResult() && "invalid value");
@@ -243,7 +243,8 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
243243
return getOperation()->emitError("could not infer memory space");
244244
}
245245

246-
return getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace);
246+
return cast<BufferLikeType>(
247+
getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace));
247248
}
248249

249250
LogicalResult AllocTensorOp::verify() {

mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ struct CallOpInterface
195195
return result;
196196
}
197197

198-
FailureOr<BaseMemRefType>
198+
FailureOr<BufferLikeType>
199199
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
200200
SmallVector<Value> &invocationStack) const {
201201
auto callOp = cast<func::CallOp>(op);
@@ -208,12 +208,13 @@ struct CallOpInterface
208208
Type resultType =
209209
funcType.getResult(cast<OpResult>(value).getResultNumber());
210210
if (auto bufferizedType = dyn_cast<BaseMemRefType>(resultType))
211-
return bufferizedType;
211+
return cast<BufferLikeType>(bufferizedType);
212212

213213
// Otherwise, call the type converter to compute the bufferized type.
214214
auto tensorType = cast<TensorType>(resultType);
215-
return options.functionArgTypeConverterFn(
216-
tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
215+
return cast<BufferLikeType>(options.functionArgTypeConverterFn(
216+
tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
217+
options));
217218
}
218219

219220
/// All function arguments are writable. It is the responsibility of the
@@ -371,16 +372,16 @@ struct FuncOpInterface
371372
return getAliasingBranchOpOperands(op, cast<BlockArgument>(value), state);
372373
}
373374

374-
FailureOr<BaseMemRefType>
375+
FailureOr<BufferLikeType>
375376
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
376377
SmallVector<Value> &invocationStack) const {
377378
auto funcOp = cast<FuncOp>(op);
378379
auto bbArg = cast<BlockArgument>(value);
379380

380381
// Function arguments are special.
381382
if (bbArg.getOwner() == &funcOp.getBody().front())
382-
return getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(),
383-
options);
383+
return cast<BufferLikeType>(
384+
getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options));
384385

385386
return OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel::
386387
getBufferType(op, value, options, invocationStack);

mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 26 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ struct IfOpInterface
270270
return success();
271271
}
272272

273-
FailureOr<BaseMemRefType>
273+
FailureOr<BufferLikeType>
274274
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
275275
SmallVector<Value> &invocationStack) const {
276276
auto ifOp = cast<scf::IfOp>(op);
@@ -306,15 +306,15 @@ struct IfOpInterface
306306

307307
// Best case: Both branches have the exact same buffer type.
308308
if (thenBufferType == elseBufferType)
309-
return thenBufferType;
309+
return cast<BufferLikeType>(thenBufferType);
310310

311311
// Memory space mismatch.
312312
if (thenBufferType.getMemorySpace() != elseBufferType.getMemorySpace())
313313
return op->emitError("inconsistent memory space on then/else branches");
314314

315315
// Layout maps are different: Promote to fully dynamic layout map.
316-
return getMemRefTypeWithFullyDynamicLayout(
317-
cast<TensorType>(opResult.getType()), thenBufferType.getMemorySpace());
316+
return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
317+
cast<TensorType>(opResult.getType()), thenBufferType.getMemorySpace()));
318318
}
319319
};
320320

@@ -384,7 +384,7 @@ struct IndexSwitchOpInterface
384384
return success();
385385
}
386386

387-
FailureOr<BaseMemRefType>
387+
FailureOr<BufferLikeType>
388388
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
389389
SmallVector<Value> &invocationStack) const {
390390
auto switchOp = cast<scf::IndexSwitchOp>(op);
@@ -428,7 +428,7 @@ struct IndexSwitchOpInterface
428428
cast<TensorType>(value.getType()), bufferType.getMemorySpace());
429429
}
430430

431-
return bufferType;
431+
return cast<BufferLikeType>(bufferType);
432432
}
433433
};
434434

@@ -514,12 +514,12 @@ getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
514514
/// If both buffer types are equal, no casts are needed the computed buffer type
515515
/// can be used directly. Otherwise, the buffer types can only differ in their
516516
/// layout map and a cast must be inserted.
517-
static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
517+
static FailureOr<BufferLikeType> computeLoopRegionIterArgBufferType(
518518
Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue,
519519
const BufferizationOptions &options, SmallVector<Value> &invocationStack) {
520520
// Determine the buffer type of the init_arg.
521-
auto initArgBufferType = bufferization::detail::asMemRefType(
522-
bufferization::getBufferType(initArg, options, invocationStack));
521+
auto initArgBufferType =
522+
bufferization::getBufferType(initArg, options, invocationStack);
523523
if (failed(initArgBufferType))
524524
return failure();
525525

@@ -538,15 +538,15 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
538538
}
539539

540540
// Compute the buffer type of the yielded value.
541-
BaseMemRefType yieldedValueBufferType;
541+
BufferLikeType yieldedValueBufferType;
542542
if (isa<BaseMemRefType>(yieldedValue.getType())) {
543543
// scf.yield was already bufferized.
544-
yieldedValueBufferType = cast<BaseMemRefType>(yieldedValue.getType());
544+
yieldedValueBufferType = cast<BufferLikeType>(yieldedValue.getType());
545545
} else {
546546
// Note: This typically triggers a recursive call for the buffer type of
547547
// the iter_arg.
548-
auto maybeBufferType = bufferization::detail::asMemRefType(
549-
bufferization::getBufferType(yieldedValue, options, invocationStack));
548+
auto maybeBufferType =
549+
bufferization::getBufferType(yieldedValue, options, invocationStack);
550550
if (failed(maybeBufferType))
551551
return failure();
552552
yieldedValueBufferType = *maybeBufferType;
@@ -574,8 +574,8 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
574574
"expected same shape");
575575
}
576576
#endif // NDEBUG
577-
return getMemRefTypeWithFullyDynamicLayout(
578-
iterTensorType, yieldedBufferType.getMemorySpace());
577+
return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
578+
iterTensorType, yieldedBufferType.getMemorySpace()));
579579
}
580580

581581
/// Return `true` if the given loop may have 0 iterations.
@@ -695,7 +695,7 @@ struct ForOpInterface
695695
return success();
696696
}
697697

698-
FailureOr<BaseMemRefType>
698+
FailureOr<BufferLikeType>
699699
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
700700
SmallVector<Value> &invocationStack) const {
701701
auto forOp = cast<scf::ForOp>(op);
@@ -705,12 +705,7 @@ struct ForOpInterface
705705
if (auto opResult = dyn_cast<OpResult>(value)) {
706706
// The type of an OpResult must match the corresponding iter_arg type.
707707
BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
708-
auto bufferType =
709-
bufferization::getBufferType(bbArg, options, invocationStack);
710-
if (failed(bufferType))
711-
return failure();
712-
assert(isa<BaseMemRefType>(*bufferType) && "expected memref type");
713-
return cast<BaseMemRefType>(*bufferType);
708+
return bufferization::getBufferType(bbArg, options, invocationStack);
714709
}
715710

716711
// Compute result/argument number.
@@ -1027,7 +1022,7 @@ struct WhileOpInterface
10271022
return success();
10281023
}
10291024

1030-
FailureOr<BaseMemRefType>
1025+
FailureOr<BufferLikeType>
10311026
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
10321027
SmallVector<Value> &invocationStack) const {
10331028
auto whileOp = cast<scf::WhileOp>(op);
@@ -1060,10 +1055,10 @@ struct WhileOpInterface
10601055
Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum];
10611056
if (!isa<TensorType>(conditionYieldedVal.getType())) {
10621057
// scf.condition was already bufferized.
1063-
return cast<BaseMemRefType>(conditionYieldedVal.getType());
1058+
return cast<BufferLikeType>(conditionYieldedVal.getType());
10641059
}
1065-
return bufferization::detail::asMemRefType(bufferization::getBufferType(
1066-
conditionYieldedVal, options, invocationStack));
1060+
return bufferization::getBufferType(conditionYieldedVal, options,
1061+
invocationStack);
10671062
}
10681063

10691064
/// Assert that yielded values of an scf.while op are equivalent to their
@@ -1297,22 +1292,22 @@ struct ForallOpInterface
12971292
return success();
12981293
}
12991294

1300-
FailureOr<BaseMemRefType>
1295+
FailureOr<BufferLikeType>
13011296
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
13021297
SmallVector<Value> &invocationStack) const {
13031298
auto forallOp = cast<ForallOp>(op);
13041299

13051300
if (auto bbArg = dyn_cast<BlockArgument>(value))
13061301
// A tensor block argument has the same bufferized type as the
13071302
// corresponding output operand.
1308-
return bufferization::detail::asMemRefType(bufferization::getBufferType(
1309-
forallOp.getTiedOpOperand(bbArg)->get(), options, invocationStack));
1303+
return bufferization::getBufferType(
1304+
forallOp.getTiedOpOperand(bbArg)->get(), options, invocationStack);
13101305

13111306
// The bufferized result type is the same as the bufferized type of the
13121307
// corresponding output operand.
1313-
return bufferization::detail::asMemRefType(bufferization::getBufferType(
1308+
return bufferization::getBufferType(
13141309
forallOp.getOutputs()[cast<OpResult>(value).getResultNumber()], options,
1315-
invocationStack));
1310+
invocationStack);
13161311
}
13171312

13181313
bool isRepetitiveRegion(Operation *op, unsigned index) const {

0 commit comments

Comments
 (0)