Skip to content

Commit

Permalink
Introduce triton_cpu.DotOp.
Browse files Browse the repository at this point in the history
Signed-off-by: Ilya Enkovich <[email protected]>
  • Loading branch information
ienkovich committed Dec 6, 2024
1 parent 0389f38 commit 02c9a81
Show file tree
Hide file tree
Showing 11 changed files with 229 additions and 107 deletions.
24 changes: 24 additions & 0 deletions include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -166,5 +166,29 @@ def TTC_AssertOp : TTC_Op<"assert", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
let assemblyFormat = "$condition `,` $message attr-dict `:` type($condition)";
}

def TTC_DotOp : TTC_Op<"dot", [Pure,
DeclareOpInterfaceMethods<InferTypeOpInterface>,
TypesMatchWith<"result's type matches accumulator's type",
"d", "c", "$_self">]> {
let summary = "dot";

let description = [{Same as tt.dot but on vectors.}];

let arguments = (
ins
TTC_Vector:$a,
TTC_Vector:$b,
TTC_Vector:$c,
DefaultValuedAttr<TT_InputPrecisionAttr, "::mlir::triton::InputPrecision::IEEE">:$inputPrecision,
DefaultValuedAttr<I32Attr, "0">:$maxNumImpreciseAcc
);

let results = (outs TTC_Vector:$d);

let assemblyFormat = [{
$a`,` $b`,` $c (`,` `inputPrecision` `=` $inputPrecision^)? attr-dict `:`
type($a) `*` type($b) `->` type($d)
}];
}

#endif
11 changes: 11 additions & 0 deletions lib/Dialect/TritonCPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,15 @@ void ExternElementwiseOp::getEffects(
SideEffects::DefaultResource::get());
}

LogicalResult
DotOp::inferReturnTypes(MLIRContext *context, std::optional<Location> location,
ValueRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
// type is the same as the accumulator
auto accTy = cast<VectorType>(operands[2].getType());
inferredReturnTypes.push_back(accTy);
return success();
}

} // namespace mlir::triton::cpu
8 changes: 4 additions & 4 deletions test/TritonCPU/dot-to-amx.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ module {
%6 = triton_cpu.extract_memref %1 : <tensor<32x16xbf16>> -> memref<32x16xbf16, strided<[16, 1]>> loc(#loc)
%7:2 = triton_cpu.extract_indices %1 : <tensor<32x16xbf16>> -> index, index loc(#loc)
%8 = vector.transfer_read %6[%7#0, %7#1], %cst {in_bounds = [true, true]} : memref<32x16xbf16, strided<[16, 1]>>, vector<32x16xbf16> loc(#loc)
%9 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %5, %8, %cst_0 : vector<16x32xbf16>, vector<32x16xbf16> into vector<16x16xf32> loc(#loc)
%9 = triton_cpu.dot %5, %8, %cst_0, inputPrecision = ieee : vector<16x32xbf16> * vector<32x16xbf16> -> vector<16x16xf32> loc(#loc)
%10 = triton_cpu.extract_memref %2 : <tensor<16x16xf32>> -> memref<16x16xf32, strided<[16, 1]>> loc(#loc)
%11:2 = triton_cpu.extract_indices %2 : <tensor<16x16xf32>> -> index, index loc(#loc)
vector.transfer_write %9, %10[%11#0, %11#1] {in_bounds = [true, true]} : vector<16x16xf32>, memref<16x16xf32, strided<[16, 1]>> loc(#loc)
Expand Down Expand Up @@ -80,7 +80,7 @@ module {
%6 = triton_cpu.extract_memref %1 : <tensor<128x16xi8>> -> memref<128x16xi8, strided<[16, 1]>> loc(#loc)
%7:2 = triton_cpu.extract_indices %1 : <tensor<128x16xi8>> -> index, index loc(#loc)
%8 = vector.transfer_read %6[%7#0, %7#1], %c0_i8 {in_bounds = [true, true]} : memref<128x16xi8, strided<[16, 1]>>, vector<128x16xi8> loc(#loc)
%9 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %5, %8, %cst : vector<16x128xi8>, vector<128x16xi8> into vector<16x16xi32> loc(#loc)
%9 = triton_cpu.dot %5, %8, %cst, inputPrecision = ieee : vector<16x128xi8> * vector<128x16xi8> -> vector<16x16xi32> loc(#loc)
%10 = triton_cpu.extract_memref %2 : <tensor<16x16xi32>> -> memref<16x16xi32, strided<[16, 1]>> loc(#loc)
%11:2 = triton_cpu.extract_indices %2 : <tensor<16x16xi32>> -> index, index loc(#loc)
vector.transfer_write %9, %10[%11#0, %11#1] {in_bounds = [true, true]} : vector<16x16xi32>, memref<16x16xi32, strided<[16, 1]>> loc(#loc)
Expand Down Expand Up @@ -136,7 +136,7 @@ module {
%6 = triton_cpu.extract_memref %1 : <tensor<64x32xbf16>> -> memref<64x32xbf16, strided<[32, 1]>> loc(#loc)
%7:2 = triton_cpu.extract_indices %1 : <tensor<64x32xbf16>> -> index, index loc(#loc)
%8 = vector.transfer_read %6[%7#0, %7#1], %cst {in_bounds = [true, true]} : memref<64x32xbf16, strided<[32, 1]>>, vector<64x32xbf16> loc(#loc)
%9 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %5, %8, %cst_0 : vector<16x64xbf16>, vector<64x32xbf16> into vector<16x32xf32> loc(#loc)
%9 = triton_cpu.dot %5, %8, %cst_0, inputPrecision = ieee : vector<16x64xbf16> * vector<64x32xbf16> -> vector<16x32xf32> loc(#loc)
%10 = triton_cpu.extract_memref %2 : <tensor<16x32xf32>> -> memref<16x32xf32, strided<[32, 1]>> loc(#loc)
%11:2 = triton_cpu.extract_indices %2 : <tensor<16x32xf32>> -> index, index loc(#loc)
vector.transfer_write %9, %10[%11#0, %11#1] {in_bounds = [true, true]} : vector<16x32xf32>, memref<16x32xf32, strided<[32, 1]>> loc(#loc)
Expand Down Expand Up @@ -237,7 +237,7 @@ module {
%9 = triton_cpu.extract_memref %arg6 : <tensor<64x32xf8E5M2>> -> memref<128x32xf8E5M2, strided<[32, 1]>> loc(#loc)
%10:2 = triton_cpu.extract_indices %arg6 : <tensor<64x32xf8E5M2>> -> index, index loc(#loc)
%11 = vector.transfer_read %9[%10#0, %10#1], %cst {in_bounds = [true, true]} : memref<128x32xf8E5M2, strided<[32, 1]>>, vector<64x32xf8E5M2> loc(#loc)
%12 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %8, %11, %arg4 : vector<64x64xf8E5M2>, vector<64x32xf8E5M2> into vector<64x32xf32> loc(#loc)
%12 = triton_cpu.dot %8, %11, %arg4, inputPrecision = ieee : vector<64x64xf8E5M2> * vector<64x32xf8E5M2> -> vector<64x32xf32> loc(#loc)
%13 = tt.advance %arg5, [%c0_i32, %c64_i32] : <tensor<64x64xf8E5M2>> loc(#loc)
%14 = tt.advance %arg6, [%c64_i32, %c0_i32] : <tensor<64x32xf8E5M2>> loc(#loc)
scf.yield %12, %13, %14 : vector<64x32xf32>, !tt.ptr<tensor<64x64xf8E5M2>>, !tt.ptr<tensor<64x32xf8E5M2>> loc(#loc)
Expand Down
1 change: 1 addition & 0 deletions third_party/cpu/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def make_tttcir(self, mod, metadata, opt):
amx_fp16 = False
amx_bf16 = 'amx-bf16' in self.cpu_features
cpu.passes.ttcpuir.add_convert_dot_to_amx(pm, amx_int8, amx_fp16, amx_bf16)
cpu.passes.ttcpuir.add_convert_dot_generic(pm)
promote_bf16_to_fp32 = self.cpu_arch == "x86_64" and "avx512bf16" not in self.cpu_features
# We don't have any lowering for mixed precision matmuls, so always use casts for now
convert_mixed_precision_matmul = True
Expand Down
1 change: 1 addition & 0 deletions third_party/cpu/include/TritonCPUTransforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ createConvertDotProduct(bool useHorizontalSum);
std::unique_ptr<OperationPass<ModuleOp>> createConvertDotToAMX();
std::unique_ptr<OperationPass<ModuleOp>>
createConvertDotToAMX(bool convertInt8, bool convertFp16, bool convertBf16);
std::unique_ptr<OperationPass<ModuleOp>> createConvertDotGeneric();

#define GEN_PASS_REGISTRATION
#include "cpu/include/TritonCPUTransforms/Passes.h.inc"
Expand Down
12 changes: 12 additions & 0 deletions third_party/cpu/include/TritonCPUTransforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,16 @@ def ConvertDotToAMX : Pass<"triton-cpu-convert-dot-to-amx", "mlir::ModuleOp"> {
"mlir::triton::cpu::TritonCPUDialect"];
}

def ConvertDotGeneric : Pass<"triton-cpu-convert-dot-generic", "mlir::ModuleOp"> {
let summary = "Generic convertion of dot product op.";
let description = [{
This pass is used to lower matmul operations to generic vector code.
}];

let constructor = "mlir::triton::cpu::createConvertDotGeneric()";

let dependentDialects = ["mlir::vector::VectorDialect",
"mlir::triton::cpu::TritonCPUDialect"];
}

#endif
3 changes: 2 additions & 1 deletion third_party/cpu/lib/TritonCPUTransforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
add_triton_library(TritonCPUTransforms
ConvertDotOp/ConvertDotGeneric.cpp
ConvertDotOp/ConvertDotToAMX.cpp
ConvertDotProduct.cpp
ConvertDotToAMX.cpp
ConvertUnsupportedOps.cpp
DecomposeFpConversions.cpp
OptimizeMasks.cpp
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
#include "cpu/include/TritonCPUTransforms/OptCommon.h"

#include "cpu/include/TritonCPUTransforms/Passes.h"

#include "mlir/Dialect/AMX/AMXDialect.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonCPU/IR/Dialect.h"
#include <iostream>
#include <utility>

namespace mlir {
namespace triton {
namespace cpu {
#define GEN_PASS_DEF_CONVERTDOTGENERIC
#include "cpu/include/TritonCPUTransforms/Passes.h.inc"
} // namespace cpu
} // namespace triton
} // namespace mlir

using namespace mlir;
using namespace mlir::triton;
using namespace mlir::triton::cpu;

namespace {

class DotConversionTarget : public ConversionTarget {
public:
explicit DotConversionTarget(MLIRContext &ctx) : ConversionTarget(ctx) {
addLegalDialect<vector::VectorDialect>();
addLegalDialect<arith::ArithDialect>();
addLegalDialect<TritonDialect>();
addLegalDialect<TritonCPUDialect>();
addIllegalOp<cpu::DotOp>();
}
};

struct DotOpConversion : public OpConversionPattern<cpu::DotOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(cpu::DotOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MLIRContext *ctx = op.getContext();
Location loc = op.getLoc();
Value a = op.getA();
Value b = op.getB();
Value c = op.getC();
VectorType aType = cast<VectorType>(a.getType());
VectorType bType = cast<VectorType>(b.getType());
VectorType cType = cast<VectorType>(c.getType());

uint32_t rank = aType.getRank();
if (rank == 2) {
auto aMap = AffineMap::getMultiDimMapWithTargets(3, {0, 2}, ctx);
auto bMap = AffineMap::getMultiDimMapWithTargets(3, {2, 1}, ctx);
auto cMap = AffineMap::getMultiDimMapWithTargets(3, {0, 1}, ctx);
auto iteratorTypes = rewriter.getArrayAttr(
{vector::IteratorTypeAttr::get(ctx, vector::IteratorType::parallel),
vector::IteratorTypeAttr::get(ctx, vector::IteratorType::parallel),
vector::IteratorTypeAttr::get(ctx,
vector::IteratorType::reduction)});
rewriter.replaceOpWithNewOp<vector::ContractionOp>(
op, a, b, c, rewriter.getAffineMapArrayAttr({aMap, bMap, cMap}),
iteratorTypes);
return success();
} else if (rank == 3) {
auto aMap = AffineMap::getMultiDimMapWithTargets(4, {0, 1, 3}, ctx);
auto bMap = AffineMap::getMultiDimMapWithTargets(4, {0, 3, 2}, ctx);
auto cMap = AffineMap::getMultiDimMapWithTargets(4, {0, 1, 2}, ctx);
auto iteratorTypes = rewriter.getArrayAttr(
{vector::IteratorTypeAttr::get(ctx, vector::IteratorType::parallel),
vector::IteratorTypeAttr::get(ctx, vector::IteratorType::parallel),
vector::IteratorTypeAttr::get(ctx, vector::IteratorType::parallel),
vector::IteratorTypeAttr::get(ctx,
vector::IteratorType::reduction)});
rewriter.replaceOpWithNewOp<vector::ContractionOp>(
op, a, b, c, rewriter.getAffineMapArrayAttr({aMap, bMap, cMap}),
iteratorTypes);
return success();
}

return failure();
}

SmallVector<Value> deinterleave(Location loc, ArrayRef<Value> vals,
ConversionPatternRewriter &rewriter) const {
SmallVector<Value> res;
for (auto &val : vals) {
auto op = rewriter.create<vector::DeinterleaveOp>(loc, val);
res.push_back(op.getResult(0));
res.push_back(op.getResult(1));
}
return res;
}
};

struct ConvertDotGeneric
: public triton::cpu::impl::ConvertDotGenericBase<ConvertDotGeneric> {
using ConvertDotGenericBase::ConvertDotGenericBase;

ConvertDotGeneric() : ConvertDotGenericBase() {}

void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp mod = getOperation();

DotConversionTarget convTarget(*context);
RewritePatternSet patterns(context);
patterns.add<DotOpConversion>(context);

if (failed(applyPartialConversion(mod, convTarget, std::move(patterns))))
return signalPassFailure();
}
};

} // namespace

namespace mlir {
namespace triton {
namespace cpu {

std::unique_ptr<OperationPass<ModuleOp>> createConvertDotGeneric() {
return std::make_unique<ConvertDotGeneric>();
}

} // namespace cpu
} // namespace triton
} // namespace mlir
Loading

0 comments on commit 02c9a81

Please sign in to comment.