diff --git a/source/opt/folding_rules.cpp b/source/opt/folding_rules.cpp index 792b0f16ca..20d712adb9 100644 --- a/source/opt/folding_rules.cpp +++ b/source/opt/folding_rules.cpp @@ -153,7 +153,7 @@ std::vector GetWordsFromScalarFloatConstant( const analysis::FloatConstant* c) { assert(c != nullptr); uint32_t width = c->type()->AsFloat()->width(); - assert(width == 16 || width == 32 || width == 64); + assert(width == 8 || width == 16 || width == 32 || width == 64); if (width == 64) { utils::FloatProxy result(c->GetDouble()); return result.GetWords(); diff --git a/test/opt/fold_test.cpp b/test/opt/fold_test.cpp index 8db13505ae..d54c331c93 100644 --- a/test/opt/fold_test.cpp +++ b/test/opt/fold_test.cpp @@ -226,6 +226,7 @@ OpCapability Int16 OpCapability Int64 OpCapability CooperativeMatrixKHR OpExtension "SPV_KHR_cooperative_matrix" +OpExtension "SPV_EXT_float8" %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 OpEntryPoint Fragment %main "main" @@ -238,6 +239,8 @@ OpName %main "main" %float = OpTypeFloat 32 %double = OpTypeFloat 64 %half = OpTypeFloat 16 +%float8e4m3 = OpTypeFloat 8 Float8E4M3EXT +%float8e5m2 = OpTypeFloat 8 Float8E5M2EXT %101 = OpConstantTrue %bool ; Need a def with an numerical id to define id maps. %true = OpConstantTrue %bool %false = OpConstantFalse %bool @@ -444,6 +447,10 @@ OpName %main "main" %half_null = OpConstantNull %half %half_0_1 = OpConstantComposite %v2half %108 %half_1 %v4half_0_1_0_0 = OpConstantComposite %v4half %108 %half_1 %108 %108 +%float8e4m3_1 = OpConstant %float8e4m3 1.0 +%float8e4m3_0x1_7p_0 = OpConstant %float8e4m3 0x1.7p+0 ; float(1.4375) -> float8e4m3(1.5) (RTNE) +%float8e5m2_1 = OpConstant %float8e5m2 1.0 +%float8e5m2_0x1_7p_0 = OpConstant %float8e5m2 0x1.7p+0 ; float(1.4375) -> float8e5m2(1.5) (RTNE) %106 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0 %v4float_0_0_0_0 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0 %v4float_0_0_0_1 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_1 @@ -1346,7 +1353,43 @@ INSTANTIATE_TEST_SUITE_P(TestCase, IntegerInstructionFoldingTest, "%2 = OpBitReverse %uint %111\n" + "OpReturn\n" + "OpFunctionEnd", - 2, 0) + 2, 0), + // TODO: hex_float.h contains some errors when converting float32 to + // smaller floating point types, which causes incorrect results. Please see + // https://github.com/KhronosGroup/glslang/issues/4241 and + // https://godbolt.org/z/684sEjzGY for details. + // Test case 98: Bit-cast float8e4m3 1 to float8e4m3 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpBitcast %float8e4m3 %float8e4m3_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0x38), + // Test case 99: Bit-cast float8e4m3 1.4375 to ubyte + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpBitcast %ubyte %float8e4m3_0x1_7p_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0x3b), + // Test case 100: Bit-cast float8e5m2 1 to float8e5m2 + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpBitcast %float8e5m2 %float8e5m2_1\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0x3c), + // Test case 101: Bit-cast float8e5m2 1.4375 to ubyte + InstructionFoldingCase( + Header() + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%2 = OpBitcast %ubyte %float8e5m2_0x1_7p_0\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, 0x3d) )); // clang-format on