Skip to content

Fix broken transpose match in OptimizeDotOperands #6819

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 78 additions & 51 deletions lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,67 +18,94 @@ namespace mlir::triton::gpu {

namespace {
// Given
// dot(convert(trans(src)) #dot_operand) ->
// dot(convert(trans(src)) #dot_operand) | dot(trans(src) #dot_operand) ->
// dot(convert(local_load(trans(alloc(src)))))
// change the encoding of the inner convert to a special, swizzled shared
// encoding.
class SwizzleShmemConvert : public OpRewritePattern<ConvertLayoutOp> {
class SwizzleShmemConvert : public OpRewritePattern<DotOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(ConvertLayoutOp cvtOp,
LogicalResult matchAndRewrite(DotOp dotOp,
PatternRewriter &rewriter) const override {
if (!cvtOp->hasOneUse() ||
!isa<triton::DotOp>(cvtOp->use_begin()->getOwner()))
return failure();
// Match outerCvt(trans(innerCvt(x))).
auto trans = cvtOp.getSrc().getDefiningOp<TransOp>();
if (!trans || trans.getOrder() != ArrayRef<int32_t>{1, 0})
return failure();
bool rewrite = false;

RankedTensorType srcTy = trans.getSrc().getType();
for (auto value : dotOp.getOperands()) {
if (!value.getDefiningOp()) {
continue;
}

if (auto srcCvt = trans.getSrc().getDefiningOp<ConvertLayoutOp>()) {
srcTy = srcCvt.getSrc().getType();
}
RankedTensorType sharedLoadTy = cvtOp.getType();
auto cvtEncoding =
dyn_cast<DotOperandEncodingAttr>(sharedLoadTy.getEncoding());
if (!cvtEncoding)
return failure();
RankedTensorType sharedLoadTy;
TransOp trans;

// Set needTrans to true here. newInnerCvtEnc is computed based on
// argEncoding which is before the transpose. Without needTrans we will
// compute vec and maxPhase based on incorrect m, n and k size of mma. The
// type inference of MemDescTransOp simply swap the order but doesn't fix
// the vec and maxPhase for the YType, hence it would causing incorrect
// swizzling code.
auto ctx = getContext();
auto oldCTALayout = triton::gpu::getCTALayout(srcTy.getEncoding());
auto newCTALayout = permuteCTALayout(ctx, oldCTALayout, trans.getOrder());
assert(succeeded(newCTALayout));
auto newInnerCvtEnc =
SwizzledSharedEncodingAttr::get(ctx, cvtEncoding, srcTy.getShape(),
/*order=*/getOrderForMemory(srcTy),
*newCTALayout, srcTy.getElementType(),
/*needTrans=*/true);
if (newInnerCvtEnc == cvtEncoding)
return failure();
rewriter.setInsertionPoint(trans);
auto sharedMemorySpace = SharedMemorySpaceAttr::get(getContext());
auto alloc = rewriter.create<LocalAllocOp>(
trans.getLoc(),
MemDescType::get(srcTy.getShape(), srcTy.getElementType(),
newInnerCvtEnc, sharedMemorySpace),
trans.getSrc());
auto newTrans = rewriter.create<MemDescTransOp>(trans.getLoc(), alloc,
ArrayRef<int32_t>({1, 0}));
auto localLoadOp =
rewriter.create<LocalLoadOp>(trans.getLoc(), sharedLoadTy, newTrans);
rewriter.modifyOpInPlace(cvtOp, [&]() {
cvtOp.getSrcMutable().assign(localLoadOp.getResult());
});
return success();
if ((trans = dyn_cast<TransOp>(value.getDefiningOp()))) {
sharedLoadTy = trans.getType();
} else if (auto cvt = dyn_cast<ConvertLayoutOp>(value.getDefiningOp())) {
if (!(trans = cvt.getSrc().getDefiningOp<TransOp>()) ||
trans.getOrder() != ArrayRef<int32_t>{1, 0}) {
continue;
}

sharedLoadTy = cvt.getType();
} else {
continue;
}

auto feedsDot = llvm::all_of(value.getUsers(),
[](auto user) { return isa<DotOp>(user); });
if (!feedsDot) {
continue;
}

if (!value.getDefiningOp()->hasOneUse() ||
!isa<triton::DotOp>(value.getDefiningOp()->use_begin()->getOwner()))
continue;

// Match trans(innerCvt(x)).
RankedTensorType srcTy = trans.getSrc().getType();

if (auto srcCvt = trans.getSrc().getDefiningOp<ConvertLayoutOp>()) {
srcTy = srcCvt.getSrc().getType();
}
auto opEncoding =
dyn_cast<DotOperandEncodingAttr>(sharedLoadTy.getEncoding());
if (!opEncoding)
continue;

// Set needTrans to true here. newInnerCvtEnc is computed based on
// argEncoding which is before the transpose. Without needTrans we will
// compute vec and maxPhase based on incorrect m, n and k size of mma.
// The type inference of MemDescTransOp simply swap the order but
// doesn't fix the vec and maxPhase for the YType, hence it would
// causing incorrect swizzling code.
auto ctx = getContext();
auto oldCTALayout = triton::gpu::getCTALayout(srcTy.getEncoding());
auto newCTALayout = permuteCTALayout(ctx, oldCTALayout, trans.getOrder());
assert(succeeded(newCTALayout));
auto newInnerCvtEnc =
SwizzledSharedEncodingAttr::get(ctx, opEncoding, srcTy.getShape(),
/*order=*/getOrderForMemory(srcTy),
*newCTALayout, srcTy.getElementType(),
/*needTrans=*/true);
if (newInnerCvtEnc == opEncoding)
continue;

rewriter.setInsertionPoint(trans);
auto sharedMemorySpace = SharedMemorySpaceAttr::get(getContext());
auto alloc = rewriter.create<LocalAllocOp>(
trans.getLoc(),
MemDescType::get(srcTy.getShape(), srcTy.getElementType(),
newInnerCvtEnc, sharedMemorySpace),
trans.getSrc());
auto newTrans = rewriter.create<MemDescTransOp>(
trans.getLoc(), alloc, ArrayRef<int32_t>({1, 0}));
auto localLoadOp =
rewriter.create<LocalLoadOp>(trans.getLoc(), sharedLoadTy, newTrans);

trans.replaceAllUsesWith(localLoadOp.getResult());
rewrite = true;
}
return rewrite ? success() : failure();
}
};

Expand Down