@@ -43,6 +43,22 @@ func.func @torch.aten.matmul.2d(%arg0: !torch.vtensor<[8,16],f32>, %arg1: !torch
43
43
44
44
// -----
45
45
46
+ // CHECK-LABEL: func.func @torch.aten.matmul.4d
47
+ // CHECK-DAG: %[[LHS:.+]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[1,2,32,400],f32> -> tensor<1x2x32x400xf32>
48
+ // CHECK-DAG: %[[RHS:.+]] = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[1,2,400,32],f32> -> tensor<1x2x400x32xf32>
49
+ // CHECK-DAG: %[[LHS_CAST:.*]] = tensor.cast %[[LHS]] : tensor<1x2x32x400xf32> to tensor<1x2x32x400xf32>
50
+ // CHECK-DAG: %[[RHS_CAST:.*]] = tensor.cast %[[RHS]] : tensor<1x2x400x32xf32> to tensor<1x2x400x32xf32>
51
+ // CHECK-DAG: %[[COLLAPSED_LHS:.+]] = tensor.collapse_shape %[[LHS_CAST]] {{\[\[}}0, 1], [2], [3]] : tensor<1x2x32x400xf32> into tensor<2x32x400xf32>
52
+ // CHECK-DAG: %[[COLLAPSED_RHS:.+]] = tensor.collapse_shape %[[RHS_CAST]] {{\[\[}}0, 1], [2], [3]] : tensor<1x2x400x32xf32> into tensor<2x400x32xf32>
53
+ // CHECK: %[[MATMUL:.+]] = linalg.batch_matmul ins(%[[COLLAPSED_RHS]], %[[COLLAPSED_LHS]] : tensor<2x400x32xf32>, tensor<2x32x400xf32>) outs(%{{.*}} : tensor<2x400x400xf32>) -> tensor<2x400x400xf32>
54
+ // CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[MATMUL]] {{\[\[}}0, 1], [2], [3]] output_shape [1, 2, 400, 400] : tensor<2x400x400xf32> into tensor<1x2x400x400xf32>
55
+ func.func @torch.aten.matmul.4d (%arg0: !torch.vtensor <[1 ,2 ,32 ,400 ],f32 >, %arg1: !torch.vtensor <[1 ,2 ,400 ,32 ],f32 >) -> !torch.vtensor <[1 ,2 ,400 ,400 ],f32 > {
56
+ %0 = torch.aten.matmul %arg1 , %arg0 : !torch.vtensor <[1 ,2 ,400 ,32 ],f32 >, !torch.vtensor <[1 ,2 ,32 ,400 ],f32 > -> !torch.vtensor <[1 ,2 ,400 ,400 ],f32 >
57
+ return %0 : !torch.vtensor <[1 ,2 ,400 ,400 ],f32 >
58
+ }
59
+
60
+ // -----
61
+
46
62
// CHECK-LABEL: func.func @torch.aten.mm$basic_strict(
47
63
// CHECK-NOT: assert
48
64
func.func @torch.aten.mm$basic_strict (%arg0: !torch.vtensor <[?,?],f32 >, %arg1: !torch.vtensor <[?,?],f32 >) -> !torch.vtensor <[?,2 ],f32 >
0 commit comments