Skip to content

Commit d943efb

Browse files
authored
[MLIR] Fix issues with XeGPU to XeVM pass. (#155946)
Fixes two issue with XeGPU to XeVM pass 1. xegpu.update_nd_offset op lower generated incorrect code sequence 2. xegpu.store_nd did not lower single element vector
1 parent 288442f commit d943efb

File tree

2 files changed

+45
-26
lines changed

2 files changed

+45
-26
lines changed

mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -259,23 +259,23 @@ class UpdateNdOffsetToXeVMPattern
259259
// Only 2D offsets are supported for now.
260260
if (mixedOffsets.size() != 2)
261261
return rewriter.notifyMatchFailure(op, "Expected 2D offsets.");
262-
auto tdesc = adaptor.getTensorDesc();
262+
auto payload = adaptor.getTensorDesc();
263263
// Utility for updating payload offset values from op fold result.
264264
auto updateOffset = [&](unsigned idx, int payloadPos) -> Value {
265265
Value offset =
266266
getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[idx]);
267267
offset = getValueOrCreateCastToIndexLike(rewriter, loc,
268268
rewriter.getI32Type(), offset);
269269
Value oldOffset =
270-
vector::ExtractOp::create(rewriter, loc, tdesc, payloadPos);
270+
vector::ExtractOp::create(rewriter, loc, payload, payloadPos);
271271
Value newOffset = arith::AddIOp::create(rewriter, loc, oldOffset, offset);
272-
return vector::InsertOp::create(rewriter, loc, newOffset, tdesc,
272+
return vector::InsertOp::create(rewriter, loc, newOffset, payload,
273273
payloadPos);
274274
};
275275
// Update offsets in the payload.
276-
auto val = updateOffset(0, static_cast<int>(NdTdescOffset::TensorOffsetH));
277-
val = updateOffset(1, static_cast<int>(NdTdescOffset::TensorOffsetW));
278-
rewriter.replaceOp(op, val);
276+
payload = updateOffset(0, static_cast<int>(NdTdescOffset::TensorOffsetH));
277+
payload = updateOffset(1, static_cast<int>(NdTdescOffset::TensorOffsetW));
278+
rewriter.replaceOp(op, payload);
279279
return success();
280280
}
281281
};
@@ -354,18 +354,23 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
354354
auto tileH = tdescTy.getDimSize(0);
355355
int32_t vblocks = tdescTy.getArrayLength();
356356
if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
357-
VectorType srcVecTy = dyn_cast<VectorType>(adaptor.getValue().getType());
357+
Value src = adaptor.getValue();
358+
// If store value is a scalar, get value from op instead of adaptor.
359+
// Adaptor might have optimized away single element vector
360+
if (src.getType().isIntOrFloat()) {
361+
src = op.getValue();
362+
}
363+
VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
358364
if (!srcVecTy)
359365
return rewriter.notifyMatchFailure(
360366
op, "Expected store value to be a vector type.");
361-
auto storeCacheControl =
362-
translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
363-
Value src = adaptor.getValue();
364367
// Get flat vector type of integer type with matching element bit size.
365368
VectorType newSrcVecTy =
366369
encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
367370
if (srcVecTy != newSrcVecTy)
368371
src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
372+
auto storeCacheControl =
373+
translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
369374
xevm::BlockStore2dOp::create(
370375
rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
371376
offsetH, elemBitSize, tileW, tileH, src,

mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
gpu.module @create_nd_tdesc {
44
// CHECK-LABEL: gpu.func @create_nd_tdesc
5-
// CHECK-SAME: %[[ARG0:.*]]: memref<8x16xf32, 1>, %[[ARG1:.*]]: ui64,
5+
// CHECK-SAME: %[[ARG0:.*]]: memref<16x32xf32, 1>, %[[ARG1:.*]]: ui64,
66
// CHECK-SAME: %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index, %[[ARG6:.*]]: index, %[[ARG7:.*]]: index
7-
gpu.func @create_nd_tdesc(%src: memref<8x16xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index,
7+
gpu.func @create_nd_tdesc(%src: memref<16x32xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index,
88
%stride1: index, %stride2: index, %offset1: index, %offset2: index) kernel {
99
// CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index
1010
// CHECK: %[[BASE_ADDR:.*]] = arith.index_castui %[[VAR0]] : index to i64
@@ -23,35 +23,35 @@ gpu.module @create_nd_tdesc {
2323
%ptr_tdesc = xegpu.create_nd_tdesc %ptr, shape:[%shape1, %shape2], strides:[%stride1, %stride2]
2424
: ui64 -> !xegpu.tensor_desc<8x16xf32>
2525

26-
// CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<8x16xf32, 1> to memref<8x16xf32>
27-
%srcce = memref.memory_space_cast %src : memref<8x16xf32, 1> to memref<8x16xf32>
26+
// CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<16x32xf32, 1> to memref<16x32xf32>
27+
%srcce = memref.memory_space_cast %src : memref<16x32xf32, 1> to memref<16x32xf32>
2828

2929
// CHECK: %[[CST_1:.*]] = arith.constant dense<0> : vector<8xi32>
30-
// CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<8x16xf32> -> index
30+
// CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32> -> index
3131
// CHECK: %[[OFFSET_W2:.*]] = arith.constant 0 : i32
3232
// CHECK: %[[OFFSET_H2:.*]] = arith.constant 0 : i32
33+
// CHECK: %[[C32_I64:.*]] = arith.constant 32 : i64
34+
// CHECK: %[[SHAPE_W2:.*]] = arith.trunci %[[C32_I64]] : i64 to i32
3335
// CHECK: %[[C16_I64:.*]] = arith.constant 16 : i64
34-
// CHECK: %[[SHAPE_W2:.*]] = arith.trunci %c16_i64 : i64 to i32
35-
// CHECK: %[[C8_I64:.*]] = arith.constant 8 : i64
36-
// CHECK: %[[SHAPE_H2:.*]] = arith.trunci %c8_i64 : i64 to i32
36+
// CHECK: %[[SHAPE_H2:.*]] = arith.trunci %[[C16_I64]] : i64 to i32
3737
// CHECK: %[[BASE_ADDR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
3838
// CHECK: %[[VAR14:.*]] = vector.bitcast %[[CST_1]] : vector<8xi32> to vector<4xi64>
3939
// CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2]], %[[VAR14]] [0] : i64 into vector<4xi64>
4040
// CHECK: %[[VAR16:.*]] = vector.bitcast %[[VAR15]] : vector<4xi64> to vector<8xi32>
4141
// CHECK: %[[VAR17:.*]] = vector.insert %[[SHAPE_W2]], %[[VAR16]] [2] : i32 into vector<8xi32>
4242
// CHECK: %[[VAR18:.*]] = vector.insert %[[SHAPE_H2]], %[[VAR17]] [3] : i32 into vector<8xi32>
4343
// CHECK: %[[VAR19:.*]] = vector.insert %[[OFFSET_W2]], %[[VAR18]] [4] : i32 into vector<8xi32>
44-
// CHECK: %[[VAR20:.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32>
45-
%src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
44+
// CHECK: %[[PAYLOAD:.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32>
45+
%src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
4646

4747
// CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32>
48-
// CHECK: %[[INTPTR_2:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<8x16xf32> -> index
48+
// CHECK: %[[INTPTR_2:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32> -> index
4949
// CHECK: %[[OFFSET_W3:.*]] = arith.index_cast %[[ARG7]] : index to i32
5050
// CHECK: %[[OFFSET_H3:.*]] = arith.index_cast %[[ARG6]] : index to i32
51-
// CHECK: %[[C16_I64_6:.*]] = arith.constant 16 : i64
52-
// CHECK: %[[SHAPE_W3:.*]] = arith.trunci %[[C16_I64_6]] : i64 to i32
53-
// CHECK: %[[C8_I64_7:.*]] = arith.constant 8 : i64
54-
// CHECK: %[[SHAPE_H3:.*]] = arith.trunci %[[C8_I64_7]] : i64 to i32
51+
// CHECK: %[[C32_I64_6:.*]] = arith.constant 32 : i64
52+
// CHECK: %[[SHAPE_W3:.*]] = arith.trunci %[[C32_I64_6]] : i64 to i32
53+
// CHECK: %[[C16_I64_7:.*]] = arith.constant 16 : i64
54+
// CHECK: %[[SHAPE_H3:.*]] = arith.trunci %[[C16_I64_7]] : i64 to i32
5555
// CHECK: %[[BASE_ADDR3:.*]] = arith.index_castui %[[INTPTR_2]] : index to i64
5656
// CHECK: %[[VAR26:.*]] = vector.bitcast %[[CST_4]] : vector<8xi32> to vector<4xi64>
5757
// CHECK: %[[VAR27:.*]] = vector.insert %[[BASE_ADDR3]], %[[VAR26]] [0] : i64 into vector<4xi64>
@@ -60,7 +60,21 @@ gpu.module @create_nd_tdesc {
6060
// CHECK: %[[VAR30:.*]] = vector.insert %[[SHAPE_H3]], %[[VAR29]] [3] : i32 into vector<8xi32>
6161
// CHECK: %[[VAR31:.*]] = vector.insert %[[OFFSET_W3]], %[[VAR30]] [4] : i32 into vector<8xi32>
6262
// CHECK: %[[VAR32:.*]] = vector.insert %[[OFFSET_H3]], %[[VAR31]] [5] : i32 into vector<8xi32>
63-
%src_tdesc2 = xegpu.create_nd_tdesc %srcce[%offset1, %offset2] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
63+
%src_tdesc2 = xegpu.create_nd_tdesc %srcce[%offset1, %offset2] : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
64+
65+
// CHECK: %[[C8:.*]] = arith.constant 8 : index
66+
%c8 = arith.constant 8 : index
67+
// CHECK: %[[C16:.*]] = arith.constant 16 : index
68+
%c16 = arith.constant 16 : index
69+
// CHECK: %[[VAR33:.*]] = arith.index_cast %[[C8]] : index to i32
70+
// CHECK: %[[OLD_OFFSET_H:.*]] = vector.extract %[[PAYLOAD]][5] : i32 from vector<8xi32>
71+
// CHECK: %[[NEW_OFFSET_H:.*]] = arith.addi %[[OLD_OFFSET_H]], %[[VAR33]] : i32
72+
// CHECK: %[[NEW_PAYLOAD:.*]] = vector.insert %[[NEW_OFFSET_H]], %[[PAYLOAD]] [5] : i32 into vector<8xi32>
73+
// CHECK: %[[VAR37:.*]] = arith.index_cast %[[C16]] : index to i32
74+
// CHECK: %[[OLD_OFFSET_H:.*]] = vector.extract %[[NEW_PAYLOAD]][4] : i32 from vector<8xi32>
75+
// CHECK: %[[NEW_OFFSET_H:.*]] = arith.addi %[[OLD_OFFSET_H]], %[[VAR37]] : i32
76+
// CHECK: %[[FINAL_PAYLOAD:.*]] = vector.insert %[[NEW_OFFSET_H]], %[[NEW_PAYLOAD]] [4] : i32 into vector<8xi32>
77+
%updated_tdesc = xegpu.update_nd_offset %src_tdesc, [%c8, %c16] : !xegpu.tensor_desc<8x16xf32>
6478
gpu.return
6579
}
6680
}

0 commit comments

Comments
 (0)