-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR][NVVM] Add support for converting fp4/6/8 to fp16x2 #162439
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This change adds the following NVVM dialect Ops for converting fp4/6/8 to fp16x2: - convert.f4x2.to.f16x2 - convert.f6x2.to.f16x2 - convert.f8x2.to.f16x2 - convert.f8x2.to.bf16x2 Tests are added in `convert_fp4x2.mlir`, `convert_fp6x2.mlir`, and `convert_fp8x2.mlir`. PTX Reference: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-llvm Author: Srinivasa Ravi (Wolfram70) ChangesThis change adds the following NVVM dialect Ops for converting fp4/6/8
Tests are added in PTX Reference: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt Full diff: https://github.com/llvm/llvm-project/pull/162439.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index f8e3167b42c35..3d1aab02d2f75 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1855,6 +1855,111 @@ def NVVM_ConvertBF16x2ToF8x2Op : NVVM_Op<"convert.bf16x2.to.f8x2"> {
}];
}
+class NVVM_ConvertF8x2ToFP16x2Op_Base <string dstType>
+: NVVM_Op<"convert.f8x2.to." # !tolower(dstType) # "x2"> {
+ let summary = "Convert a pair of f8 inputs to " # !tolower(dstType) # "x2";
+ let description = [{
+ This Op converts the given f8 inputs in a i8x2 vector to }] # !tolower(dstType) # [{.
+
+ The result `dst` is represented as a vector of }] # !tolower(dstType) # [{ elements.
+
+ The `relu` attribute, when set, lowers to the '.relu' variant of
+ the cvt instruction.
+
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
+ }];
+ let results = (outs VectorOfLengthAndType<[2], [!cast<Type>(dstType)]>:$dst);
+ let arguments = !if(!eq(dstType, "F16"),
+ (ins VectorOfLengthAndType<[2], [I8]>:$src,
+ DefaultValuedAttr<BoolAttr, "false">:$relu,
+ TypeAttr:$srcType),
+ (ins VectorOfLengthAndType<[2], [I8]>:$src,
+ TypeAttr:$srcType));
+ let assemblyFormat = "$src attr-dict `:` type($src) `(` $srcType `)` `->` type($dst)";
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ static IDArgPair
+ getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder);
+ }];
+
+ string llvmBuilder = [{
+ auto [intId, args] =
+ NVVM::ConvertF8x2To}] # dstType # [{x2Op::getIntrinsicIDAndArgs(*op, moduleTranslation, builder);
+ $dst = createIntrinsicCall(builder, intId, args);
+ }];
+}
+def NVVM_ConvertF8x2ToF16x2Op : NVVM_ConvertF8x2ToFP16x2Op_Base<"F16">;
+def NVVM_ConvertF8x2ToBF16x2Op : NVVM_ConvertF8x2ToFP16x2Op_Base<"BF16">;
+
+def NVVM_ConvertF6x2ToF16x2Op : NVVM_Op<"convert.f6x2.to.f16x2"> {
+ let summary = "Convert a pair of f6 inputs to f16x2";
+ let description = [{
+ This Op converts the given f6 inputs in a i8x2 vector to f16x2.
+
+ The result `dst` is represented as a vector of f16 elements.
+
+ The `relu` attribute, when set, lowers to the '.relu' variant of
+ the cvt instruction.
+
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
+ }];
+ let results = (outs VectorOfLengthAndType<[2], [F16]>:$dst);
+ let arguments = (ins VectorOfLengthAndType<[2], [I8]>:$src,
+ DefaultValuedAttr<BoolAttr, "false">:$relu,
+ TypeAttr:$srcType);
+ let assemblyFormat = "$src attr-dict `:` type($src) `(` $srcType `)` `->` type($dst)";
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ static IDArgPair
+ getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder);
+ }];
+
+ string llvmBuilder = [{
+ auto [intId, args] =
+ NVVM::ConvertF6x2ToF16x2Op::getIntrinsicIDAndArgs(*op, moduleTranslation, builder);
+ $dst = createIntrinsicCall(builder, intId, args);
+ }];
+}
+
+def NVVM_ConvertF4x2ToF16x2Op : NVVM_Op<"convert.f4x2.to.f16x2"> {
+ let summary = "Convert a pair of f4 inputs to f16x2";
+ let description = [{
+ This Op converts the given f4 inputs packed in an i8 to f16x2.
+
+ The result `dst` is represented as a vector of f16 elements. The value
+ converted from the lower 4 bits of `src` is stored in the first element of
+ `dst` and the value converted from the upper 4 bits of `src` is stored in
+ the second element of `dst`.
+
+ The `relu` attribute, when set, lowers to the '.relu' variant of
+ the cvt instruction.
+
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
+ }];
+ let results = (outs VectorOfLengthAndType<[2], [F16]>:$dst);
+ let arguments = (ins I8:$src,
+ DefaultValuedAttr<BoolAttr, "false">:$relu,
+ TypeAttr:$srcType);
+ let assemblyFormat = "$src attr-dict `:` type($src) `(` $srcType `)` `->` type($dst)";
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ static IDArgPair
+ getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder);
+ }];
+
+ string llvmBuilder = [{
+ auto [intId, args] =
+ NVVM::ConvertF4x2ToF16x2Op::getIntrinsicIDAndArgs(*op, moduleTranslation, builder);
+ $dst = createIntrinsicCall(builder, intId, args);
+ }];
+}
+
//===----------------------------------------------------------------------===//
// NVVM MMA Ops
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index e8f8824d47de0..d1bc969b667d3 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -271,6 +271,51 @@ LogicalResult ConvertBF16x2ToF8x2Op::verify() {
return success();
}
+LogicalResult ConvertF8x2ToF16x2Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (!llvm::isa<Float8E4M3FNType, Float8E5M2Type>(getSrcType()))
+ return emitOpError("Only ")
+ << mlir::Float8E4M3FNType::get(ctx) << " and "
+ << mlir::Float8E5M2Type::get(ctx)
+ << " types are supported for conversions from f8x2 to f16x2.";
+
+ return success();
+}
+
+LogicalResult ConvertF8x2ToBF16x2Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+ if (!llvm::isa<Float8E8M0FNUType>(getSrcType()))
+ return emitOpError("Only ")
+ << mlir::Float8E8M0FNUType::get(ctx)
+ << " type is supported for conversions from f8x2 to bf16x2.";
+
+ return success();
+}
+
+LogicalResult ConvertF6x2ToF16x2Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (!llvm::isa<Float6E2M3FNType, Float6E3M2FNType>(getSrcType()))
+ return emitOpError("Only ")
+ << mlir::Float6E2M3FNType::get(ctx) << " and "
+ << mlir::Float6E3M2FNType::get(ctx)
+ << " types are supported for conversions from f6x2 to f16x2.";
+
+ return success();
+}
+
+LogicalResult ConvertF4x2ToF16x2Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (!llvm::isa<Float4E2M1FNType>(getSrcType()))
+ return emitOpError("Only ")
+ << mlir::Float4E2M1FNType::get(ctx)
+ << " type is supported for conversions from f4x2 to f16x2.";
+
+ return success();
+}
+
LogicalResult BulkStoreOp::verify() {
if (getInitVal() != 0)
return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
@@ -2055,6 +2100,100 @@ ConvertBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
}
}
+NVVM::IDArgPair ConvertF8x2ToF16x2Op::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto curOp = cast<NVVM::ConvertF8x2ToF16x2Op>(op);
+
+ bool hasRelu = curOp.getRelu();
+
+ llvm::Intrinsic::ID intId =
+ llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(curOp.getSrcType())
+ .Case<Float8E4M3FNType>([&](Float8E4M3FNType type) {
+ return hasRelu
+ ? llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn_relu
+ : llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn;
+ })
+ .Case<Float8E5M2Type>([&](Float8E5M2Type type) {
+ return hasRelu
+ ? llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn_relu
+ : llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn;
+ })
+ .Default([](mlir::Type type) {
+ llvm_unreachable("Invalid type for ConvertF8x2ToF16x2Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
+
+ llvm::Value *packedI16 =
+ builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
+ llvm::Type::getInt16Ty(builder.getContext()));
+
+ return {intId, {packedI16}};
+}
+
+NVVM::IDArgPair ConvertF8x2ToBF16x2Op::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto curOp = cast<NVVM::ConvertF8x2ToBF16x2Op>(op);
+
+ llvm::Intrinsic::ID intId = llvm::Intrinsic::nvvm_ue8m0x2_to_bf16x2;
+ llvm::Value *packedI16 =
+ builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
+ llvm::Type::getInt16Ty(builder.getContext()));
+
+ return {intId, {packedI16}};
+}
+
+NVVM::IDArgPair ConvertF6x2ToF16x2Op::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto curOp = cast<NVVM::ConvertF6x2ToF16x2Op>(op);
+
+ bool hasRelu = curOp.getRelu();
+
+ llvm::Intrinsic::ID intId =
+ llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(curOp.getSrcType())
+ .Case<Float6E2M3FNType>([&](Float6E2M3FNType type) {
+ return hasRelu ? llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn_relu
+ : llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn;
+ })
+ .Case<Float6E3M2FNType>([&](Float6E3M2FNType type) {
+ return hasRelu ? llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn_relu
+ : llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn;
+ })
+ .Default([](mlir::Type type) {
+ llvm_unreachable("Invalid type for ConvertF6x2ToF16x2Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
+
+ llvm::Value *packedI16 =
+ builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
+ llvm::Type::getInt16Ty(builder.getContext()));
+
+ return {intId, {packedI16}};
+}
+
+NVVM::IDArgPair ConvertF4x2ToF16x2Op::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto curOp = cast<NVVM::ConvertF4x2ToF16x2Op>(op);
+
+ bool hasRelu = curOp.getRelu();
+
+ llvm::Intrinsic::ID intId =
+ llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(curOp.getSrcType())
+ .Case<Float4E2M1FNType>([&](Float4E2M1FNType type) {
+ return hasRelu ? llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn_relu
+ : llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn;
+ })
+ .Default([](mlir::Type type) {
+ llvm_unreachable("Invalid type for ConvertF4x2ToF16x2Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
+
+ llvm::Value *extendedI16 =
+ builder.CreateZExt(mt.lookupValue(curOp.getSrc()),
+ llvm::Type::getInt16Ty(builder.getContext()));
+
+ return {intId, {extendedI16}};
+}
+
llvm::Intrinsic::ID
Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op,
LLVM::ModuleTranslation &mt,
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
new file mode 100644
index 0000000000000..e43dea4065c08
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// -----
+
+// CHECK-LABEL: @convert_f4x2_to_f16x2
+llvm.func @convert_f4x2_to_f16x2(%src : i8) {
+ // CHECK: %[[res1:.*]] = zext i8 %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn(i16 %[[res1]])
+ %res1 = nvvm.convert.f4x2.to.f16x2 %src : i8 (f4E2M1FN)-> vector<2xf16>
+ // CHECK: %[[res2:.*]] = zext i8 %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn.relu(i16 %[[res2]])
+ %res2 = nvvm.convert.f4x2.to.f16x2 %src {relu = true} : i8 (f4E2M1FN)-> vector<2xf16>
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir
index 04163b578aa02..939037d9db066 100644
--- a/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir
@@ -19,3 +19,27 @@ llvm.func @convert_f32x2_to_fp6x2_vector(%srcA : f32, %srcB : f32) {
%res2 = nvvm.convert.f32x2.to.f6x2 <e3m2> %srcA, %srcB : vector<2xi8>
llvm.return
}
+
+// -----
+
+// CHECK-LABEL: @convert_f6x2_to_f16x2_e2m3
+llvm.func @convert_f6x2_to_f16x2_e2m3(%src : vector<2xi8>) {
+ // CHECK: %[[res1:.*]] = bitcast <2 x i8> %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e2m3x2.to.f16x2.rn(i16 %[[res1]])
+ %res1 = nvvm.convert.f6x2.to.f16x2 %src : vector<2xi8> (f6E2M3FN)-> vector<2xf16>
+ // CHECK: %[[res2:.*]] = bitcast <2 x i8> %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e2m3x2.to.f16x2.rn.relu(i16 %[[res2]])
+ %res2 = nvvm.convert.f6x2.to.f16x2 %src {relu = true} : vector<2xi8> (f6E2M3FN)-> vector<2xf16>
+ llvm.return
+}
+
+// CHECK-LABEL: @convert_f6x2_to_f16x2_e3m2
+llvm.func @convert_f6x2_to_f16x2_e3m2(%src : vector<2xi8>) {
+ // CHECK: %[[res1:.*]] = bitcast <2 x i8> %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e3m2x2.to.f16x2.rn(i16 %[[res1]])
+ %res1 = nvvm.convert.f6x2.to.f16x2 %src : vector<2xi8> (f6E3M2FN)-> vector<2xf16>
+ // CHECK: %[[res2:.*]] = bitcast <2 x i8> %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e3m2x2.to.f16x2.rn.relu(i16 %[[res2]])
+ %res2 = nvvm.convert.f6x2.to.f16x2 %src {relu = true} : vector<2xi8> (f6E3M2FN)-> vector<2xf16>
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir
index 4a15efb9e805c..52078742d740b 100644
--- a/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir
@@ -100,3 +100,37 @@ llvm.func @convert_bf16x2_to_f8x2_vector_return(%src : vector<2xbf16>) {
%res2 = nvvm.convert.bf16x2.to.f8x2 <ue8m0> %src {rnd = #nvvm.fp_rnd_mode<rp>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> -> vector<2xi8>
llvm.return
}
+
+// -----
+
+// CHECK-LABEL: @convert_f8x2_to_f16x2
+llvm.func @convert_f8x2_to_f16x2_e4m3(%src : vector<2xi8>) {
+ // CHECK: %[[res1:.*]] = bitcast <2 x i8> %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e4m3x2.to.f16x2.rn(i16 %[[res1]])
+ %res1 = nvvm.convert.f8x2.to.f16x2 %src : vector<2xi8> (f8E4M3FN)-> vector<2xf16>
+ // CHECK: %[[res2:.*]] = bitcast <2 x i8> %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e4m3x2.to.f16x2.rn.relu(i16 %[[res2]])
+ %res2 = nvvm.convert.f8x2.to.f16x2 %src {relu = true} : vector<2xi8> (f8E4M3FN)-> vector<2xf16>
+ llvm.return
+}
+
+// CHECK-LABEL: @convert_f8x2_to_f16x2_e5m2
+llvm.func @convert_f8x2_to_f16x2_e5m2(%src : vector<2xi8>) {
+ // CHECK: %[[res1:.*]] = bitcast <2 x i8> %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e5m2x2.to.f16x2.rn(i16 %[[res1]])
+ %res1 = nvvm.convert.f8x2.to.f16x2 %src : vector<2xi8> (f8E5M2)-> vector<2xf16>
+ // CHECK: %[[res2:.*]] = bitcast <2 x i8> %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e5m2x2.to.f16x2.rn.relu(i16 %[[res2]])
+ %res2 = nvvm.convert.f8x2.to.f16x2 %src {relu = true} : vector<2xi8> (f8E5M2)-> vector<2xf16>
+ llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @convert_f8x2_to_bf16x2_ue8m0
+llvm.func @convert_f8x2_to_bf16x2_ue8m0(%src : vector<2xi8>) {
+ // CHECK: %[[res1:.*]] = bitcast <2 x i8> %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ue8m0x2.to.bf16x2(i16 %[[res1]])
+ %res1 = nvvm.convert.f8x2.to.bf16x2 %src : vector<2xi8> (f8E8M0FNU)-> vector<2xbf16>
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index 383f4829f3287..b5e6ba5aa1406 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -246,6 +246,38 @@ llvm.func @nvvm_cvt_bf16x2_to_f8x2_invalid_rounding(%src : vector<2xbf16>) {
// -----
+llvm.func @nvvm_cvt_f8x2_to_f16x2_invalid_type(%src : vector<2xi8>) {
+ // expected-error @below {{Only 'f8E4M3FN' and 'f8E5M2' types are supported for conversions from f8x2 to f16x2.}}
+ %res = nvvm.convert.f8x2.to.f16x2 %src : vector<2xi8> (f8E4M3) -> vector<2xf16>
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_f8x2_to_bf16x2_invalid_type(%src : vector<2xi8>) {
+ // expected-error @below {{Only 'f8E8M0FNU' type is supported for conversions from f8x2 to bf16x2.}}
+ %res = nvvm.convert.f8x2.to.bf16x2 %src : vector<2xi8> (f8E4M3FN) -> vector<2xbf16>
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_f6x2_to_f16x2_invalid_type(%src : vector<2xi8>) {
+ // expected-error @below {{Only 'f6E2M3FN' and 'f6E3M2FN' types are supported for conversions from f6x2 to f16x2.}}
+ %res = nvvm.convert.f6x2.to.f16x2 %src : vector<2xi8> (f8E4M3FN) -> vector<2xf16>
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_f4x2_to_f16x2_invalid_type(%src : i8) {
+ // expected-error @below {{Only 'f4E2M1FN' type is supported for conversions from f4x2 to f16x2.}}
+ %res = nvvm.convert.f4x2.to.f16x2 %src : i8 (f6E2M3FN) -> vector<2xf16>
+ llvm.return
+}
+
+// -----
+
llvm.func @nvvm_prefetch_L1_with_evict_priority(%global_ptr: !llvm.ptr<1>) {
// expected-error @below {{cache eviction priority supported only for cache level L2}}
nvvm.prefetch level = L1, evict_priority = evict_last, %global_ptr : !llvm.ptr<1>
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
let arguments = !if(!eq(dstType, "F16"), | ||
(ins VectorOfLengthAndType<[2], [I8]>:$src, | ||
DefaultValuedAttr<BoolAttr, "false">:$relu, | ||
TypeAttr:$srcType), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am wondering do we need TypeAttr
. Isn't that clear from the name of the OP?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case, this TypeAttr
is for the source type since it can be either e4m3
(f8E4M3FN
) or e5m2
(f8E5M2
).
}]; | ||
} | ||
|
||
def NVVM_ConvertF4x2ToF16x2Op : NVVM_Op<"convert.f4x2.to.f16x2"> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We've many compiler that targets NVVM dialect. They will replicate convert OP selection logic.
As we add more exotic hardware types, the number of ops keeps growing even further.
IMHO, we need an NVGPU dialect OP where can bake the NVVM OP selection.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, Guray, agreed. Once these narrow-fp types conversions are implemented in NVVM, we could add an NVGPU Op to handle the any -> any
conversions, which can dispatch to the right NVVM Op depending on the data-types etc.
} | ||
|
||
def NVVM_ConvertF8x2ToF16x2Op : | ||
NVVM_ConvertToFP16x2Op_Base<"f8", VectorOfLengthAndType<[2], [I8]>, "F16">; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious,
would it look better if we move the "x2" semantics here itself? like "f8x2" and "F16x2"..
I believe it may not be possible due to cast<Type>
on the dstType?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, we'd have to do some string truncation in that case which I think will not be as clean as it currently.
This change adds the following NVVM dialect Ops for converting fp4/6/8
to fp16x2:
convert.f4x2.to.f16x2
convert.f6x2.to.f16x2
convert.f8x2.to.f16x2
convert.f8x2.to.bf16x2
Tests are added in
convert_fp4x2.mlir
,convert_fp6x2.mlir
, andconvert_fp8x2.mlir
.PTX Reference: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt