diff --git a/BUILD.bazel b/BUILD.bazel index b7d1ab59c9..c1c964c963 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -1183,6 +1183,7 @@ cc_library( ":chlo_ops", ":chlo_rewriters_inc_gen", ":stablehlo_aggressive_simplification_inc_gen", + ":stablehlo_broadcast_lowering", ":stablehlo_create_compatibility_expander_inc_gen", ":stablehlo_create_complex_math_expander_inc_gen", ":stablehlo_legalize_deprecated_ops_inc_gen", @@ -1922,6 +1923,24 @@ cc_library( ], ) +cc_test( + name = "chlo_builder_test", + srcs = ["stablehlo/integrations/cpp/builder/ChloBuilderTest.cpp"], + deps = [ + ":attr_type_builder_util", + ":chlo_builder", + ":func_builder", + ":mlir_builder", + ":register", + ":stablehlo_builder", + ":stablehlo_ops", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@llvm-project//third-party/unittest:gmock", + "@llvm-project//third-party/unittest:gtest", + ], +) + gentbl_cc_library( name = "func_builder_inc", tbl_outs = { diff --git a/stablehlo/dialect/Base.cpp b/stablehlo/dialect/Base.cpp index 5b6d19a776..13bf183897 100644 --- a/stablehlo/dialect/Base.cpp +++ b/stablehlo/dialect/Base.cpp @@ -29,6 +29,7 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "mlir/Dialect/Quant/IR/QuantTypes.h" @@ -781,6 +782,14 @@ bool isValidQuantizedDimension(Type type) { numScales == rankedType.getDimSize(quantDim)); } +bool isBoundedDynamic(Type type) { + RankedTensorType rankedType = dyn_cast(type); + if (!rankedType) return false; + auto boundedAttr = + mlir::dyn_cast_if_present(rankedType.getEncoding()); + return boundedAttr != nullptr; +} + bool hasSingleBoundedDimension(Type type) { RankedTensorType rankedType = dyn_cast(type); auto boundedAttr = diff --git a/stablehlo/dialect/Base.h b/stablehlo/dialect/Base.h index e32def8026..05868dbb01 100644 --- a/stablehlo/dialect/Base.h +++ b/stablehlo/dialect/Base.h @@ -101,6 +101,9 @@ bool isValidStablehloQuantizedElementType(Type elementType); // mentioned in the StableHLO specification. bool isValidQuantizedDimension(Type type); +// Returns true if the given type is a bounded dynamic tensor. +bool isBoundedDynamic(Type type); + // Returns true if the given type has a single bounded dimension. bool hasSingleBoundedDimension(Type type); @@ -135,19 +138,19 @@ FailureOr inferMostSpecificType(std::optional location, LogicalResult inferMostSpecificTypeComponents( std::optional location, TypeRange inputTypes, - SmallVectorImpl &inferredReturnShapes); + SmallVectorImpl& inferredReturnShapes); // Matches a constant with integer value into int64_t. -LogicalResult matchInt(Value value, int64_t &result); +LogicalResult matchInt(Value value, int64_t& result); // Matches a constant tensor with integer values into a 1-dimensional vector. // Doesn't preserve the bitness or the signedness of the underlying values, // extracting them into int64_t. -LogicalResult matchInts(Value value, SmallVector &result); +LogicalResult matchInts(Value value, SmallVector& result); // Matches a constant tensor with integer values into a 1-dimensional vector. // Preserves the bitness and the signedness of the underlying values. -LogicalResult matchInts(Value value, SmallVector &result); +LogicalResult matchInts(Value value, SmallVector& result); // Matches a constant tensor with integer values. // Unlike the functions above, it doesn't return these values - it just checks @@ -166,8 +169,8 @@ LogicalResult matchInts(Value value); // // and returns %4 as the shape value. LogicalResult deriveShapeFromOperand( - OpBuilder *builder, Operation *op, Value operand, - SmallVectorImpl *reifiedReturnShapes); + OpBuilder* builder, Operation* op, Value operand, + SmallVectorImpl* reifiedReturnShapes); // Type derivation function that returns a tensor type with a new element type. ShapedType getSameShapeTensorType(ShapedType shapedType, Type elementType); @@ -199,15 +202,15 @@ Attribute boundsToEncoding(Attribute prototype, ArrayRef bounds); // If the attribute is valid but not all shape operands are constants, // returns failure. LogicalResult getShapeRefinements( - std::optional location, Operation *operation, - SmallVector &refinements); + std::optional location, Operation* operation, + SmallVector& refinements); // For each type in `types`, recursively flatten tuple types into `result`. // Result is populated via in-order traversal of tuple types in `types`, i.e.: // * Flattenings of individual types from `types` follow one another in the // same order as `types`. // * Same for flattenings of element types of tuple types. -void flattenTupleTypes(TypeRange types, SmallVector &result); +void flattenTupleTypes(TypeRange types, SmallVector& result); // Does the inverse of `flattenTupleTypes` - takes `types` and recursively // unflattens it, creating tuple types as needed to exactly match the structure @@ -215,7 +218,7 @@ void flattenTupleTypes(TypeRange types, SmallVector &result); // Fails if the number of elements in flattened prototype is different from // the number of elements in types. LogicalResult unflattenTupleTypes(TypeRange prototype, TypeRange types, - SmallVector &result); + SmallVector& result); ShapedType createShapedType(ShapedTypeComponents components); @@ -224,7 +227,7 @@ ShapedType createShapedType(ShapedTypeComponents components); // prettyprinting logic between them. class HloDialectInterface : public DialectInterface::Base { public: - HloDialectInterface(Dialect *dialect) : Base(dialect) {} + HloDialectInterface(Dialect* dialect) : Base(dialect) {} // Creates a TokenType type, specific to this dialect. // See docs for the particular type in the corresponding dialect. @@ -283,8 +286,8 @@ namespace bytecode { // Note this may cause issues if enums use an int64_t and have a large value. // All enums in StableHLO and CHLO currently use uint32_t. template -EnumTypeAttr readEnumAttribute(DialectBytecodeReader &reader, - MLIRContext *context, SymbolizeFn symbolizeFn) { +EnumTypeAttr readEnumAttribute(DialectBytecodeReader& reader, + MLIRContext* context, SymbolizeFn symbolizeFn) { uint64_t code; if (failed(reader.readVarInt(code))) return EnumTypeAttr(); @@ -295,7 +298,7 @@ EnumTypeAttr readEnumAttribute(DialectBytecodeReader &reader, } template -void writeEnumAttribute(EnumTypeAttr val, DialectBytecodeWriter &writer) { +void writeEnumAttribute(EnumTypeAttr val, DialectBytecodeWriter& writer) { static_assert( std::is_same::type, uint32_t>::value, @@ -311,7 +314,7 @@ void writeEnumAttribute(EnumTypeAttr val, DialectBytecodeWriter &writer) { // shape operands. The last `count` operands are assumed to be shape operands. // To be speculatable, such an op must have only static inputs and constant // shape operands. -mlir::Speculation::Speculatability getShapedSpeculatability(Operation *op, +mlir::Speculation::Speculatability getShapedSpeculatability(Operation* op, int64_t shapeCount); // Applies `fn` to `type` if it is not a `tuple` type. Otherwise, applies `fn` @@ -334,7 +337,7 @@ class PairwiseSameOperandAndResultType : public mlir::OpTrait::TraitBase { public: - static LogicalResult verifyTrait(Operation *op) { + static LogicalResult verifyTrait(Operation* op) { const int numOperands = op->getNumOperands(); const int numResults = op->getNumResults(); if (numOperands != numResults) { @@ -358,7 +361,7 @@ class PairwiseSameOperandAndResultElementType : public mlir::OpTrait::TraitBase { public: - static LogicalResult verifyTrait(Operation *op) { + static LogicalResult verifyTrait(Operation* op) { const int numOperands = op->getNumOperands(); const int numResults = op->getNumResults(); if (numOperands != numResults) { @@ -383,7 +386,7 @@ class CompatibleOperandsAndResultElementType : public mlir::OpTrait::TraitBase { public: - static LogicalResult verifyTrait(Operation *op) { + static LogicalResult verifyTrait(Operation* op) { Type expected; if (op->getNumResults() != 0) expected = op->getResult(0).getType(); if (op->getNumOperands() != 0) expected = op->getOperand(0).getType(); @@ -408,7 +411,7 @@ class CompatibleOperandsElementType : public mlir::OpTrait::TraitBase { public: - static LogicalResult verifyTrait(Operation *op) { + static LogicalResult verifyTrait(Operation* op) { if (failed(mlir::OpTrait::impl::verifyAtLeastNOperands(op, 1))) return failure(); @@ -431,7 +434,7 @@ class CompatibleOperandsAndResultType : public mlir::OpTrait::TraitBase { public: - static LogicalResult verifyTrait(Operation *op) { + static LogicalResult verifyTrait(Operation* op) { Type expected; if (op->getNumResults() != 0) expected = op->getResult(0).getType(); if (op->getNumOperands() != 0) expected = op->getOperand(0).getType(); @@ -451,10 +454,10 @@ class CompatibleOperandsAndResultType } static LogicalResult inferReturnTypes( - MLIRContext * /*context*/, std::optional location, + MLIRContext* /*context*/, std::optional location, ValueRange operands, DictionaryAttr /*attributes*/, OpaqueProperties /*properties*/, RegionRange /*regions*/, - SmallVectorImpl &inferredReturnTypes) { + SmallVectorImpl& inferredReturnTypes) { // TODO(b/231358795): Review the use of InferTypeOpInterface for ops that // support quantization or sparsity. if (operands.empty()) @@ -473,10 +476,10 @@ class CompatibleOperandsAndResultType // It needs to be paired with INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS // (see examples in StablehloOps.cpp). static LogicalResult inferReturnTypeComponentsFromOperands( - MLIRContext *context, std::optional location, + MLIRContext* context, std::optional location, ValueShapeRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &inferredReturnShapes) { + SmallVectorImpl& inferredReturnShapes) { SmallVector inferredReturnTypes; if (failed(inferReturnTypes(context, location, operands.getValues(), attributes, properties, regions, diff --git a/stablehlo/dialect/ChloOps.cpp b/stablehlo/dialect/ChloOps.cpp index 8a6d7c6cf1..fbe15b1e66 100644 --- a/stablehlo/dialect/ChloOps.cpp +++ b/stablehlo/dialect/ChloOps.cpp @@ -365,11 +365,14 @@ LogicalResult ConstantLikeOp::inferReturnTypeComponents( Type elementType = op.getValue().getType(); Type operandType = op.getOperand().getType(); if (isa(operandType)) { + // TODO(b/326463552): Remove unranked dynamism from CHLO. inferredReturnShapes.emplace_back(elementType); - } else { - const auto& shape = cast(operandType).getShape(); - inferredReturnShapes.emplace_back(shape, elementType); + return success(); } + auto rankedType = cast(operandType); + const auto& shape = rankedType.getShape(); + Attribute encoding = rankedType.getEncoding(); + inferredReturnShapes.emplace_back(shape, elementType, encoding); return success(); } diff --git a/stablehlo/integrations/cpp/builder/CMakeLists.txt b/stablehlo/integrations/cpp/builder/CMakeLists.txt index 28444f4ca3..cda63bf274 100644 --- a/stablehlo/integrations/cpp/builder/CMakeLists.txt +++ b/stablehlo/integrations/cpp/builder/CMakeLists.txt @@ -137,6 +137,7 @@ if (TARGET llvm_gtest) set_target_properties(check-stablehlo-ci PROPERTIES FOLDER "Tests") add_unittest(check-stablehlo-ci "unittests" MlirBuilderTest.cpp + ChloBuilderTest.cpp StablehloBuilderTest.cpp AttrTypeBuilderUtilTest.cpp ) diff --git a/stablehlo/integrations/cpp/builder/ChloBuilder.cpp b/stablehlo/integrations/cpp/builder/ChloBuilder.cpp index 761b1a9c44..9f5d0c0d5d 100644 --- a/stablehlo/integrations/cpp/builder/ChloBuilder.cpp +++ b/stablehlo/integrations/cpp/builder/ChloBuilder.cpp @@ -31,5 +31,15 @@ namespace chlo { #include "stablehlo/integrations/cpp/builder/ChloBuilder.cpp.inc" +///////////////// +// MANUAL APIs +///////////////// + +MlirOp ConstantLike(MlirOp input, DenseElementsAttr val) { + MlirBuilder& builder = input.getBuilder(); + auto splat_val = val.getSplatValue(); + return builder.create(splat_val, input.getValue()); +} + } // namespace chlo } // namespace mlir diff --git a/stablehlo/integrations/cpp/builder/ChloBuilder.h b/stablehlo/integrations/cpp/builder/ChloBuilder.h index d21aefe048..328ea5dd59 100644 --- a/stablehlo/integrations/cpp/builder/ChloBuilder.h +++ b/stablehlo/integrations/cpp/builder/ChloBuilder.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "llvm/ADT/SmallVector.h" +#include "mlir/IR/BuiltinAttributes.h" #include "stablehlo/dialect/ChloOps.h" #include "stablehlo/integrations/cpp/builder/MlirBuilder.h" @@ -31,6 +32,12 @@ namespace chlo { #include "stablehlo/integrations/cpp/builder/ChloBuilder.h.inc" +///////////////// +// MANUAL APIs +///////////////// + +MlirOp ConstantLike(MlirOp input, DenseElementsAttr val); + } // namespace chlo } // namespace mlir diff --git a/stablehlo/integrations/cpp/builder/ChloBuilderTest.cpp b/stablehlo/integrations/cpp/builder/ChloBuilderTest.cpp new file mode 100644 index 0000000000..408b21dcde --- /dev/null +++ b/stablehlo/integrations/cpp/builder/ChloBuilderTest.cpp @@ -0,0 +1,141 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "gtest/gtest.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Support/DebugStringHelper.h" +#include "mlir/Support/LLVM.h" +#include "stablehlo/dialect/Register.h" +#include "stablehlo/integrations/cpp/builder/AttrTypeBuilderUtil.h" +#include "stablehlo/integrations/cpp/builder/ChloBuilder.h" +#include "stablehlo/integrations/cpp/builder/FuncBuilder.h" +#include "stablehlo/integrations/cpp/builder/MlirBuilder.h" +#include "stablehlo/integrations/cpp/builder/StablehloBuilder.h" + +namespace mlir { +namespace chlo { + +namespace { + +// Wrap a module builder and register the classes needed +class ChloModuleBuilder { + public: + ChloModuleBuilder() + : context_(), module_builder_(context_, mlir::unknownLoc(context_)) { + DialectRegistry registry; + stablehlo::registerAllDialects(registry); + context_.appendDialectRegistry(registry); + context_.loadAllAvailableDialects(); + } + + ModuleBuilder& get() { return module_builder_; } + ModuleBuilder* operator->() { return &module_builder_; } + + private: + MLIRContext context_; + ModuleBuilder module_builder_; +}; + +// TODO: Make a FileCheck matcher + +} // namespace + +TEST(ChloBuilderTest, SmokeTest) { + std::string expected = R"mlir(module { + func.func @main(%arg0: tensor<2xi64>) -> tensor<2xi64> { + %0 = chlo.constant dense<1> : tensor + %1 = chlo.broadcast_add %arg0, %0 : (tensor<2xi64>, tensor) -> tensor<2xi64> + return %1 : tensor<2xi64> + } +})mlir"; + + ChloModuleBuilder mb; + { // Build Main Func + Location funcLoc = fileLineColLoc(mb->getContext(), "main.mlir", 1, 1); + func::FunctionBuilder fb(mb.get(), "main", funcLoc); + auto type2xi64 = makeTensorType(mb->getContext(), {2}, ElementType::I64); + auto typeScalari64 = makeTensorType(mb->getContext(), {}, ElementType::I64); + auto arg0 = func::Argument(fb, type2xi64); + auto cst = Constant(fb, mlir::makeConstant(1L, typeScalari64)); + auto add = BroadcastAdd(arg0, cst); + func::Return(fb, {add}); + } + + OwningOpRef module = mb->build(); + EXPECT_TRUE(succeeded(mlir::verify(*module))); + EXPECT_EQ(expected, debugString(*module)); +} + +TEST(MlirBuilderTest, ConstantLike) { + std::string expected = R"mlir(module { + func.func @main(%arg0: tensor<2xi64>) -> tensor<2xi64> { + %0 = "chlo.constant_like"(%arg0) <{value = 1 : i64}> : (tensor<2xi64>) -> tensor<2xi64> + return %0 : tensor<2xi64> + } +})mlir"; + + ChloModuleBuilder mb; + { // Build Main Func + Location funcLoc = fileLineColLoc(mb->getContext(), "main.mlir", 1, 1); + func::FunctionBuilder fb(mb.get(), "main", funcLoc); + auto type2xi64 = makeTensorType(mb->getContext(), {2}, ElementType::I64); + auto typeScalari64 = makeTensorType(mb->getContext(), {}, ElementType::I64); + auto arg0 = func::Argument(fb, type2xi64); + auto cst = ConstantLike(arg0, mlir::makeConstant(1L, typeScalari64)); + func::Return(fb, {cst}); + } + + OwningOpRef module = mb->build(); + EXPECT_TRUE(succeeded(mlir::verify(*module))); + EXPECT_EQ(expected, debugString(*module)); +} + +TEST(MlirBuilderTest, ConstantLikeBounded) { + std::string expected = R"mlir(module { + func.func @main(%arg0: tensor<2xi64>, %arg1: tensor) -> tensor> { + %0 = stablehlo.set_dimension_size %arg0, %arg1, dim = 0 : (tensor<2xi64>, tensor) -> tensor> + %1 = "chlo.constant_like"(%0) <{value = 1 : i32}> : (tensor>) -> tensor> + return %1 : tensor> + } +})mlir"; + + ChloModuleBuilder mb; + { // Build Main Func + Location funcLoc = fileLineColLoc(mb->getContext(), "main.mlir", 1, 1); + func::FunctionBuilder fb(mb.get(), "main", funcLoc); + auto type2xi64 = makeTensorType(mb->getContext(), {2}, ElementType::I64); + auto typei32 = makeTensorType(mb->getContext(), {}, ElementType::I32); + auto arg0 = func::Argument(fb, type2xi64); + auto arg1 = func::Argument(fb, typei32); + auto sds = stablehlo::SetDimensionSize(arg0, arg1, 0); + auto cst = ConstantLike(sds, mlir::makeConstant(1L, typei32)); + func::Return(fb, {cst}); + } + + OwningOpRef module = mb->build(); + EXPECT_TRUE(succeeded(mlir::verify(*module))); + EXPECT_EQ(expected, debugString(*module)); +} + +} // namespace chlo +} // namespace mlir diff --git a/stablehlo/integrations/cpp/builder/StablehloBuilder.cpp b/stablehlo/integrations/cpp/builder/StablehloBuilder.cpp index 56c7e80453..49f7ec0c55 100644 --- a/stablehlo/integrations/cpp/builder/StablehloBuilder.cpp +++ b/stablehlo/integrations/cpp/builder/StablehloBuilder.cpp @@ -67,6 +67,7 @@ MlirOp ConvertElementType(MlirOp input, Type resultElementType) { MlirOp operand = input; auto inputType = mlir::cast(input.getType()); auto resultType = inputType.clone(resultElementType); + if (inputType == resultType) return input; // skip no-op convert if (isa(inputType.getElementType()) && !isa(resultElementType)) { operand = stablehlo::Real(operand); diff --git a/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir b/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir index e86bbc5a52..f314a28d76 100644 --- a/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir +++ b/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir @@ -622,6 +622,10 @@ func.func @asinh_complex_f32(%arg : tensor>) -> tensor func.return %result : tensor> } +////// +// Broadcast binary elementwise ops tests are located in +// chlo_legalize_to_stablehlo_broadcast.mlir + // ----- // Lower statically shaped `constant_like` to constant. @@ -636,6 +640,24 @@ func.func @constant_like_static_shape(%arg : tensor<1x2xi64>) -> tensor<1x2xf32> // ----- +// Lower dynamically shaped `constant_like` to broadcasted constant. +// CHECK-LABEL: constant_like_bounded_dynamic_shape +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2xi64>, %[[ARG1:.*]]: tensor) +func.func @constant_like_bounded_dynamic_shape(%arg0: tensor<2xi64>, %arg1: tensor) -> tensor> { + %0 = stablehlo.set_dimension_size %arg0, %arg1, dim = 0 : (tensor<2xi64>, tensor) -> tensor> + // CHECK-NOT: chlo.constant_like + // CHECK: %[[ARG0_DYN:.*]] = stablehlo.set_dimension_size %[[ARG0]], %[[ARG1]], dim = 0 : (tensor<2xi64>, tensor) -> tensor> + // CHECK: %[[CST:.*]] = stablehlo.constant dense<1> : tensor + // CHECK-NEXT: %[[BCAST:.*]] = stablehlo.broadcast_in_dim %[[CST]], dims = [] : (tensor) -> tensor<2xi32> + // CHECK-NEXT: %[[GDS:.*]] = stablehlo.get_dimension_size %[[ARG0_DYN]], dim = 0 : (tensor>) -> tensor + // CHECK-NEXT: %[[SDS:.*]] = stablehlo.set_dimension_size %[[BCAST]], %[[GDS]], dim = 0 : (tensor<2xi32>, tensor) -> tensor> + // CHECK-NEXT: return %[[SDS]] : tensor> + %1 = "chlo.constant_like"(%0) <{value = 1 : i32}> : (tensor>) -> tensor> + return %1 : tensor> +} + +// ----- + // Lower dynamically shaped `constant_like` to broadcasted constant. // CHECK-LABEL: constant_like_dynamic_shape // CHECK-SAME: (%[[ARG:.*]]: tensor) diff --git a/stablehlo/tests/chlo/chlo_legalize_to_stablehlo_broadcast.mlir b/stablehlo/tests/chlo/chlo_legalize_to_stablehlo_broadcast.mlir index e19269ef3f..d6381919d8 100644 --- a/stablehlo/tests/chlo/chlo_legalize_to_stablehlo_broadcast.mlir +++ b/stablehlo/tests/chlo/chlo_legalize_to_stablehlo_broadcast.mlir @@ -3,8 +3,8 @@ // Check the non-broadcast case for each registered op, then just check a // representative op for detailed broadcast semantics. -// CHECK-LABEL: @addWithoutBroadcast -func.func @addWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { +// CHECK-LABEL: @add_no_broadcast +func.func @add_no_broadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // CHECK: stablehlo.add %arg0, %arg1 %0 = chlo.broadcast_add %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> func.return %0 : tensor<4xf32> @@ -12,8 +12,8 @@ func.func @addWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> te // ----- -// CHECK-LABEL: @addStaticBroadcastExpanding -func.func @addStaticBroadcastExpanding(%arg0: tensor<4xf32>, %arg1: tensor) -> tensor<4xf32> { +// CHECK-LABEL: @add_static_broadcast_expanding +func.func @add_static_broadcast_expanding(%arg0: tensor<4xf32>, %arg1: tensor) -> tensor<4xf32> { // CHECK: %[[BROADCAST:.+]] = stablehlo.broadcast_in_dim %arg1, dims = [] : (tensor) -> tensor<4xf32> // CHECK-NEXT: stablehlo.add %arg0, %[[BROADCAST]] // CHECK-NOT: shape @@ -23,8 +23,8 @@ func.func @addStaticBroadcastExpanding(%arg0: tensor<4xf32>, %arg1: tensor) // ----- -// CHECK-LABEL: @addStaticBroadcastSameRank -func.func @addStaticBroadcastSameRank(%arg0: tensor<1x4xf32>, %arg1: tensor<4x1xf32>) -> tensor<4x4xf32> { +// CHECK-LABEL: @add_static_broadcast_same_rank +func.func @add_static_broadcast_same_rank(%arg0: tensor<1x4xf32>, %arg1: tensor<4x1xf32>) -> tensor<4x4xf32> { // CHECK: %[[ARG0_B:.+]] = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<1x4xf32>) -> tensor<4x4xf32> // CHECK-NEXT: %[[ARG1_B:.+]] = stablehlo.broadcast_in_dim %arg1, dims = [0, 1] : (tensor<4x1xf32>) -> tensor<4x4xf32> // CHECK-NEXT: stablehlo.add %[[ARG0_B]], %[[ARG1_B]] : tensor<4x4xf32> @@ -35,11 +35,33 @@ func.func @addStaticBroadcastSameRank(%arg0: tensor<1x4xf32>, %arg1: tensor<4x1x // ----- +// [<=10] x [<=10] => [<=10] +// CHECK-LABEL: func @add_bounded_dynamic_no_broadcast +func.func @add_bounded_dynamic_no_broadcast(%arg0: tensor>, %arg1: tensor>) -> tensor> { + // CHECK-NEXT: stablehlo.add %arg0, %arg1 + %0 = chlo.broadcast_add %arg0, %arg1 : (tensor>, tensor>) -> tensor> + return %0 : tensor> +} + +// ----- + +// [<=10] x [] => [<=10] +// CHECK-LABEL: func @add_bounded_dynamic_expanding +func.func @add_bounded_dynamic_expanding(%arg0: tensor>, %arg1: tensor) -> tensor> { + // CHECK: %[[RHS_BCAST:.+]] = stablehlo.broadcast_in_dim %arg1, dims = [] : (tensor) -> tensor<10xf64> + // CHECK: %[[DIM_SIZE:.+]] = stablehlo.get_dimension_size %arg0, dim = 0 + // CHECK: %[[RHS_BCAST_DYN:.+]] = stablehlo.set_dimension_size %[[RHS_BCAST]], %[[DIM_SIZE]], dim = 0 + // CHECK-NEXT: stablehlo.add %arg0, %[[RHS_BCAST_DYN]] + %0 = chlo.broadcast_add %arg0, %arg1 : (tensor>, tensor) -> tensor> + return %0 : tensor> +} + +// ----- -// CHECK-LABEL: @dynamicBroadcast +// CHECK-LABEL: @add_dynamic_broadcast // CHECK-SAME: %[[ARG0:.+]]: tensor // CHECK-SAME: %[[ARG1:.+]]: tensor -func.func @dynamicBroadcast(%arg0: tensor, %arg1: tensor) -> tensor { +func.func @add_dynamic_broadcast(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]] // CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]] // CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]] @@ -57,10 +79,10 @@ func.func @dynamicBroadcast(%arg0: tensor, %arg1: tensor) -> ten // ----- -// CHECK-LABEL: @dynamicBroadcastComplex +// CHECK-LABEL: @dynamic_broadcast_complex // CHECK-SAME: %[[ARG0:.+]]: tensor // CHECK-SAME: %[[ARG1:.+]]: tensor -func.func @dynamicBroadcastComplex(%arg0: tensor, %arg1: tensor) -> tensor> { +func.func @dynamic_broadcast_complex(%arg0: tensor, %arg1: tensor) -> tensor> { // CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]] // CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]] // CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]] @@ -78,10 +100,10 @@ func.func @dynamicBroadcastComplex(%arg0: tensor, %arg1: tensor) // ----- -// CHECK-LABEL: @dynamicBroadcastCompare +// CHECK-LABEL: @compare_dynamic_broadcast // CHECK-SAME: %[[ARG0:.+]]: tensor // CHECK-SAME: %[[ARG1:.+]]: tensor -func.func @dynamicBroadcastCompare(%arg0: tensor, %arg1: tensor) -> tensor { +func.func @compare_dynamic_broadcast(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]] // CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]] // CHECK: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]] @@ -191,8 +213,8 @@ func.func @selectv2_dynamic_ranked(%arg0: tensor<1xi1>, %arg1: tensor<2x?x8xi32> // ----- // Verifies that broadcast_dimensions validity checks are valid. -// CHECK-LABEL: @dynamicNonScalarBroadcastDimensions -func.func @dynamicNonScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { +// CHECK-LABEL: @dynamic_non_scalar_broadcast_dimensions +func.func @dynamic_non_scalar_broadcast_dimensions(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { // CHECK: stablehlo.add %0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = array } : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> func.return %0 : tensor<1x4xf32> @@ -201,8 +223,8 @@ func.func @dynamicNonScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: te // ----- // Verifies that broadcast_dimensions validity checks are valid. -// CHECK-LABEL: @dynamicNonScalarByScalarBroadcastDimensions -func.func @dynamicNonScalarByScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor) -> tensor<1x4xf32> { +// CHECK-LABEL: @dynamic_non_scalar_by_scalar_broadcast_dimensions +func.func @dynamic_non_scalar_by_scalar_broadcast_dimensions(%arg0: tensor<1x4xf32>, %arg1: tensor) -> tensor<1x4xf32> { // CHECK: stablehlo.add %0 = chlo.broadcast_add %arg0, %arg1 : (tensor<1x4xf32>, tensor) -> tensor<1x4xf32> func.return %0 : tensor<1x4xf32> @@ -211,7 +233,7 @@ func.func @dynamicNonScalarByScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, % // ----- // Verifies that invalid broadcast dimensions are rejected. -func.func @dynamicNonScalarBroadcastDimensionsSizeMismatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { +func.func @dynamic_non_scalar_broadcast_dimensions_size_mismatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { // expected-warning @+2 {{unsupported non prefix-padded dynamic rank broadcast_dimensions}} // expected-error @+1 {{failed to legalize operation}} %0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = array} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> @@ -221,7 +243,7 @@ func.func @dynamicNonScalarBroadcastDimensionsSizeMismatch(%arg0: tensor<1x4xf32 // ----- // Verifies that invalid broadcast dimensions are rejected. -func.func @dynamicNonScalarBroadcastDimensionsMismatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { +func.func @dynamic_non_scalar_broadcast_dimensions_mismatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { // expected-warning @+2 {{unsupported non prefix-padded dynamic rank broadcast_dimensions}} // expected-error @+1 {{failed to legalize operation}} %0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = array} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> @@ -232,8 +254,8 @@ func.func @dynamicNonScalarBroadcastDimensionsMismatch(%arg0: tensor<1x4xf32>, % // Note that broadcast_add is used as a proxy for all of the template // expansions. Tests below merely verify that the op has an expansion. -// CHECK-LABEL: @andWithoutBroadcast -func.func @andWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { +// CHECK-LABEL: @and_no_broadcast +func.func @and_no_broadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { // CHECK: stablehlo.and %arg0, %arg1 %0 = chlo.broadcast_and %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> func.return %0 : tensor<4xi1> @@ -241,8 +263,8 @@ func.func @andWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tens // ----- -// CHECK-LABEL: @atan2WithoutBroadcast -func.func @atan2WithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { +// CHECK-LABEL: @atan2_no_broadcast +func.func @atan2_no_broadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // CHECK: stablehlo.atan2 %arg0, %arg1 %0 = chlo.broadcast_atan2 %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> func.return %0 : tensor<4xf32> @@ -250,8 +272,8 @@ func.func @atan2WithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> // ----- -// CHECK-LABEL: @compareWithoutBroadcast -func.func @compareWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xi1> { +// CHECK-LABEL: @compare_no_broadcast +func.func @compare_no_broadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xi1> { // CHECK: stablehlo.compare EQ, %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> %0 = chlo.broadcast_compare %arg0, %arg1 {comparison_direction = #chlo} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> func.return %0 : tensor<4xi1> @@ -259,8 +281,8 @@ func.func @compareWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) - // ----- -// CHECK-LABEL: @complexWithoutBroadcast -func.func @complexWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xcomplex> { +// CHECK-LABEL: @complex_no_broadcast +func.func @complex_no_broadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xcomplex> { // CHECK: stablehlo.complex %arg0, %arg1 : tensor<4xcomplex> %0 = chlo.broadcast_complex %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex> func.return %0 : tensor<4xcomplex> @@ -268,8 +290,8 @@ func.func @complexWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) - // ----- -// CHECK-LABEL: @divideWithoutBroadcast -func.func @divideWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { +// CHECK-LABEL: @divide_no_broadcast +func.func @divide_no_broadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // CHECK: stablehlo.divide %arg0, %arg1 %0 = chlo.broadcast_divide %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> func.return %0 : tensor<4xf32> @@ -277,8 +299,8 @@ func.func @divideWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> // ----- -// CHECK-LABEL: @maximumWithoutBroadcast -func.func @maximumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { +// CHECK-LABEL: @maximum_no_broadcast +func.func @maximum_no_broadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // CHECK: stablehlo.maximum %arg0, %arg1 %0 = chlo.broadcast_maximum %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> func.return %0 : tensor<4xf32> @@ -286,8 +308,8 @@ func.func @maximumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) - // ----- -// CHECK-LABEL: @minimumWithoutBroadcast -func.func @minimumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { +// CHECK-LABEL: @minimum_no_broadcast +func.func @minimum_no_broadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // CHECK: stablehlo.minimum %arg0, %arg1 %0 = chlo.broadcast_minimum %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> func.return %0 : tensor<4xf32> @@ -295,8 +317,8 @@ func.func @minimumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) - // ----- -// CHECK-LABEL: @multiplyWithoutBroadcast -func.func @multiplyWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { +// CHECK-LABEL: @multiply_no_broadcast +func.func @multiply_no_broadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // CHECK: stablehlo.multiply %arg0, %arg1 %0 = chlo.broadcast_multiply %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> func.return %0 : tensor<4xf32> @@ -304,8 +326,8 @@ func.func @multiplyWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) // ----- -// CHECK-LABEL: @orWithoutBroadcast -func.func @orWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { +// CHECK-LABEL: @or_no_broadcast +func.func @or_no_broadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { // CHECK: stablehlo.or %arg0, %arg1 %0 = chlo.broadcast_or %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> func.return %0 : tensor<4xi1> @@ -313,8 +335,8 @@ func.func @orWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tenso // ----- -// CHECK-LABEL: @powerWithoutBroadcast -func.func @powerWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { +// CHECK-LABEL: @power_no_broadcast +func.func @power_no_broadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // CHECK: stablehlo.power %arg0, %arg1 %0 = chlo.broadcast_power %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> func.return %0 : tensor<4xf32> @@ -322,8 +344,8 @@ func.func @powerWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> // ----- -// CHECK-LABEL: @remainderWithoutBroadcast -func.func @remainderWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { +// CHECK-LABEL: @remainder_no_broadcast +func.func @remainder_no_broadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // CHECK: stablehlo.remainder %arg0, %arg1 %0 = chlo.broadcast_remainder %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> func.return %0 : tensor<4xf32> @@ -331,8 +353,8 @@ func.func @remainderWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) // ----- -// CHECK-LABEL: @shift_leftWithoutBroadcast -func.func @shift_leftWithoutBroadcast(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { +// CHECK-LABEL: @shift_left_no_broadcast +func.func @shift_left_no_broadcast(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { // CHECK: stablehlo.shift_left %arg0, %arg1 %0 = chlo.broadcast_shift_left %arg0, %arg1 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> func.return %0 : tensor<4xi32> @@ -340,8 +362,8 @@ func.func @shift_leftWithoutBroadcast(%arg0: tensor<4xi32>, %arg1: tensor<4xi32> // ----- -// CHECK-LABEL: @shift_right_arithmeticWithoutBroadcast -func.func @shift_right_arithmeticWithoutBroadcast(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { +// CHECK-LABEL: @shift_right_arithmetic_no_broadcast +func.func @shift_right_arithmetic_no_broadcast(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { // CHECK: stablehlo.shift_right_arithmetic %arg0, %arg1 %0 = chlo.broadcast_shift_right_arithmetic %arg0, %arg1 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> func.return %0 : tensor<4xi32> @@ -349,8 +371,8 @@ func.func @shift_right_arithmeticWithoutBroadcast(%arg0: tensor<4xi32>, %arg1: t // ----- -// CHECK-LABEL: @shift_right_logicalWithoutBroadcast -func.func @shift_right_logicalWithoutBroadcast(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { +// CHECK-LABEL: @shift_right_logical_no_broadcast +func.func @shift_right_logical_no_broadcast(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { // CHECK: stablehlo.shift_right_logical %arg0, %arg1 %0 = chlo.broadcast_shift_right_logical %arg0, %arg1 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> func.return %0 : tensor<4xi32> @@ -358,8 +380,8 @@ func.func @shift_right_logicalWithoutBroadcast(%arg0: tensor<4xi32>, %arg1: tens // ----- -// CHECK-LABEL: @subWithoutBroadcast -func.func @subWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { +// CHECK-LABEL: @sub_no_broadcast +func.func @sub_no_broadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // CHECK: stablehlo.subtract %arg0, %arg1 %0 = chlo.broadcast_subtract %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> func.return %0 : tensor<4xf32> @@ -367,16 +389,16 @@ func.func @subWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> te // ----- -// CHECK-LABEL: @xorWithoutBroadcast -func.func @xorWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { +// CHECK-LABEL: @xor_no_broadcast +func.func @xor_no_broadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { // CHECK: stablehlo.xor %arg0, %arg1 %0 = chlo.broadcast_xor %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> func.return %0 : tensor<4xi1> } // ----- -// CHECK-LABEL: @NextAfterWithoutBroadcast -func.func @NextAfterWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) +// CHECK-LABEL: @next_after_no_broadcast +func.func @next_after_no_broadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // CHECK-NOT: chlo.broadcast_next_after %0 = chlo.broadcast_next_after %arg0, %arg1 @@ -386,8 +408,8 @@ func.func @NextAfterWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) // ----- -// CHECK-LABEL: @PolygammaWithoutBroadcast -func.func @PolygammaWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) +// CHECK-LABEL: @Polygamma_no_broadcast +func.func @Polygamma_no_broadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // CHECK-NOT: chlo.broadcast_polygamma // CHECK-NOT: chlo.polygamma @@ -398,8 +420,8 @@ func.func @PolygammaWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) // ----- -// CHECK-LABEL: @ZetaWithoutBroadcast -func.func @ZetaWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) +// CHECK-LABEL: @Zeta_no_broadcast +func.func @Zeta_no_broadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // CHECK-NOT: chlo.broadcast_zeta // CHECK-NOT: chlo.zeta diff --git a/stablehlo/transforms/CMakeLists.txt b/stablehlo/transforms/CMakeLists.txt index 646a64b421..8daa634098 100644 --- a/stablehlo/transforms/CMakeLists.txt +++ b/stablehlo/transforms/CMakeLists.txt @@ -113,6 +113,7 @@ add_mlir_dialect_library(StablehloPasses MLIRTransformUtils StablehloBase StablehloBroadcastUtils + StablehloBroadcastLowering StablehloLinalgTransforms StablehloOps StablehloOptimizationPasses diff --git a/stablehlo/transforms/ChloLegalizeToStablehlo.cpp b/stablehlo/transforms/ChloLegalizeToStablehlo.cpp index 6dae4370d5..54163cf6ba 100644 --- a/stablehlo/transforms/ChloLegalizeToStablehlo.cpp +++ b/stablehlo/transforms/ChloLegalizeToStablehlo.cpp @@ -35,7 +35,6 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" @@ -51,6 +50,7 @@ #include "stablehlo/transforms/ChloDecompositionUtils.h" #include "stablehlo/transforms/PassUtils.h" #include "stablehlo/transforms/Passes.h" +#include "stablehlo/transforms/StablehloBroadcastLowering.h" // This must precede all other headers, otherwise during Windows cross // compilation, M_PI will not be defined. @@ -201,35 +201,14 @@ static Value getConstantLikeSmallestNormalizedValue(OpBuilder& b, Location loc, val); } -// Broadcast using numpy-style broadcasting semantics. -// This is only valid if the CHLO op has static shaped operands, and no -// explicitly specified broadcast_dimensions. -// -// Asserts that input is ranked tensor type. -Value numpyBroadcastIfNeeded(Value op, RankedTensorType opResultType, - PatternRewriter& rewriter) { - RankedTensorType inputType = cast(op.getType()); - RankedTensorType broadcastedResultType = - opResultType.clone(inputType.getElementType()); - - // No broadcasting needed if input type matches broadcasted result type. - if (inputType == broadcastedResultType) return op; - - // broadcast dims are the last dims for numpy style broadcasting. - int64_t inputRank = inputType.getRank(); - int64_t resultRank = opResultType.getRank(); - auto broadcastDimensions = - llvm::to_vector(llvm::seq(resultRank - inputRank, resultRank)); - return stablehlo::BroadcastInDimOp::create(rewriter, op.getLoc(), - broadcastedResultType, op, - broadcastDimensions) - .getResult(); -} - //===----------------------------------------------------------------------===// // Broadcasting Patterns. //===----------------------------------------------------------------------===// +bool isStaticOrBoundedDynamicTensor(RankedTensorType type) { + return type.hasStaticShape() || hlo::isBoundedDynamic(type); +} + // Converts binary ops that statically are determined to not broadcast directly // to the corresponding stablehlo non-broadcasting op. template @@ -243,12 +222,14 @@ struct ConvertTrivialNonBroadcastBinaryOp final // Only rewrite for statically determinable non-broadcasting cases. auto lhsType = dyn_cast(adaptor.getLhs().getType()); auto rhsType = dyn_cast(adaptor.getRhs().getType()); - if (!lhsType || !rhsType || lhsType.getShape() != rhsType.getShape() || - !lhsType.hasStaticShape() || !rhsType.hasStaticShape()) + if (!lhsType || !rhsType || !isStaticOrBoundedDynamicTensor(lhsType) || + !isStaticOrBoundedDynamicTensor(rhsType) || + lhsType.getShape() != rhsType.getShape() || + lhsType.getEncoding() != rhsType.getEncoding()) return rewriter.notifyMatchFailure( op, "expected LHS and RHS to be ranked tensors with matching shapes that " - "are all static"); + "are all static or bounded dynamic"); rewriter.replaceOp( op, ValueRange{Adaptor::createOp(op, op.getType(), @@ -270,41 +251,46 @@ struct ConvertTrivialNumpyBroadcastBinaryOp final // Only rewrite for statically determinable non-broadcasting cases. auto lhsType = dyn_cast(adaptor.getLhs().getType()); auto rhsType = dyn_cast(adaptor.getRhs().getType()); - if (!lhsType || !rhsType || !lhsType.hasStaticShape() || - !rhsType.hasStaticShape()) + if (!lhsType || !rhsType || !isStaticOrBoundedDynamicTensor(lhsType) || + !isStaticOrBoundedDynamicTensor(rhsType)) return rewriter.notifyMatchFailure( op, - "expected LHS and RHS to be ranked tensor types with static " - "shape"); + "expected LHS and RHS to be ranked tensor types with static or " + "bounded dynamic shape"); // Rely on CHLO type inference to figure out the proper broadcasted shape. auto resultType = dyn_cast(op.getResult().getType()); - if (!resultType || !resultType.hasStaticShape()) + if (!resultType || !isStaticOrBoundedDynamicTensor(resultType)) return rewriter.notifyMatchFailure( - op, "expected result to be a ranked tensor type with static shape"); + op, + "expected result to be a ranked tensor type with static or bounded " + "dynamic shape"); auto lhs = adaptor.getLhs(); auto rhs = adaptor.getRhs(); auto broadcastDimensions = adaptor.getBroadcastDimensions(); if (broadcastDimensions && - !hlo::isLegalNumpyRankedBroadcast(lhs, rhs, *broadcastDimensions)) + !hlo::isLegalNumpyRankedBroadcast(lhs, rhs, *broadcastDimensions)) { return rewriter.notifyMatchFailure( op, "expected implicit broadcast_dimensions or numpy-style broadcasting"); + } LLVM_DEBUG(llvm::dbgs() << "CHLO Decomposing " << op->getName() << " with broadcast " << lhsType << " x " << rhsType << " -> " << resultType << "\n"); - // If operands are static directly create stablehlo broadcasting ops. - // Use numpy-style broadcasting with using StableHLO broadcast ops, - // when user didn't specify broadcast_dimensions. - auto lhsBroadcast = - numpyBroadcastIfNeeded(adaptor.getLhs(), resultType, rewriter); - auto rhsBroadcast = - numpyBroadcastIfNeeded(adaptor.getRhs(), resultType, rewriter); - auto result = Adaptor::createOp(op, resultType, - {lhsBroadcast, rhsBroadcast}, rewriter); + // If operands are static or bounded dynamic, directly create stablehlo + // broadcasting ops. Use numpy-style broadcasting with using StableHLO + // broadcast ops. Can leave off broadcast_dimensions since the above + // logic verifies that they are the default for numpy-style broadcasting. + mlir::SmallVector broadcastOperands = {lhs, rhs}; + auto broadcasted_values = + stablehlo::numpyBroadcastIfNeeded(rewriter, broadcastOperands); + if (failed(broadcasted_values)) return failure(); + + auto result = + Adaptor::createOp(op, resultType, *broadcasted_values, rewriter); rewriter.replaceOp(op, {result.getResult()}); return success(); } @@ -425,7 +411,21 @@ struct ConvertConstantLikeOp final return success(); } - // Lower to broadcasted constant. + // Lower to cst -> broadcast -> set_dimension_size if bounded dynamic. + if (hlo::isBoundedDynamic(resultTy)) { + Value constant = mlir::stablehlo::ConstantOp::create( + rewriter, op.getLoc(), op.getValue()); + mlir::FailureOr operandDims = + getDimensions(adaptor.getOperand()); + if (failed(operandDims)) return failure(); + mlir::FailureOr broadcast = + stablehlo::numpyBroadcastIfNeeded(rewriter, constant, *operandDims); + if (failed(broadcast)) return failure(); + rewriter.replaceOp(op, *broadcast); + return success(); + } + + // Lower unbounded dynamic to broadcasted constant. Location loc = op.getLoc(); Value constant = mlir::stablehlo::ConstantOp::create(rewriter, loc, op.getValue()); diff --git a/stablehlo/transforms/StablehloBroadcastLowering.cpp b/stablehlo/transforms/StablehloBroadcastLowering.cpp index d44c09a0a6..e876eb98ff 100644 --- a/stablehlo/transforms/StablehloBroadcastLowering.cpp +++ b/stablehlo/transforms/StablehloBroadcastLowering.cpp @@ -59,26 +59,6 @@ DimensionInfo getDimensionInfo(Value op, mlir::RankedTensorType tensorType, }; } -FailureOr getDimensions(Value op) { - // Get tensor type - mlir::RankedTensorType tensor_type = dyn_cast(op.getType()); - if (!tensor_type) - return emitError(op.getLoc(), - "expected ranked tensor type for broadcast inputs"); - - auto encoding = - mlir::dyn_cast_if_present( - tensor_type.getEncoding()); - - Dimensions dimensions; - dimensions.reserve(tensor_type.getRank()); - for (int64_t idx = 0; idx < tensor_type.getRank(); ++idx) { - auto dimInfo = getDimensionInfo(op, tensor_type, encoding, idx); - dimensions.push_back(dimInfo); - } - return dimensions; -} - FailureOr getNumpyBroadcastShapeWithBounds(Value op, const Dimensions& a, const Dimensions& b) { @@ -132,6 +112,28 @@ FailureOr getNumpyBroadcastShapeWithBounds(Value op, return result; } +} // namespace + +FailureOr getDimensions(Value op) { + // Get tensor type + mlir::RankedTensorType tensor_type = dyn_cast(op.getType()); + if (!tensor_type) + return emitError(op.getLoc(), + "expected ranked tensor type for broadcast inputs"); + + auto encoding = + mlir::dyn_cast_if_present( + tensor_type.getEncoding()); + + Dimensions dimensions; + dimensions.reserve(tensor_type.getRank()); + for (int64_t idx = 0; idx < tensor_type.getRank(); ++idx) { + auto dimInfo = getDimensionInfo(op, tensor_type, encoding, idx); + dimensions.push_back(dimInfo); + } + return dimensions; +} + mlir::RankedTensorType getRankedTensorType(const Dimensions& dims, mlir::Type element_type) { mlir::SmallVector shape; @@ -155,8 +157,6 @@ mlir::RankedTensorType getRankedTensorType(const Dimensions& dims, return mlir::RankedTensorType::get(shape, element_type, encoding); } -} // namespace - FailureOr getNumpyBroadcastShape(OpBuilder& builder, ArrayRef ops) { if (ops.empty()) diff --git a/stablehlo/transforms/StablehloBroadcastLowering.h b/stablehlo/transforms/StablehloBroadcastLowering.h index 56de1bc1d5..b31b8fc7ef 100644 --- a/stablehlo/transforms/StablehloBroadcastLowering.h +++ b/stablehlo/transforms/StablehloBroadcastLowering.h @@ -47,6 +47,14 @@ struct DimensionInfo { using Dimensions = SmallVector; std::string toString(const Dimensions& dims); +// Returns the dimensions of the given op, or failure if the op's type is not a +// ranked tensor. +FailureOr getDimensions(Value op); + +// Returns the ranked tensor type with the given dimensions and element type. +mlir::RankedTensorType getRankedTensorType(const Dimensions& dims, + mlir::Type element_type); + // Returns the common shape these ops would broadcast to, or an error if the // ops are not broadcastable. FailureOr getNumpyBroadcastShape(OpBuilder& builder,