diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 53394ecec42a..ca364b4e603b 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -19,66 +19,93 @@ namespace mlir::triton::gpu { namespace { // Given -// dot(convert(trans(src)) #dot_operand) -> +// dot(convert(trans(src)) #dot_operand) | dot(trans(src) #dot_operand) -> // dot(convert(local_load(trans(alloc(src))))) // change the encoding of the inner convert to a special, swizzled shared // encoding. -class SwizzleShmemConvert : public OpRewritePattern { +class SwizzleShmemConvert : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(ConvertLayoutOp cvtOp, + LogicalResult matchAndRewrite(DotOp dotOp, PatternRewriter &rewriter) const override { - if (!cvtOp->hasOneUse() || - !isa(cvtOp->use_begin()->getOwner())) - return failure(); - // Match outerCvt(trans(innerCvt(x))). - auto trans = cvtOp.getSrc().getDefiningOp(); - if (!trans || trans.getOrder() != ArrayRef{1, 0}) - return failure(); + bool rewrite = false; - RankedTensorType srcTy = trans.getSrc().getType(); + for (auto value : dotOp.getOperands()) { + if (!value.getDefiningOp()) { + continue; + } - if (auto srcCvt = trans.getSrc().getDefiningOp()) { - srcTy = srcCvt.getSrc().getType(); - } - RankedTensorType sharedLoadTy = cvtOp.getType(); - auto cvtEncoding = - dyn_cast(sharedLoadTy.getEncoding()); - if (!cvtEncoding) - return failure(); + RankedTensorType sharedLoadTy; + TransOp trans; - // Set needTrans to true here. newInnerCvtEnc is computed based on - // argEncoding which is before the transpose. Without needTrans we will - // compute vec and maxPhase based on incorrect m, n and k size of mma. The - // type inference of MemDescTransOp simply swap the order but doesn't fix - // the vec and maxPhase for the YType, hence it would causing incorrect - // swizzling code. - auto ctx = getContext(); - auto oldCTALayout = triton::gpu::getCTALayout(srcTy.getEncoding()); - auto newCTALayout = permuteCTALayout(ctx, oldCTALayout, trans.getOrder()); - auto newInnerCvtEnc = - SwizzledSharedEncodingAttr::get(ctx, cvtEncoding, srcTy.getShape(), - /*order=*/getOrderForMemory(srcTy), - newCTALayout, srcTy.getElementType(), - /*needTrans=*/true); - if (newInnerCvtEnc == cvtEncoding) - return failure(); - rewriter.setInsertionPoint(trans); - auto sharedMemorySpace = SharedMemorySpaceAttr::get(getContext()); - auto alloc = rewriter.create( - trans.getLoc(), - MemDescType::get(srcTy.getShape(), srcTy.getElementType(), - newInnerCvtEnc, sharedMemorySpace), - trans.getSrc()); - auto newTrans = rewriter.create(trans.getLoc(), alloc, - ArrayRef({1, 0})); - auto localLoadOp = - rewriter.create(trans.getLoc(), sharedLoadTy, newTrans); - rewriter.modifyOpInPlace(cvtOp, [&]() { - cvtOp.getSrcMutable().assign(localLoadOp.getResult()); - }); - return success(); + if ((trans = dyn_cast(value.getDefiningOp()))) { + sharedLoadTy = trans.getType(); + } else if (auto cvt = dyn_cast(value.getDefiningOp())) { + if (!(trans = cvt.getSrc().getDefiningOp()) || + trans.getOrder() != ArrayRef{1, 0}) { + continue; + } + + sharedLoadTy = cvt.getType(); + } else { + continue; + } + + auto feedsDot = llvm::all_of(value.getUsers(), + [](auto user) { return isa(user); }); + if (!feedsDot) { + continue; + } + + if (!value.getDefiningOp()->hasOneUse() || + !isa(value.getDefiningOp()->use_begin()->getOwner())) + continue; + + // Match trans(innerCvt(x)). + RankedTensorType srcTy = trans.getSrc().getType(); + + if (auto srcCvt = trans.getSrc().getDefiningOp()) { + srcTy = srcCvt.getSrc().getType(); + } + auto opEncoding = + dyn_cast(sharedLoadTy.getEncoding()); + if (!opEncoding) + continue; + + // Set needTrans to true here. newInnerCvtEnc is computed based on + // argEncoding which is before the transpose. Without needTrans we will + // compute vec and maxPhase based on incorrect m, n and k size of mma. + // The type inference of MemDescTransOp simply swap the order but + // doesn't fix the vec and maxPhase for the YType, hence it would + // causing incorrect swizzling code. + auto ctx = getContext(); + auto oldCTALayout = triton::gpu::getCTALayout(srcTy.getEncoding()); + auto newCTALayout = permuteCTALayout(ctx, oldCTALayout, trans.getOrder()); + auto newInnerCvtEnc = + SwizzledSharedEncodingAttr::get(ctx, opEncoding, srcTy.getShape(), + /*order=*/getOrderForMemory(srcTy), + newCTALayout, srcTy.getElementType(), + /*needTrans=*/true); + if (newInnerCvtEnc == opEncoding) + continue; + + rewriter.setInsertionPoint(trans); + auto sharedMemorySpace = SharedMemorySpaceAttr::get(getContext()); + auto alloc = rewriter.create( + trans.getLoc(), + MemDescType::get(srcTy.getShape(), srcTy.getElementType(), + newInnerCvtEnc, sharedMemorySpace), + trans.getSrc()); + auto newTrans = rewriter.create( + trans.getLoc(), alloc, ArrayRef({1, 0})); + auto localLoadOp = + rewriter.create(trans.getLoc(), sharedLoadTy, newTrans); + + trans.replaceAllUsesWith(localLoadOp.getResult()); + rewrite = true; + } + return rewrite ? success() : failure(); } };