2
2
3
3
gpu.module @create_nd_tdesc {
4
4
// CHECK-LABEL: gpu.func @create_nd_tdesc
5
- // CHECK-SAME: %[[ARG0:.*]]: memref<8x16xf32 , 1>, %[[ARG1:.*]]: ui64,
5
+ // CHECK-SAME: %[[ARG0:.*]]: memref<16x32xf32 , 1>, %[[ARG1:.*]]: ui64,
6
6
// CHECK-SAME: %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index, %[[ARG6:.*]]: index, %[[ARG7:.*]]: index
7
- gpu.func @create_nd_tdesc (%src: memref <8 x 16 x f32 , 1 >, %ptr: ui64 , %shape1: index , %shape2: index ,
7
+ gpu.func @create_nd_tdesc (%src: memref <16 x 32 x f32 , 1 >, %ptr: ui64 , %shape1: index , %shape2: index ,
8
8
%stride1: index , %stride2: index , %offset1: index , %offset2: index ) kernel {
9
9
// CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index
10
10
// CHECK: %[[BASE_ADDR:.*]] = arith.index_castui %[[VAR0]] : index to i64
@@ -23,35 +23,35 @@ gpu.module @create_nd_tdesc {
23
23
%ptr_tdesc = xegpu.create_nd_tdesc %ptr , shape :[%shape1 , %shape2 ], strides :[%stride1 , %stride2 ]
24
24
: ui64 -> !xegpu.tensor_desc <8 x16 xf32 >
25
25
26
- // CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<8x16xf32 , 1> to memref<8x16xf32 >
27
- %srcce = memref.memory_space_cast %src : memref <8 x 16 x f32 , 1 > to memref <8 x 16 x f32 >
26
+ // CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<16x32xf32 , 1> to memref<16x32xf32 >
27
+ %srcce = memref.memory_space_cast %src : memref <16 x 32 x f32 , 1 > to memref <16 x 32 x f32 >
28
28
29
29
// CHECK: %[[CST_1:.*]] = arith.constant dense<0> : vector<8xi32>
30
- // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<8x16xf32 > -> index
30
+ // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32 > -> index
31
31
// CHECK: %[[OFFSET_W2:.*]] = arith.constant 0 : i32
32
32
// CHECK: %[[OFFSET_H2:.*]] = arith.constant 0 : i32
33
+ // CHECK: %[[C32_I64:.*]] = arith.constant 32 : i64
34
+ // CHECK: %[[SHAPE_W2:.*]] = arith.trunci %[[C32_I64]] : i64 to i32
33
35
// CHECK: %[[C16_I64:.*]] = arith.constant 16 : i64
34
- // CHECK: %[[SHAPE_W2:.*]] = arith.trunci %c16_i64 : i64 to i32
35
- // CHECK: %[[C8_I64:.*]] = arith.constant 8 : i64
36
- // CHECK: %[[SHAPE_H2:.*]] = arith.trunci %c8_i64 : i64 to i32
36
+ // CHECK: %[[SHAPE_H2:.*]] = arith.trunci %[[C16_I64]] : i64 to i32
37
37
// CHECK: %[[BASE_ADDR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
38
38
// CHECK: %[[VAR14:.*]] = vector.bitcast %[[CST_1]] : vector<8xi32> to vector<4xi64>
39
39
// CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2]], %[[VAR14]] [0] : i64 into vector<4xi64>
40
40
// CHECK: %[[VAR16:.*]] = vector.bitcast %[[VAR15]] : vector<4xi64> to vector<8xi32>
41
41
// CHECK: %[[VAR17:.*]] = vector.insert %[[SHAPE_W2]], %[[VAR16]] [2] : i32 into vector<8xi32>
42
42
// CHECK: %[[VAR18:.*]] = vector.insert %[[SHAPE_H2]], %[[VAR17]] [3] : i32 into vector<8xi32>
43
43
// CHECK: %[[VAR19:.*]] = vector.insert %[[OFFSET_W2]], %[[VAR18]] [4] : i32 into vector<8xi32>
44
- // CHECK: %[[VAR20 :.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32>
45
- %src_tdesc = xegpu.create_nd_tdesc %srcce : memref <8 x 16 x f32 > -> !xegpu.tensor_desc <8 x16 xf32 >
44
+ // CHECK: %[[PAYLOAD :.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32>
45
+ %src_tdesc = xegpu.create_nd_tdesc %srcce : memref <16 x 32 x f32 > -> !xegpu.tensor_desc <8 x16 xf32 >
46
46
47
47
// CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32>
48
- // CHECK: %[[INTPTR_2:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<8x16xf32 > -> index
48
+ // CHECK: %[[INTPTR_2:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32 > -> index
49
49
// CHECK: %[[OFFSET_W3:.*]] = arith.index_cast %[[ARG7]] : index to i32
50
50
// CHECK: %[[OFFSET_H3:.*]] = arith.index_cast %[[ARG6]] : index to i32
51
- // CHECK: %[[C16_I64_6 :.*]] = arith.constant 16 : i64
52
- // CHECK: %[[SHAPE_W3:.*]] = arith.trunci %[[C16_I64_6 ]] : i64 to i32
53
- // CHECK: %[[C8_I64_7 :.*]] = arith.constant 8 : i64
54
- // CHECK: %[[SHAPE_H3:.*]] = arith.trunci %[[C8_I64_7 ]] : i64 to i32
51
+ // CHECK: %[[C32_I64_6 :.*]] = arith.constant 32 : i64
52
+ // CHECK: %[[SHAPE_W3:.*]] = arith.trunci %[[C32_I64_6 ]] : i64 to i32
53
+ // CHECK: %[[C16_I64_7 :.*]] = arith.constant 16 : i64
54
+ // CHECK: %[[SHAPE_H3:.*]] = arith.trunci %[[C16_I64_7 ]] : i64 to i32
55
55
// CHECK: %[[BASE_ADDR3:.*]] = arith.index_castui %[[INTPTR_2]] : index to i64
56
56
// CHECK: %[[VAR26:.*]] = vector.bitcast %[[CST_4]] : vector<8xi32> to vector<4xi64>
57
57
// CHECK: %[[VAR27:.*]] = vector.insert %[[BASE_ADDR3]], %[[VAR26]] [0] : i64 into vector<4xi64>
@@ -60,7 +60,21 @@ gpu.module @create_nd_tdesc {
60
60
// CHECK: %[[VAR30:.*]] = vector.insert %[[SHAPE_H3]], %[[VAR29]] [3] : i32 into vector<8xi32>
61
61
// CHECK: %[[VAR31:.*]] = vector.insert %[[OFFSET_W3]], %[[VAR30]] [4] : i32 into vector<8xi32>
62
62
// CHECK: %[[VAR32:.*]] = vector.insert %[[OFFSET_H3]], %[[VAR31]] [5] : i32 into vector<8xi32>
63
- %src_tdesc2 = xegpu.create_nd_tdesc %srcce [%offset1 , %offset2 ] : memref <8 x16 xf32 > -> !xegpu.tensor_desc <8 x16 xf32 >
63
+ %src_tdesc2 = xegpu.create_nd_tdesc %srcce [%offset1 , %offset2 ] : memref <16 x32 xf32 > -> !xegpu.tensor_desc <8 x16 xf32 >
64
+
65
+ // CHECK: %[[C8:.*]] = arith.constant 8 : index
66
+ %c8 = arith.constant 8 : index
67
+ // CHECK: %[[C16:.*]] = arith.constant 16 : index
68
+ %c16 = arith.constant 16 : index
69
+ // CHECK: %[[VAR33:.*]] = arith.index_cast %[[C8]] : index to i32
70
+ // CHECK: %[[OLD_OFFSET_H:.*]] = vector.extract %[[PAYLOAD]][5] : i32 from vector<8xi32>
71
+ // CHECK: %[[NEW_OFFSET_H:.*]] = arith.addi %[[OLD_OFFSET_H]], %[[VAR33]] : i32
72
+ // CHECK: %[[NEW_PAYLOAD:.*]] = vector.insert %[[NEW_OFFSET_H]], %[[PAYLOAD]] [5] : i32 into vector<8xi32>
73
+ // CHECK: %[[VAR37:.*]] = arith.index_cast %[[C16]] : index to i32
74
+ // CHECK: %[[OLD_OFFSET_H:.*]] = vector.extract %[[NEW_PAYLOAD]][4] : i32 from vector<8xi32>
75
+ // CHECK: %[[NEW_OFFSET_H:.*]] = arith.addi %[[OLD_OFFSET_H]], %[[VAR37]] : i32
76
+ // CHECK: %[[FINAL_PAYLOAD:.*]] = vector.insert %[[NEW_OFFSET_H]], %[[NEW_PAYLOAD]] [4] : i32 into vector<8xi32>
77
+ %updated_tdesc = xegpu.update_nd_offset %src_tdesc , [%c8 , %c16 ] : !xegpu.tensor_desc <8 x16 xf32 >
64
78
gpu.return
65
79
}
66
80
}
0 commit comments