diff --git a/src/CodeGen_X86.cpp b/src/CodeGen_X86.cpp index 08d52587b57f..f111a7d2607b 100644 --- a/src/CodeGen_X86.cpp +++ b/src/CodeGen_X86.cpp @@ -277,20 +277,20 @@ const x86Intrinsic intrinsic_defs[] = { {"dpbf16psx4", Float(32, 4), "dot_product", {Float(32, 4), BFloat(16, 8), BFloat(16, 8)}, Target::AVX512_Zen4}, {"dpbusdx16", Int(32, 16), "dot_product", {Int(32, 16), UInt(8, 64), Int(8, 64)}, Target::AVX512_Zen4}, - {"dpbusdx8", Int(32, 8), "dot_product", {Int(32, 8), UInt(8, 32), Int(8, 32)}, Target::AVX512_Zen4}, - {"dpbusdx4", Int(32, 4), "dot_product", {Int(32, 4), UInt(8, 16), Int(8, 16)}, Target::AVX512_Zen4}, + {"dpbusdx8", Int(32, 8), "dot_product", {Int(32, 8), UInt(8, 32), Int(8, 32)}, Target::AVXVNNI}, + {"dpbusdx4", Int(32, 4), "dot_product", {Int(32, 4), UInt(8, 16), Int(8, 16)}, Target::AVXVNNI}, {"dpwssdx16", Int(32, 16), "dot_product", {Int(32, 16), Int(16, 32), Int(16, 32)}, Target::AVX512_Zen4}, - {"dpwssdx8", Int(32, 8), "dot_product", {Int(32, 8), Int(16, 16), Int(16, 16)}, Target::AVX512_Zen4}, - {"dpwssdx4", Int(32, 4), "dot_product", {Int(32, 4), Int(16, 8), Int(16, 8)}, Target::AVX512_Zen4}, + {"dpwssdx8", Int(32, 8), "dot_product", {Int(32, 8), Int(16, 16), Int(16, 16)}, Target::AVXVNNI}, + {"dpwssdx4", Int(32, 4), "dot_product", {Int(32, 4), Int(16, 8), Int(16, 8)}, Target::AVXVNNI}, {"dpbusdsx16", Int(32, 16), "saturating_dot_product", {Int(32, 16), UInt(8, 64), Int(8, 64)}, Target::AVX512_Zen4}, - {"dpbusdsx8", Int(32, 8), "saturating_dot_product", {Int(32, 8), UInt(8, 32), Int(8, 32)}, Target::AVX512_Zen4}, - {"dpbusdsx4", Int(32, 4), "saturating_dot_product", {Int(32, 4), UInt(8, 16), Int(8, 16)}, Target::AVX512_Zen4}, + {"dpbusdsx8", Int(32, 8), "saturating_dot_product", {Int(32, 8), UInt(8, 32), Int(8, 32)}, Target::AVXVNNI}, + {"dpbusdsx4", Int(32, 4), "saturating_dot_product", {Int(32, 4), UInt(8, 16), Int(8, 16)}, Target::AVXVNNI}, {"dpwssdsx16", Int(32, 16), "saturating_dot_product", {Int(32, 16), Int(16, 32), Int(16, 32)}, Target::AVX512_Zen4}, - {"dpwssdsx8", Int(32, 8), "saturating_dot_product", {Int(32, 8), Int(16, 16), Int(16, 16)}, Target::AVX512_Zen4}, - {"dpwssdsx4", Int(32, 4), "saturating_dot_product", {Int(32, 4), Int(16, 8), Int(16, 8)}, Target::AVX512_Zen4}, + {"dpwssdsx8", Int(32, 8), "saturating_dot_product", {Int(32, 8), Int(16, 16), Int(16, 16)}, Target::AVXVNNI}, + {"dpwssdsx4", Int(32, 4), "saturating_dot_product", {Int(32, 4), Int(16, 8), Int(16, 8)}, Target::AVXVNNI}, {"tileloadd64_i8", Int(8, 1024), "tile_load", {Int(16), Int(16), Handle(), Int(64), Int(64)}, Target::AVX512_SapphireRapids, x86Intrinsic::AccessesMemory}, {"tileloadd64_i8", UInt(8, 1024), "tile_load", {Int(16), Int(16), Handle(), Int(64), Int(64)}, Target::AVX512_SapphireRapids, x86Intrinsic::AccessesMemory}, @@ -1063,6 +1063,9 @@ string CodeGen_X86::mattrs() const { if (target.has_feature(Target::F16C)) { attrs.emplace_back("+f16c"); } + if (target.has_feature(Target::AVXVNNI)) { + attrs.emplace_back("+avxvnni"); + } if (target.has_feature(Target::AVX512) || target.has_feature(Target::AVX512_KNL) || target.has_feature(Target::AVX512_Skylake) || @@ -1089,9 +1092,6 @@ string CodeGen_X86::mattrs() const { attrs.emplace_back("+avx512bitalg"); attrs.emplace_back("+avx512vbmi2"); } - if (target.has_feature(Target::AVXVNNI)) { - attrs.emplace_back("+avxvnni"); - } if (target.has_feature(Target::AVX512_SapphireRapids)) { attrs.emplace_back("+amx-int8"); attrs.emplace_back("+amx-bf16"); diff --git a/src/Target.cpp b/src/Target.cpp index c5d47bcdf43b..6be97e7bd32f 100644 --- a/src/Target.cpp +++ b/src/Target.cpp @@ -400,6 +400,12 @@ Target calculate_host_target() { const uint32_t avx512_cannonlake = avx512_skylake | avx512ifma; // Assume ifma => vbmi if ((info2[1] & avx2) == avx2) { initial_features.push_back(Target::AVX2); + // avxvnni (note, not avx512vnni) result in eax + const uint32_t avxvnni = 1U << 4; + // TODO: port to family/model -based detection. + if ((info3[0] & avxvnni) == avxvnni) { + initial_features.push_back(Target::AVXVNNI); + } } if ((info2[1] & avx512) == avx512) { initial_features.push_back(Target::AVX512); @@ -415,14 +421,9 @@ Target calculate_host_target() { if ((info2[1] & avx512_cannonlake) == avx512_cannonlake) { initial_features.push_back(Target::AVX512_Cannonlake); - const uint32_t avxvnni = 1U << 4; // avxvnni (note, not avx512vnni) result in eax const uint32_t avx512bf16 = 1U << 5; // bf16 result in eax, with cpuid(eax=7, ecx=1) - // TODO: port to family/model -based detection. - if ((info3[0] & avxvnni) == avxvnni) { - initial_features.push_back(Target::AVXVNNI); - if ((info3[0] & avx512bf16) == avx512bf16) { - initial_features.push_back(Target::AVX512_SapphireRapids); - } + if ((info3[0] & avx512bf16) == avx512bf16) { + initial_features.push_back(Target::AVX512_SapphireRapids); } } } diff --git a/src/runtime/x86_avx2.ll b/src/runtime/x86_avx2.ll index 3407c03c7029..54801c973981 100644 --- a/src/runtime/x86_avx2.ll +++ b/src/runtime/x86_avx2.ll @@ -76,3 +76,66 @@ define weak_odr <8 x i32> @hadd_pmadd_i16_avx2(<16 x i16> %a) nounwind alwaysinl declare <8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16>, <16 x i16>) nounwind readnone +define weak_odr <8 x i32> @dpbusdx8(<8 x i32> %init, <32 x i8> %a, <32 x i8> %b) nounwind alwaysinline { + %1 = bitcast <32 x i8> %a to <8 x i32> + %2 = bitcast <32 x i8> %b to <8 x i32> + %3 = tail call <8 x i32> @llvm.x86.avx512.vpdpbusd.256(<8 x i32> %init, <8 x i32> %1, <8 x i32> %2) + ret <8 x i32> %3 +} +declare <8 x i32> @llvm.x86.avx512.vpdpbusd.256(<8 x i32>, <8 x i32>, <8 x i32>) + +define weak_odr <4 x i32> @dpbusdx4(<4 x i32> %init, <16 x i8> %a, <16 x i8> %b) nounwind alwaysinline { + %1 = bitcast <16 x i8> %a to <4 x i32> + %2 = bitcast <16 x i8> %b to <4 x i32> + %3 = tail call <4 x i32> @llvm.x86.avx512.vpdpbusd.128(<4 x i32> %init, <4 x i32> %1, <4 x i32> %2) + ret <4 x i32> %3 +} +declare <4 x i32> @llvm.x86.avx512.vpdpbusd.128(<4 x i32>, <4 x i32>, <4 x i32>) + +define weak_odr <8 x i32> @dpbusdsx8(<8 x i32> %init, <32 x i8> %a, <32 x i8> %b) nounwind alwaysinline { + %1 = bitcast <32 x i8> %a to <8 x i32> + %2 = bitcast <32 x i8> %b to <8 x i32> + %3 = tail call <8 x i32> @llvm.x86.avx512.vpdpbusds.256(<8 x i32> %init, <8 x i32> %1, <8 x i32> %2) + ret <8 x i32> %3 +} +declare <8 x i32> @llvm.x86.avx512.vpdpbusds.256(<8 x i32>, <8 x i32>, <8 x i32>) + +define weak_odr <4 x i32> @dpbusdsx4(<4 x i32> %init, <16 x i8> %a, <16 x i8> %b) nounwind alwaysinline { + %1 = bitcast <16 x i8> %a to <4 x i32> + %2 = bitcast <16 x i8> %b to <4 x i32> + %3 = tail call <4 x i32> @llvm.x86.avx512.vpdpbusds.128(<4 x i32> %init, <4 x i32> %1, <4 x i32> %2) + ret <4 x i32> %3 +} +declare <4 x i32> @llvm.x86.avx512.vpdpbusds.128(<4 x i32>, <4 x i32>, <4 x i32>) + +define weak_odr <8 x i32> @dpwssdx8(<8 x i32> %init, <16 x i16> %a, <16 x i16> %b) nounwind alwaysinline { + %1 = bitcast <16 x i16> %a to <8 x i32> + %2 = bitcast <16 x i16> %b to <8 x i32> + %3 = tail call <8 x i32> @llvm.x86.avx512.vpdpwssd.256(<8 x i32> %init, <8 x i32> %1, <8 x i32> %2) + ret <8 x i32> %3 +} +declare <8 x i32> @llvm.x86.avx512.vpdpwssd.256(<8 x i32>, <8 x i32>, <8 x i32>) + +define weak_odr <4 x i32> @dpwssdx4(<4 x i32> %init, <8 x i16> %a, <8 x i16> %b) nounwind alwaysinline { + %1 = bitcast <8 x i16> %a to <4 x i32> + %2 = bitcast <8 x i16> %b to <4 x i32> + %3 = tail call <4 x i32> @llvm.x86.avx512.vpdpwssd.128(<4 x i32> %init, <4 x i32> %1, <4 x i32> %2) + ret <4 x i32> %3 +} +declare <4 x i32> @llvm.x86.avx512.vpdpwssd.128(<4 x i32>, <4 x i32>, <4 x i32>) + +define weak_odr <8 x i32> @dpwssdsx8(<8 x i32> %init, <16 x i16> %a, <16 x i16> %b) nounwind alwaysinline { + %1 = bitcast <16 x i16> %a to <8 x i32> + %2 = bitcast <16 x i16> %b to <8 x i32> + %3 = tail call <8 x i32> @llvm.x86.avx512.vpdpwssds.256(<8 x i32> %init, <8 x i32> %1, <8 x i32> %2) + ret <8 x i32> %3 +} +declare <8 x i32> @llvm.x86.avx512.vpdpwssds.256(<8 x i32>, <8 x i32>, <8 x i32>) + +define weak_odr <4 x i32> @dpwssdsx4(<4 x i32> %init, <8 x i16> %a, <8 x i16> %b) nounwind alwaysinline { + %1 = bitcast <8 x i16> %a to <4 x i32> + %2 = bitcast <8 x i16> %b to <4 x i32> + %3 = tail call <4 x i32> @llvm.x86.avx512.vpdpwssds.128(<4 x i32> %init, <4 x i32> %1, <4 x i32> %2) + ret <4 x i32> %3 +} +declare <4 x i32> @llvm.x86.avx512.vpdpwssds.128(<4 x i32>, <4 x i32>, <4 x i32>) diff --git a/src/runtime/x86_cpu_features.cpp b/src/runtime/x86_cpu_features.cpp index 8e63c2495394..0cb021046d07 100644 --- a/src/runtime/x86_cpu_features.cpp +++ b/src/runtime/x86_cpu_features.cpp @@ -109,6 +109,8 @@ extern "C" WEAK int halide_get_cpu_features(CpuFeatures *features) { if (use_64_bits && have_avx && have_f16c && have_rdrand) { int info2[4]; cpuid(info2, 7); + int32_t info3[4]; + cpuid(info3, 7, 1); constexpr uint32_t avx2 = 1U << 5; constexpr uint32_t avx512f = 1U << 16; constexpr uint32_t avx512dq = 1U << 17; @@ -126,6 +128,9 @@ extern "C" WEAK int halide_get_cpu_features(CpuFeatures *features) { constexpr uint32_t avx512_cannonlake = avx512_skylake | avx512ifma; // Assume ifma => vbmi if ((info2[1] & avx2) == avx2) { halide_set_available_cpu_feature(features, halide_target_feature_avx2); + if ((info3[0] & avxvnni) == avxvnni) { + halide_set_available_cpu_feature(features, halide_target_feature_avxvnni); + } } if ((info2[1] & avx512) == avx512) { halide_set_available_cpu_feature(features, halide_target_feature_avx512); @@ -138,13 +143,8 @@ extern "C" WEAK int halide_get_cpu_features(CpuFeatures *features) { if ((info2[1] & avx512_cannonlake) == avx512_cannonlake) { halide_set_available_cpu_feature(features, halide_target_feature_avx512_cannonlake); - int32_t info3[4]; - cpuid(info3, 7, 1); - if ((info3[0] & avxvnni) == avxvnni) { - halide_set_available_cpu_feature(features, halide_target_feature_avxvnni); - if ((info3[0] & avx512bf16) == avx512bf16) { - halide_set_available_cpu_feature(features, halide_target_feature_avx512_sapphirerapids); - } + if ((info3[0] & avx512bf16) == avx512bf16) { + halide_set_available_cpu_feature(features, halide_target_feature_avx512_sapphirerapids); } } } diff --git a/test/correctness/simd_op_check.h b/test/correctness/simd_op_check.h index 25b641800987..154f0d5df8c4 100644 --- a/test/correctness/simd_op_check.h +++ b/test/correctness/simd_op_check.h @@ -132,6 +132,7 @@ class SimdOpCheckTest { Target::ARMv89a, Target::AVX, Target::AVX2, + Target::AVXVNNI, Target::AVX512, Target::AVX512_Cannonlake, Target::AVX512_KNL, diff --git a/test/correctness/simd_op_check_x86.cpp b/test/correctness/simd_op_check_x86.cpp index 0b2b3a8455fa..108760548385 100644 --- a/test/correctness/simd_op_check_x86.cpp +++ b/test/correctness/simd_op_check_x86.cpp @@ -709,6 +709,7 @@ int main(int argc, char **argv) { // real reason to test avx without it. Target("x86-64-linux-sse41-avx-f16c-fma"), Target("x86-64-linux-sse41-avx-f16c-fma-avx2"), + Target("x86-64-linux-sse41-avx-f16c-fma-avx2-avxvnni"), // See above: don't test avx512 without extra features, the test // isn't yet set up to test it properly. // Target("x86-64-linux-sse41-avx-avx2-avx512"),