diff --git a/llvm/cmake/modules/HandleLLVMOptions.cmake b/llvm/cmake/modules/HandleLLVMOptions.cmake index c126b0d073322..c59911c52c625 100644 --- a/llvm/cmake/modules/HandleLLVMOptions.cmake +++ b/llvm/cmake/modules/HandleLLVMOptions.cmake @@ -797,6 +797,16 @@ if (MSVC) # Enable warnings if (LLVM_ENABLE_WARNINGS) + # Remove all -wd flag to enable warnings + if (NOT CLANG_CL) + set(msvc_warning_flags + # Promoted warnings. + -w14062 # Promote 'enumerator in switch of enum is not handled' to level 1 warning. + + # Promoted warnings to errors. + -we4238 # Promote 'nonstandard extension used : class rvalue used as lvalue' to error. + ) + endif(NOT CLANG_CL) # Put /W4 in front of all the -we flags. cl.exe doesn't care, but for # clang-cl having /W4 after the -we flags will re-enable the warnings # disabled by -we. diff --git a/llvm/include/llvm/ADT/FunctionExtras.h b/llvm/include/llvm/ADT/FunctionExtras.h index d92868e3715f4..33ad1db48716f 100644 --- a/llvm/include/llvm/ADT/FunctionExtras.h +++ b/llvm/include/llvm/ADT/FunctionExtras.h @@ -153,7 +153,7 @@ template class UniqueFunctionBase { void *StoragePtr; size_t Size; size_t Alignment; - } OutOfLineStorage; + } OutOfLineStorage = {}; static_assert( sizeof(OutOfLineStorageT) <= InlineStorageSize, "Should always use all of the out-of-line storage for inline storage!"); diff --git a/llvm/lib/DebugInfo/DWARF/DWARFVerifier.cpp b/llvm/lib/DebugInfo/DWARF/DWARFVerifier.cpp index 8ec3f1729b974..04f5fc0ebae67 100644 --- a/llvm/lib/DebugInfo/DWARF/DWARFVerifier.cpp +++ b/llvm/lib/DebugInfo/DWARF/DWARFVerifier.cpp @@ -1444,8 +1444,8 @@ void DWARFVerifier::verifyNameIndexAttribute( } if (AttrEnc.Index == dwarf::DW_IDX_parent) { - constexpr static auto AllowedForms = {dwarf::Form::DW_FORM_flag_present, - dwarf::Form::DW_FORM_ref4}; + static constexpr dwarf::Form AllowedForms[] = { + dwarf::Form::DW_FORM_flag_present, dwarf::Form::DW_FORM_ref4}; if (!is_contained(AllowedForms, AttrEnc.Form)) { ErrorCategory.Report("Unexpected NameIndex Abbreviation", [&]() { error() << formatv( diff --git a/llvm/lib/Support/Windows/Threading.inc b/llvm/lib/Support/Windows/Threading.inc index d862dbd7f71c9..9c584415e0cbe 100644 --- a/llvm/lib/Support/Windows/Threading.inc +++ b/llvm/lib/Support/Windows/Threading.inc @@ -42,6 +42,9 @@ void llvm_thread_join_impl(HANDLE hThread) { if (::WaitForSingleObject(hThread, INFINITE) == WAIT_FAILED) { ReportLastErrorFatal("WaitForSingleObject failed"); } + if (::CloseHandle(hThread) == FALSE) { + ReportLastErrorFatal("CloseHandle failed"); + } } void llvm_thread_detach_impl(HANDLE hThread) { diff --git a/mlir/docs/DefiningDialects/AttributesAndTypes.md b/mlir/docs/DefiningDialects/AttributesAndTypes.md index 022bdad9fe512..fde737f51356a 100644 --- a/mlir/docs/DefiningDialects/AttributesAndTypes.md +++ b/mlir/docs/DefiningDialects/AttributesAndTypes.md @@ -565,6 +565,11 @@ For Attributes, these methods will have the form: - `void MyAttr::print(AsmPrinter &p) const` +It is possible to use newlines and indents in custom `print` methods. +However, multiline Types or Attributes are not recommended nor allowed in the upstream MLIR dialects. +They can be used in custom dialects to improve flexibility and readability, e.g. in cases of +multiple nested Types and Attributes. + #### Using `assemblyFormat` Attributes and types defined in ODS with a mnemonic can define an diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h index 6245f88db3d19..6ecf4a291cfc6 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -260,12 +260,12 @@ struct BufferizationOptions { std::function; /// Initializer function for analysis state. using AnalysisStateInitFn = std::function; - /// Tensor -> MemRef type converter. - /// Parameters: tensor type, memory space, func op, bufferization options + /// Tensor-like -> Buffer-like type conversion. + /// Parameters: tensor-like type, memory space, func op, bufferization options using FunctionArgTypeConverterFn = - std::function; - /// Tensor -> MemRef type converter. + /// Tensor -> MemRef type conversion. /// Parameters: tensor type, memory space, bufferization options using UnknownTypeConverterFn = std::function; @@ -335,10 +335,12 @@ struct BufferizationOptions { /// predictable. void setFunctionBoundaryTypeConversion(LayoutMapOption layoutMapOption); - /// Type converter from tensors to memrefs. This type converter is used to - /// determine bufferized function argument and result types. By default, a - /// type converter that returns a memref type with a fully dynamic layout map - /// is used. + /// Type conversion from tensors to buffers. This type conversion is used to + /// determine bufferized function argument and result types. + /// + /// By default, if tensor is a (builtin) tensor type, it is converted to a + /// memref type with a fully dynamic layout map; if tensor is a (generic) + /// tensor-like type, it is converted using TensorLikeType::getBufferType(). /// /// If `bufferizeFunctionBoundaries` is not set, this function isn't used. FunctionArgTypeConverterFn functionArgTypeConverterFn = nullptr; @@ -350,10 +352,9 @@ struct BufferizationOptions { /// If `bufferizeFunctionBoundaries` is not set, this flag has no effect. bool inferFunctionResultLayout = true; - /// Type converter from tensors to memrefs. This type converter is used if no - /// memref type could be inferred during bufferization. By default, a type - /// converter that returns a memref type with a fully dynamic layout map is - /// used. + /// Type conversion from tensors to memrefs. This type conversion is used if + /// no memref type could be inferred during bufferization. By default, returns + /// a memref type with a fully dynamic layout map. UnknownTypeConverterFn unknownTypeConverterFn = nullptr; // Use during type conversion to determine the memory space for memref based diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td index 23bf5cf15e256..62ddd5193dde2 100644 --- a/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td +++ b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td @@ -402,6 +402,26 @@ class quant_ScalarOrTensorOf : def quant_QuantizedType : Type($_self)">, "quantized type">; +// UniformQuantizedPerAxisType +def quant_UniformQuantizedPerAxisType : + DialectType($_self)">, + "UniformQuantizedPerAxisType">; + +// QuantileQuantizedPerAxisType +def quant_QuantileQuantizedPerAxisType : + DialectType($_self)">, + "QuantileQuantizedPerAxisType">; + +// Predicate for detecting a container or primitive of UniformQuantizedPerAxisType. +def quant_UniformQuantizedPerAxisValueType : + quant_ScalarOrTensorOf; + +// Predicate for detecting a container or primitive of QuantileQuantizedPerAxisType. +def quant_QuantileQuantizedPerAxisValueType : + quant_ScalarOrTensorOf; + def quant_ScalarType : Type:$quantiles, + DoubleAPFloat:$scale, + SignedVarInt:$zeroPoint, + SignedVarInt:$storageTypeMin, + SignedVarInt:$storageTypeMax +)>; + +def QuantileQuantizedPerAxisType: DialectType<(type + VarInt:$flags, + Type:$storageType, + Type:$quantileType, + Type:$expressedType, + VarInt:$quantizedDimension, + SignedVarInt:$storageTypeMin, + SignedVarInt:$storageTypeMax, + Array:$quantiles, + Array:$scales, + Array:$zeroPoints +)> { + // Note: builder order differs from bytecode. + let cBuilder = [{ + get<$_resultType>(context, flags, storageType, quantileType, expressedType, quantiles, scales, + zeroPoints, quantizedDimension, storageTypeMin, storageTypeMax) + }]; +} + /// This enum contains marker codes used to indicate which attribute is /// currently being decoded, and how it should be decoded. The order of these /// codes should generally be unchanged, as any changes will inevitably break @@ -106,7 +137,8 @@ def QuantDialectTypes : DialectTypes<"Quant"> { let elems = [ReservedOrDead, AnyQuantizedType, AnyQuantizedTypeWithExpressedType, CalibratedQuantizedType, UniformQuantizedType, UniformQuantizedPerAxisType, - UniformQuantizedSubChannelType]; + UniformQuantizedSubChannelType, + QuantileQuantizedType, QuantileQuantizedPerAxisType]; } #endif // QUANT_BYTECODE diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h index 44062fe376ec0..911076abdbaab 100644 --- a/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h +++ b/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h @@ -26,6 +26,8 @@ struct AnyQuantizedTypeStorage; struct UniformQuantizedSubChannelTypeStorage; struct UniformQuantizedTypeStorage; struct UniformQuantizedPerAxisTypeStorage; +struct QuantileQuantizedTypeStorage; +struct QuantileQuantizedPerAxisTypeStorage; struct CalibratedQuantizedTypeStorage; } // namespace detail @@ -83,6 +85,24 @@ class QuantizedType : public Type { return llvm::maxUIntN(integralWidth); } + static constexpr int64_t getDefaultMaximumForF8E4M3FN() { return 448; } + + static constexpr int64_t getDefaultMinimumForF8E4M3FN() { + return -getDefaultMaximumForF8E4M3FN(); + } + + static constexpr int64_t getDefaultMaximumForF8E5M2() { return 57344; } + + static constexpr int64_t getDefaultMinimumForF8E5M2() { + return -getDefaultMaximumForF8E5M2(); + } + + static constexpr int64_t getDefaultMaximumForF4E2M1FN() { return 6; } + + static constexpr int64_t getDefaultMinimumForF4E2M1FN() { + return -getDefaultMaximumForF4E2M1FN(); + } + /// Gets the original expressed type that this quantized type approximates. /// Note that this presumes that the quantized type was always derived from /// a floating point type, which in the broadest definition, is not true (i.e. @@ -253,7 +273,7 @@ class AnyQuantizedType /// Per-layer, optional parameters omitted: /// !quant /// -/// StorageType: 'i'|'u' NumBits +/// StorageType: 'i'|'u' NumBits, 'f4', 'hf8', 'bf8' /// ExpressedType: 'f16', 'f32', 'bf16', 'f64' /// Scale: A legal double value /// ZeroPoint: An integer value @@ -287,6 +307,8 @@ class UniformQuantizedType int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax); + static bool classof(mlir::Type type); + /// Gets the scale term. The scale designates the difference between the real /// values corresponding to consecutive quantized values differing by 1. double getScale() const; @@ -311,7 +333,7 @@ class UniformQuantizedType /// Per-axis, optional parameters omitted: /// !quant /// -/// StorageType: 'i'|'u' NumBits +/// StorageType: 'i'|'u' NumBits, 'f4', 'hf8', 'bf8' /// ExpressedType: 'f16', 'f32', 'bf16', 'f64' /// QuantizedDim: An integer value /// QuantParams: (Scale ':' ZeroPoint)+ @@ -350,6 +372,8 @@ class UniformQuantizedPerAxisType int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax); + static bool classof(mlir::Type type); + /// Gets the quantization scales. The scales designate the difference between /// the real values corresponding to consecutive quantized values differing /// by 1. The ith scale corresponds to the ith slice in the @@ -383,6 +407,153 @@ class UniformQuantizedPerAxisType } }; +// clang-format off +/// QuantileQuantizedType derives from UniformQuantizedType and adds to it a +/// look up table array of quantile values. The type of the data in the look up +/// table is determined by the quantileType member: supported quantileType types +/// are integer/unsigned/f4/hf8/bf8/f16/bf16/f32/f64. +/// +/// Syntax synopsis: +/// Per-layer, all parameters expressed: +/// !quant +/// Per-layer, optional parameters omitted: +/// !quant +/// +/// StorageType: 'i'|'u' NumBits, 'f4', 'hf8', 'bf8' +/// QuantileType: 'i'|'u' NumBits, 'f4', 'hf8', 'bf8', 'f16', 'bf16', 'f32', 'f64' +/// ExpressedType: 'f16', 'f32', 'bf16', 'f64' +/// Quantiles: Quantile+ +/// Quantile: A legal double value +/// Scale: A legal double value +/// ZeroPoint: An integer value +// clang-format on +class QuantileQuantizedType + : public Type::TypeBase { +public: + using Base::Base; + using Base::getChecked; + + static constexpr StringLiteral name = "quant.quantile"; + + /// Gets an instance of the type with all parameters specified but not + /// checked. + static QuantileQuantizedType get(unsigned flags, Type storageType, + Type quantileType, Type expressedType, + ArrayRef quantiles, double scale, + int64_t zeroPoint, int64_t storageTypeMin, + int64_t storageTypeMax); + + static QuantileQuantizedType + getChecked(function_ref emitError, unsigned flags, + Type storageType, Type quantileType, Type expressedType, + ArrayRef quantiles, double scale, int64_t zeroPoint, + int64_t storageTypeMin, int64_t storageTypeMax); + + /// Verifies construction invariants and issues errors/warnings. + static LogicalResult + verifyInvariants(function_ref emitError, unsigned flags, + Type storageType, Type quantileType, Type expressedType, + ArrayRef quantiles, double scale, int64_t zeroPoint, + int64_t storageTypeMin, int64_t storageTypeMax); + + static bool classof(mlir::Type type); + + /// Gets the quantileType + Type getQuantileType() const; + + /// Gets the quantileType bit width + unsigned getQuantileTypeIntegralWidth() const; + + /// Gets the quantile values + ArrayRef getQuantiles() const; + + // Fixed point values are real numbers divided by a scale. + // Currently, only signed storage types are treated as fixed point. + // A fixed point value can be obtained from an affine value by subtracting + // the zeroPoint. + // In the future, this may be explicit versus implied by type and zeroPoint. + bool isFixedPoint() const { return isSigned() && getZeroPoint() == 0; } +}; + +// clang-format off +/// Represents per-axis QuantileQuantizedType (also known as per-channel +/// quantization). The type of the data in the look up table is determined by +/// the quantileType member: supported quantileType types are +/// integer/unsigned/f4/hf8/bf8/f16/bf16/f32/f64. +/// +/// Syntax synopsis: +/// Per-axis, all parameters expressed: +/// !quant +/// Per-axis, optional parameters omitted: +/// !quant +/// +/// StorageType: 'i'|'u' NumBits, 'f4', 'hf8', 'bf8' +/// QuantileType: 'i'|'u' NumBits, 'f4', 'hf8', 'bf8', 'f16', 'bf16', 'f32', 'f64' +/// ExpressedType: 'f16', 'f32', 'bf16', 'f64' +/// QuantizedDim: An integer value +/// Quantiles: Quantile+ +/// Quantile: A legal double value +/// QuantParams: (Scale ':' ZeroPoint)+ +/// Scale: A legal double value +/// ZeroPoint: An integer value +// clang-format on +class QuantileQuantizedPerAxisType + : public Type::TypeBase { +public: + using Base::Base; + using Base::getChecked; + + static constexpr StringLiteral name = "quant.quantile_per_axis"; + + /// Gets an instance of the type with all parameters specified but not + /// checked. + static QuantileQuantizedPerAxisType + get(unsigned flags, Type storageType, Type quantileType, Type expressedType, + ArrayRef quantiles, ArrayRef scales, + ArrayRef zeroPoints, int32_t quantizedDimension, + int64_t storageTypeMin, int64_t storageTypeMax); + + /// Gets an instance of the type with all specified parameters checked. + /// Returns a nullptr convertible type on failure. + static QuantileQuantizedPerAxisType + getChecked(function_ref emitError, unsigned flags, + Type storageType, Type quantileType, Type expressedType, + ArrayRef quantiles, ArrayRef scales, + ArrayRef zeroPoints, int32_t quantizedDimension, + int64_t storageTypeMin, int64_t storageTypeMax); + + /// Verifies construction invariants and issues errors/warnings. + static LogicalResult + verifyInvariants(function_ref emitError, unsigned flags, + Type storageType, Type quantileType, Type expressedType, + ArrayRef quantiles, ArrayRef scales, + ArrayRef zeroPoints, int32_t quantizedDimension, + int64_t storageTypeMin, int64_t storageTypeMax); + + static bool classof(mlir::Type type); + + /// Gets the quantileType + Type getQuantileType() const; + + /// Gets the quantileType bit width + unsigned getQuantileTypeIntegralWidth() const; + + /// Gets the quantile values + ArrayRef getQuantiles() const; + + /// Fixed point values are real numbers divided by a scale. + /// Currently, only signed storage types are treated as fixed point. + /// A fixed point value can be obtained from an affine value by subtracting + /// the zeroPoint. + /// In the future, this may be explicit versus implied by type and zeroPoint. + bool isFixedPoint() const { + return isSigned() && !llvm::is_contained(getZeroPoints(), 0); + } +}; + /// Represents sub-channel (also known as blockwise quantization). /// /// Syntax synopsis: @@ -396,7 +567,7 @@ class UniformQuantizedPerAxisType /// ScaleZeroList ::= ScaleZero (',' ScaleZero)* /// ScaleZero ::= Scale (':' ZeroPoint)? /// -/// StorageType: 'i'|'u' NumBits +/// StorageType: 'i'|'u' NumBits, 'f4', 'hf8', 'bf8' /// ExpressedType: 'f16', 'f32', 'bf16', 'f64' /// AxisSpec: An integer value /// BlockSizeSpec: An integer value diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index d70aa346eaa1f..a2a4c01f0d2a0 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -130,6 +130,18 @@ class AsmPrinter { /// Return the raw output stream used by this printer. virtual raw_ostream &getStream() const; + /// Print a newline and indent the printer to the start of the current + /// operation/attribute/type. + /// Note: For attributes and types this method should only be used in + /// custom dialects. Usage in upstream MLIR dialects is currently disallowed. + virtual void printNewline(); + + /// Increase indentation. + virtual void increaseIndent(); + + /// Decrease indentation. + virtual void decreaseIndent(); + /// Print the given floating point value in a stabilized form that can be /// roundtripped through the IR. This is the companion to the 'parseFloat' /// hook on the AsmParser. @@ -448,16 +460,6 @@ class OpAsmPrinter : public AsmPrinter { /// Print a loc(...) specifier if printing debug info is enabled. virtual void printOptionalLocationSpecifier(Location loc) = 0; - /// Print a newline and indent the printer to the start of the current - /// operation. - virtual void printNewline() = 0; - - /// Increase indentation. - virtual void increaseIndent() = 0; - - /// Decrease indentation. - virtual void decreaseIndent() = 0; - /// Print a block argument in the usual format of: /// %ssaName : type {attr1=42} loc("here") /// where location printing is controlled by the standard internal option. @@ -1654,16 +1656,22 @@ class OpAsmParser : public AsmParser { SmallVectorImpl &result) { size_t operandSize = llvm::range_size(operands); size_t typeSize = llvm::range_size(types); - if (operandSize != typeSize) { + if (typeSize != 0 && operandSize != typeSize) { // If no location was provided, report errors at the beginning of the op. return emitError(loc.isValid() ? loc : getNameLoc()) << "number of operands and types do not match: got " << operandSize << " operands and " << typeSize << " types"; } - for (auto [operand, type] : llvm::zip_equal(operands, types)) - if (resolveOperand(operand, type, result)) - return failure(); + if (typeSize == 0) { + for (auto it : operands) + if (resolveOperand(it, Type(), result)) + return failure(); + } else { + for (auto [operand, type] : llvm::zip_equal(operands, types)) + if (resolveOperand(operand, type, result)) + return failure(); + } return success(); } diff --git a/mlir/include/mlir/IR/Properties.td b/mlir/include/mlir/IR/Properties.td index a6221f9aaaef9..9be84405c936f 100644 --- a/mlir/include/mlir/IR/Properties.td +++ b/mlir/include/mlir/IR/Properties.td @@ -468,7 +468,7 @@ class ArrayProp, string newSummary = ""> : return $_diag() << "expected array attribute"; for (::mlir::Attribute elemAttr : arrayAttr) { }] # _makePropStorage.ret # [{ - auto elemRes = [&](Attribute propAttr, }] # elem.storageType # [{& propStorage) -> ::mlir::LogicalResult { + auto elemRes = [&](::mlir::Attribute propAttr, }] # elem.storageType # [{& propStorage) -> ::mlir::LogicalResult { }] # !subst("$_attr", "propAttr", !subst("$_storage", "propStorage", elem.convertFromAttribute)) # [{ }(elemAttr, elemVal); @@ -480,7 +480,7 @@ class ArrayProp, string newSummary = ""> : }]; let convertToAttribute = [{ - SmallVector elems; + SmallVector<::mlir::Attribute> elems; for (const auto& elemVal : $_storage) { auto elemAttr = [&](const }] # elem.storageType #[{& propStorage) -> ::mlir::Attribute { }] # !subst("$_storage", "propStorage", elem.convertToAttribute) # [{ @@ -647,7 +647,7 @@ class OptionalProp } ::mlir::Attribute presentAttr = arrayAttr[0]; }] # _makePropStorage.ret # [{ - auto presentRes = [&](Attribute propAttr, }] # p.storageType # [{& propStorage) -> ::mlir::LogicalResult { + auto presentRes = [&](::mlir::Attribute propAttr, }] # p.storageType # [{& propStorage) -> ::mlir::LogicalResult { }] # !subst("$_storage", "propStorage", !subst("$_attr", "propAttr", p.convertFromAttribute)) # [{ }(presentAttr, presentVal); diff --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h index 2162a74a51580..8959dab047103 100644 --- a/mlir/include/mlir/IR/StorageUniquerSupport.h +++ b/mlir/include/mlir/IR/StorageUniquerSupport.h @@ -200,7 +200,7 @@ class StorageUserBase : public BaseT, public Traits... { // If the construction invariants fail then we return a null attribute. if (failed(ConcreteT::verifyInvariants(emitErrorFn, args...))) return ConcreteT(); - return UniquerT::template get(ctx, args...); + return UniquerT::template get(ctx, std::forward(args)...); } /// Get an instance of the concrete type from a void pointer. diff --git a/mlir/include/mlir/IR/SymbolInterfaces.td b/mlir/include/mlir/IR/SymbolInterfaces.td index a8b04d0453110..f715d82199bd9 100644 --- a/mlir/include/mlir/IR/SymbolInterfaces.td +++ b/mlir/include/mlir/IR/SymbolInterfaces.td @@ -171,11 +171,6 @@ def Symbol : OpInterface<"SymbolOpInterface"> { if (concreteOp.isDeclaration() && concreteOp.isPublic()) return concreteOp.emitOpError("symbol declaration cannot have public " "visibility"); - auto parent = $_op->getParentOp(); - if (parent && !parent->hasTrait() && parent->isRegistered()) { - return concreteOp.emitOpError("symbol's parent must have the SymbolTable " - "trait"); - } return success(); }]; @@ -227,4 +222,6 @@ def SymbolUserOpInterface : OpInterface<"SymbolUserOpInterface"> { // Op defines a symbol table. def SymbolTable : NativeOpTrait<"SymbolTable">; +def SymbolContainer : NativeOpTrait<"SymbolContainer">; + #endif // MLIR_IR_SYMBOLINTERFACES diff --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h index e4622354b8980..007648e0e9b6e 100644 --- a/mlir/include/mlir/IR/SymbolTable.h +++ b/mlir/include/mlir/IR/SymbolTable.h @@ -259,6 +259,13 @@ class SymbolTable { StringAttr newSymbolName, Region *from); + static LogicalResult replaceAllSymbolUses(Operation *oldSymbol, + SymbolRefAttr newSymbolName, + Operation *from); + static LogicalResult replaceAllSymbolUses(Operation *oldSymbol, + SymbolRefAttr newSymbolName, + Region *from); + private: Operation *symbolTableOp; @@ -422,6 +429,7 @@ class SymbolUserMap { /// Replace all of the uses of the given symbol with `newSymbolName`. void replaceAllUsesWith(Operation *symbol, StringAttr newSymbolName); + void replaceAllUsesWith(Operation *symbol, SymbolRefAttr newSymbolName); private: /// A reference to the symbol table used to construct this map. @@ -482,6 +490,40 @@ class SymbolTable : public TraitBase { } }; +template +class SymbolContainer : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return mlir::success(); // TODO::implement + } + + /// Look up a symbol with the specified name, returning null if no such + /// name exists. Symbol names never include the @ on them. Note: This + /// performs a linear scan of held symbols. + Operation *lookupSymbol(StringAttr name) { + return mlir::SymbolTable::lookupSymbolIn(this->getOperation(), name); + } + template + T lookupSymbol(StringAttr name) { + return dyn_cast_or_null(lookupSymbol(name)); + } + Operation *lookupSymbol(SymbolRefAttr symbol) { + return mlir::SymbolTable::lookupSymbolIn(this->getOperation(), symbol); + } + template + T lookupSymbol(SymbolRefAttr symbol) { + return dyn_cast_or_null(lookupSymbol(symbol)); + } + + Operation *lookupSymbol(StringRef name) { + return mlir::SymbolTable::lookupSymbolIn(this->getOperation(), name); + } + template + T lookupSymbol(StringRef name) { + return dyn_cast_or_null(lookupSymbol(name)); + } +}; + } // namespace OpTrait //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h index aef7ec622fe4f..f2a7d27f5c78e 100644 --- a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h +++ b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h @@ -146,17 +146,17 @@ class EffectInstance { : effect(effect), resource(resource), stage(stage), effectOnFullRegion(effectOnFullRegion) {} template ::value, - bool> = true> + std::enable_if_t::value, + bool> = true> EffectInstance(EffectT *effect, T value, Resource *resource = DefaultResource::get()) : effect(effect), resource(resource), value(value), stage(0), effectOnFullRegion(false) {} template ::value, - bool> = true> + std::enable_if_t::value, + bool> = true> EffectInstance(EffectT *effect, T value, int stage, bool effectOnFullRegion, Resource *resource = DefaultResource::get()) : effect(effect), resource(resource), value(value), stage(stage), @@ -223,6 +223,9 @@ class EffectInstance { if (OpResult result = llvm::dyn_cast_if_present(value)) { return result; } + if (Value result = llvm::dyn_cast_if_present(value)) { + return result; + } return cast_if_present(value); } @@ -264,7 +267,8 @@ class EffectInstance { /// The Symbol, OpOperand, OpResult or BlockArgument that the effect applies /// to. This is optionally null. - PointerUnion value; + PointerUnion + value; /// Additional parameters of the effect instance. An attribute is used for /// type-safe structured storage and context-based uniquing. Concrete effects diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h index 16893c6db87b1..c882ab14f78ab 100644 --- a/mlir/include/mlir/Pass/Pass.h +++ b/mlir/include/mlir/Pass/Pass.h @@ -298,6 +298,10 @@ class Pass { /// pass. void copyOptionValuesFrom(const Pass *other); + /// Copy the option values from 'other', which are PassPipeline options. + /// Here we copy only those options that have the same argument name. + void copyOptionValuesFrom(const detail::PassOptions &other); + private: /// Out of line virtual method to ensure vtables and metadata are emitted to a /// single .o file. diff --git a/mlir/include/mlir/Pass/PassOptions.h b/mlir/include/mlir/Pass/PassOptions.h index 0c71f78b52d3d..318ea120a0413 100644 --- a/mlir/include/mlir/Pass/PassOptions.h +++ b/mlir/include/mlir/Pass/PassOptions.h @@ -340,6 +340,9 @@ class PassOptions : protected llvm::cl::SubCommand { /// same options as 'this'. void copyOptionValuesFrom(const PassOptions &other); + /// Copy only those options that have the same argument name. + void matchAndCopyOptionValuesFrom(const PassOptions &otherPassOptions); + /// Parse options out as key=value pairs that can then be handed off to the /// `llvm::cl` command line passing infrastructure. Everything is space /// separated. diff --git a/mlir/include/mlir/Support/Timing.h b/mlir/include/mlir/Support/Timing.h index a8a4bfd1c6cf1..3707e9219b6d7 100644 --- a/mlir/include/mlir/Support/Timing.h +++ b/mlir/include/mlir/Support/Timing.h @@ -473,6 +473,11 @@ void registerDefaultTimingManagerCLOptions(); /// 'registerDefaultTimingManagerOptions' to a `DefaultTimingManager`. void applyDefaultTimingManagerCLOptions(DefaultTimingManager &tm); +/// Create an output strategy for the specified format, to be passed to +/// DefaultTimingManager::setOutput(). +std::unique_ptr +createOutputStrategy(DefaultTimingManager::OutputFormat fmt, raw_ostream &os); + } // namespace mlir #endif // MLIR_SUPPORT_TIMING_H diff --git a/mlir/include/mlir/Transforms/InliningUtils.h b/mlir/include/mlir/Transforms/InliningUtils.h index ed6413d8cd44c..3fb877c900598 100644 --- a/mlir/include/mlir/Transforms/InliningUtils.h +++ b/mlir/include/mlir/Transforms/InliningUtils.h @@ -172,6 +172,14 @@ class DialectInlinerInterface return result; } + /// Hook to cleanup IR before erase call op + virtual void eraseCall(Operation *call) const; + + /// Hook to get proper place where callable region will be inlined + /// By default returns block of the call operation + virtual std::tuple + getInlineBlockAndPoint(Operation *call) const; + /// Process a set of blocks that have been inlined for a call. This callback /// is invoked before inlined terminator operations have been processed. virtual void processInlinedCallBlocks( diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp index 756d3d01a4534..b444181c30251 100644 --- a/mlir/lib/AsmParser/Parser.cpp +++ b/mlir/lib/AsmParser/Parser.cpp @@ -1093,16 +1093,24 @@ Value OperationParser::resolveSSAUse(UnresolvedOperand useInfo, Type type) { // If we have already seen a value of this name, return it. if (useInfo.number < entries.size() && entries[useInfo.number].value) { Value result = entries[useInfo.number].value; + // Check that the type matches the other uses. - if (result.getType() == type) - return maybeRecordUse(result); - - emitError(useInfo.location, "use of value '") - .append(useInfo.name, - "' expects different type than prior uses: ", type, " vs ", - result.getType()) - .attachNote(getEncodedSourceLocation(entries[useInfo.number].loc)) - .append("prior use here"); + if (type && result.getType() != type) { + emitError(useInfo.location, "use of value '") + .append(useInfo.name, + "' expects different type than prior uses: ", type, " vs ", + result.getType()) + .attachNote(getEncodedSourceLocation(entries[useInfo.number].loc)) + .append("prior use here"); + return nullptr; + } + + return maybeRecordUse(result); + } + + if (!type) { + emitError(useInfo.location, "forward reference of value '") + .append(useInfo.name, "' requires explicit type specification"); return nullptr; } diff --git a/mlir/lib/Dialect/Arith/IR/CMakeLists.txt b/mlir/lib/Dialect/Arith/IR/CMakeLists.txt index 4beb99ccfdfba..ffcdf46d7a801 100644 --- a/mlir/lib/Dialect/Arith/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Arith/IR/CMakeLists.txt @@ -22,6 +22,7 @@ add_mlir_dialect_library(MLIRArithDialect MLIRArithOpsInterfacesIncGen LINK_LIBS PUBLIC + MLIRTransformUtils MLIRCastInterfaces MLIRDialect MLIRInferIntRangeCommon diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp index 8f17a82fabe03..3eef4d673d771 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -341,11 +341,21 @@ bool OpFilter::isOpAllowed(Operation *op) const { namespace { /// Default function arg type converter: Use a fully dynamic layout map. -BaseMemRefType -defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace, +BufferLikeType +defaultFunctionArgTypeConverter(TensorLikeType type, Attribute memorySpace, func::FuncOp funcOp, const BufferizationOptions &options) { - return getMemRefTypeWithFullyDynamicLayout(type, memorySpace); + if (auto tensorType = mlir::dyn_cast(type)) { + return cast( + getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace)); + } + + // If not builtin, fallback to TensorLikeType::getBufferType() + auto bufferType = + type.getBufferType(options, [&]() { return funcOp->emitError(); }); + assert(succeeded(bufferType) && + "a valid buffer is always expected at function boundary"); + return *bufferType; } /// Default unknown type converter: Use a fully dynamic layout map. BaseMemRefType @@ -388,14 +398,25 @@ BufferizationOptions::dynCastBufferizableOp(Value value) const { void BufferizationOptions::setFunctionBoundaryTypeConversion( LayoutMapOption layoutMapOption) { - functionArgTypeConverterFn = [=](TensorType tensorType, Attribute memorySpace, + functionArgTypeConverterFn = [=](TensorLikeType type, Attribute memorySpace, func::FuncOp funcOp, const BufferizationOptions &options) { - if (layoutMapOption == LayoutMapOption::IdentityLayoutMap) - return bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType, - memorySpace); - return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType, - memorySpace); + if (auto tensorType = mlir::dyn_cast(type)) { + if (layoutMapOption == LayoutMapOption::IdentityLayoutMap) + return cast( + bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType, + memorySpace)); + return cast( + bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType, + memorySpace)); + } + + // If not builtin, fallback to TensorLikeType::getBufferType() + auto bufferType = + type.getBufferType(options, [&]() { return funcOp->emitError(); }); + assert(succeeded(bufferType) && + "a valid buffer is always expected at function boundary"); + return *bufferType; }; inferFunctionResultLayout = layoutMapOption == LayoutMapOption::InferLayoutMap; @@ -662,16 +683,6 @@ bool AnalysisState::hasUndefinedContents(OpOperand *opOperand) const { return false; } -// bufferization.to_buffer is not allowed to change the rank. -static void ensureToBufferOpIsValid(Value tensor, Type memrefType) { -#ifndef NDEBUG - auto rankedTensorType = llvm::dyn_cast(tensor.getType()); - assert((!rankedTensorType || llvm::cast(memrefType).getRank() == - rankedTensorType.getRank()) && - "to_buffer would be invalid: mismatching ranks"); -#endif -} - FailureOr bufferization::getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options, const BufferizationState &state) { @@ -690,7 +701,7 @@ FailureOr bufferization::getBuffer(RewriterBase &rewriter, Value value, FailureOr bufferType = getBufferType(value, options, state); if (failed(bufferType)) return failure(); - ensureToBufferOpIsValid(value, *bufferType); + return rewriter .create(value.getLoc(), *bufferType, value) .getResult(); diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp index 6c08cdfb669f3..71d5fcb233f41 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp @@ -73,9 +73,6 @@ struct BuiltinTensorExternalModel mlir::LogicalResult verifyCompatibleBufferType( mlir::Type tensor, BufferLikeType bufferType, llvm::function_ref emitError) const { - assert(isa(tensor) && "expected tensor type"); - assert(isa(bufferType) && "expected memref type"); - auto tensorType = cast(tensor); auto memrefType = cast(bufferType); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp index 246555dc8c699..37509e561f401 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -404,7 +404,7 @@ bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter, // Compute the new signature. SmallVector newTypes; for (BlockArgument &bbArg : block->getArguments()) { - auto tensorType = dyn_cast(bbArg.getType()); + auto tensorType = dyn_cast(bbArg.getType()); if (!tensorType) { newTypes.push_back(bbArg.getType()); continue; diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp index 2a98203da9d7d..4c29d657d443b 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -49,29 +49,47 @@ void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) { #endif // NDEBUG } +// Note: this is a local adaptor to unify TensorType and TensorLikeType code +// paths that both work with BufferizationOptions. +static mlir::Attribute +getDefaultMemorySpace(const BufferizationOptions &options, + TensorLikeType type) { + if (auto tensorType = dyn_cast(type)) { + return *options.defaultMemorySpaceFn(tensorType); + } + return nullptr; +} + /// Return the index-th bufferized function argument type. This assumes that the /// specified argument is a tensor. If the tensor is ranked, a layout map may be /// specified by the user (as per `options.functionArgTypeConverterFn`). -static BaseMemRefType +static BufferLikeType getBufferizedFunctionArgType(FuncOp funcOp, int64_t index, const BufferizationOptions &options) { - auto tensorType = - dyn_cast(funcOp.getFunctionType().getInput(index)); - assert(tensorType && "expected TensorType"); - - BaseMemRefType memrefType = options.functionArgTypeConverterFn( - tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options); - - auto layoutAttr = funcOp.getArgAttrOfType( - index, BufferizationDialect::kBufferLayoutAttrName); - if (!layoutAttr) - return memrefType; - - auto rankedMemrefType = dyn_cast(memrefType); - assert(rankedMemrefType && "buffer layout not supported on unranked tensors"); - return MemRefType::get(rankedMemrefType.getShape(), - rankedMemrefType.getElementType(), layoutAttr, - rankedMemrefType.getMemorySpace()); + auto type = + dyn_cast(funcOp.getFunctionType().getInput(index)); + assert(type && "expected TensorLikeType"); + + // Note: For builtin tensors there is additional logic related to layout. + if (auto tensorType = dyn_cast(type)) { + BufferLikeType memrefType = options.functionArgTypeConverterFn( + type, *options.defaultMemorySpaceFn(tensorType), funcOp, options); + + auto layoutAttr = funcOp.getArgAttrOfType( + index, BufferizationDialect::kBufferLayoutAttrName); + if (!layoutAttr) + return memrefType; + + auto rankedMemrefType = dyn_cast(memrefType); + assert(rankedMemrefType && + "buffer layout not supported on unranked tensors"); + return cast(MemRefType::get( + rankedMemrefType.getShape(), rankedMemrefType.getElementType(), + layoutAttr, rankedMemrefType.getMemorySpace())); + } + + return options.functionArgTypeConverterFn(type, /*memSpace=*/nullptr, funcOp, + options); } /// Return the FuncOp called by `callOp`. @@ -227,13 +245,13 @@ struct CallOpInterface FunctionType funcType = funcOp.getFunctionType(); Type resultType = funcType.getResult(cast(value).getResultNumber()); - if (auto bufferizedType = dyn_cast(resultType)) - return cast(bufferizedType); + if (auto bufferizedType = dyn_cast(resultType)) + return bufferizedType; // Otherwise, call the type converter to compute the bufferized type. - auto tensorType = cast(resultType); + auto tensorType = cast(resultType); return cast(options.functionArgTypeConverterFn( - tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, + tensorType, getDefaultMemorySpace(options, tensorType), funcOp, options)); } @@ -248,7 +266,7 @@ struct CallOpInterface SmallVector resultTypes; for (Value result : callOp.getResults()) { Type returnType = result.getType(); - if (!isa(returnType)) { + if (!isa(returnType)) { // Non-tensor values are returned. resultTypes.push_back(returnType); continue; @@ -272,7 +290,7 @@ struct CallOpInterface for (OpOperand &opOperand : callOp->getOpOperands()) { // Non-tensor operands are just copied. - if (!isa(opOperand.get().getType())) { + if (!isa(opOperand.get().getType())) { newOperands.push_back(opOperand.get()); continue; } @@ -285,8 +303,8 @@ struct CallOpInterface Value buffer = *maybeBuffer; // Caller / callee type mismatch is handled with castOrReallocMemRefValue. - auto memRefType = funcType.getInput(opOperand.getOperandNumber()); - if (!isa(memRefType)) { + auto bufferType = funcType.getInput(opOperand.getOperandNumber()); + if (!isa(bufferType)) { // The called function was not bufferized yet. This can happen when // there cycles in the function call graph. Compute the bufferized // result type. @@ -296,7 +314,7 @@ struct CallOpInterface state); if (failed(maybeBufferType)) return failure(); - memRefType = *maybeBufferType; + bufferType = *maybeBufferType; } // Since we don't yet have a clear layout story, to_buffer may @@ -305,8 +323,8 @@ struct CallOpInterface // that will either canonicalize away or fail compilation until we can do // something better. Insert a reallocation + copy if it cannot be // statically guaranteed that a direct cast would be valid. - if (buffer.getType() != memRefType) { - auto memrefDstType = dyn_cast(memRefType); + if (buffer.getType() != bufferType) { + auto memrefDstType = dyn_cast(bufferType); assert(memrefDstType && "buffer layout not supported on unranked tensors"); FailureOr replacement = bufferization::castOrReallocMemRefValue( @@ -369,7 +387,7 @@ struct FuncOpInterface static bool supportsUnstructuredControlFlow() { return true; } bool hasTensorSemantics(Operation *op) const { - auto isaTensor = llvm::IsaPred; + auto isaTensor = llvm::IsaPred; // A function has tensor semantics if it has tensor arguments/results. auto funcOp = cast(op); @@ -405,8 +423,8 @@ struct FuncOpInterface // Function arguments are special. if (bbArg.getOwner() == &funcOp.getBody().front()) - return cast( - getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options)); + return getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), + options); return OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel:: getBufferType(op, value, options, state, invocationStack); @@ -429,7 +447,7 @@ struct FuncOpInterface SmallVector argTypes; for (const auto &it : llvm::enumerate(funcType.getInputs())) { Type argType = it.value(); - if (isa(argType)) { + if (isa(argType)) { argTypes.push_back( getBufferizedFunctionArgType(funcOp, it.index(), options)); continue; @@ -440,9 +458,9 @@ struct FuncOpInterface // Compute the result types. SmallVector retTypes; for (Type resultType : funcType.getResults()) { - if (auto tensorType = dyn_cast(resultType)) { - BaseMemRefType resultType = options.functionArgTypeConverterFn( - tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, + if (auto tensorType = dyn_cast(resultType)) { + BufferLikeType resultType = options.functionArgTypeConverterFn( + tensorType, getDefaultMemorySpace(options, tensorType), funcOp, options); retTypes.push_back(resultType); continue; @@ -472,7 +490,7 @@ struct FuncOpInterface SmallVector returnValues; for (auto [returnVal, bufferizedType] : llvm::zip_equal(returnOp->getOperands(), retTypes)) { - auto tensorType = dyn_cast(returnVal.getType()); + auto tensorType = dyn_cast(returnVal.getType()); rewriter.setInsertionPoint(returnOp); // If not a tensor type just forward it. diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt index d987b72e98354..c58ff5673bf08 100644 --- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt @@ -28,6 +28,7 @@ add_mlir_dialect_library(MLIRLLVMDialect Core LINK_LIBS PUBLIC + MLIRTransformUtils MLIRCallInterfaces MLIRControlFlowInterfaces MLIRDataLayoutInterfaces diff --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp index e23a0d6aba825..e8fc6784b0be5 100644 --- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp @@ -191,7 +191,8 @@ LogicalResult verifyQuantizationOp(Operation *op, QuantizedType quantizedType, void QuantDialect::initialize() { addTypes(); + UniformQuantizedPerAxisType, UniformQuantizedSubChannelType, + QuantileQuantizedType, QuantileQuantizedPerAxisType>(); addOperations< #define GET_OP_LIST #include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc" diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp index 8122c4db684e5..d3968eb152894 100644 --- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp @@ -22,7 +22,8 @@ namespace { // Return the minimum scale representable in a given float type double getMinScale(Type expressedType) { auto floatType = cast(expressedType); - return APFloat::getSmallest(floatType.getFloatSemantics()).convertToDouble(); + return APFloat::getLargest(floatType.getFloatSemantics(), /*Negative=*/true) + .convertToDouble(); } // Return the maximum scale representable in a given float type @@ -46,28 +47,42 @@ QuantizedType::verifyInvariants(function_ref emitError, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax) { - // Verify that the storage type is integral. - // This restriction may be lifted at some point in favor of using bf16 - // or f16 as exact representations on hardware where that is advantageous. - auto intStorageType = llvm::dyn_cast(storageType); - if (!intStorageType) - return emitError() << "storage type must be integral"; - unsigned integralWidth = intStorageType.getWidth(); - - // Verify storage width. - if (integralWidth == 0 || integralWidth > MaxStorageBits) - return emitError() << "illegal storage type size: " << integralWidth; - // Verify storageTypeMin and storageTypeMax. bool isSigned = (flags & QuantizationFlags::Signed) == QuantizationFlags::Signed; - int64_t defaultIntegerMin = - getDefaultMinimumForInteger(isSigned, integralWidth); - int64_t defaultIntegerMax = - getDefaultMaximumForInteger(isSigned, integralWidth); - if (storageTypeMax - storageTypeMin <= 0 || - storageTypeMin < defaultIntegerMin || - storageTypeMax > defaultIntegerMax) { + + // Integral storage type width checks + if (mlir::isa(storageType)) { + unsigned integralWidth = + llvm::dyn_cast(storageType).getWidth(); + + if (integralWidth == 0 || integralWidth > MaxStorageBits) + return emitError() << "illegal storage type size: " << integralWidth; + } + + int64_t defaultMin, defaultMax; + if (mlir::isa(storageType)) { + const auto width = llvm::dyn_cast(storageType).getWidth(); + defaultMin = QuantizedType::getDefaultMinimumForInteger(isSigned, width); + defaultMax = QuantizedType::getDefaultMaximumForInteger(isSigned, width); + } else if (mlir::isa(storageType)) { + defaultMin = QuantizedType::getDefaultMinimumForF8E5M2(); + defaultMax = QuantizedType::getDefaultMaximumForF8E5M2(); + } else if (mlir::isa(storageType)) { + defaultMin = QuantizedType::getDefaultMinimumForF8E4M3FN(); + defaultMax = QuantizedType::getDefaultMaximumForF8E4M3FN(); + } else if (mlir::isa(storageType)) { + defaultMin = QuantizedType::getDefaultMinimumForF4E2M1FN(); + defaultMax = QuantizedType::getDefaultMaximumForF4E2M1FN(); + } else { + return emitError() + << "illegal storage type, supported types are: integral " + "types, Float8E4M3FNType, Float8E5M2Type and Float4E2M1FNType "; + } + + // Verify storageTypeMin and storageTypeMax. + if (storageTypeMax - storageTypeMin <= 0 || storageTypeMin < defaultMin || + storageTypeMax > defaultMax) { return emitError() << "illegal storage min and storage max: (" << storageTypeMin << ":" << storageTypeMax << ")"; } @@ -319,13 +334,18 @@ LogicalResult UniformQuantizedType::verifyInvariants( // Verify scale. double minScale = getMinScale(expressedType); double maxScale = getMaxScale(expressedType); - if (scale < minScale || scale > maxScale) + if (scale < minScale || scale > maxScale || std::isnan(scale)) return emitError() << "scale out of expressed type range [" << minScale << ", " << maxScale << "]"; return success(); } +bool UniformQuantizedType::classof(mlir::Type type) { + return type.getTypeID() == mlir::TypeID::get() || + type.getTypeID() == mlir::TypeID::get(); +} + double UniformQuantizedType::getScale() const { return getImpl()->scale; } int64_t UniformQuantizedType::getZeroPoint() const { @@ -383,7 +403,7 @@ LogicalResult UniformQuantizedPerAxisType::verifyInvariants( double minScale = getMinScale(expressedType); double maxScale = getMaxScale(expressedType); for (double scale : scales) { - if (scale < minScale || scale > maxScale) + if (scale < minScale || scale > maxScale || std::isnan(scale)) return emitError() << "scale out of expressed type range [" << minScale << ", " << maxScale << "]"; } @@ -395,6 +415,11 @@ LogicalResult UniformQuantizedPerAxisType::verifyInvariants( return success(); } +bool UniformQuantizedPerAxisType::classof(mlir::Type type) { + return type.getTypeID() == mlir::TypeID::get() || + type.getTypeID() == mlir::TypeID::get(); +} + ArrayRef UniformQuantizedPerAxisType::getScales() const { return getImpl()->getScales(); } @@ -407,6 +432,182 @@ int32_t UniformQuantizedPerAxisType::getQuantizedDimension() const { return getImpl()->quantizedDimension; } +QuantileQuantizedType +QuantileQuantizedType::get(unsigned flags, Type storageType, Type quantileType, + Type expressedType, ArrayRef quantiles, + double scale, int64_t zeroPoint, + int64_t storageTypeMin, int64_t storageTypeMax) { + return Base::get(storageType.getContext(), flags, storageType, quantileType, + expressedType, quantiles, scale, zeroPoint, storageTypeMin, + storageTypeMax); +} + +QuantileQuantizedType QuantileQuantizedType::getChecked( + function_ref emitError, unsigned flags, + Type storageType, Type quantileType, Type expressedType, + ArrayRef quantiles, double scale, int64_t zeroPoint, + int64_t storageTypeMin, int64_t storageTypeMax) { + return Base::getChecked(emitError, storageType.getContext(), flags, + storageType, quantileType, expressedType, quantiles, + scale, zeroPoint, storageTypeMin, storageTypeMax); +} +LogicalResult QuantileQuantizedType::verifyInvariants( + function_ref emitError, unsigned flags, + Type storageType, Type quantileType, Type expressedType, + ArrayRef quantiles, double scale, int64_t zeroPoint, + int64_t storageTypeMin, int64_t storageTypeMax) { + if (failed(UniformQuantizedType::verifyInvariants( + emitError, flags, storageType, expressedType, scale, zeroPoint, + storageTypeMin, storageTypeMax))) { + return failure(); + } + + unsigned typeWidth{}; + if (mlir::isa(storageType)) { + typeWidth = llvm::dyn_cast(storageType).getWidth(); + } else if (mlir::isa( + storageType)) { + // Float8E5M2Type, Float8E4M3FNType and Float4E2M1FNType derive from + // FloatType. + typeWidth = llvm::dyn_cast(storageType).getWidth(); + } else { + return emitError() + << "illegal storage type, supported types are: integral " + "types, Float8E4M3FNType, Float8E5M2Type and Float4E2M1FNType "; + } + + const size_t storageTypeRange = storageTypeMax - storageTypeMin + 1; + const size_t typeWidthSize = 1 << typeWidth; + const size_t expectedSize = + (storageTypeRange < typeWidthSize) && !mlir::isa(storageType) + ? storageTypeRange + : typeWidthSize; + + const auto quantileArraySize = quantiles.size(); + if (quantileArraySize != expectedSize) { + return emitError() << "quantiles array size needs to be equal to " + "2^(bit_size(storageType)), or (storageTypeMax - " + "storageTypeMin + 1) when max and min differ from " + "the type limits; expected: " + << expectedSize << ", found: " << quantileArraySize; + } + + // Verify quantiles + for (double quantile : quantiles) { + if (std::isinf(quantile) || std::isnan(quantile)) { + return emitError() << "illegal quantile value: " << quantile; + } + } + + return success(); +} + +bool QuantileQuantizedType::classof(mlir::Type type) { + return type.getTypeID() == mlir::TypeID::get(); +} + +Type QuantileQuantizedType::getQuantileType() const { + return getImpl()->quantileType; +} + +unsigned QuantileQuantizedType::getQuantileTypeIntegralWidth() const { + return getImpl()->getQuantileType().getIntOrFloatBitWidth(); +} + +ArrayRef QuantileQuantizedType::getQuantiles() const { + return getImpl()->getQuantiles(); +} + +QuantileQuantizedPerAxisType QuantileQuantizedPerAxisType::get( + unsigned flags, Type storageType, Type quantileType, Type expressedType, + ArrayRef quantiles, ArrayRef scales, + ArrayRef zeroPoints, int32_t quantizedDimension, + int64_t storageTypeMin, int64_t storageTypeMax) { + return Base::get(storageType.getContext(), flags, storageType, quantileType, + expressedType, quantiles, scales, zeroPoints, + quantizedDimension, storageTypeMin, storageTypeMax); +} + +QuantileQuantizedPerAxisType QuantileQuantizedPerAxisType::getChecked( + function_ref emitError, unsigned flags, + Type storageType, Type quantileType, Type expressedType, + ArrayRef quantiles, ArrayRef scales, + ArrayRef zeroPoints, int32_t quantizedDimension, + int64_t storageTypeMin, int64_t storageTypeMax) { + return Base::getChecked(emitError, storageType.getContext(), flags, + storageType, quantileType, expressedType, quantiles, + scales, zeroPoints, quantizedDimension, + storageTypeMin, storageTypeMax); +} + +LogicalResult QuantileQuantizedPerAxisType::verifyInvariants( + function_ref emitError, unsigned flags, + Type storageType, Type quantileType, Type expressedType, + ArrayRef quantiles, ArrayRef scales, + ArrayRef zeroPoints, int32_t quantizedDimension, + int64_t storageTypeMin, int64_t storageTypeMax) { + if (failed(UniformQuantizedPerAxisType::verifyInvariants( + emitError, flags, storageType, expressedType, scales, zeroPoints, + quantizedDimension, storageTypeMin, storageTypeMax))) { + return failure(); + } + + unsigned typeWidth{}; + if (mlir::isa(storageType)) { + typeWidth = llvm::dyn_cast(storageType).getWidth(); + } else if (mlir::isa( + storageType)) { + // Float8E5M2Type, Float8E4M3FNType and Float4E2M1FNType derive from + // FloatType. + typeWidth = llvm::dyn_cast(storageType).getWidth(); + } else { + return emitError() + << "illegal storage type, supported types are: integral " + "types, Float8E4M3FNType, Float8E5M2Type and Float4E2M1FNType "; + } + + const size_t storageTypeRange = storageTypeMax - storageTypeMin + 1; + const size_t typeWidthSize = 1 << typeWidth; + const size_t expectedSize = + (storageTypeRange < typeWidthSize) && !mlir::isa(storageType) + ? storageTypeRange + : typeWidthSize; + + const auto quantileArraySize = quantiles.size(); + if (quantileArraySize != expectedSize) { + return emitError() << "quantiles array size needs to be equal to " + "2^(bit_size(storageType)), or (storageTypeMax - " + "storageTypeMin + 1) when max and min differ from " + "the type limits; expected: " + << expectedSize << ", found: " << quantileArraySize; + } + + // Verify quantiles + for (double quantile : quantiles) { + if (std::isinf(quantile) || std::isnan(quantile)) { + return emitError() << "illegal quantile value: " << quantile; + } + } + + return success(); +} + +bool QuantileQuantizedPerAxisType::classof(mlir::Type type) { + return type.getTypeID() == mlir::TypeID::get(); +} + +Type QuantileQuantizedPerAxisType::getQuantileType() const { + return getImpl()->quantileType; +} + +unsigned QuantileQuantizedPerAxisType::getQuantileTypeIntegralWidth() const { + return getImpl()->getQuantileType().getIntOrFloatBitWidth(); +} + +ArrayRef QuantileQuantizedPerAxisType::getQuantiles() const { + return getImpl()->getQuantiles(); +} + UniformQuantizedSubChannelType UniformQuantizedSubChannelType::get( unsigned flags, Type storageType, Type expressedType, DenseElementsAttr scales, DenseElementsAttr zeroPoints, diff --git a/mlir/lib/Dialect/Quant/IR/TypeDetail.h b/mlir/lib/Dialect/Quant/IR/TypeDetail.h index a43bce354c324..65f4ecdcec85e 100644 --- a/mlir/lib/Dialect/Quant/IR/TypeDetail.h +++ b/mlir/lib/Dialect/Quant/IR/TypeDetail.h @@ -253,6 +253,160 @@ struct UniformQuantizedPerAxisTypeStorage : public QuantizedTypeStorage { int32_t quantizedDimension; }; +struct QuantileQuantizedTypeStorage : public UniformQuantizedTypeStorage { + struct KeyTy : public UniformQuantizedTypeStorage::KeyTy { + KeyTy(unsigned flags, Type storageType, Type quantileType, + Type expressedType, ArrayRef quantiles, double scale, + int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax) + : UniformQuantizedTypeStorage::KeyTy(flags, storageType, expressedType, + scale, zeroPoint, storageTypeMin, + storageTypeMax), + quantileType(quantileType), quantiles(quantiles) {} + + Type quantileType; + ArrayRef quantiles; + Type getQuantileType() const { return quantileType; } + ArrayRef getQuantiles() const { return quantiles; } + + // Check for equality of two structures that share KeyTy data members + // (by name). + template + static bool genericIsEqual(const T &lhs, const U &rhs) { + return UniformQuantizedTypeStorage::KeyTy::genericIsEqual(lhs, rhs) && + lhs.getQuantileType() == rhs.getQuantileType() && + lhs.getQuantiles() == rhs.getQuantiles(); + } + + bool operator==(const KeyTy &other) const { + return genericIsEqual(*this, other); + } + + unsigned getHashValue() const { + int64_t scaleBits = llvm::bit_cast(scale); + int64_t *quantilesCast = llvm::bit_cast(quantiles.data()); + ArrayRef quantilesBits(quantilesCast, quantiles.size()); + return llvm::hash_combine( + flags, storageType, quantileType, expressedType, + llvm::hash_combine_range(quantilesBits.begin(), quantilesBits.end()), + scaleBits, zeroPoint, storageTypeMin, storageTypeMax); + } + }; + + QuantileQuantizedTypeStorage(const KeyTy &key, ArrayRef quantiles) + : UniformQuantizedTypeStorage(key), quantileType(key.getQuantileType()), + quantilesElements(quantiles.data()), + quantilesParamsSize(quantiles.size()) {} + + bool operator==(const KeyTy &key) const { + return KeyTy::genericIsEqual(*this, key); + } + + /// Construction. + static QuantileQuantizedTypeStorage * + construct(TypeStorageAllocator &allocator, const KeyTy &key) { + ArrayRef quantiles = allocator.copyInto(key.quantiles); + return new (allocator.allocate()) + QuantileQuantizedTypeStorage(key, quantiles); + } + + static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); } + + Type getQuantileType() const { return quantileType; } + + ArrayRef getQuantiles() const { + return ArrayRef(quantilesElements, quantilesParamsSize); + } + + Type quantileType; + const double *quantilesElements; + unsigned quantilesParamsSize; +}; + +struct QuantileQuantizedPerAxisTypeStorage + : public UniformQuantizedPerAxisTypeStorage { + struct KeyTy : public UniformQuantizedPerAxisTypeStorage::KeyTy { + KeyTy(unsigned flags, Type storageType, Type quantileType, + Type expressedType, ArrayRef quantiles, + ArrayRef scales, ArrayRef zeroPoints, + int32_t quantizedDimension, int64_t storageTypeMin, + int64_t storageTypeMax) + : UniformQuantizedPerAxisTypeStorage::KeyTy( + flags, storageType, expressedType, scales, zeroPoints, + quantizedDimension, storageTypeMin, storageTypeMax), + quantileType(quantileType), quantiles(quantiles) {} + + Type quantileType; + ArrayRef quantiles; + Type getQuantileType() const { return quantileType; } + ArrayRef getQuantiles() const { return quantiles; } + + // Check for equality of two structures that share KeyTy data members + // (by name). + template + static bool genericIsEqual(const T &lhs, const U &rhs) { + return UniformQuantizedPerAxisTypeStorage::KeyTy::genericIsEqual(lhs, + rhs) && + lhs.getQuantileType() == rhs.getQuantileType() && + lhs.getQuantiles() == rhs.getQuantiles(); + } + + bool operator==(const KeyTy &other) const { + return genericIsEqual(*this, other); + } + + unsigned getHashValue() const { + int64_t *scalesCast = llvm::bit_cast(scales.data()); + ArrayRef scalesBits(scalesCast, scales.size()); + int64_t *quantilesCast = llvm::bit_cast(quantiles.data()); + ArrayRef quantilesBits(quantilesCast, quantiles.size()); + return llvm::hash_combine( + flags, storageType, quantileType, expressedType, + llvm::hash_combine_range(quantilesBits.begin(), quantilesBits.end()), + llvm::hash_combine_range(scalesBits.begin(), scalesBits.end()), + llvm::hash_combine_range(zeroPoints.begin(), zeroPoints.end()), + storageTypeMin, storageTypeMax); + } + }; + + // We pass quantiles, scales and zeroPoints in directly rather than relying on + // KeyTy because we have to create new reallocated versions in `construct` + // below. + QuantileQuantizedPerAxisTypeStorage(const KeyTy &key, + ArrayRef quantiles, + ArrayRef scales, + ArrayRef zeroPoints) + : UniformQuantizedPerAxisTypeStorage(key, scales, zeroPoints), + quantileType(key.getQuantileType()), + quantilesElements(quantiles.data()), + quantilesParamsSize(quantiles.size()) {} + + bool operator==(const KeyTy &key) const { + return KeyTy::genericIsEqual(*this, key); + } + + /// Construction. + static QuantileQuantizedPerAxisTypeStorage * + construct(TypeStorageAllocator &allocator, const KeyTy &key) { + ArrayRef quantiles = allocator.copyInto(key.quantiles); + ArrayRef scales = allocator.copyInto(key.scales); + ArrayRef zeroPoints = allocator.copyInto(key.zeroPoints); + return new (allocator.allocate()) + QuantileQuantizedPerAxisTypeStorage(key, quantiles, scales, zeroPoints); + } + + static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); } + + Type getQuantileType() const { return quantileType; } + + ArrayRef getQuantiles() const { + return ArrayRef(quantilesElements, quantilesParamsSize); + } + + Type quantileType; + const double *quantilesElements; + unsigned quantilesParamsSize; +}; // namespace detail + struct UniformQuantizedSubChannelTypeStorage : public QuantizedTypeStorage { struct KeyTy { KeyTy(unsigned flags, Type storageType, Type expressedType, diff --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp index c6a6881b46f26..82185c531d016 100644 --- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp +++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp @@ -21,9 +21,9 @@ using namespace mlir; using namespace quant; -static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) { +static Type parseStorageType(DialectAsmParser &parser, bool &isSigned) { auto typeLoc = parser.getCurrentLocation(); - IntegerType type; + Type type; // Parse storage type (alpha_ident, integer_literal). StringRef identifier; @@ -32,20 +32,32 @@ static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) { if (result.has_value()) { if (!succeeded(*result)) return nullptr; - isSigned = !type.isUnsigned(); - storageTypeWidth = type.getWidth(); - } else if (succeeded(parser.parseKeyword(&identifier))) { - // Otherwise, this must be an unsigned integer (`u` integer-literal). - if (!identifier.consume_front("u")) { - parser.emitError(typeLoc, "illegal storage type prefix"); + if (auto intType = llvm::dyn_cast(type)) { + isSigned = !intType.isUnsigned(); + storageTypeWidth = intType.getWidth(); + } else if (mlir::isa( + type)) { + storageTypeWidth = llvm::dyn_cast(type).getWidth(); + isSigned = true; + } else { + parser.emitError(typeLoc, "illegal quantized storage type alias"); return nullptr; } - if (identifier.getAsInteger(10, storageTypeWidth)) { - parser.emitError(typeLoc, "expected storage type width"); + } else if (succeeded(parser.parseKeyword(&identifier))) { + // Otherwise, this must be an unsigned integer (`u` integer-literal) + if (identifier.consume_front("u")) { + if (identifier.getAsInteger(10, storageTypeWidth)) { + parser.emitError(typeLoc, "expected storage type width"); + return nullptr; + } + isSigned = false; + type = parser.getBuilder().getIntegerType(storageTypeWidth); + + } else { + parser.emitError(typeLoc, "illegal quantized storage type alias"); return nullptr; } - isSigned = false; - type = parser.getBuilder().getIntegerType(storageTypeWidth); + } else { return nullptr; } @@ -60,35 +72,96 @@ static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) { return type; } -static ParseResult parseStorageRange(DialectAsmParser &parser, - IntegerType storageType, bool isSigned, - int64_t &storageTypeMin, +static Type parseQuantileType(DialectAsmParser &parser) { + auto typeLoc = parser.getCurrentLocation(); + Type type; + + // Parse storage type (alpha_ident, integer_literal). + StringRef identifier; + unsigned storageTypeWidth = 0; + OptionalParseResult result = parser.parseOptionalType(type); + if (result.has_value()) { + if (!succeeded(*result)) + return nullptr; + + if (!mlir::isa(type) && !mlir::isa(type)) { + parser.emitError(typeLoc, "illegal quantile type alias"); + return nullptr; + } + } else if (succeeded(parser.parseKeyword(&identifier))) { + // Otherwise, this must be an unsigned integer (`u` integer-literal) + if (identifier.consume_front("u")) { + if (identifier.getAsInteger(10, storageTypeWidth)) { + parser.emitError(typeLoc, "expected quantile type width"); + return nullptr; + } + constexpr bool isSigned = false; + type = parser.getBuilder().getIntegerType(storageTypeWidth, isSigned); + + } else { + parser.emitError(typeLoc, "illegal quantile type alias"); + return nullptr; + } + } else { + return nullptr; + } + + return type; +} + +static ParseResult +checkStorageRange(DialectAsmParser &parser, int64_t storageTypeMin, + int64_t storageTypeMax, int64_t defaultStorageTypeMin, + int64_t defaultStorageTypeMax, SMLoc minLoc, SMLoc maxLoc) { + if (storageTypeMin < defaultStorageTypeMin) { + return parser.emitError(minLoc, "illegal storage type minimum: ") + << storageTypeMin; + } + if (storageTypeMax > defaultStorageTypeMax) { + return parser.emitError(maxLoc, "illegal storage type maximum: ") + << storageTypeMax; + } + return success(); +} + +static ParseResult parseStorageRange(DialectAsmParser &parser, Type storageType, + bool isSigned, int64_t &storageTypeMin, int64_t &storageTypeMax) { - int64_t defaultIntegerMin = QuantizedType::getDefaultMinimumForInteger( - isSigned, storageType.getWidth()); - int64_t defaultIntegerMax = QuantizedType::getDefaultMaximumForInteger( - isSigned, storageType.getWidth()); + int64_t defaultMin, defaultMax; + if (mlir::isa(storageType)) { + const auto width = llvm::dyn_cast(storageType).getWidth(); + defaultMin = QuantizedType::getDefaultMinimumForInteger(isSigned, width); + defaultMax = QuantizedType::getDefaultMaximumForInteger(isSigned, width); + } else if (mlir::isa(storageType)) { + defaultMin = QuantizedType::getDefaultMinimumForF8E5M2(); + defaultMax = QuantizedType::getDefaultMaximumForF8E5M2(); + } else if (mlir::isa(storageType)) { + defaultMin = QuantizedType::getDefaultMinimumForF8E4M3FN(); + defaultMax = QuantizedType::getDefaultMaximumForF8E4M3FN(); + } else if (mlir::isa(storageType)) { + defaultMin = QuantizedType::getDefaultMinimumForF4E2M1FN(); + defaultMax = QuantizedType::getDefaultMaximumForF4E2M1FN(); + } else { + defaultMin = std::numeric_limits::max(); + defaultMax = std::numeric_limits::min(); + } + if (failed(parser.parseOptionalLess())) { - storageTypeMin = defaultIntegerMin; - storageTypeMax = defaultIntegerMax; + storageTypeMin = defaultMin; + storageTypeMax = defaultMax; return success(); } // Explicit storage min and storage max. + // F8 and F4 min and max values are integers, so parseInteger() is used. SMLoc minLoc = parser.getCurrentLocation(), maxLoc; if (parser.parseInteger(storageTypeMin) || parser.parseColon() || parser.getCurrentLocation(&maxLoc) || parser.parseInteger(storageTypeMax) || parser.parseGreater()) return failure(); - if (storageTypeMin < defaultIntegerMin) { - return parser.emitError(minLoc, "illegal storage type minimum: ") - << storageTypeMin; - } - if (storageTypeMax > defaultIntegerMax) { - return parser.emitError(maxLoc, "illegal storage type maximum: ") - << storageTypeMax; - } - return success(); + + return checkStorageRange(parser, storageTypeMin, storageTypeMax, defaultMin, + defaultMax, minLoc, maxLoc); } static FloatType parseExpressedTypeAndRange(DialectAsmParser &parser, @@ -118,7 +191,7 @@ static FloatType parseExpressedTypeAndRange(DialectAsmParser &parser, /// storage-type ::= (`i` | `u`) integer-literal /// expressed-type-spec ::= `:` `f` integer-literal static Type parseAnyType(DialectAsmParser &parser) { - IntegerType storageType; + Type storageType; FloatType expressedType; unsigned typeFlags = 0; int64_t storageTypeMin; @@ -167,10 +240,11 @@ isScaleInExpressedTypeRange(function_ref emitError, Type expressedType, double scale) { auto floatType = cast(expressedType); double minScale = - APFloat::getSmallest(floatType.getFloatSemantics()).convertToDouble(); + APFloat::getLargest(floatType.getFloatSemantics(), /*Negative=*/true) + .convertToDouble(); double maxScale = APFloat::getLargest(floatType.getFloatSemantics()).convertToDouble(); - if (scale < minScale || scale > maxScale) + if (scale < minScale || scale > maxScale || std::isnan(scale)) return emitError() << "scale " << scale << " out of expressed type range [" << minScale << ", " << maxScale << "]"; return success(); @@ -299,7 +373,7 @@ parseQuantParamListUntilRBrace(DialectAsmParser &parser, Type expressedType, return success(); } -/// Parses a UniformQuantizedType. +/// Parses a UniformQuantizedType or a QuantileQuantizedType. /// /// uniform_type ::= uniform_per_layer /// | uniform_per_axis @@ -312,7 +386,8 @@ parseQuantParamListUntilRBrace(DialectAsmParser &parser, Type expressedType, /// block-size-info `,` scale-zero-tensor `>` /// storage-spec ::= storage-type (`<` storage-range `>`)? /// storage-range ::= integer-literal `:` integer-literal -/// storage-type ::= (`i` | `u`) integer-literal +/// storage-type ::= (`i` | `u`) integer-literal | `f8E5M2` | `f8E4M3FN` +// | `f4E2M1FN` /// expressed-type-spec ::= `:` `f` integer-literal /// axis-spec ::= `:` integer-literal /// scale-zero ::= scale (`:` zero-point)? @@ -326,8 +401,38 @@ parseQuantParamListUntilRBrace(DialectAsmParser &parser, Type expressedType, /// scale-zero-dense-exp ::= `{` /// scale-zero-tensor (`,` scale-zero-tensor)* /// `}` -static Type parseUniformType(DialectAsmParser &parser) { - IntegerType storageType; +/// +/// quantile_type ::= quantile_per_layer +/// | quantile_per_axis +/// quantile_per_layer ::= `quantile<` storage-spec quantile-type-spec +/// expressed-type-spec `,` quantiles-list `,` +/// scale-zero `>` +/// quantile_per_axis ::= `quantile<` storage-spec quantile-type-spec +/// expressed-type-spec axis-spec `,` quantiles-list +/// scale-zero-list `>` +/// storage-spec ::= storage-type (`<` storage-range `>`)? +/// storage-range ::= integer-literal `:` integer-literal +/// storage-type ::= (`i` | `u`) integer-literal | `f8E5M2` | `f8E4M3FN` +// | `f4E2M1FN` +/// quantile-type-spec ::= `:` ((`i` | `u` | `f`) integer-literal | `f8E5M2` +// | `f8E4M3FN` | `f4E2M1FN`) +/// expressed-type-spec ::= `:` `f` integer-literal +/// axis-spec ::= `:` integer-literal +/// quantiles-list ::= `{` quantile (`,` quantile)* `}` +/// scale-zero ::= scale (`:` zero-point)? +/// scale ::= float-literal +/// zero-point ::= integer-literal +/// scale-zero-list ::= scale-zero (`,` scale-zero)* +/// block-size-info ::= `{` `}` | `{` axis-block `:` (`,` axis-block)* `}` +/// axis-block ::= axis-spec `:` block-size-spec +/// block-size-spec ::= integer-literal +/// scale-zero-tensor ::= scale-zero-dense-exp | scale-zero-list +/// scale-zero-dense-exp ::= `{` +/// scale-zero-tensor (`,` scale-zero-tensor)* +/// `}` +static Type parseUniformType(DialectAsmParser &parser, bool isQuantile) { + Type storageType; + Type quantileType; FloatType expressedType; unsigned typeFlags = 0; int64_t storageTypeMin; @@ -336,6 +441,7 @@ static Type parseUniformType(DialectAsmParser &parser) { bool isSubChannel = false; SmallVector quantizedDimensions; SmallVector blockSizes; + SmallVector quantiles; SmallVector scales; SmallVector zeroPoints; @@ -360,6 +466,17 @@ static Type parseUniformType(DialectAsmParser &parser) { return nullptr; } + // Quantile type. + if (isQuantile) { + if (parser.parseColon()) { + return nullptr; + } + quantileType = parseQuantileType(parser); + if (!quantileType) { + return nullptr; + } + } + // Expressed type. if (parser.parseColon() || parser.parseType(expressedType)) { return nullptr; @@ -388,6 +505,28 @@ static Type parseUniformType(DialectAsmParser &parser) { return nullptr; } + // Quantile list. + if (isQuantile) { + if (parser.parseLBrace()) { + return nullptr; + } + + do { + quantiles.emplace_back(); + if (parser.parseFloat(quantiles.back())) { + return nullptr; + } + } while (succeeded(parser.parseOptionalComma())); + + if (parser.parseRBrace()) { + return nullptr; + } + + if (parser.parseColon()) { + return nullptr; + } + } + // Quantization parameter (scales/zeroPoints) specification. bool isPerTensor = !isPerAxis && !isSubChannel; SmallVector dims; @@ -411,6 +550,20 @@ static Type parseUniformType(DialectAsmParser &parser) { return nullptr; } + if (isQuantile) { + if (isPerAxis) { + return parser.getChecked( + typeFlags, storageType, quantileType, expressedType, quantiles, + scales, zeroPoints, quantizedDimensions[0], storageTypeMin, + storageTypeMax); + } + + assert(!isSubChannel && "Sub-channel quantile types are not supported"); + return parser.getChecked( + typeFlags, storageType, quantileType, expressedType, quantiles, + scales.front(), zeroPoints.front(), storageTypeMin, storageTypeMax); + } + if (isPerAxis) { return parser.getChecked( typeFlags, storageType, expressedType, scales, zeroPoints, @@ -478,7 +631,9 @@ Type QuantDialect::parseType(DialectAsmParser &parser) const { return nullptr; if (typeNameSpelling == "uniform") - return parseUniformType(parser); + return parseUniformType(parser, false); + if (typeNameSpelling == "quantile") + return parseUniformType(parser, true); if (typeNameSpelling == "any") return parseAnyType(parser); if (typeNameSpelling == "calibrated") @@ -493,19 +648,68 @@ static void printStorageType(QuantizedType type, DialectAsmPrinter &out) { // storage type unsigned storageWidth = type.getStorageTypeIntegralWidth(); bool isSigned = type.isSigned(); - if (isSigned) { + if (mlir::isa(type.getStorageType())) { + out << "f8E5M2"; + } else if (mlir::isa(type.getStorageType())) { + out << "f8E4M3FN"; + } else if (mlir::isa(type.getStorageType())) { + out << "f4E2M1FN"; + } else if (isSigned) { out << "i" << storageWidth; } else { out << "u" << storageWidth; } // storageTypeMin and storageTypeMax if not default. - if (type.hasStorageTypeBounds()) { + int64_t defaultMin = + mlir::isa(type.getStorageType()) + ? QuantizedType::getDefaultMinimumForInteger(isSigned, storageWidth) + : mlir::isa(type.getStorageType()) + ? QuantizedType::getDefaultMinimumForF8E5M2() + : mlir::isa(type.getStorageType()) + ? QuantizedType::getDefaultMinimumForF8E4M3FN() + : mlir::isa(type.getStorageType()) + ? QuantizedType::getDefaultMinimumForF4E2M1FN() + : std::numeric_limits::max(); + + int64_t defaultMax = + mlir::isa(type.getStorageType()) + ? QuantizedType::getDefaultMaximumForInteger(isSigned, storageWidth) + : mlir::isa(type.getStorageType()) + ? QuantizedType::getDefaultMaximumForF8E5M2() + : mlir::isa(type.getStorageType()) + ? QuantizedType::getDefaultMaximumForF8E4M3FN() + : mlir::isa(type.getStorageType()) + ? QuantizedType::getDefaultMaximumForF4E2M1FN() + : std::numeric_limits::min(); + + if (defaultMin != type.getStorageTypeMin() || + defaultMax != type.getStorageTypeMax()) { out << "<" << type.getStorageTypeMin() << ":" << type.getStorageTypeMax() << ">"; } } +static void printQuantileType(Type quantileType, DialectAsmPrinter &out) { + if (auto intType = llvm::dyn_cast(quantileType)) { + const unsigned storageTypeWidth = intType.getWidth(); + if (intType.isUnsigned()) { + out << ":u" << storageTypeWidth; + } else { + out << ":i" << storageTypeWidth; + } + } else if (mlir::isa(quantileType)) { + out << ":f8E5M2"; + } else if (mlir::isa(quantileType)) { + out << ":f8E4M3FN"; + } else if (mlir::isa(quantileType)) { + out << ":f4E2M1FN"; + } else { + // Float types + out << ":" << quantileType; + } +} + static void printQuantParams(double scale, int64_t zeroPoint, DialectAsmPrinter &out) { out << scale; @@ -638,6 +842,56 @@ printUniformQuantizedSubChannelType(UniformQuantizedSubChannelType type, out << ">"; } +/// Helper that prints a QuantileQuantizedType. +static void printQuantileQuantizedType(QuantileQuantizedType type, + DialectAsmPrinter &out) { + out << "quantile<"; + printStorageType(type, out); + printQuantileType(type.getQuantileType(), out); + out << ":" << type.getExpressedType() << ", "; + + // scheme specific parameters + ArrayRef quantiles = type.getQuantiles(); + out << "{"; + llvm::interleave( + llvm::seq(0, quantiles.size()), out, + [&](size_t index) { out << quantiles[index]; }, ","); + out << "}:"; + + printQuantParams(type.getScale(), type.getZeroPoint(), out); + out << ">"; +} + +/// Helper that prints a QuantileQuantizedPerAxisType. +static void printQuantileQuantizedPerAxisType(QuantileQuantizedPerAxisType type, + DialectAsmPrinter &out) { + out << "quantile<"; + printStorageType(type, out); + printQuantileType(type.getQuantileType(), out); + out << ":" << type.getExpressedType() << ":"; + out << type.getQuantizedDimension(); + out << ", "; + + // scheme specific parameters + ArrayRef quantiles = type.getQuantiles(); + out << "{"; + llvm::interleave( + llvm::seq(0, quantiles.size()), out, + [&](size_t index) { out << quantiles[index]; }, ","); + out << "}:"; + + ArrayRef scales = type.getScales(); + ArrayRef zeroPoints = type.getZeroPoints(); + out << "{"; + llvm::interleave( + llvm::seq(0, scales.size()), out, + [&](size_t index) { + printQuantParams(scales[index], zeroPoints[index], out); + }, + ","); + out << "}>"; +} + /// Helper that prints a CalibratedQuantizedType. static void printCalibratedQuantizedType(CalibratedQuantizedType type, DialectAsmPrinter &out) { @@ -650,15 +904,20 @@ static void printCalibratedQuantizedType(CalibratedQuantizedType type, void QuantDialect::printType(Type type, DialectAsmPrinter &os) const { if (auto anyType = llvm::dyn_cast(type)) printAnyQuantizedType(anyType, os); + else if (auto uniformType = llvm::dyn_cast(type)) + printQuantileQuantizedType(uniformType, os); + else if (auto perAxisType = + llvm::dyn_cast(type)) + printQuantileQuantizedPerAxisType(perAxisType, os); else if (auto uniformType = llvm::dyn_cast(type)) printUniformQuantizedType(uniformType, os); else if (auto perAxisType = llvm::dyn_cast(type)) printUniformQuantizedPerAxisType(perAxisType, os); + else if (auto calibratedType = llvm::dyn_cast(type)) + printCalibratedQuantizedType(calibratedType, os); else if (auto perAxisType = llvm::dyn_cast(type)) printUniformQuantizedSubChannelType(perAxisType, os); - else if (auto calibratedType = llvm::dyn_cast(type)) - printCalibratedQuantizedType(calibratedType, os); else llvm_unreachable("Unhandled quantized type"); } diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp index 09326242eec2a..91ce8bb0c0ea2 100644 --- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp @@ -185,6 +185,30 @@ class ConvertWhileOpTypes }; } // namespace +namespace { +class ConvertIndexSwitchOpTypes + : public Structural1ToNConversionPattern { +public: + using Structural1ToNConversionPattern::Structural1ToNConversionPattern; + + std::optional + convertSourceOp(IndexSwitchOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + TypeRange dstTypes) const { + auto newOp = + IndexSwitchOp::create(rewriter, op.getLoc(), dstTypes, op.getArg(), + op.getCases(), op.getNumCases()); + + for (unsigned i = 0u; i < op.getNumRegions(); i++) { + auto &dstRegion = newOp.getRegion(i); + rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end()); + } + return newOp; + } +}; +} // namespace + namespace { // When the result types of a ForOp/IfOp get changed, the operand types of the // corresponding yield op need to be changed. In order to trigger the @@ -219,19 +243,19 @@ class ConvertConditionOpTypes : public OpConversionPattern { void mlir::scf::populateSCFStructuralTypeConversions( const TypeConverter &typeConverter, RewritePatternSet &patterns) { patterns.add( - typeConverter, patterns.getContext()); + ConvertWhileOpTypes, ConvertConditionOpTypes, + ConvertIndexSwitchOpTypes>(typeConverter, patterns.getContext()); } void mlir::scf::populateSCFStructuralTypeConversionTarget( const TypeConverter &typeConverter, ConversionTarget &target) { - target.addDynamicallyLegalOp([&](Operation *op) { + target.addDynamicallyLegalOp([&](Operation *op) { return typeConverter.isLegal(op->getResultTypes()); }); target.addDynamicallyLegalOp([&](scf::YieldOp op) { // We only have conversions for a subset of ops that use scf.yield // terminators. - if (!isa(op->getParentOp())) + if (!isa(op->getParentOp())) return true; return typeConverter.isLegal(op.getOperandTypes()); }); diff --git a/mlir/lib/Dialect/UB/IR/CMakeLists.txt b/mlir/lib/Dialect/UB/IR/CMakeLists.txt index 84125ea0b5718..ef9cf277b8ba5 100644 --- a/mlir/lib/Dialect/UB/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/UB/IR/CMakeLists.txt @@ -9,5 +9,6 @@ add_mlir_dialect_library(MLIRUBDialect MLIRUBOpsInterfacesIncGen LINK_LIBS PUBLIC + MLIRTransformUtils MLIRIR ) diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index f95ad290a1981..67b000bc3a5b5 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -418,6 +418,21 @@ class AsmPrinter::Impl { /// Returns the output stream of the printer. raw_ostream &getStream() { return os; } + /// Print a newline and indent the printer to the start of the current + /// operation/attribute/type. + /// Note: For attributes and types this method should only be used in + /// custom dialects. Usage in MLIR dialects is disallowed. + void printNewline() { + os << newLine; + os.indent(currentIndent); + } + + /// Increase indentation. + void increaseIndent() { currentIndent += indentWidth; } + + /// Decrease indentation. + void decreaseIndent() { currentIndent -= indentWidth; } + template inline void interleaveComma(const Container &c, UnaryFunctor eachFn) const { llvm::interleaveComma(c, os, eachFn); @@ -532,6 +547,12 @@ class AsmPrinter::Impl { /// A tracker for the number of new lines emitted during printing. NewLineCounter newLine; + + /// The number of spaces used as an indent. + const static unsigned indentWidth = 2; + + /// This is the current indentation level for nested structures. + unsigned currentIndent = 0; }; } // namespace mlir @@ -1004,6 +1025,9 @@ class DummyAliasDialectAsmPrinter : public DialectAsmPrinter { /// The following are hooks of `DialectAsmPrinter` that are not necessary for /// determining potential aliases. + void printNewline() override {} + void increaseIndent() override {} + void decreaseIndent() override {} void printFloat(const APFloat &) override {} void printKeywordOrString(StringRef) override {} void printString(StringRef) override {} @@ -2893,6 +2917,13 @@ void AsmPrinter::Impl::printDialectAttribute(Attribute attr) { { llvm::raw_string_ostream attrNameStr(attrName); Impl subPrinter(attrNameStr, state); + + // The values of currentIndent and newLine are assigned to the created + // subprinter, so that the indent level and number of printed lines can be + // tracked. + subPrinter.currentIndent = currentIndent; + subPrinter.newLine = newLine; + DialectAsmPrinter printer(subPrinter); dialect.printAttribute(attr, printer); } @@ -2907,6 +2938,13 @@ void AsmPrinter::Impl::printDialectType(Type type) { { llvm::raw_string_ostream typeNameStr(typeName); Impl subPrinter(typeNameStr, state); + + // The values of currentIndent and newLine are assigned to the created + // subprinter, so that the indent level and number of printed lines can be + // tracked. + subPrinter.currentIndent = currentIndent; + subPrinter.newLine = newLine; + DialectAsmPrinter printer(subPrinter); dialect.printType(type, printer); } @@ -2947,6 +2985,21 @@ raw_ostream &AsmPrinter::getStream() const { return impl->getStream(); } +void AsmPrinter::printNewline() { + assert(impl && "expected AsmPrinter::printNewLine to be overriden"); + impl->printNewline(); +} + +void AsmPrinter::increaseIndent() { + assert(impl && "expected AsmPrinter::increaseIndent to be overriden"); + impl->increaseIndent(); +} + +void AsmPrinter::decreaseIndent() { + assert(impl && "expected AsmPrinter::decreaseIndent to be overriden"); + impl->decreaseIndent(); +} + /// Print the given floating point value in a stablized form. void AsmPrinter::printFloat(const APFloat &value) { assert(impl && "expected AsmPrinter::printFloat to be overriden"); @@ -3277,19 +3330,6 @@ class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter { printTrailingLocation(loc); } - /// Print a newline and indent the printer to the start of the current - /// operation. - void printNewline() override { - os << newLine; - os.indent(currentIndent); - } - - /// Increase indentation. - void increaseIndent() override { currentIndent += indentWidth; } - - /// Decrease indentation. - void decreaseIndent() override { currentIndent -= indentWidth; } - /// Print a block argument in the usual format of: /// %ssaName : type {attr1=42} loc("here") /// where location printing is controlled by the standard internal option. @@ -3415,12 +3455,6 @@ class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter { // top-level we start with "builtin" as the default, so that the top-level // `module` operation prints as-is. SmallVector defaultDialectStack{"builtin"}; - - /// The number of spaces used for indenting nested operations. - const static unsigned indentWidth = 2; - - // This is the current indentation level for nested structures. - unsigned currentIndent = 0; }; } // namespace diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp index 1604ebba190a1..e2624cb7f3eb2 100644 --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -486,7 +486,8 @@ bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace) { return true; // Supported built-in attributes. - if (llvm::isa(memorySpace)) + if (llvm::isa(memorySpace)) return true; // Allow custom dialect attributes. diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp index aaa4d5617eb4f..fa028b905f960 100644 --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -56,8 +56,10 @@ collectValidReferencesFor(Operation *symbol, StringAttr symbolName, StringAttr symbolNameId = StringAttr::get(ctx, SymbolTable::getSymbolAttrName()); do { - // Each parent of 'symbol' should define a symbol table. - if (!symbolTableOp->hasTrait()) + // Each parent of 'symbol' should define a symbol table or be a symbol + // container + if (!symbolTableOp->hasTrait() && + !symbolTableOp->hasTrait()) return failure(); // Each parent of 'symbol' should also be a symbol. StringAttr symbolTableName = getNameIfSymbol(symbolTableOp, symbolNameId); @@ -117,7 +119,8 @@ walkSymbolTable(Operation *op, /// Build a symbol table with the symbols within the given operation. SymbolTable::SymbolTable(Operation *symbolTableOp) : symbolTableOp(symbolTableOp) { - assert(symbolTableOp->hasTrait() && + assert((symbolTableOp->hasTrait() || + symbolTableOp->hasTrait()) && "expected operation to have SymbolTable trait"); assert(symbolTableOp->getNumRegions() == 1 && "expected operation to have a single region"); @@ -384,7 +387,8 @@ void SymbolTable::walkSymbolTables( /// was found. Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp, StringAttr symbol) { - assert(symbolTableOp->hasTrait()); + assert(symbolTableOp->hasTrait() || + symbolTableOp->hasTrait()); Region ®ion = symbolTableOp->getRegion(0); if (region.empty()) return nullptr; @@ -425,7 +429,8 @@ static LogicalResult lookupSymbolInImpl( return success(); // Verify that the root is also a symbol table. - if (!symbolTableOp->hasTrait()) + if (!symbolTableOp->hasTrait() && + !symbolTableOp->hasTrait()) return failure(); // Otherwise, lookup each of the nested non-leaf references and ensure that @@ -702,7 +707,8 @@ static SmallVector collectSymbolScopes(Operation *symbol, Operation *limitIt = symbol->getParentOp(); for (size_t i = 0, e = references.size(); i != e; ++i, limitIt = limitIt->getParentOp()) { - assert(limitIt->hasTrait()); + assert(limitIt->hasTrait() || + limitIt->hasTrait()); scopes.push_back({references[i], &limitIt->getRegion(0)}); } return scopes; @@ -870,23 +876,27 @@ bool SymbolTable::symbolKnownUseEmpty(Operation *symbol, Region *from) { /// Generates a new symbol reference attribute with a new leaf reference. static SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr, - FlatSymbolRefAttr newLeafAttr) { + SymbolRefAttr newLeafAttr) { if (llvm::isa(oldAttr)) return newLeafAttr; auto nestedRefs = llvm::to_vector<2>(oldAttr.getNestedReferences()); - nestedRefs.back() = newLeafAttr; + nestedRefs.back() = FlatSymbolRefAttr::get(newLeafAttr.getRootReference()); + + nestedRefs.append(newLeafAttr.getNestedReferences().begin(), + newLeafAttr.getNestedReferences().end()); + return SymbolRefAttr::get(oldAttr.getRootReference(), nestedRefs); } /// The implementation of SymbolTable::replaceAllSymbolUses below. template -static LogicalResult -replaceAllSymbolUsesImpl(SymbolT symbol, StringAttr newSymbol, IRUnitT *limit) { +static LogicalResult replaceAllSymbolUsesImpl(SymbolT symbol, + SymbolRefAttr newSymbol, + IRUnitT *limit) { // Generate a new attribute to replace the given attribute. - FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(newSymbol); for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) { SymbolRefAttr oldAttr = scope.symbol; - SymbolRefAttr newAttr = generateNewRefAttr(scope.symbol, newLeafAttr); + SymbolRefAttr newAttr = generateNewRefAttr(scope.symbol, newSymbol); AttrTypeReplacer replacer; replacer.addReplacement( [&](SymbolRefAttr attr) -> std::pair { @@ -899,11 +909,13 @@ replaceAllSymbolUsesImpl(SymbolT symbol, StringAttr newSymbol, IRUnitT *limit) { auto oldNestedRefs = oldAttr.getNestedReferences(); auto nestedRefs = attr.getNestedReferences(); if (oldNestedRefs.empty()) - return {SymbolRefAttr::get(newSymbol, nestedRefs), - WalkResult::skip()}; + return {newAttr, WalkResult::skip()}; auto newNestedRefs = llvm::to_vector<4>(nestedRefs); - newNestedRefs[oldNestedRefs.size() - 1] = newLeafAttr; + newNestedRefs[oldNestedRefs.size() - 1] = + FlatSymbolRefAttr::get(newAttr.getRootReference()); + newNestedRefs.append(newAttr.getNestedReferences().begin(), + newAttr.getNestedReferences().end()); return {SymbolRefAttr::get(attr.getRootReference(), newNestedRefs), WalkResult::skip()}; } @@ -928,21 +940,37 @@ replaceAllSymbolUsesImpl(SymbolT symbol, StringAttr newSymbol, IRUnitT *limit) { LogicalResult SymbolTable::replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from) { - return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from); + auto newSymRef = mlir::FlatSymbolRefAttr::get(newSymbol); + return replaceAllSymbolUsesImpl(oldSymbol, newSymRef, from); } LogicalResult SymbolTable::replaceAllSymbolUses(Operation *oldSymbol, StringAttr newSymbol, Operation *from) { - return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from); + auto newSymRef = mlir::FlatSymbolRefAttr::get(newSymbol); + return replaceAllSymbolUsesImpl(oldSymbol, newSymRef, from); } LogicalResult SymbolTable::replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Region *from) { - return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from); + auto newSymRef = mlir::FlatSymbolRefAttr::get(newSymbol); + return replaceAllSymbolUsesImpl(oldSymbol, newSymRef, from); } LogicalResult SymbolTable::replaceAllSymbolUses(Operation *oldSymbol, StringAttr newSymbol, Region *from) { + auto newSymRef = mlir::FlatSymbolRefAttr::get(newSymbol); + return replaceAllSymbolUsesImpl(oldSymbol, newSymRef, from); +} + +LogicalResult SymbolTable::replaceAllSymbolUses(Operation *oldSymbol, + SymbolRefAttr newSymbol, + Operation *from) { + return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from); +} + +LogicalResult SymbolTable::replaceAllSymbolUses(Operation *oldSymbol, + SymbolRefAttr newSymbol, + Region *from) { return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from); } @@ -1084,7 +1112,7 @@ SymbolUserMap::SymbolUserMap(SymbolTableCollection &symbolTable, } void SymbolUserMap::replaceAllUsesWith(Operation *symbol, - StringAttr newSymbolName) { + SymbolRefAttr newSymbolName) { auto it = symbolToUsers.find(symbol); if (it == symbolToUsers.end()) return; @@ -1111,6 +1139,11 @@ void SymbolUserMap::replaceAllUsesWith(Operation *symbol, } } +void SymbolUserMap::replaceAllUsesWith(Operation *symbol, + StringAttr newSymbolName) { + replaceAllUsesWith(symbol, mlir::FlatSymbolRefAttr::get(newSymbolName)); +} + //===----------------------------------------------------------------------===// // Visibility parsing implementation. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp index dafec39fd5eb0..8ba352d2b1ab0 100644 --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -80,6 +80,12 @@ void Pass::copyOptionValuesFrom(const Pass *other) { passOptions.copyOptionValuesFrom(other->passOptions); } +/// Copy the option values from 'other', which are PassPipeline options. +/// Here we copy only those options that have the same argument name. +void Pass::copyOptionValuesFrom(const PassOptions &other) { + passOptions.matchAndCopyOptionValuesFrom(other); +} + /// Prints out the pass in the textual representation of pipelines. If this is /// an adaptor pass, print its pass managers. When `pretty` is true, the /// printed pipeline is formatted for readability. diff --git a/mlir/lib/Pass/PassRegistry.cpp b/mlir/lib/Pass/PassRegistry.cpp index 7c294f08a32bb..ed5ab073b9159 100644 --- a/mlir/lib/Pass/PassRegistry.cpp +++ b/mlir/lib/Pass/PassRegistry.cpp @@ -277,6 +277,20 @@ void detail::PassOptions::copyOptionValuesFrom(const PassOptions &other) { std::get<0>(optionsIt)->copyValueFrom(*std::get<1>(optionsIt)); } +/// Copy only those options that have the same argument name. +void detail::PassOptions::matchAndCopyOptionValuesFrom( + const PassOptions &other) { + for (auto *optionsIt : other.options) { + const auto &it = llvm::find_if(options, [&](OptionBase *option) { + return option->getArgStr() == optionsIt->getArgStr(); + }); + + if (it != options.end()) { + (*it)->copyValueFrom(*optionsIt); + } + } +} + /// Parse in the next argument from the given options string. Returns a tuple /// containing [the key of the option, the value of the option, updated /// `options` string pointing after the parsed option]. diff --git a/mlir/lib/Support/Timing.cpp b/mlir/lib/Support/Timing.cpp index ac16eb7d224c9..e1e51965354b0 100644 --- a/mlir/lib/Support/Timing.cpp +++ b/mlir/lib/Support/Timing.cpp @@ -621,11 +621,17 @@ void mlir::applyDefaultTimingManagerCLOptions(DefaultTimingManager &tm) { return; tm.setEnabled(options->timing); tm.setDisplayMode(options->displayMode); + tm.setOutput(createOutputStrategy(options->outputFormat, llvm::errs())); +} - std::unique_ptr printer; - if (options->outputFormat == OutputFormat::Text) - printer = std::make_unique(llvm::errs()); - else if (options->outputFormat == OutputFormat::Json) - printer = std::make_unique(llvm::errs()); - tm.setOutput(std::move(printer)); +std::unique_ptr +mlir::createOutputStrategy(DefaultTimingManager::OutputFormat fmt, + raw_ostream &os) { + switch (fmt) { + case OutputFormat::Text: + return std::make_unique(os); + case OutputFormat::Json: + return std::make_unique(os); + } + llvm_unreachable("Invalid output format"); } diff --git a/mlir/lib/Transforms/Utils/Inliner.cpp b/mlir/lib/Transforms/Utils/Inliner.cpp index b639e87f52744..6ee71de0f6fdb 100644 --- a/mlir/lib/Transforms/Utils/Inliner.cpp +++ b/mlir/lib/Transforms/Utils/Inliner.cpp @@ -687,7 +687,9 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface, useList.mergeUsesAfterInlining(it.targetNode, it.sourceNode); // then erase the call. - call.erase(); + const auto *callInterface = + inlinerIface.getInterfaceFor(call->getDialect()); + callInterface->eraseCall(call); // If we inlined in place, mark the node for deletion. if (inlineInPlace) { diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp index eeb40529cc2fe..2b0a7e277aace 100644 --- a/mlir/lib/Transforms/Utils/InliningUtils.cpp +++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp @@ -91,6 +91,19 @@ static void remapInlinedOperands(iterator_range inlinedBlocks, block.walk(remapOperands); } +//===----------------------------------------------------------------------===// +// DialectInlinerInterface +//===----------------------------------------------------------------------===// + +void DialectInlinerInterface::eraseCall(Operation *call) const { + call->erase(); +} + +std::tuple +DialectInlinerInterface::getInlineBlockAndPoint(Operation *call) const { + return std::make_tuple(call->getBlock(), std::next(call->getIterator())); +} + //===----------------------------------------------------------------------===// // InlinerInterface //===----------------------------------------------------------------------===// @@ -548,9 +561,11 @@ LogicalResult mlir::inlineCall( if (!interface.isLegalToInline(call, callable, shouldCloneInlinedRegion)) return cleanupState(); + auto [inlineBlock, inlinePoint] = callInterface->getInlineBlockAndPoint(call); + // Attempt to inline the call. - if (failed(inlineRegionImpl(interface, cloneCallback, src, call->getBlock(), - ++call->getIterator(), mapper, callResults, + if (failed(inlineRegionImpl(interface, cloneCallback, src, inlineBlock, + inlinePoint, mapper, callResults, callableResultTypes, call.getLoc(), shouldCloneInlinedRegion, call))) return cleanupState(); diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir index 2efb5893c8511..eb0093106dc11 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir @@ -810,3 +810,59 @@ module @inner_module { return %t : tensor<5xf32> } } + +// ----- + +// CHECK: func.func @custom_types( +// CHECK-SAME: %[[arg:.*]]: !test.test_memref<[4, 4], f64> +// CHECK-SAME: ) -> (!test.test_memref<[4, 8], f64>, +// CHECK-SAME: !test.test_memref<[4, 8], f64>) +func.func @custom_types(%arg: !test.test_tensor<[4, 4], f64>) + -> (!test.test_tensor<[4, 8], f64>, !test.test_tensor<[4, 8], f64>) { + // CHECK: %[[out1:.*]] = "test.dummy_memref_op"(%[[arg]]) : + // CHECK-SAME: (!test.test_memref<[4, 4], f64>) -> !test.test_memref<[4, 8], f64> + %out1 = "test.dummy_tensor_op"(%arg) : (!test.test_tensor<[4, 4], f64>) + -> !test.test_tensor<[4, 8], f64> + + // CHECK: %[[alloc:.*]] = "test.create_memref_op" + // CHECK: %[[out2:.*]] = "test.dummy_memref_op"(%[[alloc]]) + // CHECK-SAME: (!test.test_memref<[4, 4], f64>) -> !test.test_memref<[4, 8], f64> + %alloc = "test.create_tensor_op"() : () -> !test.test_tensor<[4, 4], f64> + %out2 = "test.dummy_tensor_op"(%alloc) : (!test.test_tensor<[4, 4], f64>) + -> !test.test_tensor<[4, 8], f64> + + // CHECK: return %[[out1]], %[[out2]] + return %out1, %out2 : + !test.test_tensor<[4, 8], f64>, !test.test_tensor<[4, 8], f64> +} + +// ----- + +// CHECK: func.func @custom_types_foo( +// CHECK-SAME: %[[arg:.*]]: !test.test_memref<[4, 4], f64> +// CHECK-SAME: ) -> !test.test_memref<[4, 4], f64> +func.func @custom_types_foo(%arg: !test.test_tensor<[4, 4], f64>) + -> !test.test_tensor<[4, 4], f64> { + // CHECK: %[[out:.*]] = "test.dummy_memref_op"(%[[arg]]) + %out = "test.dummy_tensor_op"(%arg) : (!test.test_tensor<[4, 4], f64>) + -> !test.test_tensor<[4, 4], f64> + // CHECK: return %[[out]] + return %out : !test.test_tensor<[4, 4], f64> +} + +// CHECK: func.func @custom_types_bar( +// CHECK-SAME: %[[arg:.*]]: !test.test_memref<[4, 4], f64> +// CHECK-SAME: ) -> !test.test_memref<[4, 8], f64> +func.func @custom_types_bar(%arg: !test.test_tensor<[4, 4], f64>) + -> !test.test_tensor<[4, 8], f64> { + // CHECK: %[[call:.*]] = call @custom_types_foo(%[[arg]]) + %call = func.call @custom_types_foo(%arg) : (!test.test_tensor<[4, 4], f64>) + -> !test.test_tensor<[4, 4], f64> + + // CHECK: %[[out:.*]] = "test.dummy_memref_op"(%[[call]]) + %out = "test.dummy_tensor_op"(%call) : (!test.test_tensor<[4, 4], f64>) + -> !test.test_tensor<[4, 8], f64> + + // CHECK: return %[[out]] + return %out : !test.test_tensor<[4, 8], f64> +} diff --git a/mlir/test/Dialect/Bufferization/invalid.mlir b/mlir/test/Dialect/Bufferization/invalid.mlir index 2c8807b66de74..9884b040119d0 100644 --- a/mlir/test/Dialect/Bufferization/invalid.mlir +++ b/mlir/test/Dialect/Bufferization/invalid.mlir @@ -127,3 +127,63 @@ func.func @invalid_manual_deallocation() { // expected-error @below{{op attribute 'bufferization.manual_deallocation' can be used only on ops that have an allocation and/or free side effect}} arith.constant {bufferization.manual_deallocation} 0 : index } + +// ----- + +func.func @invalid_rank_to_buffer(%t: tensor<1x2x3x4xf32>) { + // expected-error @below{{'bufferization.to_buffer' op failed to verify that specified tensor and buffer types match}} + // expected-error @below{{shapes do not match}} + %b = bufferization.to_buffer %t + : tensor<1x2x3x4xf32> to memref<1x2x3xf32> + return +} + +// ----- + +func.func @invalid_rank_to_tensor(%b: memref<1x2x3xf32>) { + // expected-error @below{{'bufferization.to_tensor' op failed to verify that specified tensor and buffer types match}} + // expected-error @below{{shapes do not match}} + %t = bufferization.to_tensor %b + : memref<1x2x3xf32> to tensor<1x2x3x4xf32> + return +} + +// ----- + +func.func @invalid_shape_to_buffer(%t: tensor<1x2x3x4xf32>) { + // expected-error @below{{'bufferization.to_buffer' op failed to verify that specified tensor and buffer types match}} + // expected-error @below{{shapes do not match}} + %b = bufferization.to_buffer %t + : tensor<1x2x3x4xf32> to memref<1x2x4x3xf32> + return +} + +// ----- + +func.func @invalid_shape_to_tensor(%b: memref<1x2x4x3xf32>) { + // expected-error @below{{'bufferization.to_tensor' op failed to verify that specified tensor and buffer types match}} + // expected-error @below{{shapes do not match}} + %t = bufferization.to_tensor %b + : memref<1x2x4x3xf32> to tensor<1x2x3x4xf32> + return +} + +// ----- + +func.func @invalid_type_to_buffer(%t: tensor<1x2x3x4xf32>) { + // expected-error @below{{'bufferization.to_buffer' op failed to verify that specified tensor and buffer types match}} + // expected-error @below{{element types do not match}} + %b = bufferization.to_buffer %t + : tensor<1x2x3x4xf32> to memref<1x2x3x4xf16> + return +} + +// ----- + +func.func @invalid_type_to_tensor(%b: memref<1x2x3x4xf16>) { + // expected-error @below{{'bufferization.to_tensor' op failed to verify that specified tensor and buffer types match}} + // expected-error @below{{element types do not match}} + %t2 = bufferization.to_tensor %b + : memref<1x2x3x4xf16> to tensor<1x2x3x4xf32> + return +} diff --git a/mlir/test/Dialect/Bufferization/ops.mlir b/mlir/test/Dialect/Bufferization/ops.mlir index fc6df4a09f706..b0db1bb2d0389 100644 --- a/mlir/test/Dialect/Bufferization/ops.mlir +++ b/mlir/test/Dialect/Bufferization/ops.mlir @@ -83,3 +83,40 @@ func.func @test_dealloc_op(%arg0: memref<2xf32>, %arg1: memref<4xi32>, bufferization.dealloc return %0#0, %0#1 : i1, i1 } + +// CHECK: func.func @test_builtin_custom_builtin_type_conversion +// CHECK-SAME: (%[[t:.*]]: tensor<42xf32>) -> tensor<42xf32> +func.func @test_builtin_custom_builtin_type_conversion(%t: tensor<42xf32>) + -> tensor<42xf32> { + // CHECK: %[[buffer:.*]] = bufferization.to_buffer %[[t]] + // CHECK-SAME: to !test.test_memref<[42], f32> + %buffer = bufferization.to_buffer %t + : tensor<42xf32> to !test.test_memref<[42], f32> + + // CHECK: %[[tensor:.*]] = bufferization.to_tensor %[[buffer]] + // CHECK-SAME: to tensor<42xf32> + %tensor = bufferization.to_tensor %buffer + : !test.test_memref<[42], f32> to tensor<42xf32> + + // CHECK: return %[[tensor]] + return %tensor : tensor<42xf32> +} + +// CHECK: func.func @test_custom_builtin_custom_type_conversion +// CHECK-SAME: (%[[t:.*]]: !test.test_tensor<[42], f32>) +// CHECK-SAME: -> !test.test_tensor<[42], f32> +func.func @test_custom_builtin_custom_type_conversion(%t: !test.test_tensor<[42], f32>) + -> !test.test_tensor<[42], f32> { + // CHECK: %[[buffer:.*]] = bufferization.to_buffer %[[t]] + // CHECK-SAME: to memref<42xf32> + %buffer = bufferization.to_buffer %t + : !test.test_tensor<[42], f32> to memref<42xf32> + + // CHECK: %[[tensor:.*]] = bufferization.to_tensor %[[buffer]] + // CHECK-SAME: to !test.test_tensor<[42], f32> + %tensor = bufferization.to_tensor %buffer + : memref<42xf32> to !test.test_tensor<[42], f32> + + // CHECK: return %[[tensor]] + return %tensor : !test.test_tensor<[42], f32> +} diff --git a/mlir/test/Dialect/LLVMIR/global.mlir b/mlir/test/Dialect/LLVMIR/global.mlir index 193ab7987a2b6..19a65be62dc92 100644 --- a/mlir/test/Dialect/LLVMIR/global.mlir +++ b/mlir/test/Dialect/LLVMIR/global.mlir @@ -132,7 +132,7 @@ llvm.mlir.global internal constant @constant(37.0) : !llvm.label // ----- func.func @foo() { - // expected-error @+1 {{op symbol's parent must have the SymbolTable trait}} + // expected-error @+1 {{must appear at the module level}} llvm.mlir.global internal @bar(42) : i32 return diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index bd1106e304c60..62cb562e513e8 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -99,20 +99,6 @@ func.func @alloca_non_integer_alignment() { // ----- -func.func @gep_missing_input_result_type(%pos : i64, %base : !llvm.ptr) { - // expected-error@+1 {{number of operands and types do not match: got 2 operands and 0 types}} - llvm.getelementptr %base[%pos] : () -> (), i64 -} - -// ----- - -func.func @gep_missing_input_type(%pos : i64, %base : !llvm.ptr) { - // expected-error@+1 {{number of operands and types do not match: got 2 operands and 0 types}} - llvm.getelementptr %base[%pos] : () -> (!llvm.ptr), i64 -} - -// ----- - func.func @gep_missing_result_type(%pos : i64, %base : !llvm.ptr) { // expected-error@+1 {{op requires one result}} llvm.getelementptr %base[%pos] : (!llvm.ptr, i64) -> (), i64 diff --git a/mlir/test/Dialect/Linalg/transform-op-replace.mlir b/mlir/test/Dialect/Linalg/transform-op-replace.mlir index 1a40912977dec..2801522e81ac2 100644 --- a/mlir/test/Dialect/Linalg/transform-op-replace.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-replace.mlir @@ -12,10 +12,8 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op transform.structured.replace %0 { - builtin.module { - func.func @foo() { - "dummy_op"() : () -> () - } + func.func @foo() { + "dummy_op"() : () -> () } } : (!transform.any_op) -> !transform.any_op transform.yield diff --git a/mlir/test/Dialect/Quant/Bytecode/types.mlir b/mlir/test/Dialect/Quant/Bytecode/types.mlir index 8c79b757eeb19..b2fb2ea9f0e91 100644 --- a/mlir/test/Dialect/Quant/Bytecode/types.mlir +++ b/mlir/test/Dialect/Quant/Bytecode/types.mlir @@ -73,3 +73,59 @@ module @parseUniformSubChannel attributes { // CHECK: !quant.uniform bytecode.test = !quant.uniform } {} + +//===----------------------------------------------------------------------===// +// QuantileQuantized +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: parseQuantilePerLayerFp16 +module @parseQuantilePerLayerFp16 attributes { + // CHECK: !quant.quantile + bytecode.test = !quant.quantile +} {} + +// CHECK-LABEL: parseQuantilePerLayerBf16 +module @parseQuantilePerLayerBf16 attributes { + // CHECK: !quant.quantile + bytecode.test = !quant.quantile +} {} + +// CHECK-LABEL: parseQuantilePerLayerI8 +module @parseQuantilePerLayerI8 attributes { + // CHECK: !quant.quantile + bytecode.test = !quant.quantile +} {} + +// CHECK-LABEL: parseQuantilePerLayerU8 +module @parseQuantilePerLayerU8 attributes { + // CHECK: !quant.quantile + bytecode.test = !quant.quantile +} {} + +//===----------------------------------------------------------------------===// +// QuantileQuantizedPerAxis +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: parseQuantilePerAxisScaleZero +module @parseQuantilePerAxisScaleZero attributes { + // CHECK: !quant.quantile + bytecode.test = !quant.quantile +} {} + +// CHECK-LABEL: parseQuantilePerAxisScaleZeroU4 +module @parseQuantilePerAxisScaleZeroU4 attributes { + // CHECK: !quant.quantile + bytecode.test = !quant.quantile:f16:f32:1, {-1.0000,-0.8667,-0.7333,-0.6000,-0.4667,-0.3333,-0.2000,-0.0667,0.0667,0.2000,0.3333,0.4667,0.6000,0.7333,0.8667,1.0000}:{2.000000e+02:-120,9.987200e-01:127}> +} {} + +// CHECK-LABEL: parseQuantilePerAxisScaleNoZero +module @parseQuantilePerAxisScaleNoZero attributes { + // CHECK: !quant.quantile + bytecode.test = !quant.quantile +} {} + +// CHECK-LABEL: parseQuantilePerAxisMixed +module @parseQuantilePerAxisMixed attributes { + // CHECK: !quant.quantile + bytecode.test = !quant.quantile +} {} diff --git a/mlir/test/Dialect/Quant/parse-any-invalid.mlir b/mlir/test/Dialect/Quant/parse-any-invalid.mlir index 41c5f93070717..7ea4ddc61db89 100644 --- a/mlir/test/Dialect/Quant/parse-any-invalid.mlir +++ b/mlir/test/Dialect/Quant/parse-any-invalid.mlir @@ -17,12 +17,12 @@ // ----- // Unrecognized storage type: illegal prefix -// expected-error@+1 {{illegal storage type prefix}} +// expected-error@+1 {{illegal quantized storage type alias}} !qalias = !quant.any:f32> // ----- // Unrecognized storage type: no width -// expected-error@+1 {{illegal storage type prefix}} +// expected-error@+1 {{illegal quantized storage type alias}} !qalias = !quant.any:f32> // ----- diff --git a/mlir/test/Dialect/Quant/parse-quantile-invalid.mlir b/mlir/test/Dialect/Quant/parse-quantile-invalid.mlir new file mode 100644 index 0000000000000..6e1bacd7d4c4f --- /dev/null +++ b/mlir/test/Dialect/Quant/parse-quantile-invalid.mlir @@ -0,0 +1,198 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics + +// ----- +// Illegal missing quantileType +// expected-error@+1 {{expected ':'}} +!qalias = !quant.quantile +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Illegal quantileType value +// expected-error@+1 {{illegal quantile type alias}} +!qalias = !quant.quantile +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Illegal quantile array size +// expected-error@+1 {{quantiles array size needs to be equal to 2^(bit_size(storageType)), or (storageTypeMax - storageTypeMin + 1) when max and min differ from the type limits; expected: 256, found: 2}} +!qalias = !quant.quantile +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Illegal quantile array size (per axis type) +// expected-error@+1 {{quantiles array size needs to be equal to 2^(bit_size(storageType)), or (storageTypeMax - storageTypeMin + 1) when max and min differ from the type limits; expected: 256, found: 2}} +!qalias = !quant.quantile +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Unrecognized token: trailing +// expected-error@+1 {{expected '>'}} +!qalias = !quant.quantile:f16:f32, {-1.0,1.0}:0.99872:127 23> + +// ----- +// Unrecognized token: missing storage type maximum +// expected-error@+1 {{expected ':'}} +!qalias = !quant.quantile:f16:f32, {-1.0,1.0}:0.99872:127> + +// ----- +// Unrecognized token: missing closing angle bracket +// expected-error@+1 {{unbalanced '<' character in pretty dialect name}} +!qalias = !quant> + +// ----- +// Unrecognized token: missing type colon +// expected-error@+1 {{expected ':'}} +!qalias = !quant.quantilef16:f32, {-1.0,1.0}:0.99872:127> + +// ----- +// Unrecognized token: missing comma +// expected-error@+1 {{expected ','}} +!qalias = !quant.quantile + +// ----- +// Unrecognized storage type: illegal prefix +// expected-error@+1 {{illegal quantized storage type alias}} +!qalias = !quant.quantile:f16:f32, {-1.0,1.0}:0.99872:127> + +// ----- +// Unrecognized storage type: no width +// expected-error@+1 {{illegal quantized storage type alias}} +!qalias = !quant.quantile:f16:f32, {-1.0,1.0}:0.99872:127> + +// ----- +// Unrecognized storage type: storage size > 32 +// expected-error@+1 {{illegal storage type size: 33}} +!qalias = !quant.quantile + +// ----- +// Unrecognized storage type: storage size < 0 +// expected-error@+1 {{illegal quantized storage type alias}} +!qalias = !quant.quantile:f16:f32, {-1.0,1.0}:0.99872:127> + +// ----- +// Unrecognized storage type: storage size +// expected-error@+1 {{invalid integer width}} +!qalias = !quant.quantile:f16:f32, {-1.0,1.0}:0.99872:127> + +// ----- +// Illegal storage min/max: max - min < 0 +// expected-error@+1 {{illegal storage min and storage max: (2:1)}} +!qalias = !quant.quantile:f16:f32, {-1.0,1.0}:0.99872:127> + +// ----- +// Illegal storage min/max: max - min == 0 +// expected-error@+1 {{illegal storage min and storage max: (1:1)}} +!qalias = !quant.quantile:f16:f32, {-1.0,1.0}:0.99872:127> + +// ----- +// Illegal storage min/max: max > defaultMax +// expected-error@+1 {{illegal storage type maximum: 9}} +!qalias = !quant.quantile:f16:f32, {-1.0,1.0}:0.99872:127> + +// ----- +// Illegal storage min/max: min < defaultMin +// expected-error@+1 {{illegal storage type minimum: -9}} +!qalias = !quant.quantile:f16:f32, {-1.0,1.0}:0.99872:127> + +// ----- +// Illegal storage min/max: max > defaultMax +// expected-error@+1 {{illegal storage type maximum: 60000}} +!qalias = !quant.quantile:f16:f32, {-1.0,1.0}:0.99872:127> + +// ----- +// Illegal storage min/max: min < defaultMin +// expected-error@+1 {{illegal storage type minimum: -60000}} +!qalias = !quant.quantile:f16:f32, {-1.0,1.0}:0.99872:127> + +// ----- +// Illegal storage min/max: max > defaultMax +// expected-error@+1 {{illegal storage type maximum: 500}} +!qalias = !quant.quantile:f16:f32, {-1.0,1.0}:0.99872:127> + +// ----- +// Illegal storage min/max: min < defaultMin +// expected-error@+1 {{illegal storage type minimum: -500}} +!qalias = !quant.quantile:f16:f32, {-1.0,1.0}:0.99872:127> + +// ----- +// Illegal storage min/max: max > defaultMax +// expected-error@+1 {{illegal storage type maximum: 10}} +!qalias = !quant.quantile:f16:f32, {-1.0,1.0}:0.99872:127> + +// ----- +// Illegal storage min/max: min < defaultMin +// expected-error@+1 {{illegal storage type minimum: -10}} +!qalias = !quant.quantile:f16:f32, {-1.0,1.0}:0.99872:127> + +// ----- +// Illegal uniform params: invalid scale +// expected-error@+1 {{expected floating point literal}} +!qalias = !quant.quantile:f16:f32, {-1.0,1.0}:abc:127> + +// ----- +// Illegal uniform params: invalid zero point separator +// expected-error@+1 {{expected '>'}} +!qalias = !quant.quantile:f16:f32, {-1.0,1.0}:0.1abc> + +// ----- +// Illegal uniform params: missing zero point +// expected-error@+1 {{expected integer value}} +!qalias = !quant.quantile:f16:f32, {-1.0,1.0}:0.1:> + +// ----- +// Illegal uniform params: invalid zero point +// expected-error@+1 {{expected integer value}} +!qalias = !quant.quantile:f16:f32, {-1.0,1.0}:0.1:abc> + +// ----- +// Illegal expressed type: f33 +// expected-error@+1 {{expected non-function type}} +!qalias = !quant.quantile:f16:f33, {-1.0,1.0}:0.99872:127> + +// ----- +// Illegal uniform params: missing quantized dimension +// expected-error@+1 {{expected integer value}} +!qalias = !quant.quantile:f16:f32:, {-1.0,1.0}:{2.000000e+02:-19.987200e-01:1}> + +// ----- +// Illegal uniform params: unspecified quantized dimension, when multiple scales +// provided. +// expected-error@+1 {{expected floating point literal}} +!qalias = !quant.quantile:f16:f32, {-1.0,1.0}:{2.000000e+02,-19.987200e-01:1}> + +// ----- +// Illegal quantile params: unspecified quantile values +// expected-error@+1 {{expected floating point literal}} +!qalias = !quant.quantile:f16:f32, {}:0.99872:127> + +// ----- +// Illegal quantile params: missing quantile values +// expected-error@+1 {{expected floating point literal}} +!qalias = !quant.quantile:f16:f32, {-1.0,}:0.99872:127> + +// ----- +// Illegal quantile params: missing colon separator +// expected-error@+1 {{expected ':'}} +!qalias = !quant.quantile:f16:f32, {-1.0,1.0}0.99872:127> + +// ----- +// Illegal quantile params: unbalanced } +// expected-error@+1 {{unbalanced '{' character in pretty dialect name}} +!qalias = !quant.quantile:f16:f32, {-1.0,1.0:0.99872:127> + +// ----- +// Illegal quantile params: missing { +// expected-error@+1 {{unbalanced '<' character in pretty dialect name}} +!qalias = !quant.quantile:f16:f32, -1.0,1.0}:0.99872:127> diff --git a/mlir/test/Dialect/Quant/parse-quantile.mlir b/mlir/test/Dialect/Quant/parse-quantile.mlir new file mode 100644 index 0000000000000..bb20499eb74d8 --- /dev/null +++ b/mlir/test/Dialect/Quant/parse-quantile.mlir @@ -0,0 +1,201 @@ +// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file | FileCheck %s + +// ----- +// All per-layer params specified: +// [signed] storageType, storageTypeMin, storageTypeMax, expressedType, scale, zeroPoint +// CHECK: !quant.quantile:f16:f32, {-1.000000e+00,-8.667000e-01,-7.333000e-01,-6.000000e-01,-4.667000e-01,-3.333000e-01,-2.000000e-01,-0.066699999999999995,0.066699999999999995,2.000000e-01,3.333000e-01,4.667000e-01,6.000000e-01,7.333000e-01,8.667000e-01,1.000000e+00}:9.987200e-01:127> +!qalias = !quant.quantile:f16:f32, {-1.0000,-0.8667,-0.7333,-0.6000,-0.4667,-0.3333,-0.2000,-0.0667,0.0667,0.2000,0.3333,0.4667,0.6000,0.7333,0.8667,1.0000}:0.99872:127> +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Trailing whitespace. +// CHECK: !quant.quantile +!qalias = !quant.quantile +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Default min/max value optimization for integers. +// CHECK: !quant.quantile +!qalias = !quant.quantile:f16:f32, {-1.0000,-0.8667,-0.7333,-0.6000,-0.4667,-0.3333,-0.2000,-0.0667,0.0667,0.2000,0.3333,0.4667,0.6000,0.7333,0.8667,1.0000}:0.99872:127 > +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Default min/max value optimization for f8E5M2. +// CHECK: !quant.quantile +!qalias = !quant.quantile:f16:f32, {-1.0000,-0.9922,-0.9843,-0.9765,-0.9686,-0.9608,-0.9529,-0.9451,-0.9373,-0.9294,-0.9216,-0.9137,-0.9059,-0.8980,-0.8902,-0.8824,-0.8745,-0.8667,-0.8588,-0.8510,-0.8431,-0.8353,-0.8275,-0.8196,-0.8118,-0.8039,-0.7961,-0.7882,-0.7804,-0.7725,-0.7647,-0.7569,-0.7490,-0.7412,-0.7333,-0.7255,-0.7176,-0.7098,-0.7020,-0.6941,-0.6863,-0.6784,-0.6706,-0.6627,-0.6549,-0.6471,-0.6392,-0.6314,-0.6235,-0.6157,-0.6078,-0.6000,-0.5922,-0.5843,-0.5765,-0.5686,-0.5608,-0.5529,-0.5451,-0.5373,-0.5294,-0.5216,-0.5137,-0.5059,-0.4980,-0.4902,-0.4824,-0.4745,-0.4667,-0.4588,-0.4510,-0.4431,-0.4353,-0.4275,-0.4196,-0.4118,-0.4039,-0.3961,-0.3882,-0.3804,-0.3725,-0.3647,-0.3569,-0.3490,-0.3412,-0.3333,-0.3255,-0.3176,-0.3098,-0.3020,-0.2941,-0.2863,-0.2784,-0.2706,-0.2627,-0.2549,-0.2471,-0.2392,-0.2314,-0.2235,-0.2157,-0.2078,-0.2000,-0.1922,-0.1843,-0.1765,-0.1686,-0.1608,-0.1529,-0.1451,-0.1373,-0.1294,-0.1216,-0.1137,-0.1059,-0.0980,-0.0902,-0.0824,-0.0745,-0.0667,-0.0588,-0.0510,-0.0431,-0.0353,-0.0275,-0.0196,-0.0118,-0.0039,0.0039,0.0118,0.0196,0.0275,0.0353,0.0431,0.0510,0.0588,0.0667,0.0745,0.0824,0.0902,0.0980,0.1059,0.1137,0.1216,0.1294,0.1373,0.1451,0.1529,0.1608,0.1686,0.1765,0.1843,0.1922,0.2000,0.2078,0.2157,0.2235,0.2314,0.2392,0.2471,0.2549,0.2627,0.2706,0.2784,0.2863,0.2941,0.3020,0.3098,0.3176,0.3255,0.3333,0.3412,0.3490,0.3569,0.3647,0.3725,0.3804,0.3882,0.3961,0.4039,0.4118,0.4196,0.4275,0.4353,0.4431,0.4510,0.4588,0.4667,0.4745,0.4824,0.4902,0.4980,0.5059,0.5137,0.5216,0.5294,0.5373,0.5451,0.5529,0.5608,0.5686,0.5765,0.5843,0.5922,0.6000,0.6078,0.6157,0.6235,0.6314,0.6392,0.6471,0.6549,0.6627,0.6706,0.6784,0.6863,0.6941,0.7020,0.7098,0.7176,0.7255,0.7333,0.7412,0.7490,0.7569,0.7647,0.7725,0.7804,0.7882,0.7961,0.8039,0.8118,0.8196,0.8275,0.8353,0.8431,0.8510,0.8588,0.8667,0.8745,0.8824,0.8902,0.8980,0.9059,0.9137,0.9216,0.9294,0.9373,0.9451,0.9529,0.9608,0.9686,0.9765,0.9843,0.9922,1.0000}:0.99872:127 > +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Default min/max value optimization for f8E4M3FN. +// CHECK: !quant.quantile +!qalias = !quant.quantile:f16:f32, {-1.0000,-0.9922,-0.9843,-0.9765,-0.9686,-0.9608,-0.9529,-0.9451,-0.9373,-0.9294,-0.9216,-0.9137,-0.9059,-0.8980,-0.8902,-0.8824,-0.8745,-0.8667,-0.8588,-0.8510,-0.8431,-0.8353,-0.8275,-0.8196,-0.8118,-0.8039,-0.7961,-0.7882,-0.7804,-0.7725,-0.7647,-0.7569,-0.7490,-0.7412,-0.7333,-0.7255,-0.7176,-0.7098,-0.7020,-0.6941,-0.6863,-0.6784,-0.6706,-0.6627,-0.6549,-0.6471,-0.6392,-0.6314,-0.6235,-0.6157,-0.6078,-0.6000,-0.5922,-0.5843,-0.5765,-0.5686,-0.5608,-0.5529,-0.5451,-0.5373,-0.5294,-0.5216,-0.5137,-0.5059,-0.4980,-0.4902,-0.4824,-0.4745,-0.4667,-0.4588,-0.4510,-0.4431,-0.4353,-0.4275,-0.4196,-0.4118,-0.4039,-0.3961,-0.3882,-0.3804,-0.3725,-0.3647,-0.3569,-0.3490,-0.3412,-0.3333,-0.3255,-0.3176,-0.3098,-0.3020,-0.2941,-0.2863,-0.2784,-0.2706,-0.2627,-0.2549,-0.2471,-0.2392,-0.2314,-0.2235,-0.2157,-0.2078,-0.2000,-0.1922,-0.1843,-0.1765,-0.1686,-0.1608,-0.1529,-0.1451,-0.1373,-0.1294,-0.1216,-0.1137,-0.1059,-0.0980,-0.0902,-0.0824,-0.0745,-0.0667,-0.0588,-0.0510,-0.0431,-0.0353,-0.0275,-0.0196,-0.0118,-0.0039,0.0039,0.0118,0.0196,0.0275,0.0353,0.0431,0.0510,0.0588,0.0667,0.0745,0.0824,0.0902,0.0980,0.1059,0.1137,0.1216,0.1294,0.1373,0.1451,0.1529,0.1608,0.1686,0.1765,0.1843,0.1922,0.2000,0.2078,0.2157,0.2235,0.2314,0.2392,0.2471,0.2549,0.2627,0.2706,0.2784,0.2863,0.2941,0.3020,0.3098,0.3176,0.3255,0.3333,0.3412,0.3490,0.3569,0.3647,0.3725,0.3804,0.3882,0.3961,0.4039,0.4118,0.4196,0.4275,0.4353,0.4431,0.4510,0.4588,0.4667,0.4745,0.4824,0.4902,0.4980,0.5059,0.5137,0.5216,0.5294,0.5373,0.5451,0.5529,0.5608,0.5686,0.5765,0.5843,0.5922,0.6000,0.6078,0.6157,0.6235,0.6314,0.6392,0.6471,0.6549,0.6627,0.6706,0.6784,0.6863,0.6941,0.7020,0.7098,0.7176,0.7255,0.7333,0.7412,0.7490,0.7569,0.7647,0.7725,0.7804,0.7882,0.7961,0.8039,0.8118,0.8196,0.8275,0.8353,0.8431,0.8510,0.8588,0.8667,0.8745,0.8824,0.8902,0.8980,0.9059,0.9137,0.9216,0.9294,0.9373,0.9451,0.9529,0.9608,0.9686,0.9765,0.9843,0.9922,1.0000}:0.99872:127 > +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Default min/max value optimization for f4E2M1FN. +// CHECK: !quant.quantile +!qalias = !quant.quantile:f16:f32, {-1.0000,-0.8667,-0.7333,-0.6000,-0.4667,-0.3333,-0.2000,-0.0667,0.0667,0.2000,0.3333,0.4667,0.6000,0.7333,0.8667,1.0000}:0.99872:127 > +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Required per-layer params specified: +// [unsigned] storageType, expressedType, scale +// CHECK: !quant.quantile +!qalias = !quant.quantile +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Exponential scale (-) +// CHECK: !quant.quantile +!qalias = !quant.quantile +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Exponential scale (+) +// CHECK: !quant.quantile +!qalias = !quant.quantile +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Storage type: f8E5M2 +// CHECK: !quant.quantile +!qalias = !quant.quantile +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Storage type: f8E4M3FN +// CHECK: !quant.quantile +!qalias = !quant.quantile +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Storage type: f4E2M1FN +// CHECK: !quant.quantile +!qalias = !quant.quantile +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Expressed type: f32 +// CHECK: !quant.quantile +!qalias = !quant.quantile +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Expressed type: f32 +// CHECK: !quant.quantile +!qalias = !quant.quantile +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Expressed type: f16 +// CHECK: !quant.quantile +!qalias = !quant.quantile +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Expressed type: f64 +// CHECK: !quant.quantile +!qalias = !quant.quantile +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Expressed type: bf16 +// CHECK: !quant.quantile +!qalias = !quant.quantile +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Per-axis scales and zero points (affine) +// CHECK: !quant.quantile +!qalias = !quant.quantile +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Per-axis scales and no zero points (fixedpoint) +// CHECK: !quant.quantile +!qalias = !quant.quantile +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Per-axis scales and zero points (mixed affine and fixedpoint) +// CHECK: !quant.quantile +!qalias = !quant.quantile +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Negative scale checking +// CHECK: !quant.quantile +!qalias = !quant.quantile +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Per-axis negative scale checking +// CHECK: !quant.quantile +!qalias = !quant.quantile +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} diff --git a/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir b/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir index 3b358443e43f2..03e929ff58d8e 100644 --- a/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir +++ b/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir @@ -37,12 +37,12 @@ // ----- // Unrecognized storage type: illegal prefix -// expected-error@+1 {{illegal storage type prefix}} +// expected-error@+1 {{illegal quantized storage type alias}} !qalias = !quant.uniform:f32, 0.99872:127> // ----- // Unrecognized storage type: no width -// expected-error@+1 {{illegal storage type prefix}} +// expected-error@+1 {{illegal quantized storage type alias}} !qalias = !quant.uniform:f32, 0.99872:127> // ----- @@ -52,7 +52,7 @@ // ----- // Unrecognized storage type: storage size < 0 -// expected-error@+1 {{illegal storage type prefix}} +// expected-error@+1 {{illegal quantized storage type alias}} !qalias = !quant.uniform:f32, 0.99872:127> // ----- @@ -80,6 +80,36 @@ // expected-error@+1 {{illegal storage type minimum: -9}} !qalias = !quant.uniform:f32, 0.99872:127> +// ----- +// Illegal storage min/max: max > defaultMax +// expected-error@+1 {{illegal storage type maximum: 60000}} +!qalias = !quant.uniform:f32, 0.99872:127> + +// ----- +// Illegal storage min/max: min < defaultMin +// expected-error@+1 {{illegal storage type minimum: -60000}} +!qalias = !quant.uniform:f32, 0.99872:127> + +// ----- +// Illegal storage min/max: max > defaultMax +// expected-error@+1 {{illegal storage type maximum: 500}} +!qalias = !quant.uniform:f32, 0.99872:127> + +// ----- +// Illegal storage min/max: min < defaultMin +// expected-error@+1 {{illegal storage type minimum: -500}} +!qalias = !quant.uniform:f32, 0.99872:127> + +// ----- +// Illegal storage min/max: max > defaultMax +// expected-error@+1 {{illegal storage type maximum: 10}} +!qalias = !quant.uniform:f32, 0.99872:127> + +// ----- +// Illegal storage min/max: min < defaultMin +// expected-error@+1 {{illegal storage type minimum: -10}} +!qalias = !quant.uniform:f32, 0.99872:127> + // ----- // Illegal uniform params: invalid scale // expected-error@+1 {{expected floating point literal}} @@ -105,11 +135,6 @@ // expected-error@+1 {{expected non-function type}} !qalias = !quant.uniform:f33, 0.99872:127> -// ----- -// Illegal scale: negative -// expected-error@+1 {{scale -1.000000e+00 out of expressed type range}} -!qalias = !quant.uniform:f32, -1.0:127> - // ----- // Illegal uniform params: missing quantized dimension // expected-error@+1 {{expected integer value}} @@ -126,26 +151,26 @@ // expected-error@+1 {{illegal quantized dimension: -1}} !qalias = !quant.uniform -// ----- -// Scale f16 underflow -// expected-error@+1 {{scale 5.800000e-08 out of expressed type range}} -!qalias = !quant.uniform - // ----- // Scale f16 overflow // expected-error@+1 {{scale 6.600000e+04 out of expressed type range}} !qalias = !quant.uniform // ----- -// Scale f16 underflow in per-axis quantization -// expected-error@+1 {{scale 5.800000e-08 out of expressed type range}} -!qalias = !quant.uniform +// Scale f16 negative overflow +// expected-error@+1 {{scale -6.600000e+04 out of expressed type range}} +!qalias = !quant.uniform // ----- // Scale f16 overflow in per-axis quantization // expected-error@+1 {{scale 6.600000e+04 out of expressed type range}} !qalias = !quant.uniform +// ----- +// Scale f16 negative overflow in per-axis quantization +// expected-error@+1 {{scale -6.600000e+04 out of expressed type range}} +!qalias = !quant.uniform + // ----- // Illegal negative axis in sub-channel quantization // expected-error@+1 {{illegal quantized dimension: -1}} @@ -233,5 +258,5 @@ // Scale out of expressed type range in sub-channel quantization // expected-error@+2 {{scale 6.600000e+04 out of expressed type range}} !qalias = !quant.uniform + {{6.6e4:120,9.987200e-01:127}, {2.000000e+02:255,9.987200e-01}}> diff --git a/mlir/test/Dialect/Quant/parse-uniform.mlir b/mlir/test/Dialect/Quant/parse-uniform.mlir index 80a6621ed6979..e7cee25c6413d 100644 --- a/mlir/test/Dialect/Quant/parse-uniform.mlir +++ b/mlir/test/Dialect/Quant/parse-uniform.mlir @@ -19,6 +19,42 @@ func.func @parse() -> !qalias { return %0 : !qalias } +// ----- +// Default min/max value optimization for integers. +// CHECK: !quant.uniform +!qalias = !quant.uniform:f32, 0.99872:127 > +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Default min/max value optimization for f8E5M2. +// CHECK: !quant.uniform +!qalias = !quant.uniform:f32, 0.99872:127 > +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Default min/max value optimization for f8E4M3FN. +// CHECK: !quant.uniform +!qalias = !quant.uniform:f32, 0.99872:127 > +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Default min/max value optimization for f4E2M1FN. +// CHECK: !quant.uniform +!qalias = !quant.uniform:f32, 0.99872:127 > +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + // ----- // Required per-layer params specified: // [unsigned] storageType, expressedType, scale @@ -47,6 +83,33 @@ func.func @parse() -> !qalias { return %0 : !qalias } +// ----- +// Storage type: f8E5M2 +// CHECK: !quant.uniform +!qalias = !quant.uniform +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Storage type: f8E4M3FN +// CHECK: !quant.uniform +!qalias = !quant.uniform +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Storage type: f4E2M1FN +// CHECK: !quant.uniform +!qalias = !quant.uniform +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + // ----- // Storage type: i16 // CHECK: !quant.uniform @@ -172,3 +235,21 @@ func.func @parse() -> !qalias { %0 = "foo"() : () -> !qalias return %0 : !qalias } + +// ----- +// Negative scale checking +// CHECK: !quant.uniform +!qalias = !quant.uniform +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Per axis negative scale checking +// CHECK: !quant.uniform +!qalias = !quant.uniform +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} diff --git a/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir b/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir index f5d6a08b7de31..515de5502f322 100644 --- a/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir +++ b/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir @@ -86,3 +86,47 @@ func.func @while(%arg0: tensor<1024xf32, #SparseVector>, %c: i1) -> tensor<1024x } return %0: tensor<1024xf32, #SparseVector> } + +// CHECK-LABEL: func.func @index_switch( +// CHECK-SAME: %[[PRED:.*0]]: index, +// CHECK-SAME: %[[VAL_A_1:.*1]]: memref, +// CHECK-SAME: %[[VAL_A_2:.*2]]: memref, +// CHECK-SAME: %[[VAL_A_3:.*3]]: memref, +// CHECK-SAME: %[[VAL_A_4:.*4]]: !sparse_tensor.storage_specifier +// CHECK-SAME: %[[VAL_B_1:.*5]]: memref, +// CHECK-SAME: %[[VAL_B_2:.*6]]: memref, +// CHECK-SAME: %[[VAL_B_3:.*7]]: memref, +// CHECK-SAME: %[[VAL_B_4:.*8]]: !sparse_tensor.storage_specifier +// CHECK-SAME: %[[VAL_C_1:.*9]]: memref, +// CHECK-SAME: %[[VAL_C_2:.*10]]: memref, +// CHECK-SAME: %[[VAL_C_3:.*11]]: memref, +// CHECK-SAME: %[[VAL_C_4:.*12]]: !sparse_tensor.storage_specifier + +// CHECK: %[[RES:.*]]:4 = scf.index_switch %[[PRED]] +// CHECK-SAME: -> memref, memref, memref, !sparse_tensor.storage_specifier +// CHECK: case 1 { +// CHECK: scf.yield %[[VAL_A_1]], %[[VAL_A_2]], %[[VAL_A_3]], %[[VAL_A_4]] +// CHECK: case 2 { +// CHECK: scf.yield %[[VAL_B_1]], %[[VAL_B_2]], %[[VAL_B_3]], %[[VAL_B_4]] +// CHECK: default { +// CHECK: scf.yield %[[VAL_C_1]], %[[VAL_C_2]], %[[VAL_C_3]], %[[VAL_C_4]] + +// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2, %[[RES]]#3 : +// CHECK-SAME: memref, memref, memref, !sparse_tensor.storage_specifier + +func.func @index_switch(%pred: index, %a: tensor<5xf32, #SparseVector>, + %b: tensor<5xf32, #SparseVector>, + %c: tensor<5xf32, #SparseVector>) -> tensor<5xf32, #SparseVector> { + %0 = scf.index_switch %pred -> tensor<5xf32, #SparseVector> + case 1 { + scf.yield %a : tensor<5xf32, #SparseVector> + } + case 2 { + scf.yield %b : tensor<5xf32, #SparseVector> + } + default { + scf.yield %c : tensor<5xf32, #SparseVector> + } + + return %0 : tensor<5xf32, #SparseVector> +} diff --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir index 71a260f1196e9..1e2f7ad69ab8b 100644 --- a/mlir/test/Dialect/Transform/ops-invalid.mlir +++ b/mlir/test/Dialect/Transform/ops-invalid.mlir @@ -482,9 +482,10 @@ module { // ----- module attributes { transform.with_named_sequence} { + // expected-note @below {{ancestor transform op}} transform.sequence failures(suppress) { ^bb0(%arg0: !transform.any_op): - // expected-error @below {{op symbol's parent must have the SymbolTable trai}} + // expected-error @below {{cannot be defined inside another transform op}} transform.named_sequence @nested() { transform.yield } diff --git a/mlir/test/IR/invalid-func-op.mlir b/mlir/test/IR/invalid-func-op.mlir index 8fd7af22e9598..d995689ebb8d0 100644 --- a/mlir/test/IR/invalid-func-op.mlir +++ b/mlir/test/IR/invalid-func-op.mlir @@ -31,7 +31,7 @@ func.func @func_op() { // ----- func.func @func_op() { - // expected-error@+1 {{op symbol's parent must have the SymbolTable trait}} + // expected-error@+1 {{entry block must have 1 arguments to match function signature}} func.func @mixed_named_arguments(f32) { ^entry: return @@ -42,7 +42,7 @@ func.func @func_op() { // ----- func.func @func_op() { - // expected-error@+1 {{op symbol's parent must have the SymbolTable trait}} + // expected-error@+1 {{type of entry block argument #0('i32') must match the type of the corresponding argument in function signature('f32')}} func.func @mixed_named_arguments(f32) { ^entry(%arg : i32): return diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir index d0bbf8669b63d..d4a035654d34a 100644 --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -699,3 +699,19 @@ func.func @drop_references_on_block_parse_error(){ }) : () -> () return } + +// ----- + +func.func @foo() { + cf.br ^bb2 + + ^bb1: + // expected-error@+1 {{forward reference of value '%1' requires explicit type specification}} + test.format_operand_optional_type_op %0, %1 + return + + ^bb2: + %0 = arith.constant 0 : i64 + %1 = memref.alloc() : memref<1xf64> + cf.br ^bb1 +} diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir index 3bb6e38b4d613..a99f1cab01354 100644 --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -1484,3 +1484,13 @@ test.dialect_custom_format_fallback custom_format_fallback // Check that an op with an optional result parses f80 as type. // CHECK: test.format_optional_result_d_op : f80 test.format_optional_result_d_op : f80 + +// Can skip type definition for operands, if they are already defined in the same block +// CHECK-LABEL: func @optional_operand_types +func.func @optional_operand_types(%arg0: i64, %arg1: memref<1xf64>) { + // CHECK: test.format_operand_optional_type_op %arg0, %arg1 + test.format_operand_optional_type_op %arg0, %arg1 + // CHECK: test.format_operand_optional_type_op %arg0, %arg1 + test.format_operand_optional_type_op %arg0, %arg1 : memref<1xf64> + return +} diff --git a/mlir/test/IR/region.mlir b/mlir/test/IR/region.mlir index e2088817c5204..58e410af41c72 100644 --- a/mlir/test/IR/region.mlir +++ b/mlir/test/IR/region.mlir @@ -87,17 +87,18 @@ func.func @named_region_has_wrong_number_of_blocks() { // CHECK: test.single_no_terminator_op "test.single_no_terminator_op"() ( { - %foo = arith.constant 1 : i32 + func.func @foo1() { return } + func.func @foo2() { return } } ) : () -> () // CHECK: test.variadic_no_terminator_op "test.variadic_no_terminator_op"() ( { - %foo = arith.constant 1 : i32 + func.func @foo1() { return } }, { - %bar = arith.constant 1 : i32 + func.func @foo2() { return } } ) : () -> () diff --git a/mlir/test/IR/test-symbol-rauw.mlir b/mlir/test/IR/test-symbol-rauw.mlir index ba17cf9d10426..1acd834fcde8e 100644 --- a/mlir/test/IR/test-symbol-rauw.mlir +++ b/mlir/test/IR/test-symbol-rauw.mlir @@ -51,11 +51,12 @@ module { } } - // CHECK: func @symbol_bar + // FIXME:#73140 + // DISABLED-CHECK: func @symbol_bar func.func @symbol_bar() { - // CHECK: foo.op - // CHECK-SAME: use_1 = @module_a::@replaced_foo - // CHECK-SAME: use_2 = @replaced_module_b::@replaced_module_c::@replaced_foo + // DISABLED-CHECK: foo.op + // DISABLED-CHECK-SAME: use_1 = @module_a::@replaced_foo + // DISABLED-CHECK-SAME: use_2 = @replaced_module_b::@replaced_module_c::@replaced_foo "foo.op"() { use_1 = @module_a::@foo, use_2 = @module_b::@module_c::@foo @@ -97,15 +98,16 @@ module { // ----- +// FIXME:#73140 module { - // CHECK: module @replaced_foo + // DISABLED-CHECK: module @replaced_foo module @foo attributes {sym.new_name = "replaced_foo" } { - // CHECK: func.func private @foo + // DISABLED-CHECK: func.func private @foo func.func private @foo() } - // CHECK: foo.op - // CHECK-SAME: use = @replaced_foo::@foo + // DISABLED-CHECK: foo.op + // DISABLED-CHECK-SAME: use = @replaced_foo::@foo "foo.op"() { use = @foo::@foo } : () -> () diff --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir index 49cfd7e496746..85deff0386900 100644 --- a/mlir/test/IR/traits.mlir +++ b/mlir/test/IR/traits.mlir @@ -591,13 +591,15 @@ func.func @failedHasDominanceScopeOutsideDominanceFreeScope() -> () { // Ensure that SSACFG regions of operations in GRAPH regions are // checked for dominance -func.func @illegalInsideDominanceFreeScope(%cond: i1) -> () { +func.func @illegalInsideDominanceFreeScope() -> () { test.graph_region { - scf.if %cond { + func.func @test() -> i1 { + ^bb1: // expected-error @+1 {{operand #0 does not dominate this use}} %2:3 = "bar"(%1) : (i64) -> (i1,i1,i1) // expected-note @+1 {{operand defined here}} - %1 = "baz"(%2#0) : (i1) -> (i64) + %1 = "baz"(%2#0) : (i1) -> (i64) + return %2#1 : i1 } "terminator"() : () -> () } @@ -608,21 +610,20 @@ func.func @illegalInsideDominanceFreeScope(%cond: i1) -> () { // Ensure that SSACFG regions of operations in GRAPH regions are // checked for dominance -func.func @illegalCFGInsideDominanceFreeScope(%cond: i1) -> () { +func.func @illegalCDFGInsideDominanceFreeScope() -> () { test.graph_region { - scf.if %cond { - "test.ssacfg_region"() ({ - ^bb1: - // expected-error @+1 {{operand #0 does not dominate this use}} - %2:3 = "bar"(%1) : (i64) -> (i1,i1,i1) - cf.br ^bb4 - ^bb2: - cf.br ^bb2 - ^bb4: - %1 = "foo"() : ()->i64 // expected-note {{operand defined here}} - }) : () -> () + func.func @test() -> i1 { + ^bb1: + // expected-error @+1 {{operand #0 does not dominate this use}} + %2:3 = "bar"(%1) : (i64) -> (i1,i1,i1) + cf.br ^bb4 + ^bb2: + cf.br ^bb2 + ^bb4: + %1 = "foo"() : ()->i64 // expected-note {{operand defined here}} + return %2#1 : i1 } - "terminator"() : () -> () + "terminator"() : () -> () } return } diff --git a/mlir/test/Transforms/canonicalize-dce.mlir b/mlir/test/Transforms/canonicalize-dce.mlir index 84631947970de..e8608a3aa3e69 100644 --- a/mlir/test/Transforms/canonicalize-dce.mlir +++ b/mlir/test/Transforms/canonicalize-dce.mlir @@ -77,15 +77,15 @@ func.func @f(%arg0: f32, %pred: i1) { // Test case: Recursively DCE into enclosed regions. -// CHECK: func.func @f(%arg0: f32) -// CHECK-NOT: arith.addf +// CHECK: func @f(%arg0: f32) +// CHECK-NEXT: func @g(%arg1: f32) +// CHECK-NEXT: return func.func @f(%arg0: f32) { - "test.region"() ( - { - %0 = "arith.addf"(%arg0, %arg0) : (f32, f32) -> f32 - } - ) : () -> () + func.func @g(%arg1: f32) { + %0 = "arith.addf"(%arg1, %arg1) : (f32, f32) -> f32 + return + } return } diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir index 8e02c06a0a293..eff2cd609b51c 100644 --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -429,15 +429,16 @@ func.func @write_only_alloca_fold(%v: f32) { // CHECK-LABEL: func @dead_block_elim func.func @dead_block_elim() { // CHECK-NOT: ^bb - builtin.module { - func.func @nested() { - return + func.func @nested() { + return - ^bb1: - return - } + ^bb1: + return } return + +^bb1: + return } // CHECK-LABEL: func @dyn_shape_fold(%arg0: index, %arg1: index) diff --git a/mlir/test/Transforms/constant-fold.mlir b/mlir/test/Transforms/constant-fold.mlir index 0b393bf0556b9..eb530dace319f 100644 --- a/mlir/test/Transforms/constant-fold.mlir +++ b/mlir/test/Transforms/constant-fold.mlir @@ -794,15 +794,12 @@ func.func @cmpf_inf() -> (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1 // CHECK-LABEL: func @nested_isolated_region func.func @nested_isolated_region() { - // CHECK-NEXT: builtin.module { // CHECK-NEXT: func @isolated_op // CHECK-NEXT: arith.constant 2 - builtin.module { - func.func @isolated_op() { - %0 = arith.constant 1 : i32 - %2 = arith.addi %0, %0 : i32 - "foo.yield"(%2) : (i32) -> () - } + func.func @isolated_op() { + %0 = arith.constant 1 : i32 + %2 = arith.addi %0, %0 : i32 + "foo.yield"(%2) : (i32) -> () } // CHECK: "foo.unknown_region" diff --git a/mlir/test/Transforms/cse.mlir b/mlir/test/Transforms/cse.mlir index b447094874d01..84fb9c4591de5 100644 --- a/mlir/test/Transforms/cse.mlir +++ b/mlir/test/Transforms/cse.mlir @@ -252,14 +252,11 @@ func.func @nested_isolated() -> i32 { // CHECK-NEXT: arith.constant 1 %0 = arith.constant 1 : i32 - // CHECK-NEXT: builtin.module // CHECK-NEXT: @nested_func - builtin.module { - func.func @nested_func() { - // CHECK-NEXT: arith.constant 1 - %foo = arith.constant 1 : i32 - "foo.yield"(%foo) : (i32) -> () - } + func.func @nested_func() { + // CHECK-NEXT: arith.constant 1 + %foo = arith.constant 1 : i32 + "foo.yield"(%foo) : (i32) -> () } // CHECK: "foo.region" diff --git a/mlir/test/Transforms/test-legalizer-full.mlir b/mlir/test/Transforms/test-legalizer-full.mlir index 5f1148cac6501..74f312e8144a0 100644 --- a/mlir/test/Transforms/test-legalizer-full.mlir +++ b/mlir/test/Transforms/test-legalizer-full.mlir @@ -37,11 +37,9 @@ func.func @recursively_legal_invalid_op() { } /// Operation that is dynamically legal, i.e. the function has a pattern /// applied to legalize the argument type before it becomes recursively legal. - builtin.module { - func.func @dynamic_func(%arg: i64) attributes {test.recursively_legal} { - %ignored = "test.illegal_op_f"() : () -> (i32) - "test.return"() : () -> () - } + func.func @dynamic_func(%arg: i64) attributes {test.recursively_legal} { + %ignored = "test.illegal_op_f"() : () -> (i32) + "test.return"() : () -> () } "test.return"() : () -> () diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td index 382da592d0079..f98250c40f384 100644 --- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td +++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td @@ -347,6 +347,7 @@ def TestCopyCount : Test_Attr<"TestCopyCount"> { let mnemonic = "copy_count"; let parameters = (ins TestParamCopyCount:$copy_count); let assemblyFormat = "`<` $copy_count `>`"; + let genVerifyDecl = 1; } def TestConditionalAliasAttr : Test_Attr<"TestConditionalAlias"> { @@ -438,4 +439,10 @@ def TestCustomStorageCtorAttr : Test_Attr<"TestCustomStorageCtorAttr"> { let hasStorageCustomConstructor = 1; } +def TestAttrNewlineAndIndent : Test_Attr<"TestAttrNewlineAndIndent"> { + let mnemonic = "newline_and_indent"; + let parameters = (ins "::mlir::Type":$indentType); + let hasCustomAssemblyFormat = 1; +} + #endif // TEST_ATTRDEFS diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp index b31e90fc9ca91..8e28447a6ec61 100644 --- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp +++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp @@ -213,6 +213,16 @@ static void printTrueFalse(AsmPrinter &p, std::optional result) { p << (*result ? "true" : "false"); } +//===----------------------------------------------------------------------===// +// TestCopyCountAttr Implementation +//===----------------------------------------------------------------------===// + +LogicalResult TestCopyCountAttr::verify( + llvm::function_ref<::mlir::InFlightDiagnostic()> /*emitError*/, + CopyCount /*copy_count*/) { + return success(); +} + //===----------------------------------------------------------------------===// // CopyCountAttr Implementation //===----------------------------------------------------------------------===// @@ -420,6 +430,30 @@ bool TestConstMemorySpaceAttr::isValidPtrIntCast( return false; } +//===----------------------------------------------------------------------===// +// TestAttrNewlineAndIndent +//===----------------------------------------------------------------------===// + +Attribute TestAttrNewlineAndIndentAttr::parse(::mlir::AsmParser &parser, + ::mlir::Type type) { + Type indentType; + if (parser.parseLess() || parser.parseType(indentType) || + parser.parseGreater()) { + return Attribute(); + } + return get(parser.getContext(), indentType); +} + +void TestAttrNewlineAndIndentAttr::print(::mlir::AsmPrinter &printer) const { + printer << "<"; + printer.increaseIndent(); + printer.printNewline(); + printer << getIndentType(); + printer.decreaseIndent(); + printer.printNewline(); + printer << ">"; +} + //===----------------------------------------------------------------------===// // Tablegen Generated Definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp b/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp index 6d4e5e3cb4401..e8ec888502e53 100644 --- a/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp +++ b/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp @@ -153,6 +153,12 @@ static ParseResult parseCustomDirectiveOptionalOperandRef( bool expectedOptionalOperand = operandCount == 0; return success(expectedOptionalOperand != optOperand.has_value()); } +static ParseResult parseOptionalType(OpAsmParser &parser, Type &type) { + if (parser.parseOptionalColon()) + return success(); + + return parser.parseType(type); +} //===----------------------------------------------------------------------===// // Printing @@ -234,6 +240,21 @@ static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer, Value optOperand) { printer << (optOperand ? "1" : "0"); } +static bool isDefinedAbove(Value val, Operation *op) { + if (mlir::isa(val)) + return true; + + return val.getDefiningOp()->getBlock() == op->getBlock() && + val.getDefiningOp()->isBeforeInBlock(op); +} +static void printOptionalType(OpAsmPrinter &printer, + FormatOperandOptionalTypeOp op, Type type) { + if (isDefinedAbove(op.getOperand(), op)) + return; + + printer << ":"; + printer.printType(type); +} //===----------------------------------------------------------------------===// // Test parser. //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOpsSyntax.td b/mlir/test/lib/Dialect/Test/TestOpsSyntax.td index 9a96fe8b11c14..3af6f187c4f3b 100644 --- a/mlir/test/lib/Dialect/Test/TestOpsSyntax.td +++ b/mlir/test/lib/Dialect/Test/TestOpsSyntax.td @@ -309,6 +309,9 @@ def FormatOperandDOp : FormatOperandBase<"d", [{ def FormatOperandEOp : FormatOperandBase<"e", [{ $buildable `,` $operand `:` type($buildable) `,` type($operand) attr-dict }]>; +def FormatOperandOptionalTypeOp : FormatOperandBase<"optional_type", [{ + $buildable `,` $operand custom(type($operand)) attr-dict +}]>; def FormatSuccessorAOp : TEST_Op<"format_successor_a_op", [Terminator]> { let successors = (successor VariadicSuccessor:$targets); diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td index ea20597231d58..070ad0e6e56c4 100644 --- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -470,4 +470,9 @@ def TestMemrefType : Test_Type<"TestMemref", }]; } +def TestTypeNewlineAndIndent : Test_Type<"TestTypeNewlineAndIndent"> { + let mnemonic = "newline_and_indent"; + let hasCustomAssemblyFormat = 1; +} + #endif // TEST_TYPEDEFS diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp index bea043f56fe21..7d33b89e76760 100644 --- a/mlir/test/lib/Dialect/Test/TestTypes.cpp +++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp @@ -565,11 +565,39 @@ TestTensorType::getBufferType( ::mlir::LogicalResult TestTensorType::verifyCompatibleBufferType( ::mlir::bufferization::BufferLikeType bufferType, ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { - auto testMemref = dyn_cast(bufferType); - if (!testMemref) - return emitError() << "expected TestMemrefType"; + if (auto testMemref = dyn_cast(bufferType)) { + const bool valid = getShape() == testMemref.getShape() && + getElementType() == testMemref.getElementType(); + return mlir::success(valid); + } + + if (auto builtinMemref = dyn_cast(bufferType)) { + const bool valid = getShape() == builtinMemref.getShape() && + getElementType() == builtinMemref.getElementType(); + return mlir::success(valid); + } - const bool valid = getShape() == testMemref.getShape() && - getElementType() == testMemref.getElementType(); - return mlir::success(valid); + return emitError() << "expected MemRefType or TestMemrefType"; +} + +//===----------------------------------------------------------------------===// +// TestTypeNewlineAndIndent +//===----------------------------------------------------------------------===// + +Type TestTypeNewlineAndIndentType::parse(::mlir::AsmParser &parser) { + if (parser.parseLess() || parser.parseKeyword("indented_content") || + parser.parseGreater()) { + return Type(); + } + return get(parser.getContext()); +} + +void TestTypeNewlineAndIndentType::print(::mlir::AsmPrinter &printer) const { + printer << "<"; + printer.increaseIndent(); + printer.printNewline(); + printer << "indented_content"; + printer.decreaseIndent(); + printer.printNewline(); + printer << ">"; } diff --git a/mlir/test/mlir-tblgen/attrdefs.td b/mlir/test/mlir-tblgen/attrdefs.td index d47411d6e860a..a809611fd0aec 100644 --- a/mlir/test/mlir-tblgen/attrdefs.td +++ b/mlir/test/mlir-tblgen/attrdefs.td @@ -115,6 +115,11 @@ def B_CompoundAttrA : TestAttr<"CompoundA"> { // DEF: return new (allocator.allocate()) // DEF-SAME: CompoundAAttrStorage(std::move(widthOfSomething), std::move(exampleTdType), std::move(apFloat), std::move(dims), std::move(inner)); +// DEF: CompoundAAttr CompoundAAttr::getChecked( +// DEF-SAME: int widthOfSomething, ::test::SimpleTypeA exampleTdType, ::llvm::APFloat apFloat, ::llvm::ArrayRef dims, ::mlir::Type inner +// DEF-SAME: ) +// DEF-NEXT: return Base::getChecked(emitError, context, std::move(widthOfSomething), std::move(exampleTdType), std::move(apFloat), std::move(dims), std::move(inner)); + // DEF: ::mlir::Type CompoundAAttr::getInner() const { // DEF-NEXT: return getImpl()->inner; } diff --git a/mlir/test/mlir-tblgen/op-decl-and-defs.td b/mlir/test/mlir-tblgen/op-decl-and-defs.td index f213f50ae2f39..351f11f8c9082 100644 --- a/mlir/test/mlir-tblgen/op-decl-and-defs.td +++ b/mlir/test/mlir-tblgen/op-decl-and-defs.td @@ -236,14 +236,14 @@ def NS_FOp : NS_Op<"op_with_all_types_constraint", // DEFS: FOp FOp::create(::mlir::OpBuilder &builder, ::mlir::Location location, ::mlir::Value a) { // DEFS: ::mlir::OperationState __state__(location, getOperationName()); -// DEFS: build(builder, __state__, a); +// DEFS: build(builder, __state__, std::forward(a)); // DEFS: auto __res__ = ::llvm::dyn_cast(builder.create(__state__)); // DEFS: assert(__res__ && "builder didn't return the right type"); // DEFS: return __res__; // DEFS: } // DEFS: FOp FOp::create(::mlir::ImplicitLocOpBuilder &builder, ::mlir::Value a) { -// DEFS: return create(builder, builder.getLoc(), a); +// DEFS: return create(builder, builder.getLoc(), std::forward(a)); // DEFS: } def NS_GOp : NS_Op<"op_with_fixed_return_type", []> { diff --git a/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir b/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir index 89ad3594eebd8..dca46a30e37f1 100644 --- a/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir +++ b/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -split-input-file -verify-diagnostics | FileCheck %s +// RUN: mlir-opt %s -split-input-file -verify-diagnostics | FileCheck %s --strict-whitespace // CHECK-LABEL: func private @compoundA() // CHECK-SAME: #test.cmpnd_a<1, !test.smpla, [5, 6]> @@ -44,3 +44,19 @@ func.func private @hexdecimalInteger() attributes { // expected-error @below {{expected an integer}} sdg = #test.decimal_shape<1x0xb> } + +// ----- + +// CHECK-LABEL: @newlineAndIndent +// CHECK-SAME: indent = #test.newline_and_indent< +// CHECK-NEXT: {{^ }}!test.newline_and_indent< +// CHECK-NEXT: {{^ }}indented_content +// CHECK-NEXT: {{^ }}> +// CHECK-NEXT: {{^ }}> +func.func private @newlineAndIndent() attributes { + indent = #test.newline_and_indent< + !test.newline_and_indent< + indented_content + > + > +} diff --git a/mlir/test/mlir-tblgen/testdialect-typedefs.mlir b/mlir/test/mlir-tblgen/testdialect-typedefs.mlir index 18175edc81cf0..c00d368fdab8b 100644 --- a/mlir/test/mlir-tblgen/testdialect-typedefs.mlir +++ b/mlir/test/mlir-tblgen/testdialect-typedefs.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s | mlir-opt -verify-diagnostics | FileCheck %s +// RUN: mlir-opt %s | mlir-opt -verify-diagnostics | FileCheck %s --strict-whitespace ////////////// // Tests the types in the 'Test' dialect, not the ones in 'typedefs.mlir' @@ -42,3 +42,13 @@ func.func @testInt(%A : !test.int, %B : !test.int, %C : !test func.func @structTest (%A : !test.struct< {field1, !test.smpla}, {field2, !test.int} > ) { return } + +// CHECK-LABEL: @newlineAndIndent +// CHECK-SAME: !test.newline_and_indent< +// CHECK-NEXT: {{^ }}indented_content +// CHECK-NEXT: {{^ }}> +func.func @newlineAndIndent(%A : !test.newline_and_indent< + indented_content +>) { + return +} diff --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py index 4a241afb8e89d..5f30da5a61f63 100644 --- a/mlir/test/python/ir/value.py +++ b/mlir/test/python/ir/value.py @@ -237,15 +237,28 @@ def testValuePrintAsOperand(): print(value2) topFn = func.FuncOp("test", ([i32, i32], [])) - entry_block = Block.create_at_start(topFn.operation.regions[0], [i32, i32]) + entry_block1 = Block.create_at_start(topFn.operation.regions[0], [i32, i32]) - with InsertionPoint(entry_block): + with InsertionPoint(entry_block1): value3 = Operation.create("custom.op3", results=[i32]).results[0] # CHECK: Value(%[[VAL3:.*]] = "custom.op3"() : () -> i32) print(value3) value4 = Operation.create("custom.op4", results=[i32]).results[0] # CHECK: Value(%[[VAL4:.*]] = "custom.op4"() : () -> i32) print(value4) + + f = func.FuncOp("test", ([i32, i32], [])) + entry_block2 = Block.create_at_start(f.operation.regions[0], [i32, i32]) + with InsertionPoint(entry_block2): + value5 = Operation.create("custom.op5", results=[i32]).results[0] + # CHECK: Value(%[[VAL5:.*]] = "custom.op5"() : () -> i32) + print(value5) + value6 = Operation.create("custom.op6", results=[i32]).results[0] + # CHECK: Value(%[[VAL6:.*]] = "custom.op6"() : () -> i32) + print(value6) + + func.ReturnOp([]) + func.ReturnOp([]) # CHECK: %[[VAL1]] @@ -272,10 +285,20 @@ def testValuePrintAsOperand(): # CHECK: %1 print(value4.get_name(use_local_scope=True)) + # CHECK: %[[VAL5]] + print(value5.get_name()) + # CHECK: %[[VAL6]] + print(value6.get_name()) + # CHECK: %[[ARG0:.*]] - print(entry_block.arguments[0].get_name()) + print(entry_block1.arguments[0].get_name()) # CHECK: %[[ARG1:.*]] - print(entry_block.arguments[1].get_name()) + print(entry_block1.arguments[1].get_name()) + + # CHECK: %[[ARG2:.*]] + print(entry_block2.arguments[0].get_name()) + # CHECK: %[[ARG3:.*]] + print(entry_block2.arguments[1].get_name()) # CHECK: module { # CHECK: %[[VAL1]] = "custom.op1"() : () -> i32 @@ -283,6 +306,11 @@ def testValuePrintAsOperand(): # CHECK: func.func @test(%[[ARG0]]: i32, %[[ARG1]]: i32) { # CHECK: %[[VAL3]] = "custom.op3"() : () -> i32 # CHECK: %[[VAL4]] = "custom.op4"() : () -> i32 + # CHECK: func @test(%[[ARG2]]: i32, %[[ARG3]]: i32) { + # CHECK: %[[VAL5]] = "custom.op5"() : () -> i32 + # CHECK: %[[VAL6]] = "custom.op6"() : () -> i32 + # CHECK: return + # CHECK: } # CHECK: return # CHECK: } # CHECK: } diff --git a/mlir/tools/mlir-src-sharder/mlir-src-sharder.cpp b/mlir/tools/mlir-src-sharder/mlir-src-sharder.cpp index 9ec9b72150207..9271bc4bd7dcf 100644 --- a/mlir/tools/mlir-src-sharder/mlir-src-sharder.cpp +++ b/mlir/tools/mlir-src-sharder/mlir-src-sharder.cpp @@ -72,6 +72,12 @@ int main(int argc, char **argv) { "Name of the macro to be defined -- ignored by mlir-src-sharder"), llvm::cl::value_desc("macro name"), llvm::cl::Prefix); + // CMake/TableGen pass this flag, re-registering after ResetCommandLineParser + // avoids "unknown argument" errors. + llvm::cl::opt noWarnOnUnusedTemplateArg( + "no-warn-on-unused-template-args", + llvm::cl::desc("Disable unused template argument warnings.")); + llvm::InitLLVM y(argc, argv); llvm::cl::ParseCommandLineOptions(argc, argv); diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp index dbae2143b920a..3140f12c0b7e8 100644 --- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -495,7 +495,7 @@ void DefGen::emitCheckedBuilder() { MethodBody &body = m->body().indent(); auto scope = body.scope("return Base::getChecked(emitError, context", ");"); for (const auto ¶m : params) - body << ", " << param.getName(); + body << ", std::move(" << param.getName() << ")"; } static SmallVector diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp index 10a162f81ba13..ace73db03f569 100644 --- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp @@ -958,9 +958,7 @@ void DefFormat::genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx, void DefFormat::genWhitespacePrinter(WhitespaceElement *el, FmtContext &ctx, MethodBody &os) { if (el->getValue() == "\\n") { - // FIXME: The newline should be `printer.printNewLine()`, i.e., handled by - // the printer. - os << tgfmt("$_printer << '\\n';\n", &ctx); + os << tgfmt("$_printer.printNewline();\n", &ctx); } else if (!el->getValue().empty()) { os << tgfmt("$_printer << \"$0\";\n", &ctx, el->getValue()); } else { diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index f35cfa6826388..b5dc6202724b6 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -2616,7 +2616,14 @@ void OpEmitter::genInlineCreateBody( std::string nonBuilderStateArgs = ""; if (!nonBuilderStateArgsList.empty()) { llvm::raw_string_ostream nonBuilderStateArgsOS(nonBuilderStateArgs); - interleaveComma(nonBuilderStateArgsList, nonBuilderStateArgsOS); + interleave( + nonBuilderStateArgsList, + [&](StringRef name) { + nonBuilderStateArgsOS << "std::forward(" + << name << ')'; + }, + [&] { nonBuilderStateArgsOS << ", "; }); + nonBuilderStateArgs = ", " + nonBuilderStateArgs; } cWithLoc->body() << llvm::formatv(inlineCreateBody, locParamName, diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp index a55592db7132d..fd40404bf3008 100644 --- a/mlir/unittests/IR/AttributeTest.cpp +++ b/mlir/unittests/IR/AttributeTest.cpp @@ -477,8 +477,9 @@ TEST(SubElementTest, Nested) { {strAttr, trueAttr, falseAttr, boolArrayAttr, dictAttr})); } -// Test how many times we call copy-ctor when building an attribute. -TEST(CopyCountAttr, CopyCount) { +// Test how many times we call copy-ctor when building an attribute with the +// 'get' method. +TEST(CopyCountAttr, CopyCountGet) { MLIRContext context; context.loadDialect(); @@ -489,15 +490,35 @@ TEST(CopyCountAttr, CopyCount) { test::CopyCount::counter = 0; test::TestCopyCountAttr::get(&context, std::move(copyCount)); #ifndef NDEBUG - // One verification enabled only in assert-mode requires a copy. - EXPECT_EQ(counter1, 1); - EXPECT_EQ(test::CopyCount::counter, 1); + // One verification enabled only in assert-mode requires two copies: one for + // calling 'verifyInvariants' and one for calling 'verify' inside + // 'verifyInvariants'. + EXPECT_EQ(counter1, 2); + EXPECT_EQ(test::CopyCount::counter, 2); #else EXPECT_EQ(counter1, 0); EXPECT_EQ(test::CopyCount::counter, 0); #endif } +// Test how many times we call copy-ctor when building an attribute with the +// 'getChecked' method. +TEST(CopyCountAttr, CopyCountGetChecked) { + MLIRContext context; + context.loadDialect(); + test::CopyCount::counter = 0; + test::CopyCount copyCount("hello"); + auto loc = UnknownLoc::get(&context); + test::TestCopyCountAttr::getChecked(loc, &context, std::move(copyCount)); + int counter1 = test::CopyCount::counter; + test::CopyCount::counter = 0; + test::TestCopyCountAttr::getChecked(loc, &context, std::move(copyCount)); + // The verifiers require two copies: one for calling 'verifyInvariants' and + // one for calling 'verify' inside 'verifyInvariants'. + EXPECT_EQ(counter1, 2); + EXPECT_EQ(test::CopyCount::counter, 2); +} + // Test stripped printing using test dialect attribute. TEST(CopyCountAttr, PrintStripped) { MLIRContext context; diff --git a/mlir/unittests/Pass/CMakeLists.txt b/mlir/unittests/Pass/CMakeLists.txt index a47d2eead6180..e5dfd75606458 100644 --- a/mlir/unittests/Pass/CMakeLists.txt +++ b/mlir/unittests/Pass/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_unittest(MLIRPassTests AnalysisManagerTest.cpp PassManagerTest.cpp PassPipelineParserTest.cpp + PassPipelineOptionsTest.cpp ) mlir_target_link_libraries(MLIRPassTests PRIVATE diff --git a/mlir/unittests/Pass/PassPipelineOptionsTest.cpp b/mlir/unittests/Pass/PassPipelineOptionsTest.cpp new file mode 100644 index 0000000000000..2e586c3c40887 --- /dev/null +++ b/mlir/unittests/Pass/PassPipelineOptionsTest.cpp @@ -0,0 +1,136 @@ +//===- PassPipelineParserTest.cpp - Pass Parser unit tests ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "gtest/gtest.h" + +#include + +using namespace mlir; +using namespace mlir::detail; + +namespace { + +// these types are used for automatically generated code of pass +using StrPassOpt = ::mlir::Pass::Option; +using IntPassOpt = ::mlir::Pass::Option; +using BoolPassOpt = ::mlir::Pass::Option; + +// these types are used for pipeline options that we manually pass to the +// constructor +using StrOption = mlir::detail::PassOptions::Option; +using IntOption = mlir::detail::PassOptions::Option; +using BoolOption = mlir::detail::PassOptions::Option; + +const int intOptDefaultVal = 5; +const bool boolOptDefaultVal = true; + +struct SimplePassWithOptions + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SimplePassWithOptions) + + SimplePassWithOptions() = default; + SimplePassWithOptions(const SimplePassWithOptions &other) + : PassWrapper(other) {} + + SimplePassWithOptions(const detail::PassOptions &options) { + copyOptionValuesFrom(options); + } + + LogicalResult initialize(MLIRContext *ctx) final { return success(); } + + void runOnOperation() override {} + +public: + StrPassOpt strOpt{*this, "str-opt", ::llvm::cl::desc("string test option"), + llvm::cl::init("")}; + IntPassOpt intOpt{*this, "int-opt", ::llvm::cl::desc("int test option"), + llvm::cl::init(intOptDefaultVal)}; + BoolPassOpt boolOpt{*this, "bool-opt", ::llvm::cl::desc("bool test option"), + llvm::cl::init(boolOptDefaultVal)}; +}; + +TEST(PassPipelineOptionsTest, CopyAllOptions) { + struct DuplicatedOtions : ::mlir::PassPipelineOptions { + StrOption strOpt{*this, "str-opt", ::llvm::cl::desc("string test option")}; + IntOption intOpt{*this, "int-opt", ::llvm::cl::desc("int test option"), + llvm::cl::init(intOptDefaultVal)}; + BoolOption boolOpt{*this, "bool-opt", ::llvm::cl::desc("bool test option"), + llvm::cl::init(boolOptDefaultVal)}; + }; + + const auto expectedStrVal = "test1"; + const auto expectedIntVal = -intOptDefaultVal; + const auto expectedBoolVal = !boolOptDefaultVal; + + DuplicatedOtions options; + options.strOpt.setValue(expectedStrVal); + options.intOpt.setValue(expectedIntVal); + options.boolOpt.setValue(expectedBoolVal); + + const auto &pass = std::make_unique(options); + + EXPECT_EQ(pass->strOpt.getValue(), expectedStrVal); + EXPECT_EQ(pass->intOpt.getValue(), expectedIntVal); + EXPECT_EQ(pass->boolOpt.getValue(), expectedBoolVal); +} + +TEST(PassPipelineOptionsTest, CopyMatchedOptions) { + struct SomePipelineOptions + : ::mlir::PassPipelineOptions { + StrOption strOpt{*this, "str-opt", ::llvm::cl::desc("string test option")}; + IntOption intOpt{*this, "int-opt", ::llvm::cl::desc("int test option")}; + StrOption anotherStrOpt{ + *this, "another-str-pipeline-opt", + ::llvm::cl::desc("there is no such option in SimplePassWithOptions"), + llvm::cl::init("anotherOptVal")}; + IntOption anotherIntOpt{ + *this, "another-int-pipeline-opt", + ::llvm::cl::desc("there is no such option in SimplePassWithOptions"), + llvm::cl::init(10)}; + }; + + const auto expectedStrVal = "test2"; + const auto expectedIntVal = -intOptDefaultVal; + + SomePipelineOptions options; + options.strOpt.setValue(expectedStrVal); + options.intOpt.setValue(expectedIntVal); + + const auto pass = std::make_unique(options); + + EXPECT_EQ(pass->strOpt.getValue(), expectedStrVal); + EXPECT_EQ(pass->intOpt.getValue(), expectedIntVal); + EXPECT_EQ(pass->boolOpt.getValue(), boolOptDefaultVal); +} + +TEST(PassPipelineOptionsTest, NoMatchedOptions) { + struct SomePipelineOptions + : ::mlir::PassPipelineOptions { + StrOption anotherStrOpt{ + *this, "another-str-pipeline-opt", + ::llvm::cl::desc("there is no such option in SimplePassWithOptions"), + llvm::cl::init("anotherOptVal")}; + IntOption anotherIntOpt{ + *this, "another-int-pipeline-opt", + ::llvm::cl::desc("there is no such option in SimplePassWithOptions"), + llvm::cl::init(10)}; + }; + + SomePipelineOptions options; + const auto pass = std::make_unique(options); + + EXPECT_EQ(pass->strOpt.getValue(), ""); + EXPECT_EQ(pass->intOpt.getValue(), intOptDefaultVal); + EXPECT_EQ(pass->boolOpt.getValue(), boolOptDefaultVal); +} + +} // namespace