Skip to content

Commit a765903

Browse files
[mlir][bufferization] Support custom types at function boundaries (#159766)
Support custom types (3/N): allow custom tensor and buffer types in function signatures and at call-sites. This is one of the major building blocks to move in the direction of module-level one-shot-bufferization support. To achieve this, `BufferizationOptions::FunctionArgTypeConverterFn` callback is converted to work with tensor-like and buffer-like types, instead of the builtin counterparts. The default behavior for builtins remains unchanged, while custom types by default go through `TensorLikeType::getBufferType()` which is a general conversion interface.
1 parent 8417590 commit a765903

File tree

5 files changed

+154
-59
lines changed

5 files changed

+154
-59
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -260,12 +260,12 @@ struct BufferizationOptions {
260260
std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
261261
/// Initializer function for analysis state.
262262
using AnalysisStateInitFn = std::function<void(AnalysisState &)>;
263-
/// Tensor -> MemRef type converter.
264-
/// Parameters: tensor type, memory space, func op, bufferization options
263+
/// Tensor-like -> Buffer-like type conversion.
264+
/// Parameters: tensor-like type, memory space, func op, bufferization options
265265
using FunctionArgTypeConverterFn =
266-
std::function<BaseMemRefType(TensorType, Attribute memorySpace,
266+
std::function<BufferLikeType(TensorLikeType, Attribute memorySpace,
267267
func::FuncOp, const BufferizationOptions &)>;
268-
/// Tensor -> MemRef type converter.
268+
/// Tensor -> MemRef type conversion.
269269
/// Parameters: tensor type, memory space, bufferization options
270270
using UnknownTypeConverterFn = std::function<BaseMemRefType(
271271
TensorType, Attribute memorySpace, const BufferizationOptions &)>;
@@ -345,10 +345,12 @@ struct BufferizationOptions {
345345
/// predictable.
346346
void setFunctionBoundaryTypeConversion(LayoutMapOption layoutMapOption);
347347

348-
/// Type converter from tensors to memrefs. This type converter is used to
349-
/// determine bufferized function argument and result types. By default, a
350-
/// type converter that returns a memref type with a fully dynamic layout map
351-
/// is used.
348+
/// Type conversion from tensors to buffers. This type conversion is used to
349+
/// determine bufferized function argument and result types.
350+
///
351+
/// By default, if tensor is a (builtin) tensor type, it is converted to a
352+
/// memref type with a fully dynamic layout map; if tensor is a (generic)
353+
/// tensor-like type, it is converted using TensorLikeType::getBufferType().
352354
///
353355
/// If `bufferizeFunctionBoundaries` is not set, this function isn't used.
354356
FunctionArgTypeConverterFn functionArgTypeConverterFn = nullptr;
@@ -360,10 +362,9 @@ struct BufferizationOptions {
360362
/// If `bufferizeFunctionBoundaries` is not set, this flag has no effect.
361363
bool inferFunctionResultLayout = true;
362364

363-
/// Type converter from tensors to memrefs. This type converter is used if no
364-
/// memref type could be inferred during bufferization. By default, a type
365-
/// converter that returns a memref type with a fully dynamic layout map is
366-
/// used.
365+
/// Type conversion from tensors to memrefs. This type conversion is used if
366+
/// no memref type could be inferred during bufferization. By default, returns
367+
/// a memref type with a fully dynamic layout map.
367368
UnknownTypeConverterFn unknownTypeConverterFn = nullptr;
368369

369370
// Use during type conversion to determine the memory space for memref based

mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -314,11 +314,21 @@ bool OpFilter::isOpAllowed(Operation *op) const {
314314
namespace {
315315

316316
/// Default function arg type converter: Use a fully dynamic layout map.
317-
BaseMemRefType
318-
defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace,
317+
BufferLikeType
318+
defaultFunctionArgTypeConverter(TensorLikeType type, Attribute memorySpace,
319319
func::FuncOp funcOp,
320320
const BufferizationOptions &options) {
321-
return getMemRefTypeWithFullyDynamicLayout(type, memorySpace);
321+
if (auto tensorType = mlir::dyn_cast<TensorType>(type)) {
322+
return cast<BufferLikeType>(
323+
getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace));
324+
}
325+
326+
// If not builtin, fallback to TensorLikeType::getBufferType()
327+
auto bufferType =
328+
type.getBufferType(options, [&]() { return funcOp->emitError(); });
329+
assert(succeeded(bufferType) &&
330+
"a valid buffer is always expected at function boundary");
331+
return *bufferType;
322332
}
323333
/// Default unknown type converter: Use a fully dynamic layout map.
324334
BaseMemRefType
@@ -361,14 +371,25 @@ BufferizationOptions::dynCastBufferizableOp(Value value) const {
361371

362372
void BufferizationOptions::setFunctionBoundaryTypeConversion(
363373
LayoutMapOption layoutMapOption) {
364-
functionArgTypeConverterFn = [=](TensorType tensorType, Attribute memorySpace,
374+
functionArgTypeConverterFn = [=](TensorLikeType type, Attribute memorySpace,
365375
func::FuncOp funcOp,
366376
const BufferizationOptions &options) {
367-
if (layoutMapOption == LayoutMapOption::IdentityLayoutMap)
368-
return bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType,
369-
memorySpace);
370-
return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
371-
memorySpace);
377+
if (auto tensorType = mlir::dyn_cast<TensorType>(type)) {
378+
if (layoutMapOption == LayoutMapOption::IdentityLayoutMap)
379+
return cast<BufferLikeType>(
380+
bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType,
381+
memorySpace));
382+
return cast<BufferLikeType>(
383+
bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
384+
memorySpace));
385+
}
386+
387+
// If not builtin, fallback to TensorLikeType::getBufferType()
388+
auto bufferType =
389+
type.getBufferType(options, [&]() { return funcOp->emitError(); });
390+
assert(succeeded(bufferType) &&
391+
"a valid buffer is always expected at function boundary");
392+
return *bufferType;
372393
};
373394
inferFunctionResultLayout =
374395
layoutMapOption == LayoutMapOption::InferLayoutMap;

mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,7 @@ bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
425425
// Compute the new signature.
426426
SmallVector<Type> newTypes;
427427
for (BlockArgument &bbArg : block->getArguments()) {
428-
auto tensorType = dyn_cast<TensorType>(bbArg.getType());
428+
auto tensorType = dyn_cast<TensorLikeType>(bbArg.getType());
429429
if (!tensorType) {
430430
newTypes.push_back(bbArg.getType());
431431
continue;

mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp

Lines changed: 55 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -50,29 +50,47 @@ void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) {
5050
#endif // NDEBUG
5151
}
5252

53+
// Note: this is a local adaptor to unify TensorType and TensorLikeType code
54+
// paths that both work with BufferizationOptions.
55+
static mlir::Attribute
56+
getDefaultMemorySpace(const BufferizationOptions &options,
57+
TensorLikeType type) {
58+
if (auto tensorType = dyn_cast<TensorType>(type)) {
59+
return *options.defaultMemorySpaceFn(tensorType);
60+
}
61+
return nullptr;
62+
}
63+
5364
/// Return the index-th bufferized function argument type. This assumes that the
5465
/// specified argument is a tensor. If the tensor is ranked, a layout map may be
5566
/// specified by the user (as per `options.functionArgTypeConverterFn`).
56-
static BaseMemRefType
67+
static BufferLikeType
5768
getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
5869
const BufferizationOptions &options) {
59-
auto tensorType =
60-
dyn_cast<TensorType>(funcOp.getFunctionType().getInput(index));
61-
assert(tensorType && "expected TensorType");
62-
63-
BaseMemRefType memrefType = options.functionArgTypeConverterFn(
64-
tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
65-
66-
auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>(
67-
index, BufferizationDialect::kBufferLayoutAttrName);
68-
if (!layoutAttr)
69-
return memrefType;
70-
71-
auto rankedMemrefType = dyn_cast<MemRefType>(memrefType);
72-
assert(rankedMemrefType && "buffer layout not supported on unranked tensors");
73-
return MemRefType::get(
74-
rankedMemrefType.getShape(), rankedMemrefType.getElementType(),
75-
layoutAttr.getValue(), rankedMemrefType.getMemorySpace());
70+
auto type =
71+
dyn_cast<TensorLikeType>(funcOp.getFunctionType().getInput(index));
72+
assert(type && "expected TensorLikeType");
73+
74+
// Note: For builtin tensors there is additional logic related to layout.
75+
if (auto tensorType = dyn_cast<TensorType>(type)) {
76+
BufferLikeType memrefType = options.functionArgTypeConverterFn(
77+
type, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
78+
79+
auto layoutAttr = funcOp.getArgAttrOfType<MemRefLayoutAttrInterface>(
80+
index, BufferizationDialect::kBufferLayoutAttrName);
81+
if (!layoutAttr)
82+
return memrefType;
83+
84+
auto rankedMemrefType = dyn_cast<MemRefType>(memrefType);
85+
assert(rankedMemrefType &&
86+
"buffer layout not supported on unranked tensors");
87+
return cast<BufferLikeType>(MemRefType::get(
88+
rankedMemrefType.getShape(), rankedMemrefType.getElementType(),
89+
layoutAttr, rankedMemrefType.getMemorySpace()));
90+
}
91+
92+
return options.functionArgTypeConverterFn(type, /*memSpace=*/nullptr, funcOp,
93+
options);
7694
}
7795

7896
/// Return the FuncOp called by `callOp`.
@@ -207,13 +225,13 @@ struct CallOpInterface
207225
FunctionType funcType = funcOp.getFunctionType();
208226
Type resultType =
209227
funcType.getResult(cast<OpResult>(value).getResultNumber());
210-
if (auto bufferizedType = dyn_cast<BaseMemRefType>(resultType))
211-
return cast<BufferLikeType>(bufferizedType);
228+
if (auto bufferizedType = dyn_cast<BufferLikeType>(resultType))
229+
return bufferizedType;
212230

213231
// Otherwise, call the type converter to compute the bufferized type.
214-
auto tensorType = cast<TensorType>(resultType);
232+
auto tensorType = cast<TensorLikeType>(resultType);
215233
return cast<BufferLikeType>(options.functionArgTypeConverterFn(
216-
tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
234+
tensorType, getDefaultMemorySpace(options, tensorType), funcOp,
217235
options));
218236
}
219237

@@ -227,7 +245,7 @@ struct CallOpInterface
227245
SmallVector<Type> resultTypes;
228246
for (Value result : callOp.getResults()) {
229247
Type returnType = result.getType();
230-
if (!isa<TensorType>(returnType)) {
248+
if (!isa<TensorLikeType>(returnType)) {
231249
// Non-tensor values are returned.
232250
resultTypes.push_back(returnType);
233251
continue;
@@ -250,7 +268,7 @@ struct CallOpInterface
250268

251269
for (OpOperand &opOperand : callOp->getOpOperands()) {
252270
// Non-tensor operands are just copied.
253-
if (!isa<TensorType>(opOperand.get().getType())) {
271+
if (!isa<TensorLikeType>(opOperand.get().getType())) {
254272
newOperands.push_back(opOperand.get());
255273
continue;
256274
}
@@ -263,8 +281,8 @@ struct CallOpInterface
263281
Value buffer = *maybeBuffer;
264282

265283
// Caller / callee type mismatch is handled with castOrReallocMemRefValue.
266-
auto memRefType = funcType.getInput(opOperand.getOperandNumber());
267-
if (!isa<BaseMemRefType>(memRefType)) {
284+
auto bufferType = funcType.getInput(opOperand.getOperandNumber());
285+
if (!isa<BufferLikeType>(bufferType)) {
268286
// The called function was not bufferized yet. This can happen when
269287
// there cycles in the function call graph. Compute the bufferized
270288
// result type.
@@ -273,7 +291,7 @@ struct CallOpInterface
273291
funcOp.getArgument(opOperand.getOperandNumber()), options);
274292
if (failed(maybeBufferType))
275293
return failure();
276-
memRefType = *maybeBufferType;
294+
bufferType = *maybeBufferType;
277295
}
278296

279297
// Since we don't yet have a clear layout story, to_buffer may
@@ -282,8 +300,8 @@ struct CallOpInterface
282300
// that will either canonicalize away or fail compilation until we can do
283301
// something better. Insert a reallocation + copy if it cannot be
284302
// statically guaranteed that a direct cast would be valid.
285-
if (buffer.getType() != memRefType) {
286-
auto memrefDstType = dyn_cast<MemRefType>(memRefType);
303+
if (buffer.getType() != bufferType) {
304+
auto memrefDstType = dyn_cast<MemRefType>(bufferType);
287305
assert(memrefDstType &&
288306
"buffer layout not supported on unranked tensors");
289307
FailureOr<Value> replacement = bufferization::castOrReallocMemRefValue(
@@ -345,7 +363,7 @@ struct FuncOpInterface
345363
static bool supportsUnstructuredControlFlow() { return true; }
346364

347365
bool hasTensorSemantics(Operation *op) const {
348-
auto isaTensor = llvm::IsaPred<TensorType>;
366+
auto isaTensor = llvm::IsaPred<TensorLikeType>;
349367

350368
// A function has tensor semantics if it has tensor arguments/results.
351369
auto funcOp = cast<FuncOp>(op);
@@ -380,8 +398,8 @@ struct FuncOpInterface
380398

381399
// Function arguments are special.
382400
if (bbArg.getOwner() == &funcOp.getBody().front())
383-
return cast<BufferLikeType>(
384-
getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options));
401+
return getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(),
402+
options);
385403

386404
return OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel::
387405
getBufferType(op, value, options, invocationStack);
@@ -403,7 +421,7 @@ struct FuncOpInterface
403421
SmallVector<Type> argTypes;
404422
for (const auto &it : llvm::enumerate(funcType.getInputs())) {
405423
Type argType = it.value();
406-
if (isa<TensorType>(argType)) {
424+
if (isa<TensorLikeType>(argType)) {
407425
argTypes.push_back(
408426
getBufferizedFunctionArgType(funcOp, it.index(), options));
409427
continue;
@@ -414,9 +432,9 @@ struct FuncOpInterface
414432
// Compute the result types.
415433
SmallVector<Type> retTypes;
416434
for (Type resultType : funcType.getResults()) {
417-
if (auto tensorType = dyn_cast<TensorType>(resultType)) {
418-
BaseMemRefType resultType = options.functionArgTypeConverterFn(
419-
tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
435+
if (auto tensorType = dyn_cast<TensorLikeType>(resultType)) {
436+
BufferLikeType resultType = options.functionArgTypeConverterFn(
437+
tensorType, getDefaultMemorySpace(options, tensorType), funcOp,
420438
options);
421439
retTypes.push_back(resultType);
422440
continue;
@@ -446,7 +464,7 @@ struct FuncOpInterface
446464
SmallVector<Value> returnValues;
447465
for (auto [returnVal, bufferizedType] :
448466
llvm::zip_equal(returnOp->getOperands(), retTypes)) {
449-
auto tensorType = dyn_cast<TensorType>(returnVal.getType());
467+
auto tensorType = dyn_cast<TensorLikeType>(returnVal.getType());
450468
rewriter.setInsertionPoint(returnOp);
451469

452470
// If not a tensor type just forward it.

mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -796,3 +796,58 @@ func.func @result_type_mismatch(%c: i1) -> tensor<5xf32> {
796796
return %1 : tensor<5xf32>
797797
}
798798

799+
// -----
800+
801+
// CHECK: func.func @custom_types(
802+
// CHECK-SAME: %[[arg:.*]]: !test.test_memref<[4, 4], f64>
803+
// CHECK-SAME: ) -> (!test.test_memref<[4, 8], f64>,
804+
// CHECK-SAME: !test.test_memref<[4, 8], f64>)
805+
func.func @custom_types(%arg: !test.test_tensor<[4, 4], f64>)
806+
-> (!test.test_tensor<[4, 8], f64>, !test.test_tensor<[4, 8], f64>) {
807+
// CHECK: %[[out1:.*]] = "test.dummy_memref_op"(%[[arg]]) :
808+
// CHECK-SAME: (!test.test_memref<[4, 4], f64>) -> !test.test_memref<[4, 8], f64>
809+
%out1 = "test.dummy_tensor_op"(%arg) : (!test.test_tensor<[4, 4], f64>)
810+
-> !test.test_tensor<[4, 8], f64>
811+
812+
// CHECK: %[[alloc:.*]] = "test.create_memref_op"
813+
// CHECK: %[[out2:.*]] = "test.dummy_memref_op"(%[[alloc]])
814+
// CHECK-SAME: (!test.test_memref<[4, 4], f64>) -> !test.test_memref<[4, 8], f64>
815+
%alloc = "test.create_tensor_op"() : () -> !test.test_tensor<[4, 4], f64>
816+
%out2 = "test.dummy_tensor_op"(%alloc) : (!test.test_tensor<[4, 4], f64>)
817+
-> !test.test_tensor<[4, 8], f64>
818+
819+
// CHECK: return %[[out1]], %[[out2]]
820+
return %out1, %out2 :
821+
!test.test_tensor<[4, 8], f64>, !test.test_tensor<[4, 8], f64>
822+
}
823+
824+
// -----
825+
826+
// CHECK: func.func @custom_types_foo(
827+
// CHECK-SAME: %[[arg:.*]]: !test.test_memref<[4, 4], f64>
828+
// CHECK-SAME: ) -> !test.test_memref<[4, 4], f64>
829+
func.func @custom_types_foo(%arg: !test.test_tensor<[4, 4], f64>)
830+
-> !test.test_tensor<[4, 4], f64> {
831+
// CHECK: %[[out:.*]] = "test.dummy_memref_op"(%[[arg]])
832+
%out = "test.dummy_tensor_op"(%arg) : (!test.test_tensor<[4, 4], f64>)
833+
-> !test.test_tensor<[4, 4], f64>
834+
// CHECK: return %[[out]]
835+
return %out : !test.test_tensor<[4, 4], f64>
836+
}
837+
838+
// CHECK: func.func @custom_types_bar(
839+
// CHECK-SAME: %[[arg:.*]]: !test.test_memref<[4, 4], f64>
840+
// CHECK-SAME: ) -> !test.test_memref<[4, 8], f64>
841+
func.func @custom_types_bar(%arg: !test.test_tensor<[4, 4], f64>)
842+
-> !test.test_tensor<[4, 8], f64> {
843+
// CHECK: %[[call:.*]] = call @custom_types_foo(%[[arg]])
844+
%call = func.call @custom_types_foo(%arg) : (!test.test_tensor<[4, 4], f64>)
845+
-> !test.test_tensor<[4, 4], f64>
846+
847+
// CHECK: %[[out:.*]] = "test.dummy_memref_op"(%[[call]])
848+
%out = "test.dummy_tensor_op"(%call) : (!test.test_tensor<[4, 4], f64>)
849+
-> !test.test_tensor<[4, 8], f64>
850+
851+
// CHECK: return %[[out]]
852+
return %out : !test.test_tensor<[4, 8], f64>
853+
}

0 commit comments

Comments
 (0)