-
Notifications
You must be signed in to change notification settings - Fork 14.6k
[mlir] add optional type functor to call and function interfaces #146979
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
b4d8671
to
ef05b42
Compare
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir Author: Gibran Essa (gibrane) Changesadds type parsing functors to Patch is 31.77 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/146979.diff 10 Files Affected:
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index d70aa346eaa1f..5232a31ff5c77 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -463,10 +463,11 @@ class OpAsmPrinter : public AsmPrinter {
/// where location printing is controlled by the standard internal option.
/// You may pass omitType=true to not print a type, and pass an empty
/// attribute list if you don't care for attributes.
- virtual void printRegionArgument(BlockArgument arg,
- ArrayRef<NamedAttribute> argAttrs = {},
- bool omitType = false) = 0;
-
+ /// You can override default type printing behavior with the typePrinter arg.
+ virtual void printRegionArgument(
+ BlockArgument arg, ArrayRef<NamedAttribute> argAttrs = {},
+ bool omitType = false,
+ function_ref<void(OpAsmPrinter &, Type)> typePrinter = nullptr) = 0;
/// Print implementations for various things an operation contains.
virtual void printOperand(Value value) = 0;
virtual void printOperand(Value value, raw_ostream &os) = 0;
@@ -1701,13 +1702,18 @@ class OpAsmParser : public AsmParser {
///
/// If `allowType` is false or `allowAttrs` are false then the respective
/// parts of the grammar are not parsed.
- virtual ParseResult parseArgument(Argument &result, bool allowType = false,
- bool allowAttrs = false) = 0;
+ /// You can override default type parsing behavior with the typeParser arg.
+ virtual ParseResult
+ parseArgument(Argument &result, bool allowType = false,
+ bool allowAttrs = false,
+ function_ref<ParseResult(OpAsmParser &, Type &)> typeParser =
+ nullptr) = 0;
/// Parse a single argument if present.
- virtual OptionalParseResult
- parseOptionalArgument(Argument &result, bool allowType = false,
- bool allowAttrs = false) = 0;
+ virtual OptionalParseResult parseOptionalArgument(
+ Argument &result, bool allowType = false, bool allowAttrs = false,
+ function_ref<ParseResult(OpAsmParser &, Type &)> typeParser =
+ nullptr) = 0;
/// Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseArgumentList(SmallVectorImpl<Argument> &result,
diff --git a/mlir/include/mlir/Interfaces/CallInterfaces.h b/mlir/include/mlir/Interfaces/CallInterfaces.h
index 2bf3a3ca5f8a8..66f0287471da5 100644
--- a/mlir/include/mlir/Interfaces/CallInterfaces.h
+++ b/mlir/include/mlir/Interfaces/CallInterfaces.h
@@ -37,6 +37,8 @@ Operation *resolveCallable(CallOpInterface call,
SymbolTableCollection *symbolTable = nullptr);
/// Parse a function or call result list.
+/// You can override the default type parsing behavior using the typeParser
+/// parameter.
///
/// function-result-list ::= function-result-list-parens
/// | non-function-type
@@ -45,31 +47,39 @@ Operation *resolveCallable(CallOpInterface call,
/// function-result-list-no-parens ::= function-result (`,` function-result)*
/// function-result ::= type attribute-dict?
///
-ParseResult
-parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
- SmallVectorImpl<DictionaryAttr> &resultAttrs);
+ParseResult parseFunctionResultList(
+ OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
+ SmallVectorImpl<DictionaryAttr> &resultAttrs,
+ function_ref<ParseResult(OpAsmParser &, Type &)> typeParser = nullptr);
/// Parses a function signature using `parser`. This does not deal with function
/// signatures containing SSA region arguments (to parse these signatures, use
/// function_interface_impl::parseFunctionSignature). When
/// `mustParseEmptyResult`, `-> ()` is expected when there is no result type.
+/// You can override the default type parsing behavior using the typeParser
+/// parameter.
+///
///
/// no-ssa-function-signature ::= `(` no-ssa-function-arg-list `)`
/// -> function-result-list
/// no-ssa-function-arg-list ::= no-ssa-function-arg
/// (`,` no-ssa-function-arg)*
/// no-ssa-function-arg ::= type attribute-dict?
-ParseResult parseFunctionSignature(OpAsmParser &parser,
- SmallVectorImpl<Type> &argTypes,
- SmallVectorImpl<DictionaryAttr> &argAttrs,
- SmallVectorImpl<Type> &resultTypes,
- SmallVectorImpl<DictionaryAttr> &resultAttrs,
- bool mustParseEmptyResult = true);
+ParseResult parseFunctionSignature(
+ OpAsmParser &parser, SmallVectorImpl<Type> &argTypes,
+ SmallVectorImpl<DictionaryAttr> &argAttrs,
+ SmallVectorImpl<Type> &resultTypes,
+ SmallVectorImpl<DictionaryAttr> &resultAttrs,
+ bool mustParseEmptyResult = true,
+ function_ref<ParseResult(OpAsmParser &, Type &)> typeParser = nullptr);
/// Print a function signature for a call or callable operation. If a body
/// region is provided, the SSA arguments are printed in the signature. When
/// `printEmptyResult` is false, `-> function-result-list` is omitted when
/// `resultTypes` is empty.
+/// You can override the default type printing behavior using the typePrinter
+/// parameter.
+///
///
/// function-signature ::= ssa-function-signature
/// | no-ssa-function-signature
@@ -77,11 +87,11 @@ ParseResult parseFunctionSignature(OpAsmParser &parser,
/// -> function-result-list
/// ssa-function-arg-list ::= ssa-function-arg (`,` ssa-function-arg)*
/// ssa-function-arg ::= `%`name `:` type attribute-dict?
-void printFunctionSignature(OpAsmPrinter &p, TypeRange argTypes,
- ArrayAttr argAttrs, bool isVariadic,
- TypeRange resultTypes, ArrayAttr resultAttrs,
- Region *body = nullptr,
- bool printEmptyResult = true);
+void printFunctionSignature(
+ OpAsmPrinter &p, TypeRange argTypes, ArrayAttr argAttrs, bool isVariadic,
+ TypeRange resultTypes, ArrayAttr resultAttrs, Region *body = nullptr,
+ bool printEmptyResult = true,
+ function_ref<void(OpAsmPrinter &, Type)> typePrinter = nullptr);
/// Adds argument and result attributes, provided as `argAttrs` and
/// `resultAttrs` arguments, to the list of operation attributes in `result`.
diff --git a/mlir/include/mlir/Interfaces/FunctionImplementation.h b/mlir/include/mlir/Interfaces/FunctionImplementation.h
index 374c2c534f87d..de89a6bc0d50a 100644
--- a/mlir/include/mlir/Interfaces/FunctionImplementation.h
+++ b/mlir/include/mlir/Interfaces/FunctionImplementation.h
@@ -45,11 +45,14 @@ using FuncTypeBuilder = function_ref<Type(
/// indicates whether functions with variadic arguments are supported. The
/// trailing arguments are populated by this function with names, types,
/// attributes and locations of the arguments and those of the results.
+/// You can override the default type parsing behavior using the typeParser
+/// parameter.
ParseResult parseFunctionSignatureWithArguments(
OpAsmParser &parser, bool allowVariadic,
SmallVectorImpl<OpAsmParser::Argument> &arguments, bool &isVariadic,
SmallVectorImpl<Type> &resultTypes,
- SmallVectorImpl<DictionaryAttr> &resultAttrs);
+ SmallVectorImpl<DictionaryAttr> &resultAttrs,
+ function_ref<ParseResult(OpAsmParser &, Type &)> typeParser = nullptr);
/// Parser implementation for function-like operations. Uses
/// `funcTypeBuilder` to construct the custom function type given lists of
@@ -59,25 +62,32 @@ ParseResult parseFunctionSignatureWithArguments(
/// whether the function is variadic. If the builder returns a null type,
/// `result` will not contain the `type` attribute. The caller can then add a
/// type, report the error or delegate the reporting to the op's verifier.
-ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result,
- bool allowVariadic, StringAttr typeAttrName,
- FuncTypeBuilder funcTypeBuilder,
- StringAttr argAttrsName, StringAttr resAttrsName);
+/// You can override the default type parsing behavior using the typeParser
+/// parameter.
+ParseResult parseFunctionOp(
+ OpAsmParser &parser, OperationState &result, bool allowVariadic,
+ StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder,
+ StringAttr argAttrsName, StringAttr resAttrsName,
+ function_ref<ParseResult(OpAsmParser &, Type &)> typeParser = nullptr);
/// Printer implementation for function-like operations.
-void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic,
- StringRef typeAttrName, StringAttr argAttrsName,
- StringAttr resAttrsName);
+/// You can override the default type printing behavior using the typePrinter
+/// parameter.
+void printFunctionOp(
+ OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic,
+ StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName,
+ function_ref<void(OpAsmPrinter &, Type)> typePrinter = nullptr);
/// Prints the signature of the function-like operation `op`. Assumes `op` has
/// is a FunctionOpInterface and has passed verification.
-inline void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op,
- ArrayRef<Type> argTypes, bool isVariadic,
- ArrayRef<Type> resultTypes) {
+inline void printFunctionSignature(
+ OpAsmPrinter &p, FunctionOpInterface op, ArrayRef<Type> argTypes,
+ bool isVariadic, ArrayRef<Type> resultTypes,
+ function_ref<void(OpAsmPrinter &, Type)> typePrinter = nullptr) {
call_interface_impl::printFunctionSignature(
p, argTypes, op.getArgAttrsAttr(), isVariadic, resultTypes,
op.getResAttrsAttr(), &op->getRegion(0),
- /*printEmptyResult=*/false);
+ /*printEmptyResult=*/false, typePrinter);
}
/// Prints the list of function prefixed with the "attributes" keyword. The
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 756d3d01a4534..06282f648549f 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -1830,10 +1830,14 @@ class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
/// If `allowType` is false or `allowAttrs` are false then the respective
/// parts of the grammar are not parsed.
ParseResult parseArgument(Argument &result, bool allowType = false,
- bool allowAttrs = false) override {
+ bool allowAttrs = false,
+ function_ref<ParseResult(OpAsmParser &, Type &)>
+ typeParser = nullptr) override {
NamedAttrList attrs;
if (parseOperand(result.ssaName, /*allowResultNumber=*/false) ||
- (allowType && parseColonType(result.type)) ||
+ (allowType && !typeParser && parseColonType(result.type)) ||
+ (allowType && typeParser &&
+ (parseColon() || typeParser(*this, result.type))) ||
(allowAttrs && parseOptionalAttrDict(attrs)) ||
parseOptionalLocationSpecifier(result.sourceLoc))
return failure();
@@ -1842,10 +1846,12 @@ class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
}
/// Parse a single argument if present.
- OptionalParseResult parseOptionalArgument(Argument &result, bool allowType,
- bool allowAttrs) override {
+ OptionalParseResult parseOptionalArgument(
+ Argument &result, bool allowType, bool allowAttrs,
+ function_ref<ParseResult(OpAsmParser &, Type &)> typeParser =
+ nullptr) override {
if (parser.getToken().is(Token::percent_identifier))
- return parseArgument(result, allowType, allowAttrs);
+ return parseArgument(result, allowType, allowAttrs, typeParser);
return std::nullopt;
}
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index f95ad290a1981..71a372369b4bf 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -783,9 +783,10 @@ class DummyAliasOperationPrinter : private OpAsmPrinter {
print(&b);
}
- void printRegionArgument(BlockArgument arg, ArrayRef<NamedAttribute> argAttrs,
- bool omitType) override {
- printType(arg.getType());
+ void printRegionArgument(
+ BlockArgument arg, ArrayRef<NamedAttribute> argAttrs, bool omitType,
+ function_ref<void(OpAsmPrinter &, Type)> typePrinter = nullptr) override {
+ typePrinter ? typePrinter(*this, arg.getType()) : printType(arg.getType());
// Visit the argument location.
if (printerFlags.shouldPrintDebugInfo())
// TODO: Allow deferring argument locations.
@@ -3295,9 +3296,10 @@ class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter {
/// where location printing is controlled by the standard internal option.
/// You may pass omitType=true to not print a type, and pass an empty
/// attribute list if you don't care for attributes.
- void printRegionArgument(BlockArgument arg,
- ArrayRef<NamedAttribute> argAttrs = {},
- bool omitType = false) override;
+ void printRegionArgument(
+ BlockArgument arg, ArrayRef<NamedAttribute> argAttrs = {},
+ bool omitType = false,
+ function_ref<void(OpAsmPrinter &, Type)> typePrinter = nullptr) override;
/// Print the ID for the given value.
void printOperand(Value value) override { printValueID(value); }
@@ -3545,13 +3547,13 @@ void OperationPrinter::printResourceFileMetadata(
/// where location printing is controlled by the standard internal option.
/// You may pass omitType=true to not print a type, and pass an empty
/// attribute list if you don't care for attributes.
-void OperationPrinter::printRegionArgument(BlockArgument arg,
- ArrayRef<NamedAttribute> argAttrs,
- bool omitType) {
+void OperationPrinter::printRegionArgument(
+ BlockArgument arg, ArrayRef<NamedAttribute> argAttrs, bool omitType,
+ function_ref<void(OpAsmPrinter &, Type)> typePrinter) {
printOperand(arg);
if (!omitType) {
os << ": ";
- printType(arg.getType());
+ typePrinter ? typePrinter(*this, arg.getType()) : printType(arg.getType());
}
printOptionalAttrDict(argAttrs);
// TODO: We should allow location aliases on block arguments.
diff --git a/mlir/lib/Interfaces/CallInterfaces.cpp b/mlir/lib/Interfaces/CallInterfaces.cpp
index e8ed4b339a0cb..a08338e514a08 100644
--- a/mlir/lib/Interfaces/CallInterfaces.cpp
+++ b/mlir/lib/Interfaces/CallInterfaces.cpp
@@ -15,15 +15,24 @@ using namespace mlir;
// Argument and result attributes utilities
//===----------------------------------------------------------------------===//
-static ParseResult
-parseTypeAndAttrList(OpAsmParser &parser, SmallVectorImpl<Type> &types,
- SmallVectorImpl<DictionaryAttr> &attrs) {
+static inline ParseResult defaultTypeParser(OpAsmParser &parser, Type &ty) {
+ return parser.parseType(ty);
+}
+
+static inline void defaultTypePrinter(OpAsmPrinter &printer, Type ty) {
+ printer << ty;
+}
+
+static ParseResult parseTypeAndAttrList(
+ OpAsmParser &parser, SmallVectorImpl<Type> &types,
+ SmallVectorImpl<DictionaryAttr> &attrs,
+ function_ref<ParseResult(OpAsmParser &, Type &)> typeParser) {
// Parse individual function results.
return parser.parseCommaSeparatedList([&]() -> ParseResult {
types.emplace_back();
attrs.emplace_back();
NamedAttrList attrList;
- if (parser.parseType(types.back()) ||
+ if (typeParser(parser, types.back()) ||
parser.parseOptionalAttrDict(attrList))
return failure();
attrs.back() = attrList.getDictionary(parser.getContext());
@@ -33,12 +42,16 @@ parseTypeAndAttrList(OpAsmParser &parser, SmallVectorImpl<Type> &types,
ParseResult call_interface_impl::parseFunctionResultList(
OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
- SmallVectorImpl<DictionaryAttr> &resultAttrs) {
+ SmallVectorImpl<DictionaryAttr> &resultAttrs,
+ function_ref<ParseResult(OpAsmParser &, Type &)> typeParser) {
+ if (!typeParser)
+ typeParser = defaultTypeParser;
+
if (failed(parser.parseOptionalLParen())) {
// We already know that there is no `(`, so parse a type.
// Because there is no `(`, it cannot be a function type.
Type ty;
- if (parser.parseType(ty))
+ if (typeParser(parser, ty))
return failure();
resultTypes.push_back(ty);
resultAttrs.emplace_back();
@@ -48,7 +61,7 @@ ParseResult call_interface_impl::parseFunctionResultList(
// Special case for an empty set of parens.
if (succeeded(parser.parseOptionalRParen()))
return success();
- if (parseTypeAndAttrList(parser, resultTypes, resultAttrs))
+ if (parseTypeAndAttrList(parser, resultTypes, resultAttrs, typeParser))
return failure();
return parser.parseRParen();
}
@@ -57,20 +70,24 @@ ParseResult call_interface_impl::parseFunctionSignature(
OpAsmParser &parser, SmallVectorImpl<Type> &argTypes,
SmallVectorImpl<DictionaryAttr> &argAttrs,
SmallVectorImpl<Type> &resultTypes,
- SmallVectorImpl<DictionaryAttr> &resultAttrs, bool mustParseEmptyResult) {
+ SmallVectorImpl<DictionaryAttr> &resultAttrs, bool mustParseEmptyResult,
+ function_ref<ParseResult(OpAsmParser &, Type &)> typeParser) {
+ if (!typeParser)
+ typeParser = defaultTypeParser;
+
// Parse arguments.
if (parser.parseLParen())
return failure();
if (failed(parser.parseOptionalRParen())) {
- if (parseTypeAndAttrList(parser, argTypes, argAttrs))
+ if (parseTypeAndAttrList(parser, argTypes, argAttrs, typeParser))
return failure();
if (parser.parseRParen())
return failure();
}
// Parse results.
if (succeeded(parser.parseOptionalArrow()))
- return call_interface_impl::parseFunctionResultList(parser, resultTypes,
- resultAttrs);
+ return call_interface_impl::parseFunctionResultList(
+ parser, resultTypes, resultAttrs, typeParser);
if (mustParseEmptyResult)
return failure();
return success();
@@ -78,8 +95,12 @@ ParseResult call_interface_impl::parseFunctionSignature(
/// Print a function result list. The provided `attrs` must either be null, or
/// contain a set of DictionaryAttrs of the same arity as `types`.
-static void printFunctionResultList(OpAsmPrinter &p, TypeRange types,
- ArrayAttr attrs) {
+static void
+printFunctionResultList(OpAsmPrinter &p, TypeRange types, ArrayAttr attrs,
+ function_ref<void(OpAsmPrinter &, Type)> typePrinter) {
+ if (!typePrinter)
+ typePrinter = defaultTypePrinter;
+
assert(!types.empty() && "Should not be called for empty result list.");
assert((!attrs || attrs.size() == types.size()) &&
"Invalid number of attributes.");
@@ -90,22 +111,41 @@ static void printFunctionResultList(OpAsmPrinter &p, TypeRange types,
if (needsParens)
os << '(';
llvm::interleaveComma(llvm::seq<size_t>(0, types.size()), os, [&](size_t i) {
- p.printType(types[i]);
+ typePrinter(p, types[i]);
if (attrs)
p.printOptionalAttrDict(llvm::cast<DictionaryAttr>(attrs[i]).getValue());
});
if (needsParens)
os << ')';
}
+static void
+printFunctionalType(OpAsmPrinter &p, TypeRange &inputs, TypeRange &results,
+ function_ref<void(OpAsmPrinter &, Type)> typePrinter) {
+ p << '(';
+ llvm::interleaveComma(inputs, p, [&](Type ty) { typePrinter(p, ty); });
+ p << ')';
+
+ bool wrapped = !llvm::hasSingleElement(results) ||
+ llvm::isa<FunctionType>((*results.begin()));
+ if (wrapped)
+ p << '(';
+ llvm::interleaveComma(results, p, [&](Type ty) { typePrinter(p, ty); });
+ if (wrapped)
+ p << ')';
+}
void call_interface_impl::printFunctionSignature(
OpAsmPrinter &p, TypeRange argTypes, ArrayAttr argAttrs, bool isVariadic,
T...
[truncated]
|
ef05b42
to
60d163b
Compare
adds type parsing functors to
call_interface_impl
andfunction_interface_impl