Skip to content

Conversation

nbpatel
Copy link
Contributor

@nbpatel nbpatel commented Sep 30, 2025

The PR adds a change to use the layouts from the operands since store doesn't have a result

@llvmbot
Copy link
Member

llvmbot commented Sep 30, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Nishant Patel (nbpatel)

Changes

The PR adds a change to use the layouts from the operands since store doesn't have a result


Full diff: https://github.com/llvm/llvm-project/pull/161447.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp (+14-10)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir (+12-7)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 9413a9296b184..784e5d68ce885 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -824,7 +824,7 @@ struct WgToSgStoreScatterOpWithOffset
       return failure();
 
     xegpu::DistributeLayoutAttr layout =
-        xegpu::getDistributeLayoutAttr(op.getValue());
+        xegpu::getDistributeLayoutAttr(op.getOperand(0));
     if (!layout || !layout.isForWorkgroup())
       return failure();
 
@@ -844,12 +844,19 @@ struct WgToSgStoreScatterOpWithOffset
     auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize);
     for (auto [val, offs, mask] : llvm::zip(
              adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
-      xegpu::StoreScatterOp::create(rewriter, loc, val, op.getDest(), offs,
-                                    mask, chunkSizeAttr, op.getL1HintAttr(),
-                                    op.getL2HintAttr(), op.getL3HintAttr());
+      auto store = xegpu::StoreScatterOp::create(
+          rewriter, loc, val, op.getDest(), offs, mask, chunkSizeAttr,
+          op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
       // Update the layout attribute to drop sg_layout and sg_data.
-      if (auto newLayout = layout.dropSgLayoutAndData())
-        op->setAttr("layout", newLayout);
+      if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
+          !layout.getEffectiveInstDataAsInt().empty()) {
+        for (OpOperand &operand : store->getOpOperands()) {
+          // Skip for operand one (memref)
+          if (operand.getOperandNumber() == 1)
+            continue;
+          xegpu::setDistributeLayoutAttr(operand, layout.dropSgLayoutAndData());
+        }
+      }
     }
     rewriter.eraseOp(op);
     return success();
@@ -1247,10 +1254,7 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
 
   target.addDynamicallyLegalOp<xegpu::StoreScatterOp>(
       [=](xegpu::StoreScatterOp op) -> bool {
-        // Check if the layout attribute is present on the result.
-        auto layout = op->getAttrOfType<xegpu::LayoutAttr>("layout");
-        if (!layout)
-          return true;
+        auto layout = xegpu::getDistributeLayoutAttr(op.getOperand(0));
         return isLegal(layout);
       });
 
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 03c63861705d9..38392fd10b742 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -282,15 +282,20 @@ gpu.module @test_distribution {
   // CHECK-LABEL: @store_scatter
   // CHECK-SAME: %[[ARG0:.*]]: memref<256xf16>
   gpu.func @store_scatter(%dest : memref<256xf16>) {
-    // CHECK: %[[VAL:.*]] = arith.constant dense<2.550000e+01> : vector<8xf16>
-    // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xindex>
-    // CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<8xi1>
+    // CHECK: %[[VAL:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8]>} dense<2.550000e+01> : vector<8xf16>
+    // CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8]>} dense<0> : vector<8xindex>
+    // CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8]>} dense<true> : vector<8xi1>
     // CHECK: xegpu.store %[[VAL]], %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}>
+    // CHECK-SAME: {layout_operand_0 = #xegpu.layout<inst_data = [8]>, layout_operand_2 = #xegpu.layout<inst_data = [8]>,
+    // CHECK-SAME: layout_operand_3 = #xegpu.layout<inst_data = [8]>}
     // CHECK-SAME: : vector<8xf16>, memref<256xf16>, vector<8xindex>, vector<8xi1>
-    %val = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<25.5> : vector<256xf16>
-    %offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<0> : vector<256xindex>
-    %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<1> : vector<256xi1>
-    xegpu.store %val, %dest[%offset], %mask {chunk_size = 1, layout = #xegpu.layout<sg_layout = [32], sg_data = [8]>, l1_hint = #xegpu.cache_hint<cached>}
+    %val = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>} dense<25.5> : vector<256xf16>
+    %offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>} dense<0> : vector<256xindex>
+    %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>} dense<1> : vector<256xi1>
+    xegpu.store %val, %dest[%offset], %mask {chunk_size = 1, layout_operand_0 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>, 
+                                             layout_operand_2 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>,
+                                             layout_operand_3 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>,
+                                             l1_hint = #xegpu.cache_hint<cached>}
       : vector<256xf16>, memref<256xf16>, vector<256xindex>, vector<256xi1>
     gpu.return
   }

return failure();

xegpu::DistributeLayoutAttr layout =
xegpu::getDistributeLayoutAttr(op.getValue());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should also work. why this change is needed?

i.e, these two calls should return the same layout.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it should work as well, but I changed it just to be consistent to call getter and setter on operand/s for store..

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants