diff --git a/include/triton-shared/Conversion/TritonToUnstructured/AtomicOpsConverter.h b/include/triton-shared/Conversion/TritonToUnstructured/AtomicOpsConverter.h new file mode 100644 index 0000000..f2b5be2 --- /dev/null +++ b/include/triton-shared/Conversion/TritonToUnstructured/AtomicOpsConverter.h @@ -0,0 +1,175 @@ +//===- AtomicOpsConverter.h -----------------------------------------------===// +// +// Atomic op lowering patterns for the triton-shared pipeline. +// +// Pipeline placement +// ------------------ +// Canonicalizers (Phase-1): +// Registered in TritonToStructured or as a standalone pre-pass. +// These rewrites run BEFORE type conversion so tt.ptr types are still intact. +// +// ScalarAtomicRMWCanonicalizer – normalise single-element tensor masks +// ScalarAtomicCASCanonicalizer – same for CAS +// AtomicMaxMinCanonicalizer – insert type-promotion casts for MAX/MIN +// +// Converters (Phase-2): +// Registered in TritonToUnstructured (or UnstructuredToMemref). +// These run AFTER TritonTypeConverter has rewritten tt.ptr → memref. +// +// AtomicRMWConverter – tt.atomic_rmw → load + arith-op + store +// AtomicCASConverter – tt.atomic_cas → load + cmpi/cmpf + scf.if + +// store +// +// Both converters emit *software* (non-hardware) atomic sequences that are +// correct for single-core or UB-local execution. AND/OR/XOR are always +// software; FADD/ADD/XCHG/MAX/MIN/UMAX/UMIN follow the same path. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace mlir { +namespace triton { +namespace shared { + +//===----------------------------------------------------------------------===// +// Phase-1 canonicalizers (run before type conversion, on tt.ptr IR) +//===----------------------------------------------------------------------===// + +/// Normalise a scalar AtomicRMWOp whose mask is a rank-1 tensor<1xi1> into an +/// op carrying a scalar i1 mask, so the Phase-2 converter only sees i1 masks. +/// +/// tt.atomic_rmw ..., %ptr, %val, %mask_tensor (mask : tensor<1xi1>) +/// → +/// %idx = arith.constant 0 +/// %m = tensor.extract %mask_tensor[%idx] +/// tt.atomic_rmw ..., %ptr, %val, %m (mask : i1) +class ScalarAtomicRMWCanonicalizer + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(triton::AtomicRMWOp op, + PatternRewriter &rewriter) const override; +}; + +/// Same normalisation for AtomicCASOp. +class ScalarAtomicCASCanonicalizer + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(triton::AtomicCASOp op, + PatternRewriter &rewriter) const override; +}; + +/// Insert arith extension/truncation casts so that the `val` operand of a +/// MAX/MIN AtomicRMWOp has the same type as the pointee type. +/// Example: fmax on f32-ptr with f16 val → insert arith.extf before the op. +class AtomicMaxMinCanonicalizer : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(triton::AtomicRMWOp op, + PatternRewriter &rewriter) const override; +}; + +//===----------------------------------------------------------------------===// +// Phase-2 converters (run after TritonTypeConverter: tt.ptr → memref) +//===----------------------------------------------------------------------===// + +/// Lower triton::AtomicRMWOp to a software read-modify-write sequence. +/// +/// ── Scalar case (ptr already converted to memref) ────────────────── +/// +/// %c0 = arith.constant 0 : index +/// %old = memref.load %ptr[%c0] +/// %new = arith.{op} %old, %val +/// memref.store %new, %ptr[%c0] ← may be inside scf.if when masked +/// // replace op result with %old +/// +/// ── Tensor case (ptr already converted to memref) ─────────────── +/// +/// %result_buf = memref.alloc() : memref +/// linalg.generic ins(%ptr_memref, %val_memref [, %mask_memref]) +/// outs(%ptr_memref, %result_buf) +/// { +/// ^bb0(%ptr_elem, %val_elem, [%mask_elem,] %ptr_out, %res_init): +/// %new = arith.{op} %ptr_elem, %val_elem +/// // if mask: yield (%new, %ptr_elem) else yield (%ptr_elem, %ptr_elem) +/// linalg.yield %selected_new, %ptr_elem +/// } +/// %result = bufferization.to_tensor %result_buf +/// +class AtomicRMWConverter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(triton::AtomicRMWOp op, + PatternRewriter &rewriter) const override; + +private: + /// Build the scalar arith binary op for the given rmwOp. + /// lhs = old value loaded from memory; rhs = atomic operand. + Value buildBinaryOp(OpBuilder &b, Location loc, triton::RMWOp rmwOp, + Type elemTy, Value lhs, Value rhs) const; + + static bool isSplatTrue(Value mask); + static bool isSplatFalse(Value mask); +}; + +/// Lower triton::AtomicCASOp to a software compare-and-swap. +/// +/// ── Scalar case ────────────────────────────────────────────────────────── +/// +/// %c0 = arith.constant 0 : index +/// %old = memref.load %ptr[%c0] +/// %eq = arith.cmpi eq, %old, %cmp (arith.cmpf oeq for float) +/// scf.if %eq { +/// memref.store %val, %ptr[%c0] +/// } +/// // replace op result with %old +/// +class AtomicCASConverter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(triton::AtomicCASOp op, + PatternRewriter &rewriter) const override; +}; +// explicit AtomicCASConverter(MLIRContext *ctx) +// : OpConversionPattern(ctx) {} +// +// LogicalResult matchAndRewrite(triton::AtomicCASOp op, OpAdaptor adaptor, +// ConversionPatternRewriter &rewriter) const +// override; +//}; + +//===----------------------------------------------------------------------===// +// Registration helpers +//===----------------------------------------------------------------------===// + +/// Populate Phase-1 (pre-conversion) canonicalization patterns. +/// Intended to be called from populateCanonicalizationPatterns() in the +/// TritonToStructured pass or a dedicated pre-pass. +inline void +populateAtomicCanonicalizationPatterns(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + +/// Populate Phase-2 (post-conversion) conversion patterns. +/// Intended to be called from populateConversionPatterns() in the +/// TritonToUnstructured or UnstructuredToMemref pass. +inline void populateAtomicConversionPatterns(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + +} // namespace shared +} // namespace triton +} // namespace mlir diff --git a/lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp b/lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp index 4c8ce0c..11773d0 100644 --- a/lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp +++ b/lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp @@ -18,6 +18,7 @@ #include "triton-shared/Analysis/OpFoldResultUtils.h" #include "triton-shared/AnalysisStructured/PtrAnalysis.h" #include "triton-shared/Conversion/TritonToStructured/TritonToStructured.h" +#include "triton-shared/Conversion/TritonToUnstructured/AtomicOpsConverter.h" #include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" #include "triton/Dialect/Triton/IR/Dialect.h" @@ -331,6 +332,11 @@ class TritonToStructuredPass op.emitWarning("Rewriting GetStructuredStateOp failed."); } }); + RewritePatternSet canonPatterns(&getContext()); + mlir::triton::shared::populateAtomicCanonicalizationPatterns(canonPatterns); + if (failed(applyPatternsGreedily(moduleOp, std::move(canonPatterns)))) { + moduleOp->emitWarning("AtomicOps canonicalization failed"); + } } }; } // namespace diff --git a/lib/Conversion/TritonToUnstructured/AtomicOpsConverter.cpp b/lib/Conversion/TritonToUnstructured/AtomicOpsConverter.cpp new file mode 100644 index 0000000..f1b8b26 --- /dev/null +++ b/lib/Conversion/TritonToUnstructured/AtomicOpsConverter.cpp @@ -0,0 +1,460 @@ +//===- AtomicOpsConverter.cpp ---------------------------------------------===// +// +// See AtomicOpsConverter.h for design notes and pipeline placement. +// +//===----------------------------------------------------------------------===// + +#include "triton-shared/Conversion/TritonToUnstructured/AtomicOpsConverter.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "atomic-ops-converter" + +using namespace mlir; +using namespace mlir::triton::shared; + +//===----------------------------------------------------------------------===// +// Internal helpers +//===----------------------------------------------------------------------===// + +namespace { + +/// Return true iff `v` is an arith.constant with a DenseElementsAttr that is a +/// all-`expected` boolean splat. +static bool isBoolSplat(Value v, bool expected) { + if (!v) + return false; + auto *def = v.getDefiningOp(); + if (!def) + return false; + auto constOp = dyn_cast(def); + if (!constOp) + return false; + auto dense = dyn_cast(constOp.getValue()); + if (!dense || !dense.isSplat()) + return false; + // Accept both i1 tensors and 0-d i1. + auto iTy = dyn_cast(dense.getType().getElementType()); + if (!iTy || iTy.getWidth() != 1) + return false; + return dense.getSplatValue() == expected; +} + +/// Extract the pointee type from a tt.ptr or tensor>. +static Type getPointeeType(Type ptrOrTensorOfPtr) { + if (auto ptrTy = dyn_cast(ptrOrTensorOfPtr)) + return ptrTy.getPointeeType(); + if (auto tensorTy = dyn_cast(ptrOrTensorOfPtr)) + if (auto ptrElem = dyn_cast(tensorTy.getElementType())) + return ptrElem.getPointeeType(); + return {}; +} + +/// Ensure `val` is a memref with `shape` and `elemTy`. +/// If it is still a ranked tensor, emit a bufferization.to_memref. +static Value ensureMemRef(OpBuilder &b, Location loc, Value val, + ArrayRef shape, Type elemTy) { + if (isa(val.getType())) + return val; + auto mrt = MemRefType::get(shape, elemTy); + return b.create(loc, mrt, val); +} + +} // namespace + +//===----------------------------------------------------------------------===// +// AtomicRMWConverter – helpers +//===----------------------------------------------------------------------===// + +bool AtomicRMWConverter::isSplatTrue(Value mask) { + return isBoolSplat(mask, true); +} +bool AtomicRMWConverter::isSplatFalse(Value mask) { + return isBoolSplat(mask, false); +} + +Value AtomicRMWConverter::buildBinaryOp(OpBuilder &b, Location loc, + triton::RMWOp kind, Type elemTy, + Value lhs, Value rhs) const { + switch (kind) { + case triton::RMWOp::FADD: + return b.create(loc, lhs, rhs); + case triton::RMWOp::ADD: + return b.create(loc, lhs, rhs); + case triton::RMWOp::XOR: + return b.create(loc, lhs, rhs); + case triton::RMWOp::OR: + return b.create(loc, lhs, rhs); + case triton::RMWOp::AND: + return b.create(loc, lhs, rhs); + case triton::RMWOp::MAX: + return isa(elemTy) + ? b.create(loc, lhs, rhs).getResult() + : b.create(loc, lhs, rhs).getResult(); + case triton::RMWOp::MIN: + return isa(elemTy) + ? b.create(loc, lhs, rhs).getResult() + : b.create(loc, lhs, rhs).getResult(); + case triton::RMWOp::UMAX: + return b.create(loc, lhs, rhs); + case triton::RMWOp::UMIN: + return b.create(loc, lhs, rhs); + case triton::RMWOp::XCHG: + return rhs; // exchange: new value is simply rhs + default: + break; + } + llvm_unreachable("unhandled RMWOp in buildBinaryOp"); +} + +//===----------------------------------------------------------------------===// +// AtomicRMWConverter::matchAndRewrite +//===----------------------------------------------------------------------===// + +LogicalResult +AtomicRMWConverter::matchAndRewrite(triton::AtomicRMWOp op, + PatternRewriter &rewriter) const { + Location loc = op.getLoc(); + // ── 直接从原始 op 拿 ptr,自己做类型转换 ───────────────────────────────── + // 不使用 adaptor.getPtr(),因为框架的 TypeConverter 可能还没把函数参数 + // 类型转换好(tt.func 签名未变)。我们手动把 !tt.ptr → memref。 + Value rawPtr = op.getPtr(); + Value val = op.getVal(); + Value mask = op.getMask(); + auto rmwKind = op.getAtomicRmwOp(); + Type resultTy = op.getResult().getType(); + Value ptr; + Type elemTy; + + // ── Fast-path: mask 全 false → op 是 no-op ─────────────────────────────── + if (mask && isSplatFalse(mask)) { + Value zero = rewriter.create( + loc, elemTy, rewriter.getZeroAttr(elemTy)); + rewriter.replaceOp(op, zero); + return success(); + } + + bool isTensor = isa(resultTy); + + // ========================================================================= + // (A) Scalar path + // ========================================================================= + if (!isTensor) { + if (auto ptrTy = dyn_cast(rawPtr.getType())) { + elemTy = ptrTy.getPointeeType(); + auto memrefTy = MemRefType::get({ShapedType::kDynamic}, elemTy); + ptr = rewriter + .create(loc, TypeRange{memrefTy}, + ValueRange{rawPtr}) + .getResult(0); + } else if (auto memrefTy = dyn_cast(rawPtr.getType())) { + ptr = rawPtr; + elemTy = memrefTy.getElementType(); + } else { + return rewriter.notifyMatchFailure(op, "scalar: unexpected ptr type"); + } + + Value c0 = rewriter.create(loc, 0); + Value old = rewriter.create(loc, ptr, ValueRange{c0}); + + bool unconditional = (!mask || isSplatTrue(mask)); + + auto emitStore = [&](OpBuilder &b, Location l) { + Value newVal = buildBinaryOp(b, l, rmwKind, elemTy, old, val); + b.create(l, newVal, ptr, ValueRange{c0}); + }; + + if (unconditional) { + emitStore(rewriter, loc); + } else { + Value maskScalar = mask; + // mask 如果还是 tt.ptr 转来的 memref 的话,load 出来 + if (isa(mask.getType())) { + Value mc0 = rewriter.create(loc, 0); + maskScalar = + rewriter.create(loc, mask, ValueRange{mc0}); + } + rewriter.create(loc, maskScalar, + [&](OpBuilder &b, Location l) { + emitStore(b, l); + b.create(l); + }); + } + + rewriter.replaceOp(op, old); + return success(); + } + + // ========================================================================= + // (B) Tensor path + // ========================================================================= + auto tensorResultTy = cast(resultTy); + ArrayRef shape = tensorResultTy.getShape(); + unsigned rank = shape.size(); + + // tensor ptr: tensor> → 每个元素都是 ptr,需要逐元素处理 + // 在 linalg.generic 里处理,ptr memref 通过 cast 获得 + Value ptrMemRef; + if (auto tensorPtrTy = dyn_cast(rawPtr.getType())) { + auto ptrElemTy = + dyn_cast(tensorPtrTy.getElementType()); + if (!ptrElemTy) + return rewriter.notifyMatchFailure(op, + "tensor ptr element is not tt.ptr"); + elemTy = ptrElemTy.getPointeeType(); + auto memrefTy = MemRefType::get(shape, elemTy); + ptrMemRef = rewriter + .create( + loc, TypeRange{memrefTy}, ValueRange{rawPtr}) + .getResult(0); + } else if (auto memrefTy = dyn_cast(rawPtr.getType())) { + ptrMemRef = rawPtr; + elemTy = memrefTy.getElementType(); + } else { + return rewriter.notifyMatchFailure(op, "unexpected tensor ptr type"); + } + + // val memref + Value valMR; + if (isa(val.getType())) { + auto valMemRefTy = MemRefType::get(shape, elemTy); + valMR = rewriter + .create(loc, TypeRange{valMemRefTy}, + ValueRange{val}) + .getResult(0); + } else { + valMR = val; + } + + // result buffer(存 old values) + Value resultBuf = + rewriter.create(loc, MemRefType::get(shape, elemTy)); + + auto idMap = rewriter.getMultiDimIdentityMap(rank); + SmallVector iters(rank, utils::IteratorType::parallel); + + bool needMask = mask && !isSplatTrue(mask); + Value maskMR; + if (needMask) { + auto maskMemRefTy = MemRefType::get(shape, rewriter.getI1Type()); + if (isa(mask.getType())) { + maskMR = rewriter + .create( + loc, TypeRange{maskMemRefTy}, ValueRange{mask}) + .getResult(0); + } else { + maskMR = mask; + } + } + + SmallVector inputs = {ptrMemRef, valMR}; + SmallVector outputs = {ptrMemRef, resultBuf}; + SmallVector maps = {idMap, idMap, idMap, idMap}; + if (needMask) { + inputs.push_back(maskMR); + maps.push_back(idMap); + } + + auto genericOp = rewriter.create( + loc, TypeRange{}, inputs, outputs, maps, iters, + [&](OpBuilder &b, Location l, ValueRange args) { + Value ptrElem = args[0]; + Value valElem = args[1]; + Value computed = buildBinaryOp(b, l, rmwKind, elemTy, ptrElem, valElem); + Value writeBack = computed; + if (needMask) { + Value maskElem = args[2]; + writeBack = b.create(l, maskElem, computed, ptrElem); + } + b.create(l, ValueRange{writeBack, ptrElem}); + }); + + MLIRContext *context = rewriter.getContext(); + const StringRef genericAtomicRMW = "GenericAtomicRMW"; + const StringRef memSemantic = "MemSemantic"; + const StringRef memSyncScope = "MemSyncScope"; + + auto rmwKindStr = [](triton::RMWOp kind) -> StringRef { + switch (kind) { + case triton::RMWOp::FADD: return "fadd"; + case triton::RMWOp::ADD: return "add"; + case triton::RMWOp::XCHG: return "xchg"; + case triton::RMWOp::AND: return "and"; + case triton::RMWOp::OR: return "or"; + case triton::RMWOp::XOR: return "xor"; + case triton::RMWOp::MAX: return "max"; + case triton::RMWOp::MIN: return "min"; + case triton::RMWOp::UMAX: return "umax"; + case triton::RMWOp::UMIN: return "umin"; + default: return "unknown"; + } + }; + + genericOp->setAttr(genericAtomicRMW, + mlir::StringAttr::get(context, rmwKindStr(rmwKind))); + genericOp->setAttr(memSemantic, + rewriter.getStringAttr(stringifyEnum(op.getSem()))); + genericOp->setAttr(memSyncScope, + rewriter.getStringAttr(stringifyEnum(op.getScope()))); + genericOp->setAttr("Software", rewriter.getUnitAttr()); + + Value resultTensor = rewriter.create( + loc, tensorResultTy, resultBuf); + rewriter.replaceOp(op, resultTensor); + return success(); +} + +//===----------------------------------------------------------------------===// +// AtomicCASConverter::matchAndRewrite +//===----------------------------------------------------------------------===// + +LogicalResult +AtomicCASConverter::matchAndRewrite(triton::AtomicCASOp op, + PatternRewriter &rewriter) const { + Location loc = op.getLoc(); + Value ptr = op.getPtr(); + Value cmp = op.getCmp(); + Value val = op.getVal(); + + auto ptrMRT = dyn_cast(ptr.getType()); + if (!ptrMRT) + return op.emitOpError( + "[AtomicCASConverter] ptr must be MemRefType after type conversion"); + + Type elemTy = ptrMRT.getElementType(); + Value c0 = rewriter.create(loc, 0); + + // Load old value. + Value old = rewriter.create(loc, ptr, ValueRange{c0}); + + // Compare old == cmp. + Value eq; + if (isa(elemTy)) { + eq = rewriter.create(loc, arith::CmpFPredicate::OEQ, old, + cmp); + } else { + eq = + rewriter.create(loc, arith::CmpIPredicate::eq, old, cmp); + } + + // Conditionally store new value. + rewriter.create(loc, eq, [&](OpBuilder &b, Location l) { + b.create(l, val, ptr, ValueRange{c0}); + b.create(l); + }); + + // Result is the old value before the potential swap. + rewriter.replaceOp(op, old); + return success(); +} + +//===----------------------------------------------------------------------===// +// Phase-1 canonicalizers +//===----------------------------------------------------------------------===// + +// ── ScalarAtomicRMWCanonicalizer ────────────────────────────────────────── + +LogicalResult +ScalarAtomicRMWCanonicalizer::matchAndRewrite(triton::AtomicRMWOp op, + PatternRewriter &rewriter) const { + // Only touch scalar (non-tensor) ptr ops. + if (isa(op.getPtr().getType())) + return failure(); + + Value mask = op.getMask(); + if (!mask) + return failure(); + + // Only rewrite if mask is a rank-1 tensor<1xi1>. + auto maskTy = dyn_cast(mask.getType()); + if (!maskTy || maskTy.getRank() != 1 || maskTy.getDimSize(0) != 1) + return failure(); + + Location loc = op.getLoc(); + Value c0 = rewriter.create(loc, 0); + Value scalarMask = + rewriter.create(loc, mask, ValueRange{c0}); + + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), op.getAtomicRmwOpAttr(), op.getPtr(), + op.getVal(), scalarMask, op.getSemAttr(), op.getScopeAttr()); + return success(); +} + +// ── ScalarAtomicCASCanonicalizer ───────────────────────────────────────── + +LogicalResult +ScalarAtomicCASCanonicalizer::matchAndRewrite(triton::AtomicCASOp op, + PatternRewriter &rewriter) const { + // AtomicCASOp does not have a mask operand in current Triton; nothing to do. + // This canonicalizer is kept as a placeholder for future extension. + return failure(); +} + +// ── AtomicMaxMinCanonicalizer ───────────────────────────────────────────── + +LogicalResult +AtomicMaxMinCanonicalizer::matchAndRewrite(triton::AtomicRMWOp op, + PatternRewriter &rewriter) const { + auto kind = op.getAtomicRmwOp(); + bool isMaxMin = (kind == triton::RMWOp::MAX || kind == triton::RMWOp::MIN || + kind == triton::RMWOp::UMAX || kind == triton::RMWOp::UMIN); + if (!isMaxMin) + return failure(); + + // Determine the pointee type from the (still-triton) ptr operand. + Type pointeeTy = getPointeeType(op.getPtr().getType()); + if (!pointeeTy) + return failure(); + + Value val = op.getVal(); + Type valTy = getElementTypeOrSelf(val.getType()); + + if (valTy == pointeeTy) + return failure(); // already matching – nothing to do + + Location loc = op.getLoc(); + + // Compute the destination type (preserve tensor wrapper if present). + Type dstTy; + if (auto valTensorTy = dyn_cast(val.getType())) { + dstTy = RankedTensorType::get(valTensorTy.getShape(), pointeeTy); + } else { + dstTy = pointeeTy; + } + + Value casted; + if (isa(pointeeTy) && isa(valTy)) { + unsigned dstW = cast(pointeeTy).getWidth(); + unsigned srcW = cast(valTy).getWidth(); + if (dstW > srcW) + casted = rewriter.create(loc, dstTy, val); + else + casted = rewriter.create(loc, dstTy, val); + } else if (isa(pointeeTy) && isa(valTy)) { + unsigned dstW = cast(pointeeTy).getWidth(); + unsigned srcW = cast(valTy).getWidth(); + if (dstW > srcW) + casted = rewriter.create(loc, dstTy, val); + else + casted = rewriter.create(loc, dstTy, val); + } else { + // Mixed float/int – unsupported here. + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), op.getAtomicRmwOpAttr(), op.getPtr(), + casted, op.getMask(), op.getSemAttr(), op.getScopeAttr()); + return success(); +} diff --git a/lib/Conversion/TritonToUnstructured/CMakeLists.txt b/lib/Conversion/TritonToUnstructured/CMakeLists.txt index d665048..2399d2b 100644 --- a/lib/Conversion/TritonToUnstructured/CMakeLists.txt +++ b/lib/Conversion/TritonToUnstructured/CMakeLists.txt @@ -1,5 +1,6 @@ add_triton_library(TritonToUnstructured TritonToUnstructuredPass.cpp + AtomicOpsConverter.cpp DEPENDS TritonStructuredTableGen @@ -20,4 +21,8 @@ add_triton_library(TritonToUnstructured TritonSharedAnalysisStructured TritonStructuredIR TritonSharedUtils + MLIRMemRefDialect + MLIRLinalgDialect + MLIRBufferizationDialect + MLIRLinalgTransforms ) diff --git a/lib/Conversion/TritonToUnstructured/TritonToUnstructuredPass.cpp b/lib/Conversion/TritonToUnstructured/TritonToUnstructuredPass.cpp index 0a9ce4d..9343495 100644 --- a/lib/Conversion/TritonToUnstructured/TritonToUnstructuredPass.cpp +++ b/lib/Conversion/TritonToUnstructured/TritonToUnstructuredPass.cpp @@ -136,6 +136,9 @@ // approach, we will only sign-extend where necessary. #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" @@ -151,6 +154,7 @@ #include "mlir/Transforms/Passes.h" #include "triton-shared/Analysis/OpFoldResultUtils.h" #include "triton-shared/AnalysisStructured/PtrAnalysis.h" +#include "triton-shared/Conversion/TritonToUnstructured/AtomicOpsConverter.h" #include "triton-shared/Conversion/TritonToUnstructured/TritonToUnstructured.h" #include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" #include "triton-shared/Utils/Utils.h" @@ -224,7 +228,8 @@ class TritonToUnstructuredPass registry .insert(); + tts::TritonStructuredDialect, memref::MemRefDialect, + linalg::LinalgDialect, bufferization::BufferizationDialect>(); } struct PtrOffset { @@ -383,7 +388,8 @@ class TritonToUnstructuredPass return success(); }) .Case([&](Operation *op) { + tts::MakeTensorPtrOp, triton::AtomicRMWOp, + triton::AtomicCASOp>([&](Operation *op) { // Special case: // We do not want to create "unstructured tensor pointer" into // tts.make_tptr if the base pointer is directly from the @@ -527,6 +533,32 @@ class TritonToUnstructuredPass store->erase(); return success(); }) + .Case([&](triton::AtomicRMWOp atomicOp) { + // AtomicRMWOp 的 ptr 操作数已经被 processUnstructuredPtrs + // 分析过, 我们只需把它的 ptr 替换成 base ptr(偏移量累积到 + // offset 里), 然后交给 Phase-2 的 AtomicRMWConverter 处理。 + // + // 策略:用 tts::GatherOp 的 base/offset 拆分方式同理—— + // 把 atomicRmw 的 ptr 操作数替换成 offsetMap 里的 base ptr, + // 并在前面插入 tt.addptr 把 offset 加回去,让 converter + // 看到正确的 ptr。 + auto offsetInfo = offsetMap.at(atomicOp.getPtr()); + // 用累积的 offset 重建一个 addptr,作为 atomic 的新 ptr。 + // (converter 会把这个 addptr 的结果再转成 memref) + auto newPtr = b.create( + loc, atomicOp.getPtr().getType(), offsetInfo.ptr, + offsetInfo.offset); + atomicOp.getPtrMutable().set(newPtr); + return success(); + }) + .Case([&](triton::AtomicCASOp casOp) { + auto offsetInfo = offsetMap.at(casOp.getPtr()); + auto newPtr = b.create( + loc, casOp.getPtr().getType(), offsetInfo.ptr, + offsetInfo.offset); + casOp.getPtrMutable().set(newPtr); + return success(); + }) .Case([&](auto makeTensorPtr) { // For block pointers, the base could come from a sequence of @@ -606,6 +638,30 @@ class TritonToUnstructuredPass if (failed(runPipeline(pm, getOperation()))) { signalPassFailure(); } + + // ── Step 3: 原子操作转换 + // 使用 applyPatternsAndFoldGreedily 而非 applyPartialConversion, + // 避免框架对 tt.func 签名做合法性检查导致失败。 + auto moduleOp = getOperation(); + RewritePatternSet patterns(&getContext()); + mlir::triton::shared::populateAtomicConversionPatterns(patterns); + + if (failed(applyPatternsGreedily(moduleOp, std::move(patterns)))) { + moduleOp->emitWarning("AtomicOps rewrite did not converge"); + } + + // 检查是否还有残留的原子 op + bool hasResidual = false; + moduleOp.walk([&](triton::AtomicRMWOp op) { + op->emitError("tt.atomic_rmw was not lowered"); + hasResidual = true; + }); + moduleOp.walk([&](triton::AtomicCASOp op) { + op->emitError("tt.atomic_cas was not lowered"); + hasResidual = true; + }); + if (hasResidual) + signalPassFailure(); } }; } // namespace diff --git a/test/Conversion/TritonToUnstructured/test_atomic_add.mlir b/test/Conversion/TritonToUnstructured/test_atomic_add.mlir new file mode 100644 index 0000000..ff99600 --- /dev/null +++ b/test/Conversion/TritonToUnstructured/test_atomic_add.mlir @@ -0,0 +1,82 @@ +// RUN: triton-shared-opt \ +// RUN: --triton-to-structured \ +// RUN: --triton-to-unstructured \ +// RUN: %s | FileCheck %s + +// ============================================================ +// (1) Scalar FADD – all-true mask → 无条件执行,无 scf.if +// ============================================================ +// CHECK-LABEL: func @scalar_fadd_true_mask +// CHECK-NOT: tt.atomic_rmw +// CHECK: memref.load +// CHECK: arith.addf +// CHECK: memref.store +// CHECK-NOT: scf.if +tt.func @scalar_fadd_true_mask(%ptr: !tt.ptr, %val: f32) -> f32 { + %mask = arith.constant true + %old = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %val, %mask + : (!tt.ptr, f32, i1) -> f32 + tt.return %old : f32 +} + +// ============================================================ +// (2) Scalar FADD – all-false mask → op 被消除,不产生访存 +// ============================================================ +// CHECK-LABEL: func @scalar_fadd_false_mask +// CHECK-NOT: tt.atomic_rmw +// CHECK-NOT: memref.load +// CHECK-NOT: memref.store +tt.func @scalar_fadd_false_mask(%ptr: !tt.ptr, %val: f32) -> f32 { + %mask = arith.constant false + %old = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %val, %mask + : (!tt.ptr, f32, i1) -> f32 + tt.return %old : f32 +} + +// ============================================================ +// (3) Scalar FADD – runtime mask → 有 scf.if 保护 store +// ============================================================ +// CHECK-LABEL: func @scalar_fadd_runtime_mask +// CHECK-NOT: tt.atomic_rmw +// CHECK: memref.load +// CHECK: scf.if +// CHECK: arith.addf +// CHECK: memref.store +tt.func @scalar_fadd_runtime_mask(%ptr: !tt.ptr, %val: f32, %mask: i1) -> f32 { + %old = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %val, %mask + : (!tt.ptr, f32, i1) -> f32 + tt.return %old : f32 +} + +// ============================================================ +// (4) Tensor FADD – all-true mask → linalg.generic +// ============================================================ +// CHECK-LABEL: func @tensor_fadd +// CHECK-NOT: tt.atomic_rmw +// CHECK: linalg.generic +// CHECK: arith.addf +// CHECK: linalg.yield +tt.func @tensor_fadd(%ptr: tensor<16x!tt.ptr>, + %val: tensor<16xf32>) -> tensor<16xf32> { + %mask = arith.constant dense : tensor<16xi1> + %old = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %val, %mask + : (tensor<16x!tt.ptr>, tensor<16xf32>, tensor<16xi1>) -> tensor<16xf32> + tt.return %old : tensor<16xf32> +} + +// ============================================================ +// (5) Tensor FADD – partial mask → linalg.generic + arith.select +// ============================================================ +// CHECK-LABEL: func @tensor_fadd_partial_mask +// CHECK-NOT: tt.atomic_rmw +// CHECK: linalg.generic +// CHECK: arith.addf +// CHECK: arith.select +// CHECK: linalg.yield +tt.func @tensor_fadd_partial_mask(%ptr: tensor<16x!tt.ptr>, + %val: tensor<16xf32>, + %mask: tensor<16xi1>) -> tensor<16xf32> { + %old = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %val, %mask + : (tensor<16x!tt.ptr>, tensor<16xf32>, tensor<16xi1>) -> tensor<16xf32> + tt.return %old : tensor<16xf32> +}