Skip to content

Commit ec5ca99

Browse files
simplify test
1 parent d35aadf commit ec5ca99

File tree

1 file changed

+10
-58
lines changed

1 file changed

+10
-58
lines changed

test/Conversion/TorchToLinalg/basic.mlir

Lines changed: 10 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -44,65 +44,17 @@ func.func @torch.aten.matmul.2d(%arg0: !torch.vtensor<[8,16],f32>, %arg1: !torch
4444
// -----
4545

4646
// CHECK-LABEL: func.func @torch.aten.matmul.4d
47-
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1,2,32,400],f32>,
48-
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[1,2,400,32],f32>) -> !torch.vtensor<[1,2,400,400],f32> {
49-
// CHECK: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[1,2,32,400],f32> -> tensor<1x2x32x400xf32>
50-
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[1,2,400,32],f32> -> tensor<1x2x400x32xf32>
51-
// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index
52-
// CHECK: %[[VAL_3:.*]] = arith.constant 2 : index
53-
// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
54-
// CHECK: %[[VAL_5:.*]] = arith.constant 2 : index
55-
// CHECK: %[[VAL_6:.*]] = arith.constant 2 : index
56-
// CHECK: %[[VAL_7:.*]] = arith.constant 0 : index
57-
// CHECK: %[[VAL_8:.*]] = arith.constant 1 : index
58-
// CHECK: %[[VAL_9:.*]] = arith.constant 0 : index
59-
// CHECK: %[[VAL_10:.*]] = arith.constant 1 : index
60-
// CHECK: %[[VAL_11:.*]] = arith.constant 1 : index
61-
// CHECK: %[[VAL_12:.*]] = arith.constant 2 : index
62-
// CHECK: %[[VAL_13:.*]] = arith.constant 400 : index
63-
// CHECK: %[[VAL_14:.*]] = arith.constant 3 : index
64-
// CHECK: %[[VAL_15:.*]] = arith.constant 32 : index
65-
// CHECK: %[[VAL_16:.*]] = arith.constant 2 : index
66-
// CHECK: %[[VAL_17:.*]] = arith.constant 32 : index
67-
// CHECK: %[[VAL_18:.*]] = arith.constant 3 : index
68-
// CHECK: %[[VAL_19:.*]] = arith.constant 400 : index
69-
// CHECK: %[[VAL_20:.*]] = arith.constant 32 : i64
70-
// CHECK: %[[VAL_21:.*]] = arith.constant 32 : i64
71-
// CHECK: %[[VAL_22:.*]] = arith.cmpi eq, %[[VAL_20]], %[[VAL_21]] : i64
72-
// CHECK: cf.assert %[[VAL_22]], "mismatching contracting dimension"
73-
// CHECK: %[[VAL_23:.*]] = arith.constant 1 : i64
74-
// CHECK: %[[VAL_24:.*]] = arith.constant 1 : i64
75-
// CHECK: %[[VAL_25:.*]] = arith.constant 2 : i64
76-
// CHECK: %[[VAL_26:.*]] = arith.constant 2 : i64
77-
// CHECK: %[[VAL_27:.*]] = arith.constant 400 : i64
78-
// CHECK: %[[VAL_28:.*]] = arith.constant 32 : i64
79-
// CHECK: %[[VAL_29:.*]] = arith.constant 32 : i64
80-
// CHECK: %[[VAL_30:.*]] = arith.constant 400 : i64
81-
// CHECK: %[[VAL_31:.*]] = arith.constant 0 : i64
82-
// CHECK: %[[VAL_32:.*]] = arith.constant 0 : index
83-
// CHECK: %[[VAL_33:.*]] = arith.constant 1 : index
84-
// CHECK: %[[VAL_34:.*]] = tensor.empty() : tensor<1x2x400x32xf32>
85-
// CHECK: %[[VAL_35:.*]] = tensor.cast %[[VAL_1]] : tensor<1x2x400x32xf32> to tensor<1x2x400x32xf32>
86-
// CHECK: %[[VAL_36:.*]] = arith.constant 0 : i64
87-
// CHECK: %[[VAL_37:.*]] = arith.constant 0 : index
88-
// CHECK: %[[VAL_38:.*]] = arith.constant 1 : index
89-
// CHECK: %[[VAL_39:.*]] = tensor.empty() : tensor<1x2x32x400xf32>
90-
// CHECK: %[[VAL_40:.*]] = tensor.cast %[[VAL_0]] : tensor<1x2x32x400xf32> to tensor<1x2x32x400xf32>
91-
// CHECK: %[[VAL_41:.*]] = tensor.collapse_shape %[[VAL_35]] {{\[\[}}0, 1], [2], [3]] : tensor<1x2x400x32xf32> into tensor<2x400x32xf32>
92-
// CHECK: %[[VAL_42:.*]] = tensor.collapse_shape %[[VAL_40]] {{\[\[}}0, 1], [2], [3]] : tensor<1x2x32x400xf32> into tensor<2x32x400xf32>
93-
// CHECK: %[[VAL_43:.*]] = arith.constant 2 : index
94-
// CHECK: %[[VAL_44:.*]] = tensor.empty() : tensor<2x400x400xf32>
95-
// CHECK: %[[VAL_45:.*]] = arith.constant 0.000000e+00 : f32
96-
// CHECK: %[[VAL_46:.*]] = linalg.fill ins(%[[VAL_45]] : f32) outs(%[[VAL_44]] : tensor<2x400x400xf32>) -> tensor<2x400x400xf32>
97-
// CHECK: %[[VAL_47:.*]] = linalg.batch_matmul ins(%[[VAL_41]], %[[VAL_42]] : tensor<2x400x32xf32>, tensor<2x32x400xf32>) outs(%[[VAL_46]] : tensor<2x400x400xf32>) -> tensor<2x400x400xf32>
98-
// CHECK: %[[VAL_48:.*]] = tensor.expand_shape %[[VAL_47]] {{\[\[}}0, 1], [2], [3]] output_shape [1, 2, 400, 400] : tensor<2x400x400xf32> into tensor<1x2x400x400xf32>
99-
// CHECK: %[[VAL_49:.*]] = tensor.cast %[[VAL_48]] : tensor<1x2x400x400xf32> to tensor<1x2x400x400xf32>
100-
// CHECK: %[[VAL_50:.*]] = torch_c.from_builtin_tensor %[[VAL_49]] : tensor<1x2x400x400xf32> -> !torch.vtensor<[1,2,400,400],f32>
101-
// CHECK: return %[[VAL_50]] : !torch.vtensor<[1,2,400,400],f32>
102-
// CHECK: }
47+
// CHECK-DAG: %[[LHS:.+]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[1,2,32,400],f32> -> tensor<1x2x32x400xf32>
48+
// CHECK-DAG: %[[RHS:.+]] = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[1,2,400,32],f32> -> tensor<1x2x400x32xf32>
49+
// CHECK-DAG: %[[LHS_CAST:.*]] = tensor.cast %[[LHS]] : tensor<1x2x32x400xf32> to tensor<1x2x32x400xf32>
50+
// CHECK-DAG: %[[RHS_CAST:.*]] = tensor.cast %[[RHS]] : tensor<1x2x400x32xf32> to tensor<1x2x400x32xf32>
51+
// CHECK-DAG: %[[COLLAPSED_LHS:.+]] = tensor.collapse_shape %[[LHS_CAST]] {{\[\[}}0, 1], [2], [3]] : tensor<1x2x32x400xf32> into tensor<2x32x400xf32>
52+
// CHECK-DAG: %[[COLLAPSED_RHS:.+]] = tensor.collapse_shape %[[RHS_CAST]] {{\[\[}}0, 1], [2], [3]] : tensor<1x2x400x32xf32> into tensor<2x400x32xf32>
53+
// CHECK: %[[MATMUL:.+]] = linalg.batch_matmul ins(%[[COLLAPSED_RHS]], %[[COLLAPSED_LHS]] : tensor<2x400x32xf32>, tensor<2x32x400xf32>) outs(%{{.*}} : tensor<2x400x400xf32>) -> tensor<2x400x400xf32>
54+
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[MATMUL]] {{\[\[}}0, 1], [2], [3]] output_shape [1, 2, 400, 400] : tensor<2x400x400xf32> into tensor<1x2x400x400xf32>
10355
func.func @torch.aten.matmul.4d(%arg0: !torch.vtensor<[1,2,32,400],f32>, %arg1: !torch.vtensor<[1,2,400,32],f32>) -> !torch.vtensor<[1,2,400,400],f32> {
104-
%0 = torch.aten.matmul %arg1, %arg0 : !torch.vtensor<[1,2,400,32],f32>, !torch.vtensor<[1,2,32,400],f32> -> !torch.vtensor<[1,2,400,400],f32>
105-
return %0 : !torch.vtensor<[1,2,400,400],f32>
56+
%0 = torch.aten.matmul %arg1, %arg0 : !torch.vtensor<[1,2,400,32],f32>, !torch.vtensor<[1,2,32,400],f32> -> !torch.vtensor<[1,2,400,400],f32>
57+
return %0 : !torch.vtensor<[1,2,400,400],f32>
10658
}
10759

10860
// -----

0 commit comments

Comments
 (0)