From cf49ea0aa50aac019431b373e77c5f1fe7f08599 Mon Sep 17 00:00:00 2001 From: mingzheTerapines Date: Mon, 22 Sep 2025 13:59:16 +0800 Subject: [PATCH 1/3] [TorchToLinalg] Fix multi batch matmul conversion to Linalg Co-authored by: chao.mei@terapines.com Improve usage of static shape information to avoid unnecessary broadcast. --- lib/Conversion/TorchToLinalg/Linear.cpp | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 9ffb7c1dc0f3..0014f0fc60b3 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -76,6 +76,15 @@ static Value transposeValue(Location loc, Value value, ArrayRef perms, return transpose; } +static int64_t getDimFromValue(Value dimValue) { + if (auto constOp = dimValue.getDefiningOp()) { + if (auto intAttr = dyn_cast(constOp.getValue())) { + return intAttr.getInt(); + } + } + return ShapedType::kDynamic; +} + class ConvertAtenMmOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -505,9 +514,9 @@ class ConvertAtenMatmulOp : public OpConversionPattern { // Broadcast the batch dimensions of both the matrices. Value broadcastedLhs, broadcastedRhs; - // TODO: Improve usage of static shape information. - SmallVector lhsTargetShape(lhsBroadcastToShape.size(), - ShapedType::kDynamic); + SmallVector lhsTargetShape = llvm::to_vector( + llvm::map_range(lhsBroadcastToShape, getDimFromValue)); + auto lhsBroadcastType = RankedTensorType::get( lhsTargetShape, lhsType.getElementType(), lhsType.getEncoding()); if (failed(torch_to_linalg::broadcastToGivenShape( @@ -516,8 +525,8 @@ class ConvertAtenMatmulOp : public OpConversionPattern { return rewriter.notifyMatchFailure( op, "unable to perform broadcast operation"); } - SmallVector rhsTargetShape(rhsBroadcastToShape.size(), - ShapedType::kDynamic); + SmallVector rhsTargetShape = llvm::to_vector( + llvm::map_range(rhsBroadcastToShape, getDimFromValue)); auto rhsBroadcastType = RankedTensorType::get( rhsTargetShape, rhsType.getElementType(), rhsType.getEncoding()); if (failed(torch_to_linalg::broadcastToGivenShape( From d35aadf5710d07f0cd7bc6b60ce362dffebd12f5 Mon Sep 17 00:00:00 2001 From: mingzheTerapines Date: Fri, 26 Sep 2025 10:26:01 +0800 Subject: [PATCH 2/3] use upstream utils functions and add lit tests --- lib/Conversion/TorchToLinalg/Linear.cpp | 21 +++----- test/Conversion/TorchToLinalg/basic.mlir | 64 ++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 13 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 0014f0fc60b3..89ec9a599e4b 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -76,15 +76,6 @@ static Value transposeValue(Location loc, Value value, ArrayRef perms, return transpose; } -static int64_t getDimFromValue(Value dimValue) { - if (auto constOp = dimValue.getDefiningOp()) { - if (auto intAttr = dyn_cast(constOp.getValue())) { - return intAttr.getInt(); - } - } - return ShapedType::kDynamic; -} - class ConvertAtenMmOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -514,8 +505,10 @@ class ConvertAtenMatmulOp : public OpConversionPattern { // Broadcast the batch dimensions of both the matrices. Value broadcastedLhs, broadcastedRhs; - SmallVector lhsTargetShape = llvm::to_vector( - llvm::map_range(lhsBroadcastToShape, getDimFromValue)); + SmallVector lhsTargetShape = + llvm::to_vector(llvm::map_range(lhsBroadcastToShape, [](Value v) { + return getConstantIntValue(v).value_or(ShapedType::kDynamic); + })); auto lhsBroadcastType = RankedTensorType::get( lhsTargetShape, lhsType.getElementType(), lhsType.getEncoding()); @@ -525,8 +518,10 @@ class ConvertAtenMatmulOp : public OpConversionPattern { return rewriter.notifyMatchFailure( op, "unable to perform broadcast operation"); } - SmallVector rhsTargetShape = llvm::to_vector( - llvm::map_range(rhsBroadcastToShape, getDimFromValue)); + SmallVector rhsTargetShape = + llvm::to_vector(llvm::map_range(rhsBroadcastToShape, [](Value v) { + return getConstantIntValue(v).value_or(ShapedType::kDynamic); + })); auto rhsBroadcastType = RankedTensorType::get( rhsTargetShape, rhsType.getElementType(), rhsType.getEncoding()); if (failed(torch_to_linalg::broadcastToGivenShape( diff --git a/test/Conversion/TorchToLinalg/basic.mlir b/test/Conversion/TorchToLinalg/basic.mlir index 262f6e646bdd..5895d19a2ec1 100644 --- a/test/Conversion/TorchToLinalg/basic.mlir +++ b/test/Conversion/TorchToLinalg/basic.mlir @@ -43,6 +43,70 @@ func.func @torch.aten.matmul.2d(%arg0: !torch.vtensor<[8,16],f32>, %arg1: !torch // ----- +// CHECK-LABEL: func.func @torch.aten.matmul.4d +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1,2,32,400],f32>, +// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[1,2,400,32],f32>) -> !torch.vtensor<[1,2,400,400],f32> { +// CHECK: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[1,2,32,400],f32> -> tensor<1x2x32x400xf32> +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[1,2,400,32],f32> -> tensor<1x2x400x32xf32> +// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_3:.*]] = arith.constant 2 : index +// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_5:.*]] = arith.constant 2 : index +// CHECK: %[[VAL_6:.*]] = arith.constant 2 : index +// CHECK: %[[VAL_7:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_8:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_9:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_10:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_11:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_12:.*]] = arith.constant 2 : index +// CHECK: %[[VAL_13:.*]] = arith.constant 400 : index +// CHECK: %[[VAL_14:.*]] = arith.constant 3 : index +// CHECK: %[[VAL_15:.*]] = arith.constant 32 : index +// CHECK: %[[VAL_16:.*]] = arith.constant 2 : index +// CHECK: %[[VAL_17:.*]] = arith.constant 32 : index +// CHECK: %[[VAL_18:.*]] = arith.constant 3 : index +// CHECK: %[[VAL_19:.*]] = arith.constant 400 : index +// CHECK: %[[VAL_20:.*]] = arith.constant 32 : i64 +// CHECK: %[[VAL_21:.*]] = arith.constant 32 : i64 +// CHECK: %[[VAL_22:.*]] = arith.cmpi eq, %[[VAL_20]], %[[VAL_21]] : i64 +// CHECK: cf.assert %[[VAL_22]], "mismatching contracting dimension" +// CHECK: %[[VAL_23:.*]] = arith.constant 1 : i64 +// CHECK: %[[VAL_24:.*]] = arith.constant 1 : i64 +// CHECK: %[[VAL_25:.*]] = arith.constant 2 : i64 +// CHECK: %[[VAL_26:.*]] = arith.constant 2 : i64 +// CHECK: %[[VAL_27:.*]] = arith.constant 400 : i64 +// CHECK: %[[VAL_28:.*]] = arith.constant 32 : i64 +// CHECK: %[[VAL_29:.*]] = arith.constant 32 : i64 +// CHECK: %[[VAL_30:.*]] = arith.constant 400 : i64 +// CHECK: %[[VAL_31:.*]] = arith.constant 0 : i64 +// CHECK: %[[VAL_32:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_33:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_34:.*]] = tensor.empty() : tensor<1x2x400x32xf32> +// CHECK: %[[VAL_35:.*]] = tensor.cast %[[VAL_1]] : tensor<1x2x400x32xf32> to tensor<1x2x400x32xf32> +// CHECK: %[[VAL_36:.*]] = arith.constant 0 : i64 +// CHECK: %[[VAL_37:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_38:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_39:.*]] = tensor.empty() : tensor<1x2x32x400xf32> +// CHECK: %[[VAL_40:.*]] = tensor.cast %[[VAL_0]] : tensor<1x2x32x400xf32> to tensor<1x2x32x400xf32> +// CHECK: %[[VAL_41:.*]] = tensor.collapse_shape %[[VAL_35]] {{\[\[}}0, 1], [2], [3]] : tensor<1x2x400x32xf32> into tensor<2x400x32xf32> +// CHECK: %[[VAL_42:.*]] = tensor.collapse_shape %[[VAL_40]] {{\[\[}}0, 1], [2], [3]] : tensor<1x2x32x400xf32> into tensor<2x32x400xf32> +// CHECK: %[[VAL_43:.*]] = arith.constant 2 : index +// CHECK: %[[VAL_44:.*]] = tensor.empty() : tensor<2x400x400xf32> +// CHECK: %[[VAL_45:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_46:.*]] = linalg.fill ins(%[[VAL_45]] : f32) outs(%[[VAL_44]] : tensor<2x400x400xf32>) -> tensor<2x400x400xf32> +// CHECK: %[[VAL_47:.*]] = linalg.batch_matmul ins(%[[VAL_41]], %[[VAL_42]] : tensor<2x400x32xf32>, tensor<2x32x400xf32>) outs(%[[VAL_46]] : tensor<2x400x400xf32>) -> tensor<2x400x400xf32> +// CHECK: %[[VAL_48:.*]] = tensor.expand_shape %[[VAL_47]] {{\[\[}}0, 1], [2], [3]] output_shape [1, 2, 400, 400] : tensor<2x400x400xf32> into tensor<1x2x400x400xf32> +// CHECK: %[[VAL_49:.*]] = tensor.cast %[[VAL_48]] : tensor<1x2x400x400xf32> to tensor<1x2x400x400xf32> +// CHECK: %[[VAL_50:.*]] = torch_c.from_builtin_tensor %[[VAL_49]] : tensor<1x2x400x400xf32> -> !torch.vtensor<[1,2,400,400],f32> +// CHECK: return %[[VAL_50]] : !torch.vtensor<[1,2,400,400],f32> +// CHECK: } +func.func @torch.aten.matmul.4d(%arg0: !torch.vtensor<[1,2,32,400],f32>, %arg1: !torch.vtensor<[1,2,400,32],f32>) -> !torch.vtensor<[1,2,400,400],f32> { + %0 = torch.aten.matmul %arg1, %arg0 : !torch.vtensor<[1,2,400,32],f32>, !torch.vtensor<[1,2,32,400],f32> -> !torch.vtensor<[1,2,400,400],f32> + return %0 : !torch.vtensor<[1,2,400,400],f32> +} + +// ----- + // CHECK-LABEL: func.func @torch.aten.mm$basic_strict( // CHECK-NOT: assert func.func @torch.aten.mm$basic_strict(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,2],f32> From ec5ca993c33bf02f15b4a4e693d515c390ef6e6a Mon Sep 17 00:00:00 2001 From: mingzheTerapines Date: Fri, 10 Oct 2025 09:43:22 +0800 Subject: [PATCH 3/3] simplify test --- test/Conversion/TorchToLinalg/basic.mlir | 68 ++++-------------------- 1 file changed, 10 insertions(+), 58 deletions(-) diff --git a/test/Conversion/TorchToLinalg/basic.mlir b/test/Conversion/TorchToLinalg/basic.mlir index 5895d19a2ec1..0caa6e0c6980 100644 --- a/test/Conversion/TorchToLinalg/basic.mlir +++ b/test/Conversion/TorchToLinalg/basic.mlir @@ -44,65 +44,17 @@ func.func @torch.aten.matmul.2d(%arg0: !torch.vtensor<[8,16],f32>, %arg1: !torch // ----- // CHECK-LABEL: func.func @torch.aten.matmul.4d -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1,2,32,400],f32>, -// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[1,2,400,32],f32>) -> !torch.vtensor<[1,2,400,400],f32> { -// CHECK: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[1,2,32,400],f32> -> tensor<1x2x32x400xf32> -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[1,2,400,32],f32> -> tensor<1x2x400x32xf32> -// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index -// CHECK: %[[VAL_3:.*]] = arith.constant 2 : index -// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index -// CHECK: %[[VAL_5:.*]] = arith.constant 2 : index -// CHECK: %[[VAL_6:.*]] = arith.constant 2 : index -// CHECK: %[[VAL_7:.*]] = arith.constant 0 : index -// CHECK: %[[VAL_8:.*]] = arith.constant 1 : index -// CHECK: %[[VAL_9:.*]] = arith.constant 0 : index -// CHECK: %[[VAL_10:.*]] = arith.constant 1 : index -// CHECK: %[[VAL_11:.*]] = arith.constant 1 : index -// CHECK: %[[VAL_12:.*]] = arith.constant 2 : index -// CHECK: %[[VAL_13:.*]] = arith.constant 400 : index -// CHECK: %[[VAL_14:.*]] = arith.constant 3 : index -// CHECK: %[[VAL_15:.*]] = arith.constant 32 : index -// CHECK: %[[VAL_16:.*]] = arith.constant 2 : index -// CHECK: %[[VAL_17:.*]] = arith.constant 32 : index -// CHECK: %[[VAL_18:.*]] = arith.constant 3 : index -// CHECK: %[[VAL_19:.*]] = arith.constant 400 : index -// CHECK: %[[VAL_20:.*]] = arith.constant 32 : i64 -// CHECK: %[[VAL_21:.*]] = arith.constant 32 : i64 -// CHECK: %[[VAL_22:.*]] = arith.cmpi eq, %[[VAL_20]], %[[VAL_21]] : i64 -// CHECK: cf.assert %[[VAL_22]], "mismatching contracting dimension" -// CHECK: %[[VAL_23:.*]] = arith.constant 1 : i64 -// CHECK: %[[VAL_24:.*]] = arith.constant 1 : i64 -// CHECK: %[[VAL_25:.*]] = arith.constant 2 : i64 -// CHECK: %[[VAL_26:.*]] = arith.constant 2 : i64 -// CHECK: %[[VAL_27:.*]] = arith.constant 400 : i64 -// CHECK: %[[VAL_28:.*]] = arith.constant 32 : i64 -// CHECK: %[[VAL_29:.*]] = arith.constant 32 : i64 -// CHECK: %[[VAL_30:.*]] = arith.constant 400 : i64 -// CHECK: %[[VAL_31:.*]] = arith.constant 0 : i64 -// CHECK: %[[VAL_32:.*]] = arith.constant 0 : index -// CHECK: %[[VAL_33:.*]] = arith.constant 1 : index -// CHECK: %[[VAL_34:.*]] = tensor.empty() : tensor<1x2x400x32xf32> -// CHECK: %[[VAL_35:.*]] = tensor.cast %[[VAL_1]] : tensor<1x2x400x32xf32> to tensor<1x2x400x32xf32> -// CHECK: %[[VAL_36:.*]] = arith.constant 0 : i64 -// CHECK: %[[VAL_37:.*]] = arith.constant 0 : index -// CHECK: %[[VAL_38:.*]] = arith.constant 1 : index -// CHECK: %[[VAL_39:.*]] = tensor.empty() : tensor<1x2x32x400xf32> -// CHECK: %[[VAL_40:.*]] = tensor.cast %[[VAL_0]] : tensor<1x2x32x400xf32> to tensor<1x2x32x400xf32> -// CHECK: %[[VAL_41:.*]] = tensor.collapse_shape %[[VAL_35]] {{\[\[}}0, 1], [2], [3]] : tensor<1x2x400x32xf32> into tensor<2x400x32xf32> -// CHECK: %[[VAL_42:.*]] = tensor.collapse_shape %[[VAL_40]] {{\[\[}}0, 1], [2], [3]] : tensor<1x2x32x400xf32> into tensor<2x32x400xf32> -// CHECK: %[[VAL_43:.*]] = arith.constant 2 : index -// CHECK: %[[VAL_44:.*]] = tensor.empty() : tensor<2x400x400xf32> -// CHECK: %[[VAL_45:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[VAL_46:.*]] = linalg.fill ins(%[[VAL_45]] : f32) outs(%[[VAL_44]] : tensor<2x400x400xf32>) -> tensor<2x400x400xf32> -// CHECK: %[[VAL_47:.*]] = linalg.batch_matmul ins(%[[VAL_41]], %[[VAL_42]] : tensor<2x400x32xf32>, tensor<2x32x400xf32>) outs(%[[VAL_46]] : tensor<2x400x400xf32>) -> tensor<2x400x400xf32> -// CHECK: %[[VAL_48:.*]] = tensor.expand_shape %[[VAL_47]] {{\[\[}}0, 1], [2], [3]] output_shape [1, 2, 400, 400] : tensor<2x400x400xf32> into tensor<1x2x400x400xf32> -// CHECK: %[[VAL_49:.*]] = tensor.cast %[[VAL_48]] : tensor<1x2x400x400xf32> to tensor<1x2x400x400xf32> -// CHECK: %[[VAL_50:.*]] = torch_c.from_builtin_tensor %[[VAL_49]] : tensor<1x2x400x400xf32> -> !torch.vtensor<[1,2,400,400],f32> -// CHECK: return %[[VAL_50]] : !torch.vtensor<[1,2,400,400],f32> -// CHECK: } +// CHECK-DAG: %[[LHS:.+]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[1,2,32,400],f32> -> tensor<1x2x32x400xf32> +// CHECK-DAG: %[[RHS:.+]] = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[1,2,400,32],f32> -> tensor<1x2x400x32xf32> +// CHECK-DAG: %[[LHS_CAST:.*]] = tensor.cast %[[LHS]] : tensor<1x2x32x400xf32> to tensor<1x2x32x400xf32> +// CHECK-DAG: %[[RHS_CAST:.*]] = tensor.cast %[[RHS]] : tensor<1x2x400x32xf32> to tensor<1x2x400x32xf32> +// CHECK-DAG: %[[COLLAPSED_LHS:.+]] = tensor.collapse_shape %[[LHS_CAST]] {{\[\[}}0, 1], [2], [3]] : tensor<1x2x32x400xf32> into tensor<2x32x400xf32> +// CHECK-DAG: %[[COLLAPSED_RHS:.+]] = tensor.collapse_shape %[[RHS_CAST]] {{\[\[}}0, 1], [2], [3]] : tensor<1x2x400x32xf32> into tensor<2x400x32xf32> +// CHECK: %[[MATMUL:.+]] = linalg.batch_matmul ins(%[[COLLAPSED_RHS]], %[[COLLAPSED_LHS]] : tensor<2x400x32xf32>, tensor<2x32x400xf32>) outs(%{{.*}} : tensor<2x400x400xf32>) -> tensor<2x400x400xf32> +// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[MATMUL]] {{\[\[}}0, 1], [2], [3]] output_shape [1, 2, 400, 400] : tensor<2x400x400xf32> into tensor<1x2x400x400xf32> func.func @torch.aten.matmul.4d(%arg0: !torch.vtensor<[1,2,32,400],f32>, %arg1: !torch.vtensor<[1,2,400,32],f32>) -> !torch.vtensor<[1,2,400,400],f32> { - %0 = torch.aten.matmul %arg1, %arg0 : !torch.vtensor<[1,2,400,32],f32>, !torch.vtensor<[1,2,32,400],f32> -> !torch.vtensor<[1,2,400,400],f32> - return %0 : !torch.vtensor<[1,2,400,400],f32> + %0 = torch.aten.matmul %arg1, %arg0 : !torch.vtensor<[1,2,400,32],f32>, !torch.vtensor<[1,2,32,400],f32> -> !torch.vtensor<[1,2,400,400],f32> + return %0 : !torch.vtensor<[1,2,400,400],f32> } // -----