From 5a2ef29d93ede034b0e651bba032e78cd42c62f0 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Tue, 30 Sep 2025 21:10:00 +0000 Subject: [PATCH] Use operand layouts for store scatter --- .../Transforms/XeGPUWgToSgDistribute.cpp | 24 +++++++++++-------- .../XeGPU/xegpu-wg-to-sg-unify-ops.mlir | 19 +++++++++------ 2 files changed, 26 insertions(+), 17 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 9413a9296b184..784e5d68ce885 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -824,7 +824,7 @@ struct WgToSgStoreScatterOpWithOffset return failure(); xegpu::DistributeLayoutAttr layout = - xegpu::getDistributeLayoutAttr(op.getValue()); + xegpu::getDistributeLayoutAttr(op.getOperand(0)); if (!layout || !layout.isForWorkgroup()) return failure(); @@ -844,12 +844,19 @@ struct WgToSgStoreScatterOpWithOffset auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize); for (auto [val, offs, mask] : llvm::zip( adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) { - xegpu::StoreScatterOp::create(rewriter, loc, val, op.getDest(), offs, - mask, chunkSizeAttr, op.getL1HintAttr(), - op.getL2HintAttr(), op.getL3HintAttr()); + auto store = xegpu::StoreScatterOp::create( + rewriter, loc, val, op.getDest(), offs, mask, chunkSizeAttr, + op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); // Update the layout attribute to drop sg_layout and sg_data. - if (auto newLayout = layout.dropSgLayoutAndData()) - op->setAttr("layout", newLayout); + if (!layout.getEffectiveLaneLayoutAsInt().empty() || + !layout.getEffectiveInstDataAsInt().empty()) { + for (OpOperand &operand : store->getOpOperands()) { + // Skip for operand one (memref) + if (operand.getOperandNumber() == 1) + continue; + xegpu::setDistributeLayoutAttr(operand, layout.dropSgLayoutAndData()); + } + } } rewriter.eraseOp(op); return success(); @@ -1247,10 +1254,7 @@ void XeGPUWgToSgDistributePass::runOnOperation() { target.addDynamicallyLegalOp( [=](xegpu::StoreScatterOp op) -> bool { - // Check if the layout attribute is present on the result. - auto layout = op->getAttrOfType("layout"); - if (!layout) - return true; + auto layout = xegpu::getDistributeLayoutAttr(op.getOperand(0)); return isLegal(layout); }); diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir index 03c63861705d9..38392fd10b742 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir @@ -282,15 +282,20 @@ gpu.module @test_distribution { // CHECK-LABEL: @store_scatter // CHECK-SAME: %[[ARG0:.*]]: memref<256xf16> gpu.func @store_scatter(%dest : memref<256xf16>) { - // CHECK: %[[VAL:.*]] = arith.constant dense<2.550000e+01> : vector<8xf16> - // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xindex> - // CHECK: %[[MASK:.*]] = arith.constant dense : vector<8xi1> + // CHECK: %[[VAL:.*]] = arith.constant {layout_result_0 = #xegpu.layout} dense<2.550000e+01> : vector<8xf16> + // CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout} dense<0> : vector<8xindex> + // CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout} dense : vector<8xi1> // CHECK: xegpu.store %[[VAL]], %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint}> + // CHECK-SAME: {layout_operand_0 = #xegpu.layout, layout_operand_2 = #xegpu.layout, + // CHECK-SAME: layout_operand_3 = #xegpu.layout} // CHECK-SAME: : vector<8xf16>, memref<256xf16>, vector<8xindex>, vector<8xi1> - %val = arith.constant {layout_result_0 = #xegpu.layout} dense<25.5> : vector<256xf16> - %offset = arith.constant {layout_result_0 = #xegpu.layout} dense<0> : vector<256xindex> - %mask = arith.constant {layout_result_0 = #xegpu.layout} dense<1> : vector<256xi1> - xegpu.store %val, %dest[%offset], %mask {chunk_size = 1, layout = #xegpu.layout, l1_hint = #xegpu.cache_hint} + %val = arith.constant {layout_result_0 = #xegpu.layout} dense<25.5> : vector<256xf16> + %offset = arith.constant {layout_result_0 = #xegpu.layout} dense<0> : vector<256xindex> + %mask = arith.constant {layout_result_0 = #xegpu.layout} dense<1> : vector<256xi1> + xegpu.store %val, %dest[%offset], %mask {chunk_size = 1, layout_operand_0 = #xegpu.layout, + layout_operand_2 = #xegpu.layout, + layout_operand_3 = #xegpu.layout, + l1_hint = #xegpu.cache_hint} : vector<256xf16>, memref<256xf16>, vector<256xindex>, vector<256xi1> gpu.return }