Skip to content

Conversation

Wolfram70
Copy link
Contributor

This change adds the following NVVM dialect Ops for converting fp4/6/8
to fp16x2:

  • convert.f4x2.to.f16x2
  • convert.f6x2.to.f16x2
  • convert.f8x2.to.f16x2
  • convert.f8x2.to.bf16x2

Tests are added in convert_fp4x2.mlir, convert_fp6x2.mlir, and
convert_fp8x2.mlir.

PTX Reference: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt

This change adds the following NVVM dialect Ops for converting fp4/6/8
to fp16x2:
- convert.f4x2.to.f16x2
- convert.f6x2.to.f16x2
- convert.f8x2.to.f16x2
- convert.f8x2.to.bf16x2

Tests are added in `convert_fp4x2.mlir`, `convert_fp6x2.mlir`, and
`convert_fp8x2.mlir`.

PTX Reference: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt
@Wolfram70 Wolfram70 requested a review from durga4github October 8, 2025 08:22
@Wolfram70 Wolfram70 self-assigned this Oct 8, 2025
@Wolfram70 Wolfram70 requested a review from grypp as a code owner October 8, 2025 08:22
@llvmbot
Copy link
Member

llvmbot commented Oct 8, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-llvm

Author: Srinivasa Ravi (Wolfram70)

Changes

This change adds the following NVVM dialect Ops for converting fp4/6/8
to fp16x2:

  • convert.f4x2.to.f16x2
  • convert.f6x2.to.f16x2
  • convert.f8x2.to.f16x2
  • convert.f8x2.to.bf16x2

Tests are added in convert_fp4x2.mlir, convert_fp6x2.mlir, and
convert_fp8x2.mlir.

PTX Reference: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt


Full diff: https://github.com/llvm/llvm-project/pull/162439.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+105)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+139)
  • (added) mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir (+14)
  • (modified) mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir (+24)
  • (modified) mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir (+34)
  • (modified) mlir/test/Target/LLVMIR/nvvmir-invalid.mlir (+32)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index f8e3167b42c35..3d1aab02d2f75 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1855,6 +1855,111 @@ def NVVM_ConvertBF16x2ToF8x2Op : NVVM_Op<"convert.bf16x2.to.f8x2"> {
   }];
 }
 
+class NVVM_ConvertF8x2ToFP16x2Op_Base <string dstType>
+: NVVM_Op<"convert.f8x2.to." # !tolower(dstType) # "x2"> {
+  let summary = "Convert a pair of f8 inputs to " # !tolower(dstType) # "x2";
+  let description = [{
+    This Op converts the given f8 inputs in a i8x2 vector to }] # !tolower(dstType) # [{.
+
+    The result `dst` is represented as a vector of }] # !tolower(dstType) # [{ elements.
+
+    The `relu` attribute, when set, lowers to the '.relu' variant of
+    the cvt instruction.
+    
+    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
+  }];
+  let results = (outs VectorOfLengthAndType<[2], [!cast<Type>(dstType)]>:$dst);
+  let arguments = !if(!eq(dstType, "F16"),
+    (ins VectorOfLengthAndType<[2], [I8]>:$src,
+         DefaultValuedAttr<BoolAttr, "false">:$relu,
+         TypeAttr:$srcType),
+    (ins VectorOfLengthAndType<[2], [I8]>:$src,
+         TypeAttr:$srcType));
+  let assemblyFormat = "$src attr-dict `:` type($src) `(` $srcType `)` `->` type($dst)";
+  let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    static IDArgPair
+    getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, 
+                          llvm::IRBuilderBase &builder);
+  }];
+
+  string llvmBuilder = [{
+    auto [intId, args] =
+    NVVM::ConvertF8x2To}] # dstType # [{x2Op::getIntrinsicIDAndArgs(*op, moduleTranslation, builder);
+    $dst = createIntrinsicCall(builder, intId, args);
+  }];
+}
+def NVVM_ConvertF8x2ToF16x2Op : NVVM_ConvertF8x2ToFP16x2Op_Base<"F16">;
+def NVVM_ConvertF8x2ToBF16x2Op : NVVM_ConvertF8x2ToFP16x2Op_Base<"BF16">;
+
+def NVVM_ConvertF6x2ToF16x2Op : NVVM_Op<"convert.f6x2.to.f16x2"> {
+  let summary = "Convert a pair of f6 inputs to f16x2";
+  let description = [{
+    This Op converts the given f6 inputs in a i8x2 vector to f16x2.
+
+    The result `dst` is represented as a vector of f16 elements.
+    
+    The `relu` attribute, when set, lowers to the '.relu' variant of
+    the cvt instruction.
+
+    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
+  }];
+  let results = (outs VectorOfLengthAndType<[2], [F16]>:$dst);
+  let arguments = (ins VectorOfLengthAndType<[2], [I8]>:$src,
+         DefaultValuedAttr<BoolAttr, "false">:$relu,
+         TypeAttr:$srcType);
+  let assemblyFormat = "$src attr-dict `:` type($src) `(` $srcType `)` `->` type($dst)";
+  let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    static IDArgPair
+    getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, 
+                          llvm::IRBuilderBase &builder);
+  }];
+
+  string llvmBuilder = [{
+    auto [intId, args] =
+    NVVM::ConvertF6x2ToF16x2Op::getIntrinsicIDAndArgs(*op, moduleTranslation, builder);
+    $dst = createIntrinsicCall(builder, intId, args);
+  }];
+}
+
+def NVVM_ConvertF4x2ToF16x2Op : NVVM_Op<"convert.f4x2.to.f16x2"> {
+  let summary = "Convert a pair of f4 inputs to f16x2";
+  let description = [{
+    This Op converts the given f4 inputs packed in an i8 to f16x2.
+
+    The result `dst` is represented as a vector of f16 elements. The value 
+    converted from the lower 4 bits of `src` is stored in the first element of 
+    `dst` and the value converted from the upper 4 bits of `src` is stored in 
+    the second element of `dst`.
+
+    The `relu` attribute, when set, lowers to the '.relu' variant of
+    the cvt instruction.
+
+    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
+  }];
+  let results = (outs VectorOfLengthAndType<[2], [F16]>:$dst);
+  let arguments = (ins I8:$src,
+         DefaultValuedAttr<BoolAttr, "false">:$relu,
+         TypeAttr:$srcType);
+  let assemblyFormat = "$src attr-dict `:` type($src) `(` $srcType `)` `->` type($dst)";
+  let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    static IDArgPair
+    getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, 
+                          llvm::IRBuilderBase &builder);
+  }];
+
+  string llvmBuilder = [{
+    auto [intId, args] =
+    NVVM::ConvertF4x2ToF16x2Op::getIntrinsicIDAndArgs(*op, moduleTranslation, builder);
+    $dst = createIntrinsicCall(builder, intId, args);
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // NVVM MMA Ops
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index e8f8824d47de0..d1bc969b667d3 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -271,6 +271,51 @@ LogicalResult ConvertBF16x2ToF8x2Op::verify() {
   return success();
 }
 
+LogicalResult ConvertF8x2ToF16x2Op::verify() {
+  mlir::MLIRContext *ctx = getContext();
+
+  if (!llvm::isa<Float8E4M3FNType, Float8E5M2Type>(getSrcType()))
+    return emitOpError("Only ")
+           << mlir::Float8E4M3FNType::get(ctx) << " and "
+           << mlir::Float8E5M2Type::get(ctx)
+           << " types are supported for conversions from f8x2 to f16x2.";
+
+  return success();
+}
+
+LogicalResult ConvertF8x2ToBF16x2Op::verify() {
+  mlir::MLIRContext *ctx = getContext();
+  if (!llvm::isa<Float8E8M0FNUType>(getSrcType()))
+    return emitOpError("Only ")
+           << mlir::Float8E8M0FNUType::get(ctx)
+           << " type is supported for conversions from f8x2 to bf16x2.";
+
+  return success();
+}
+
+LogicalResult ConvertF6x2ToF16x2Op::verify() {
+  mlir::MLIRContext *ctx = getContext();
+
+  if (!llvm::isa<Float6E2M3FNType, Float6E3M2FNType>(getSrcType()))
+    return emitOpError("Only ")
+           << mlir::Float6E2M3FNType::get(ctx) << " and "
+           << mlir::Float6E3M2FNType::get(ctx)
+           << " types are supported for conversions from f6x2 to f16x2.";
+
+  return success();
+}
+
+LogicalResult ConvertF4x2ToF16x2Op::verify() {
+  mlir::MLIRContext *ctx = getContext();
+
+  if (!llvm::isa<Float4E2M1FNType>(getSrcType()))
+    return emitOpError("Only ")
+           << mlir::Float4E2M1FNType::get(ctx)
+           << " type is supported for conversions from f4x2 to f16x2.";
+
+  return success();
+}
+
 LogicalResult BulkStoreOp::verify() {
   if (getInitVal() != 0)
     return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
@@ -2055,6 +2100,100 @@ ConvertBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
   }
 }
 
+NVVM::IDArgPair ConvertF8x2ToF16x2Op::getIntrinsicIDAndArgs(
+    Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+  auto curOp = cast<NVVM::ConvertF8x2ToF16x2Op>(op);
+
+  bool hasRelu = curOp.getRelu();
+
+  llvm::Intrinsic::ID intId =
+      llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(curOp.getSrcType())
+          .Case<Float8E4M3FNType>([&](Float8E4M3FNType type) {
+            return hasRelu
+                       ? llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn_relu
+                       : llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn;
+          })
+          .Case<Float8E5M2Type>([&](Float8E5M2Type type) {
+            return hasRelu
+                       ? llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn_relu
+                       : llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn;
+          })
+          .Default([](mlir::Type type) {
+            llvm_unreachable("Invalid type for ConvertF8x2ToF16x2Op");
+            return llvm::Intrinsic::not_intrinsic;
+          });
+
+  llvm::Value *packedI16 =
+      builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
+                            llvm::Type::getInt16Ty(builder.getContext()));
+
+  return {intId, {packedI16}};
+}
+
+NVVM::IDArgPair ConvertF8x2ToBF16x2Op::getIntrinsicIDAndArgs(
+    Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+  auto curOp = cast<NVVM::ConvertF8x2ToBF16x2Op>(op);
+
+  llvm::Intrinsic::ID intId = llvm::Intrinsic::nvvm_ue8m0x2_to_bf16x2;
+  llvm::Value *packedI16 =
+      builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
+                            llvm::Type::getInt16Ty(builder.getContext()));
+
+  return {intId, {packedI16}};
+}
+
+NVVM::IDArgPair ConvertF6x2ToF16x2Op::getIntrinsicIDAndArgs(
+    Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+  auto curOp = cast<NVVM::ConvertF6x2ToF16x2Op>(op);
+
+  bool hasRelu = curOp.getRelu();
+
+  llvm::Intrinsic::ID intId =
+      llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(curOp.getSrcType())
+          .Case<Float6E2M3FNType>([&](Float6E2M3FNType type) {
+            return hasRelu ? llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn_relu
+                           : llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn;
+          })
+          .Case<Float6E3M2FNType>([&](Float6E3M2FNType type) {
+            return hasRelu ? llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn_relu
+                           : llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn;
+          })
+          .Default([](mlir::Type type) {
+            llvm_unreachable("Invalid type for ConvertF6x2ToF16x2Op");
+            return llvm::Intrinsic::not_intrinsic;
+          });
+
+  llvm::Value *packedI16 =
+      builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
+                            llvm::Type::getInt16Ty(builder.getContext()));
+
+  return {intId, {packedI16}};
+}
+
+NVVM::IDArgPair ConvertF4x2ToF16x2Op::getIntrinsicIDAndArgs(
+    Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+  auto curOp = cast<NVVM::ConvertF4x2ToF16x2Op>(op);
+
+  bool hasRelu = curOp.getRelu();
+
+  llvm::Intrinsic::ID intId =
+      llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(curOp.getSrcType())
+          .Case<Float4E2M1FNType>([&](Float4E2M1FNType type) {
+            return hasRelu ? llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn_relu
+                           : llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn;
+          })
+          .Default([](mlir::Type type) {
+            llvm_unreachable("Invalid type for ConvertF4x2ToF16x2Op");
+            return llvm::Intrinsic::not_intrinsic;
+          });
+
+  llvm::Value *extendedI16 =
+      builder.CreateZExt(mt.lookupValue(curOp.getSrc()),
+                         llvm::Type::getInt16Ty(builder.getContext()));
+
+  return {intId, {extendedI16}};
+}
+
 llvm::Intrinsic::ID
 Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op,
                                       LLVM::ModuleTranslation &mt,
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
new file mode 100644
index 0000000000000..e43dea4065c08
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// -----
+
+// CHECK-LABEL: @convert_f4x2_to_f16x2
+llvm.func @convert_f4x2_to_f16x2(%src : i8) {
+  // CHECK: %[[res1:.*]] = zext i8 %{{.*}} to i16
+  // CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn(i16 %[[res1]])
+  %res1 = nvvm.convert.f4x2.to.f16x2 %src : i8 (f4E2M1FN)-> vector<2xf16>
+  // CHECK: %[[res2:.*]] = zext i8 %{{.*}} to i16
+  // CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn.relu(i16 %[[res2]])
+  %res2 = nvvm.convert.f4x2.to.f16x2 %src {relu = true} : i8 (f4E2M1FN)-> vector<2xf16>
+  llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir
index 04163b578aa02..939037d9db066 100644
--- a/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir
@@ -19,3 +19,27 @@ llvm.func @convert_f32x2_to_fp6x2_vector(%srcA : f32, %srcB : f32) {
   %res2 = nvvm.convert.f32x2.to.f6x2 <e3m2> %srcA, %srcB : vector<2xi8>
   llvm.return
 }
+
+// -----
+
+// CHECK-LABEL: @convert_f6x2_to_f16x2_e2m3
+llvm.func @convert_f6x2_to_f16x2_e2m3(%src : vector<2xi8>) {
+  // CHECK: %[[res1:.*]] = bitcast <2 x i8> %{{.*}} to i16
+  // CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e2m3x2.to.f16x2.rn(i16 %[[res1]])
+  %res1 = nvvm.convert.f6x2.to.f16x2 %src : vector<2xi8> (f6E2M3FN)-> vector<2xf16>
+  // CHECK: %[[res2:.*]] = bitcast <2 x i8> %{{.*}} to i16
+  // CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e2m3x2.to.f16x2.rn.relu(i16 %[[res2]])
+  %res2 = nvvm.convert.f6x2.to.f16x2 %src {relu = true} : vector<2xi8> (f6E2M3FN)-> vector<2xf16>
+  llvm.return
+}
+
+// CHECK-LABEL: @convert_f6x2_to_f16x2_e3m2
+llvm.func @convert_f6x2_to_f16x2_e3m2(%src : vector<2xi8>) {
+  // CHECK: %[[res1:.*]] = bitcast <2 x i8> %{{.*}} to i16
+  // CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e3m2x2.to.f16x2.rn(i16 %[[res1]])
+  %res1 = nvvm.convert.f6x2.to.f16x2 %src : vector<2xi8> (f6E3M2FN)-> vector<2xf16>
+  // CHECK: %[[res2:.*]] = bitcast <2 x i8> %{{.*}} to i16
+  // CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e3m2x2.to.f16x2.rn.relu(i16 %[[res2]])
+  %res2 = nvvm.convert.f6x2.to.f16x2 %src {relu = true} : vector<2xi8> (f6E3M2FN)-> vector<2xf16>
+  llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir
index 4a15efb9e805c..52078742d740b 100644
--- a/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir
@@ -100,3 +100,37 @@ llvm.func @convert_bf16x2_to_f8x2_vector_return(%src : vector<2xbf16>) {
   %res2 = nvvm.convert.bf16x2.to.f8x2 <ue8m0> %src {rnd = #nvvm.fp_rnd_mode<rp>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> -> vector<2xi8>
   llvm.return
 }
+
+// -----
+
+// CHECK-LABEL: @convert_f8x2_to_f16x2
+llvm.func @convert_f8x2_to_f16x2_e4m3(%src : vector<2xi8>) {
+  // CHECK: %[[res1:.*]] = bitcast <2 x i8> %{{.*}} to i16
+  // CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e4m3x2.to.f16x2.rn(i16 %[[res1]])
+  %res1 = nvvm.convert.f8x2.to.f16x2 %src : vector<2xi8> (f8E4M3FN)-> vector<2xf16>
+  // CHECK: %[[res2:.*]] = bitcast <2 x i8> %{{.*}} to i16
+  // CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e4m3x2.to.f16x2.rn.relu(i16 %[[res2]])
+  %res2 = nvvm.convert.f8x2.to.f16x2 %src {relu = true} : vector<2xi8> (f8E4M3FN)-> vector<2xf16>
+  llvm.return
+}
+
+// CHECK-LABEL: @convert_f8x2_to_f16x2_e5m2
+llvm.func @convert_f8x2_to_f16x2_e5m2(%src : vector<2xi8>) {
+  // CHECK: %[[res1:.*]] = bitcast <2 x i8> %{{.*}} to i16
+  // CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e5m2x2.to.f16x2.rn(i16 %[[res1]])
+  %res1 = nvvm.convert.f8x2.to.f16x2 %src : vector<2xi8> (f8E5M2)-> vector<2xf16>
+  // CHECK: %[[res2:.*]] = bitcast <2 x i8> %{{.*}} to i16
+  // CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e5m2x2.to.f16x2.rn.relu(i16 %[[res2]])
+  %res2 = nvvm.convert.f8x2.to.f16x2 %src {relu = true} : vector<2xi8> (f8E5M2)-> vector<2xf16>
+  llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @convert_f8x2_to_bf16x2_ue8m0
+llvm.func @convert_f8x2_to_bf16x2_ue8m0(%src : vector<2xi8>) {
+  // CHECK: %[[res1:.*]] = bitcast <2 x i8> %{{.*}} to i16
+  // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ue8m0x2.to.bf16x2(i16 %[[res1]])
+  %res1 = nvvm.convert.f8x2.to.bf16x2 %src : vector<2xi8> (f8E8M0FNU)-> vector<2xbf16>
+  llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index 383f4829f3287..b5e6ba5aa1406 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -246,6 +246,38 @@ llvm.func @nvvm_cvt_bf16x2_to_f8x2_invalid_rounding(%src : vector<2xbf16>) {
 
 // -----
 
+llvm.func @nvvm_cvt_f8x2_to_f16x2_invalid_type(%src : vector<2xi8>) {
+  // expected-error @below {{Only 'f8E4M3FN' and 'f8E5M2' types are supported for conversions from f8x2 to f16x2.}}
+  %res = nvvm.convert.f8x2.to.f16x2 %src : vector<2xi8> (f8E4M3) -> vector<2xf16>
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_f8x2_to_bf16x2_invalid_type(%src : vector<2xi8>) {
+  // expected-error @below {{Only 'f8E8M0FNU' type is supported for conversions from f8x2 to bf16x2.}}
+  %res = nvvm.convert.f8x2.to.bf16x2 %src : vector<2xi8> (f8E4M3FN) -> vector<2xbf16>
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_f6x2_to_f16x2_invalid_type(%src : vector<2xi8>) {
+  // expected-error @below {{Only 'f6E2M3FN' and 'f6E3M2FN' types are supported for conversions from f6x2 to f16x2.}}
+  %res = nvvm.convert.f6x2.to.f16x2 %src : vector<2xi8> (f8E4M3FN) -> vector<2xf16>
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_f4x2_to_f16x2_invalid_type(%src : i8) {
+  // expected-error @below {{Only 'f4E2M1FN' type is supported for conversions from f4x2 to f16x2.}}
+  %res = nvvm.convert.f4x2.to.f16x2 %src : i8 (f6E2M3FN) -> vector<2xf16>
+  llvm.return
+}
+
+// -----
+
 llvm.func @nvvm_prefetch_L1_with_evict_priority(%global_ptr: !llvm.ptr<1>) {
   // expected-error @below {{cache eviction priority supported only for cache level L2}}
   nvvm.prefetch level = L1, evict_priority = evict_last, %global_ptr : !llvm.ptr<1>

Copy link

github-actions bot commented Oct 8, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

let arguments = !if(!eq(dstType, "F16"),
(ins VectorOfLengthAndType<[2], [I8]>:$src,
DefaultValuedAttr<BoolAttr, "false">:$relu,
TypeAttr:$srcType),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering do we need TypeAttr. Isn't that clear from the name of the OP?

Copy link
Contributor Author

@Wolfram70 Wolfram70 Oct 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, this TypeAttr is for the source type since it can be either e4m3 (f8E4M3FN) or e5m2 (f8E5M2).

}];
}

def NVVM_ConvertF4x2ToF16x2Op : NVVM_Op<"convert.f4x2.to.f16x2"> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've many compiler that targets NVVM dialect. They will replicate convert OP selection logic.

As we add more exotic hardware types, the number of ops keeps growing even further.

IMHO, we need an NVGPU dialect OP where can bake the NVVM OP selection.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, Guray, agreed. Once these narrow-fp types conversions are implemented in NVVM, we could add an NVGPU Op to handle the any -> any conversions, which can dispatch to the right NVVM Op depending on the data-types etc.

}

def NVVM_ConvertF8x2ToF16x2Op :
NVVM_ConvertToFP16x2Op_Base<"f8", VectorOfLengthAndType<[2], [I8]>, "F16">;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious,
would it look better if we move the "x2" semantics here itself? like "f8x2" and "F16x2"..
I believe it may not be possible due to cast<Type> on the dstType?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we'd have to do some string truncation in that case which I think will not be as clean as it currently.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants