Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 246 additions & 2 deletions src/enzyme_ad/jax/Passes/LowerEnzymeXLABLAS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,256 @@ using namespace mlir::enzyme;
using namespace mlir::enzymexla;
using namespace mlir::stablehlo;

struct SymmOpLowering : public OpRewritePattern<enzymexla::SymmOp> {

using OpRewritePattern<enzymexla::SymmOp>::OpRewritePattern;

std::string backend;
int64_t blasIntWidth;
SymmOpLowering(std::string backend, int64_t blasIntWidth,
MLIRContext *context, PatternBenefit benefit = 1)
: OpRewritePattern(context, benefit), backend(backend),
blasIntWidth(blasIntWidth) {}

LogicalResult matchAndRewrite(enzymexla::SymmOp op,
PatternRewriter &rewriter) const override {
if (backend == "cpu")
return matchAndRewriteCPU(op, rewriter);

// else if (backend == "cuda")
// return matchAndRewriteCUDA(op, rewriter);

// else if (backend == "tpu")
// return matchAndRewriteTPU(op, rewriter);

else
return rewriter.notifyMatchFailure(op, "Unknown backend: \"" + backend +
"\"");
}

LogicalResult matchAndRewriteCPU(enzymexla::SymmOp op,
PatternRewriter &rewriter) const {

auto ctx = op->getContext();
LLVMTypeConverter typeConverter(ctx);

Value a = op.getOperand(0);
Value b = op.getOperand(1);
Value c = op.getOperand(2);
Value alpha_value = op.getAlpha();
Value beta_value = op.getBeta();
auto side_value = op.getSide() == enzymexla::LapackSide::left ? 'L' : 'R';
auto uplo_value = op.getUplo() == enzymexla::LapackUplo::L ? 'L' : 'U';

auto aType = cast<RankedTensorType>(a.getType());
auto bType = cast<RankedTensorType>(b.getType());
auto cType = cast<RankedTensorType>(c.getType());
if (!aType || !bType || !cType)
return rewriter.notifyMatchFailure(
op, "operand types not ranked tensor types");

if (!aType.hasRank() || !bType.hasRank() || !cType.hasRank())
return rewriter.notifyMatchFailure(op, "expected ranked tensor types");

if (aType.getRank() != 2 || bType.getRank() > 2 || cType.getRank() > 2)
return rewriter.notifyMatchFailure(op,
"only 2D matrices supported for symm");

Type elementType = aType.getElementType();
auto blasIntType = rewriter.getIntegerType(blasIntWidth);
auto intType = RankedTensorType::get({}, blasIntType);
auto uint8Type =
RankedTensorType::get({}, rewriter.getIntegerType(8, false));
auto llvmIntType = typeConverter.convertType(blasIntType);
auto llvmPtrType = LLVM::LLVMPointerType::get(ctx);
auto llvmVoidType = LLVM::LLVMVoidType::get(ctx);

std::string blasFn;
if (auto prefix = lapackPrecisionPrefix(elementType)) {
blasFn = "enzymexla_blas_" + *prefix + "symm_";
} else {
op->emitOpError() << "Unsupported element type: " << elementType;
return rewriter.notifyMatchFailure(op, "unsupported element type");
}
std::string blasFnWrapper = blasFn + "wrapper";

auto moduleOp = op->getParentOfType<ModuleOp>();

if (!moduleOp.lookupSymbol<LLVM::LLVMFuncOp>(blasFn)) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(moduleOp.getBody());
auto funcType = LLVM::LLVMFunctionType::get(llvmVoidType,
{llvmPtrType, // side
llvmPtrType, // uplo
llvmPtrType, // m
llvmPtrType, // n
llvmPtrType, // alpha
llvmPtrType, // A
llvmPtrType, // lda
llvmPtrType, // B
llvmPtrType, // ldb
llvmPtrType, // beta
llvmPtrType, // C
llvmPtrType, // ldc
llvmIntType, llvmIntType},
false);
rewriter.create<LLVM::LLVMFuncOp>(op.getLoc(), blasFn, funcType,
LLVM::Linkage::External);
}

if (!moduleOp.lookupSymbol<LLVM::LLVMFuncOp>(blasFnWrapper)) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(moduleOp.getBody());

auto funcType = LLVM::LLVMFunctionType::get(llvmVoidType,
{
llvmPtrType, // side
llvmPtrType, // uplo
llvmPtrType, // m
llvmPtrType, // n
llvmPtrType, // alpha
llvmPtrType, // A
llvmPtrType, // lda
llvmPtrType, // B
llvmPtrType, // ldb
llvmPtrType, // beta
llvmPtrType, // C
llvmPtrType, // ldc
},
false);

auto funcOp =
LLVM::LLVMFuncOp::create(rewriter, op.getLoc(), blasFnWrapper,
funcType, LLVM::Linkage::Private);
rewriter.setInsertionPointToStart(funcOp.addEntryBlock(rewriter));

SmallVector<Value> args(funcOp.getArguments().begin(),
funcOp.getArguments().end());
auto const1 =
LLVM::ConstantOp::create(rewriter, op.getLoc(), llvmIntType,
rewriter.getIntegerAttr(llvmIntType, 1));
args.push_back(const1);
args.push_back(const1);

auto callOp = LLVM::CallOp::create(rewriter, op.getLoc(), TypeRange{},
SymbolRefAttr::get(ctx, blasFn), args);
LLVM::ReturnOp::create(rewriter, op.getLoc(), ValueRange{});
}

static int64_t fn_counter = 0;
std::string funcFnName = blasFnWrapper + "_" + std::to_string(fn_counter++);

SmallVector<bool> isColMajorArr(12, true);
SmallVector<int64_t> operandRanks = {
0, 0, 0, 0, 0, 2, 0, op.getB().getType().getRank(), 0, 0, 2, 0};
SmallVector<int64_t> outputRanks = {2};
auto operandLayouts =
getSHLOLayout(rewriter, operandRanks, isColMajorArr, 2);
auto resultLayouts = getSHLOLayout(rewriter, outputRanks, isColMajorArr, 2);

SmallVector<Attribute> aliases;
aliases.push_back(
stablehlo::OutputOperandAliasAttr::get(ctx, {}, 10, {})); /*C*/

func::FuncOp shloFunc;

{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(moduleOp.getBody());

SmallVector<Type> argTypes = {
op.getA().getType(), // A
op.getB().getType(), // B
op.getC().getType(), // C
op.getAlpha().getType(), // alpha
op.getBeta().getType(), // beta
};
SmallVector<Type> retTypes = {op.getC().getType()};

auto calleeType = rewriter.getFunctionType(argTypes, retTypes);
shloFunc =
func::FuncOp::create(rewriter, op.getLoc(), funcFnName, calleeType);
shloFunc.setPrivate();

auto &entryBlock = *shloFunc.addEntryBlock();
rewriter.setInsertionPointToStart(&entryBlock);

auto A = entryBlock.getArgument(0);
auto B = entryBlock.getArgument(1);
auto C = entryBlock.getArgument(2);
auto alpha = entryBlock.getArgument(3);
auto beta = entryBlock.getArgument(4);

auto side = rewriter.create<stablehlo::ConstantOp>(
op.getLoc(), uint8Type,
cast<ElementsAttr>(makeAttr(uint8Type, side_value)));
auto uplo = rewriter.create<stablehlo::ConstantOp>(
op.getLoc(), uint8Type,
cast<ElementsAttr>(makeAttr(uint8Type, uplo_value)));

auto lda = stablehlo::ConvertOp::create(
rewriter, op.getLoc(), intType,
stablehlo::GetDimensionSizeOp::create(rewriter, op.getLoc(), A, 0));
auto ldb = stablehlo::ConvertOp::create(
rewriter, op.getLoc(), intType,
stablehlo::GetDimensionSizeOp::create(rewriter, op.getLoc(), B, 0));
auto ldc = stablehlo::ConvertOp::create(
rewriter, op.getLoc(), intType,
stablehlo::GetDimensionSizeOp::create(rewriter, op.getLoc(), C, 0));
auto mSize = ldc;
auto nSize = stablehlo::ConvertOp::create(
rewriter, op.getLoc(), intType,
stablehlo::GetDimensionSizeOp::create(rewriter, op.getLoc(), C, 1));

auto jitCall = enzymexla::JITCallOp::create(
rewriter, op.getLoc(), TypeRange{op.getC().getType()},
mlir::FlatSymbolRefAttr::get(
ctx, blasFnWrapper), // TODO CHECK blasFnWrapper vs fn
ValueRange{side, uplo, mSize, nSize, alpha, A, lda, B, ldb, beta, C,
ldc},
rewriter.getStringAttr(""),
/*operand_layouts=*/operandLayouts,
/*result_layouts=*/resultLayouts,
/*arg_attrs=*/nullptr,
/*res_attrs=*/nullptr,
/*output_operand_aliases=*/rewriter.getArrayAttr(aliases),
/*xla_side_effect_free=*/rewriter.getUnitAttr());

func::ReturnOp::create(
rewriter, op.getLoc(),
ValueRange{jitCall.getResult(0)}); // could be empty?
}

auto callOp =
func::CallOp::create(rewriter, op.getLoc(), shloFunc,
ValueRange{op.getA(), op.getB(), op.getC(),
op.getAlpha(), op.getBeta()});

auto result = callOp.getResult(0);

rewriter.replaceAllUsesWith(op.getResult(), result);
// rewriter.eraseOp(op); // remove?

return success();
}

LogicalResult matchAndRewriteCUDA(enzymexla::SymmOp op,
PatternRewriter &rewriter) const {
return failure();
}
LogicalResult matchAndRewriteTPU(enzymexla::SymmOp op,
PatternRewriter &rewriter) const {
return failure();
}
};

struct SyrkOpLowering : public OpRewritePattern<enzymexla::SyrkOp> {
using OpRewritePattern<enzymexla::SyrkOp>::OpRewritePattern;

SyrkOpLowering(std::string backend, int64_t blasIntWidth,
MLIRContext *context, PatternBenefit benefit = 1)
: OpRewritePattern(context, benefit), backend(backend),
blasIntWidth(blasIntWidth){};
blasIntWidth(blasIntWidth) {};

LogicalResult matchAndRewrite(enzymexla::SyrkOp op,
PatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -325,7 +568,8 @@ struct LowerEnzymeXLABLASPass
auto context = getOperation()->getContext();
RewritePatternSet patterns(context);

patterns.add<SyrkOpLowering>(backend, blasIntWidth, context);
patterns.add<SyrkOpLowering, SymmOpLowering>(backend, blasIntWidth,
context);

GreedyRewriteConfig config;
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
Expand Down
Loading