Skip to content

Commit 18e6b7f

Browse files
[TorchToLinalg] Fix multi batch matmul conversion to Linalg (#4319)
Co-authored by: [email protected] Improve usage of static shape information to avoid unnecessary broadcast.
1 parent fe14bf7 commit 18e6b7f

File tree

2 files changed

+25
-5
lines changed

2 files changed

+25
-5
lines changed

lib/Conversion/TorchToLinalg/Linear.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -505,9 +505,11 @@ class ConvertAtenMatmulOp : public OpConversionPattern<AtenMatmulOp> {
505505

506506
// Broadcast the batch dimensions of both the matrices.
507507
Value broadcastedLhs, broadcastedRhs;
508-
// TODO: Improve usage of static shape information.
509-
SmallVector<int64_t> lhsTargetShape(lhsBroadcastToShape.size(),
510-
ShapedType::kDynamic);
508+
SmallVector<int64_t> lhsTargetShape =
509+
llvm::to_vector(llvm::map_range(lhsBroadcastToShape, [](Value v) {
510+
return getConstantIntValue(v).value_or(ShapedType::kDynamic);
511+
}));
512+
511513
auto lhsBroadcastType = RankedTensorType::get(
512514
lhsTargetShape, lhsType.getElementType(), lhsType.getEncoding());
513515
if (failed(torch_to_linalg::broadcastToGivenShape(
@@ -516,8 +518,10 @@ class ConvertAtenMatmulOp : public OpConversionPattern<AtenMatmulOp> {
516518
return rewriter.notifyMatchFailure(
517519
op, "unable to perform broadcast operation");
518520
}
519-
SmallVector<int64_t> rhsTargetShape(rhsBroadcastToShape.size(),
520-
ShapedType::kDynamic);
521+
SmallVector<int64_t> rhsTargetShape =
522+
llvm::to_vector(llvm::map_range(rhsBroadcastToShape, [](Value v) {
523+
return getConstantIntValue(v).value_or(ShapedType::kDynamic);
524+
}));
521525
auto rhsBroadcastType = RankedTensorType::get(
522526
rhsTargetShape, rhsType.getElementType(), rhsType.getEncoding());
523527
if (failed(torch_to_linalg::broadcastToGivenShape(

test/Conversion/TorchToLinalg/basic.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,22 @@ func.func @torch.aten.matmul.2d(%arg0: !torch.vtensor<[8,16],f32>, %arg1: !torch
4343

4444
// -----
4545

46+
// CHECK-LABEL: func.func @torch.aten.matmul.4d
47+
// CHECK-DAG: %[[LHS:.+]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[1,2,32,400],f32> -> tensor<1x2x32x400xf32>
48+
// CHECK-DAG: %[[RHS:.+]] = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[1,2,400,32],f32> -> tensor<1x2x400x32xf32>
49+
// CHECK-DAG: %[[LHS_CAST:.*]] = tensor.cast %[[LHS]] : tensor<1x2x32x400xf32> to tensor<1x2x32x400xf32>
50+
// CHECK-DAG: %[[RHS_CAST:.*]] = tensor.cast %[[RHS]] : tensor<1x2x400x32xf32> to tensor<1x2x400x32xf32>
51+
// CHECK-DAG: %[[COLLAPSED_LHS:.+]] = tensor.collapse_shape %[[LHS_CAST]] {{\[\[}}0, 1], [2], [3]] : tensor<1x2x32x400xf32> into tensor<2x32x400xf32>
52+
// CHECK-DAG: %[[COLLAPSED_RHS:.+]] = tensor.collapse_shape %[[RHS_CAST]] {{\[\[}}0, 1], [2], [3]] : tensor<1x2x400x32xf32> into tensor<2x400x32xf32>
53+
// CHECK: %[[MATMUL:.+]] = linalg.batch_matmul ins(%[[COLLAPSED_RHS]], %[[COLLAPSED_LHS]] : tensor<2x400x32xf32>, tensor<2x32x400xf32>) outs(%{{.*}} : tensor<2x400x400xf32>) -> tensor<2x400x400xf32>
54+
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[MATMUL]] {{\[\[}}0, 1], [2], [3]] output_shape [1, 2, 400, 400] : tensor<2x400x400xf32> into tensor<1x2x400x400xf32>
55+
func.func @torch.aten.matmul.4d(%arg0: !torch.vtensor<[1,2,32,400],f32>, %arg1: !torch.vtensor<[1,2,400,32],f32>) -> !torch.vtensor<[1,2,400,400],f32> {
56+
%0 = torch.aten.matmul %arg1, %arg0 : !torch.vtensor<[1,2,400,32],f32>, !torch.vtensor<[1,2,32,400],f32> -> !torch.vtensor<[1,2,400,400],f32>
57+
return %0 : !torch.vtensor<[1,2,400,400],f32>
58+
}
59+
60+
// -----
61+
4662
// CHECK-LABEL: func.func @torch.aten.mm$basic_strict(
4763
// CHECK-NOT: assert
4864
func.func @torch.aten.mm$basic_strict(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,2],f32>

0 commit comments

Comments
 (0)