@@ -50,29 +50,47 @@ void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) {
5050#endif // NDEBUG
5151}
5252
53+ // Note: this is a local adaptor to unify TensorType and TensorLikeType code
54+ // paths that both work with BufferizationOptions.
55+ static mlir::Attribute
56+ getDefaultMemorySpace (const BufferizationOptions &options,
57+ TensorLikeType type) {
58+ if (auto tensorType = dyn_cast<TensorType>(type)) {
59+ return *options.defaultMemorySpaceFn (tensorType);
60+ }
61+ return nullptr ;
62+ }
63+
5364// / Return the index-th bufferized function argument type. This assumes that the
5465// / specified argument is a tensor. If the tensor is ranked, a layout map may be
5566// / specified by the user (as per `options.functionArgTypeConverterFn`).
56- static BaseMemRefType
67+ static BufferLikeType
5768getBufferizedFunctionArgType (FuncOp funcOp, int64_t index,
5869 const BufferizationOptions &options) {
59- auto tensorType =
60- dyn_cast<TensorType>(funcOp.getFunctionType ().getInput (index));
61- assert (tensorType && " expected TensorType" );
62-
63- BaseMemRefType memrefType = options.functionArgTypeConverterFn (
64- tensorType, *options.defaultMemorySpaceFn (tensorType), funcOp, options);
65-
66- auto layoutAttr = funcOp.getArgAttrOfType <AffineMapAttr>(
67- index, BufferizationDialect::kBufferLayoutAttrName );
68- if (!layoutAttr)
69- return memrefType;
70-
71- auto rankedMemrefType = dyn_cast<MemRefType>(memrefType);
72- assert (rankedMemrefType && " buffer layout not supported on unranked tensors" );
73- return MemRefType::get (
74- rankedMemrefType.getShape (), rankedMemrefType.getElementType (),
75- layoutAttr.getValue (), rankedMemrefType.getMemorySpace ());
70+ auto type =
71+ dyn_cast<TensorLikeType>(funcOp.getFunctionType ().getInput (index));
72+ assert (type && " expected TensorLikeType" );
73+
74+ // Note: For builtin tensors there is additional logic related to layout.
75+ if (auto tensorType = dyn_cast<TensorType>(type)) {
76+ BufferLikeType memrefType = options.functionArgTypeConverterFn (
77+ type, *options.defaultMemorySpaceFn (tensorType), funcOp, options);
78+
79+ auto layoutAttr = funcOp.getArgAttrOfType <MemRefLayoutAttrInterface>(
80+ index, BufferizationDialect::kBufferLayoutAttrName );
81+ if (!layoutAttr)
82+ return memrefType;
83+
84+ auto rankedMemrefType = dyn_cast<MemRefType>(memrefType);
85+ assert (rankedMemrefType &&
86+ " buffer layout not supported on unranked tensors" );
87+ return cast<BufferLikeType>(MemRefType::get (
88+ rankedMemrefType.getShape (), rankedMemrefType.getElementType (),
89+ layoutAttr, rankedMemrefType.getMemorySpace ()));
90+ }
91+
92+ return options.functionArgTypeConverterFn (type, /* memSpace=*/ nullptr , funcOp,
93+ options);
7694}
7795
7896// / Return the FuncOp called by `callOp`.
@@ -207,13 +225,13 @@ struct CallOpInterface
207225 FunctionType funcType = funcOp.getFunctionType ();
208226 Type resultType =
209227 funcType.getResult (cast<OpResult>(value).getResultNumber ());
210- if (auto bufferizedType = dyn_cast<BaseMemRefType >(resultType))
211- return cast<BufferLikeType>( bufferizedType) ;
228+ if (auto bufferizedType = dyn_cast<BufferLikeType >(resultType))
229+ return bufferizedType;
212230
213231 // Otherwise, call the type converter to compute the bufferized type.
214- auto tensorType = cast<TensorType >(resultType);
232+ auto tensorType = cast<TensorLikeType >(resultType);
215233 return cast<BufferLikeType>(options.functionArgTypeConverterFn (
216- tensorType, *options. defaultMemorySpaceFn ( tensorType), funcOp,
234+ tensorType, getDefaultMemorySpace (options, tensorType), funcOp,
217235 options));
218236 }
219237
@@ -227,7 +245,7 @@ struct CallOpInterface
227245 SmallVector<Type> resultTypes;
228246 for (Value result : callOp.getResults ()) {
229247 Type returnType = result.getType ();
230- if (!isa<TensorType >(returnType)) {
248+ if (!isa<TensorLikeType >(returnType)) {
231249 // Non-tensor values are returned.
232250 resultTypes.push_back (returnType);
233251 continue ;
@@ -250,7 +268,7 @@ struct CallOpInterface
250268
251269 for (OpOperand &opOperand : callOp->getOpOperands ()) {
252270 // Non-tensor operands are just copied.
253- if (!isa<TensorType >(opOperand.get ().getType ())) {
271+ if (!isa<TensorLikeType >(opOperand.get ().getType ())) {
254272 newOperands.push_back (opOperand.get ());
255273 continue ;
256274 }
@@ -263,8 +281,8 @@ struct CallOpInterface
263281 Value buffer = *maybeBuffer;
264282
265283 // Caller / callee type mismatch is handled with castOrReallocMemRefValue.
266- auto memRefType = funcType.getInput (opOperand.getOperandNumber ());
267- if (!isa<BaseMemRefType>(memRefType )) {
284+ auto bufferType = funcType.getInput (opOperand.getOperandNumber ());
285+ if (!isa<BufferLikeType>(bufferType )) {
268286 // The called function was not bufferized yet. This can happen when
269287 // there cycles in the function call graph. Compute the bufferized
270288 // result type.
@@ -273,7 +291,7 @@ struct CallOpInterface
273291 funcOp.getArgument (opOperand.getOperandNumber ()), options);
274292 if (failed (maybeBufferType))
275293 return failure ();
276- memRefType = *maybeBufferType;
294+ bufferType = *maybeBufferType;
277295 }
278296
279297 // Since we don't yet have a clear layout story, to_buffer may
@@ -282,8 +300,8 @@ struct CallOpInterface
282300 // that will either canonicalize away or fail compilation until we can do
283301 // something better. Insert a reallocation + copy if it cannot be
284302 // statically guaranteed that a direct cast would be valid.
285- if (buffer.getType () != memRefType ) {
286- auto memrefDstType = dyn_cast<MemRefType>(memRefType );
303+ if (buffer.getType () != bufferType ) {
304+ auto memrefDstType = dyn_cast<MemRefType>(bufferType );
287305 assert (memrefDstType &&
288306 " buffer layout not supported on unranked tensors" );
289307 FailureOr<Value> replacement = bufferization::castOrReallocMemRefValue (
@@ -345,7 +363,7 @@ struct FuncOpInterface
345363 static bool supportsUnstructuredControlFlow () { return true ; }
346364
347365 bool hasTensorSemantics (Operation *op) const {
348- auto isaTensor = llvm::IsaPred<TensorType >;
366+ auto isaTensor = llvm::IsaPred<TensorLikeType >;
349367
350368 // A function has tensor semantics if it has tensor arguments/results.
351369 auto funcOp = cast<FuncOp>(op);
@@ -380,8 +398,8 @@ struct FuncOpInterface
380398
381399 // Function arguments are special.
382400 if (bbArg.getOwner () == &funcOp.getBody ().front ())
383- return cast<BufferLikeType>(
384- getBufferizedFunctionArgType (funcOp, bbArg. getArgNumber (), options) );
401+ return getBufferizedFunctionArgType (funcOp, bbArg. getArgNumber (),
402+ options);
385403
386404 return OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel::
387405 getBufferType (op, value, options, invocationStack);
@@ -403,7 +421,7 @@ struct FuncOpInterface
403421 SmallVector<Type> argTypes;
404422 for (const auto &it : llvm::enumerate (funcType.getInputs ())) {
405423 Type argType = it.value ();
406- if (isa<TensorType >(argType)) {
424+ if (isa<TensorLikeType >(argType)) {
407425 argTypes.push_back (
408426 getBufferizedFunctionArgType (funcOp, it.index (), options));
409427 continue ;
@@ -414,9 +432,9 @@ struct FuncOpInterface
414432 // Compute the result types.
415433 SmallVector<Type> retTypes;
416434 for (Type resultType : funcType.getResults ()) {
417- if (auto tensorType = dyn_cast<TensorType >(resultType)) {
418- BaseMemRefType resultType = options.functionArgTypeConverterFn (
419- tensorType, *options. defaultMemorySpaceFn ( tensorType), funcOp,
435+ if (auto tensorType = dyn_cast<TensorLikeType >(resultType)) {
436+ BufferLikeType resultType = options.functionArgTypeConverterFn (
437+ tensorType, getDefaultMemorySpace (options, tensorType), funcOp,
420438 options);
421439 retTypes.push_back (resultType);
422440 continue ;
@@ -446,7 +464,7 @@ struct FuncOpInterface
446464 SmallVector<Value> returnValues;
447465 for (auto [returnVal, bufferizedType] :
448466 llvm::zip_equal (returnOp->getOperands (), retTypes)) {
449- auto tensorType = dyn_cast<TensorType >(returnVal.getType ());
467+ auto tensorType = dyn_cast<TensorLikeType >(returnVal.getType ());
450468 rewriter.setInsertionPoint (returnOp);
451469
452470 // If not a tensor type just forward it.
0 commit comments