diff --git a/src/enzyme_ad/jax/Passes/LowerEnzymeXLABLAS.cpp b/src/enzyme_ad/jax/Passes/LowerEnzymeXLABLAS.cpp index 0b2979408..fd0e96895 100644 --- a/src/enzyme_ad/jax/Passes/LowerEnzymeXLABLAS.cpp +++ b/src/enzyme_ad/jax/Passes/LowerEnzymeXLABLAS.cpp @@ -29,13 +29,256 @@ using namespace mlir::enzyme; using namespace mlir::enzymexla; using namespace mlir::stablehlo; +struct SymmOpLowering : public OpRewritePattern { + + using OpRewritePattern::OpRewritePattern; + + std::string backend; + int64_t blasIntWidth; + SymmOpLowering(std::string backend, int64_t blasIntWidth, + MLIRContext *context, PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), backend(backend), + blasIntWidth(blasIntWidth) {} + + LogicalResult matchAndRewrite(enzymexla::SymmOp op, + PatternRewriter &rewriter) const override { + if (backend == "cpu") + return matchAndRewriteCPU(op, rewriter); + + // else if (backend == "cuda") + // return matchAndRewriteCUDA(op, rewriter); + + // else if (backend == "tpu") + // return matchAndRewriteTPU(op, rewriter); + + else + return rewriter.notifyMatchFailure(op, "Unknown backend: \"" + backend + + "\""); + } + + LogicalResult matchAndRewriteCPU(enzymexla::SymmOp op, + PatternRewriter &rewriter) const { + + auto ctx = op->getContext(); + LLVMTypeConverter typeConverter(ctx); + + Value a = op.getOperand(0); + Value b = op.getOperand(1); + Value c = op.getOperand(2); + Value alpha_value = op.getAlpha(); + Value beta_value = op.getBeta(); + auto side_value = op.getSide() == enzymexla::LapackSide::left ? 'L' : 'R'; + auto uplo_value = op.getUplo() == enzymexla::LapackUplo::L ? 'L' : 'U'; + + auto aType = cast(a.getType()); + auto bType = cast(b.getType()); + auto cType = cast(c.getType()); + if (!aType || !bType || !cType) + return rewriter.notifyMatchFailure( + op, "operand types not ranked tensor types"); + + if (!aType.hasRank() || !bType.hasRank() || !cType.hasRank()) + return rewriter.notifyMatchFailure(op, "expected ranked tensor types"); + + if (aType.getRank() != 2 || bType.getRank() > 2 || cType.getRank() > 2) + return rewriter.notifyMatchFailure(op, + "only 2D matrices supported for symm"); + + Type elementType = aType.getElementType(); + auto blasIntType = rewriter.getIntegerType(blasIntWidth); + auto intType = RankedTensorType::get({}, blasIntType); + auto uint8Type = + RankedTensorType::get({}, rewriter.getIntegerType(8, false)); + auto llvmIntType = typeConverter.convertType(blasIntType); + auto llvmPtrType = LLVM::LLVMPointerType::get(ctx); + auto llvmVoidType = LLVM::LLVMVoidType::get(ctx); + + std::string blasFn; + if (auto prefix = lapackPrecisionPrefix(elementType)) { + blasFn = "enzymexla_blas_" + *prefix + "symm_"; + } else { + op->emitOpError() << "Unsupported element type: " << elementType; + return rewriter.notifyMatchFailure(op, "unsupported element type"); + } + std::string blasFnWrapper = blasFn + "wrapper"; + + auto moduleOp = op->getParentOfType(); + + if (!moduleOp.lookupSymbol(blasFn)) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + auto funcType = LLVM::LLVMFunctionType::get(llvmVoidType, + {llvmPtrType, // side + llvmPtrType, // uplo + llvmPtrType, // m + llvmPtrType, // n + llvmPtrType, // alpha + llvmPtrType, // A + llvmPtrType, // lda + llvmPtrType, // B + llvmPtrType, // ldb + llvmPtrType, // beta + llvmPtrType, // C + llvmPtrType, // ldc + llvmIntType, llvmIntType}, + false); + rewriter.create(op.getLoc(), blasFn, funcType, + LLVM::Linkage::External); + } + + if (!moduleOp.lookupSymbol(blasFnWrapper)) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + auto funcType = LLVM::LLVMFunctionType::get(llvmVoidType, + { + llvmPtrType, // side + llvmPtrType, // uplo + llvmPtrType, // m + llvmPtrType, // n + llvmPtrType, // alpha + llvmPtrType, // A + llvmPtrType, // lda + llvmPtrType, // B + llvmPtrType, // ldb + llvmPtrType, // beta + llvmPtrType, // C + llvmPtrType, // ldc + }, + false); + + auto funcOp = + LLVM::LLVMFuncOp::create(rewriter, op.getLoc(), blasFnWrapper, + funcType, LLVM::Linkage::Private); + rewriter.setInsertionPointToStart(funcOp.addEntryBlock(rewriter)); + + SmallVector args(funcOp.getArguments().begin(), + funcOp.getArguments().end()); + auto const1 = + LLVM::ConstantOp::create(rewriter, op.getLoc(), llvmIntType, + rewriter.getIntegerAttr(llvmIntType, 1)); + args.push_back(const1); + args.push_back(const1); + + auto callOp = LLVM::CallOp::create(rewriter, op.getLoc(), TypeRange{}, + SymbolRefAttr::get(ctx, blasFn), args); + LLVM::ReturnOp::create(rewriter, op.getLoc(), ValueRange{}); + } + + static int64_t fn_counter = 0; + std::string funcFnName = blasFnWrapper + "_" + std::to_string(fn_counter++); + + SmallVector isColMajorArr(12, true); + SmallVector operandRanks = { + 0, 0, 0, 0, 0, 2, 0, op.getB().getType().getRank(), 0, 0, 2, 0}; + SmallVector outputRanks = {2}; + auto operandLayouts = + getSHLOLayout(rewriter, operandRanks, isColMajorArr, 2); + auto resultLayouts = getSHLOLayout(rewriter, outputRanks, isColMajorArr, 2); + + SmallVector aliases; + aliases.push_back( + stablehlo::OutputOperandAliasAttr::get(ctx, {}, 10, {})); /*C*/ + + func::FuncOp shloFunc; + + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + SmallVector argTypes = { + op.getA().getType(), // A + op.getB().getType(), // B + op.getC().getType(), // C + op.getAlpha().getType(), // alpha + op.getBeta().getType(), // beta + }; + SmallVector retTypes = {op.getC().getType()}; + + auto calleeType = rewriter.getFunctionType(argTypes, retTypes); + shloFunc = + func::FuncOp::create(rewriter, op.getLoc(), funcFnName, calleeType); + shloFunc.setPrivate(); + + auto &entryBlock = *shloFunc.addEntryBlock(); + rewriter.setInsertionPointToStart(&entryBlock); + + auto A = entryBlock.getArgument(0); + auto B = entryBlock.getArgument(1); + auto C = entryBlock.getArgument(2); + auto alpha = entryBlock.getArgument(3); + auto beta = entryBlock.getArgument(4); + + auto side = rewriter.create( + op.getLoc(), uint8Type, + cast(makeAttr(uint8Type, side_value))); + auto uplo = rewriter.create( + op.getLoc(), uint8Type, + cast(makeAttr(uint8Type, uplo_value))); + + auto lda = stablehlo::ConvertOp::create( + rewriter, op.getLoc(), intType, + stablehlo::GetDimensionSizeOp::create(rewriter, op.getLoc(), A, 0)); + auto ldb = stablehlo::ConvertOp::create( + rewriter, op.getLoc(), intType, + stablehlo::GetDimensionSizeOp::create(rewriter, op.getLoc(), B, 0)); + auto ldc = stablehlo::ConvertOp::create( + rewriter, op.getLoc(), intType, + stablehlo::GetDimensionSizeOp::create(rewriter, op.getLoc(), C, 0)); + auto mSize = ldc; + auto nSize = stablehlo::ConvertOp::create( + rewriter, op.getLoc(), intType, + stablehlo::GetDimensionSizeOp::create(rewriter, op.getLoc(), C, 1)); + + auto jitCall = enzymexla::JITCallOp::create( + rewriter, op.getLoc(), TypeRange{op.getC().getType()}, + mlir::FlatSymbolRefAttr::get( + ctx, blasFnWrapper), // TODO CHECK blasFnWrapper vs fn + ValueRange{side, uplo, mSize, nSize, alpha, A, lda, B, ldb, beta, C, + ldc}, + rewriter.getStringAttr(""), + /*operand_layouts=*/operandLayouts, + /*result_layouts=*/resultLayouts, + /*arg_attrs=*/nullptr, + /*res_attrs=*/nullptr, + /*output_operand_aliases=*/rewriter.getArrayAttr(aliases), + /*xla_side_effect_free=*/rewriter.getUnitAttr()); + + func::ReturnOp::create( + rewriter, op.getLoc(), + ValueRange{jitCall.getResult(0)}); // could be empty? + } + + auto callOp = + func::CallOp::create(rewriter, op.getLoc(), shloFunc, + ValueRange{op.getA(), op.getB(), op.getC(), + op.getAlpha(), op.getBeta()}); + + auto result = callOp.getResult(0); + + rewriter.replaceAllUsesWith(op.getResult(), result); + // rewriter.eraseOp(op); // remove? + + return success(); + } + + LogicalResult matchAndRewriteCUDA(enzymexla::SymmOp op, + PatternRewriter &rewriter) const { + return failure(); + } + LogicalResult matchAndRewriteTPU(enzymexla::SymmOp op, + PatternRewriter &rewriter) const { + return failure(); + } +}; + struct SyrkOpLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; SyrkOpLowering(std::string backend, int64_t blasIntWidth, MLIRContext *context, PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), backend(backend), - blasIntWidth(blasIntWidth){}; + blasIntWidth(blasIntWidth) {}; LogicalResult matchAndRewrite(enzymexla::SyrkOp op, PatternRewriter &rewriter) const override { @@ -325,7 +568,8 @@ struct LowerEnzymeXLABLASPass auto context = getOperation()->getContext(); RewritePatternSet patterns(context); - patterns.add(backend, blasIntWidth, context); + patterns.add(backend, blasIntWidth, + context); GreedyRewriteConfig config; if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),