Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions src/CodeGen_X86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,20 +277,20 @@ const x86Intrinsic intrinsic_defs[] = {
{"dpbf16psx4", Float(32, 4), "dot_product", {Float(32, 4), BFloat(16, 8), BFloat(16, 8)}, Target::AVX512_Zen4},

{"dpbusdx16", Int(32, 16), "dot_product", {Int(32, 16), UInt(8, 64), Int(8, 64)}, Target::AVX512_Zen4},
{"dpbusdx8", Int(32, 8), "dot_product", {Int(32, 8), UInt(8, 32), Int(8, 32)}, Target::AVX512_Zen4},
{"dpbusdx4", Int(32, 4), "dot_product", {Int(32, 4), UInt(8, 16), Int(8, 16)}, Target::AVX512_Zen4},
{"dpbusdx8", Int(32, 8), "dot_product", {Int(32, 8), UInt(8, 32), Int(8, 32)}, Target::AVXVNNI},
{"dpbusdx4", Int(32, 4), "dot_product", {Int(32, 4), UInt(8, 16), Int(8, 16)}, Target::AVXVNNI},

{"dpwssdx16", Int(32, 16), "dot_product", {Int(32, 16), Int(16, 32), Int(16, 32)}, Target::AVX512_Zen4},
{"dpwssdx8", Int(32, 8), "dot_product", {Int(32, 8), Int(16, 16), Int(16, 16)}, Target::AVX512_Zen4},
{"dpwssdx4", Int(32, 4), "dot_product", {Int(32, 4), Int(16, 8), Int(16, 8)}, Target::AVX512_Zen4},
{"dpwssdx8", Int(32, 8), "dot_product", {Int(32, 8), Int(16, 16), Int(16, 16)}, Target::AVXVNNI},
{"dpwssdx4", Int(32, 4), "dot_product", {Int(32, 4), Int(16, 8), Int(16, 8)}, Target::AVXVNNI},

{"dpbusdsx16", Int(32, 16), "saturating_dot_product", {Int(32, 16), UInt(8, 64), Int(8, 64)}, Target::AVX512_Zen4},
{"dpbusdsx8", Int(32, 8), "saturating_dot_product", {Int(32, 8), UInt(8, 32), Int(8, 32)}, Target::AVX512_Zen4},
{"dpbusdsx4", Int(32, 4), "saturating_dot_product", {Int(32, 4), UInt(8, 16), Int(8, 16)}, Target::AVX512_Zen4},
{"dpbusdsx8", Int(32, 8), "saturating_dot_product", {Int(32, 8), UInt(8, 32), Int(8, 32)}, Target::AVXVNNI},
{"dpbusdsx4", Int(32, 4), "saturating_dot_product", {Int(32, 4), UInt(8, 16), Int(8, 16)}, Target::AVXVNNI},

{"dpwssdsx16", Int(32, 16), "saturating_dot_product", {Int(32, 16), Int(16, 32), Int(16, 32)}, Target::AVX512_Zen4},
{"dpwssdsx8", Int(32, 8), "saturating_dot_product", {Int(32, 8), Int(16, 16), Int(16, 16)}, Target::AVX512_Zen4},
{"dpwssdsx4", Int(32, 4), "saturating_dot_product", {Int(32, 4), Int(16, 8), Int(16, 8)}, Target::AVX512_Zen4},
{"dpwssdsx8", Int(32, 8), "saturating_dot_product", {Int(32, 8), Int(16, 16), Int(16, 16)}, Target::AVXVNNI},
{"dpwssdsx4", Int(32, 4), "saturating_dot_product", {Int(32, 4), Int(16, 8), Int(16, 8)}, Target::AVXVNNI},

{"tileloadd64_i8", Int(8, 1024), "tile_load", {Int(16), Int(16), Handle(), Int(64), Int(64)}, Target::AVX512_SapphireRapids, x86Intrinsic::AccessesMemory},
{"tileloadd64_i8", UInt(8, 1024), "tile_load", {Int(16), Int(16), Handle(), Int(64), Int(64)}, Target::AVX512_SapphireRapids, x86Intrinsic::AccessesMemory},
Expand Down Expand Up @@ -1063,6 +1063,9 @@ string CodeGen_X86::mattrs() const {
if (target.has_feature(Target::F16C)) {
attrs.emplace_back("+f16c");
}
if (target.has_feature(Target::AVXVNNI)) {
attrs.emplace_back("+avxvnni");
}
if (target.has_feature(Target::AVX512) ||
target.has_feature(Target::AVX512_KNL) ||
target.has_feature(Target::AVX512_Skylake) ||
Expand All @@ -1089,9 +1092,6 @@ string CodeGen_X86::mattrs() const {
attrs.emplace_back("+avx512bitalg");
attrs.emplace_back("+avx512vbmi2");
}
if (target.has_feature(Target::AVXVNNI)) {
attrs.emplace_back("+avxvnni");
}
if (target.has_feature(Target::AVX512_SapphireRapids)) {
attrs.emplace_back("+amx-int8");
attrs.emplace_back("+amx-bf16");
Expand Down
15 changes: 8 additions & 7 deletions src/Target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,12 @@ Target calculate_host_target() {
const uint32_t avx512_cannonlake = avx512_skylake | avx512ifma; // Assume ifma => vbmi
if ((info2[1] & avx2) == avx2) {
initial_features.push_back(Target::AVX2);
// avxvnni (note, not avx512vnni) result in eax
const uint32_t avxvnni = 1U << 4;
// TODO: port to family/model -based detection.
if ((info3[0] & avxvnni) == avxvnni) {
initial_features.push_back(Target::AVXVNNI);
}
}
if ((info2[1] & avx512) == avx512) {
initial_features.push_back(Target::AVX512);
Expand All @@ -415,14 +421,9 @@ Target calculate_host_target() {
if ((info2[1] & avx512_cannonlake) == avx512_cannonlake) {
initial_features.push_back(Target::AVX512_Cannonlake);

const uint32_t avxvnni = 1U << 4; // avxvnni (note, not avx512vnni) result in eax
const uint32_t avx512bf16 = 1U << 5; // bf16 result in eax, with cpuid(eax=7, ecx=1)
// TODO: port to family/model -based detection.
if ((info3[0] & avxvnni) == avxvnni) {
initial_features.push_back(Target::AVXVNNI);
if ((info3[0] & avx512bf16) == avx512bf16) {
initial_features.push_back(Target::AVX512_SapphireRapids);
}
if ((info3[0] & avx512bf16) == avx512bf16) {
initial_features.push_back(Target::AVX512_SapphireRapids);
}
}
}
Expand Down
63 changes: 63 additions & 0 deletions src/runtime/x86_avx2.ll
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,66 @@ define weak_odr <8 x i32> @hadd_pmadd_i16_avx2(<16 x i16> %a) nounwind alwaysinl

declare <8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16>, <16 x i16>) nounwind readnone

define weak_odr <8 x i32> @dpbusdx8(<8 x i32> %init, <32 x i8> %a, <32 x i8> %b) nounwind alwaysinline {
%1 = bitcast <32 x i8> %a to <8 x i32>
%2 = bitcast <32 x i8> %b to <8 x i32>
%3 = tail call <8 x i32> @llvm.x86.avx512.vpdpbusd.256(<8 x i32> %init, <8 x i32> %1, <8 x i32> %2)
ret <8 x i32> %3
}
declare <8 x i32> @llvm.x86.avx512.vpdpbusd.256(<8 x i32>, <8 x i32>, <8 x i32>)

define weak_odr <4 x i32> @dpbusdx4(<4 x i32> %init, <16 x i8> %a, <16 x i8> %b) nounwind alwaysinline {
%1 = bitcast <16 x i8> %a to <4 x i32>
%2 = bitcast <16 x i8> %b to <4 x i32>
%3 = tail call <4 x i32> @llvm.x86.avx512.vpdpbusd.128(<4 x i32> %init, <4 x i32> %1, <4 x i32> %2)
ret <4 x i32> %3
}
declare <4 x i32> @llvm.x86.avx512.vpdpbusd.128(<4 x i32>, <4 x i32>, <4 x i32>)

define weak_odr <8 x i32> @dpbusdsx8(<8 x i32> %init, <32 x i8> %a, <32 x i8> %b) nounwind alwaysinline {
%1 = bitcast <32 x i8> %a to <8 x i32>
%2 = bitcast <32 x i8> %b to <8 x i32>
%3 = tail call <8 x i32> @llvm.x86.avx512.vpdpbusds.256(<8 x i32> %init, <8 x i32> %1, <8 x i32> %2)
ret <8 x i32> %3
}
declare <8 x i32> @llvm.x86.avx512.vpdpbusds.256(<8 x i32>, <8 x i32>, <8 x i32>)

define weak_odr <4 x i32> @dpbusdsx4(<4 x i32> %init, <16 x i8> %a, <16 x i8> %b) nounwind alwaysinline {
%1 = bitcast <16 x i8> %a to <4 x i32>
%2 = bitcast <16 x i8> %b to <4 x i32>
%3 = tail call <4 x i32> @llvm.x86.avx512.vpdpbusds.128(<4 x i32> %init, <4 x i32> %1, <4 x i32> %2)
ret <4 x i32> %3
}
declare <4 x i32> @llvm.x86.avx512.vpdpbusds.128(<4 x i32>, <4 x i32>, <4 x i32>)

define weak_odr <8 x i32> @dpwssdx8(<8 x i32> %init, <16 x i16> %a, <16 x i16> %b) nounwind alwaysinline {
%1 = bitcast <16 x i16> %a to <8 x i32>
%2 = bitcast <16 x i16> %b to <8 x i32>
%3 = tail call <8 x i32> @llvm.x86.avx512.vpdpwssd.256(<8 x i32> %init, <8 x i32> %1, <8 x i32> %2)
ret <8 x i32> %3
}
declare <8 x i32> @llvm.x86.avx512.vpdpwssd.256(<8 x i32>, <8 x i32>, <8 x i32>)

define weak_odr <4 x i32> @dpwssdx4(<4 x i32> %init, <8 x i16> %a, <8 x i16> %b) nounwind alwaysinline {
%1 = bitcast <8 x i16> %a to <4 x i32>
%2 = bitcast <8 x i16> %b to <4 x i32>
%3 = tail call <4 x i32> @llvm.x86.avx512.vpdpwssd.128(<4 x i32> %init, <4 x i32> %1, <4 x i32> %2)
ret <4 x i32> %3
}
declare <4 x i32> @llvm.x86.avx512.vpdpwssd.128(<4 x i32>, <4 x i32>, <4 x i32>)

define weak_odr <8 x i32> @dpwssdsx8(<8 x i32> %init, <16 x i16> %a, <16 x i16> %b) nounwind alwaysinline {
%1 = bitcast <16 x i16> %a to <8 x i32>
%2 = bitcast <16 x i16> %b to <8 x i32>
%3 = tail call <8 x i32> @llvm.x86.avx512.vpdpwssds.256(<8 x i32> %init, <8 x i32> %1, <8 x i32> %2)
ret <8 x i32> %3
}
declare <8 x i32> @llvm.x86.avx512.vpdpwssds.256(<8 x i32>, <8 x i32>, <8 x i32>)

define weak_odr <4 x i32> @dpwssdsx4(<4 x i32> %init, <8 x i16> %a, <8 x i16> %b) nounwind alwaysinline {
%1 = bitcast <8 x i16> %a to <4 x i32>
%2 = bitcast <8 x i16> %b to <4 x i32>
%3 = tail call <4 x i32> @llvm.x86.avx512.vpdpwssds.128(<4 x i32> %init, <4 x i32> %1, <4 x i32> %2)
ret <4 x i32> %3
}
declare <4 x i32> @llvm.x86.avx512.vpdpwssds.128(<4 x i32>, <4 x i32>, <4 x i32>)
14 changes: 7 additions & 7 deletions src/runtime/x86_cpu_features.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ extern "C" WEAK int halide_get_cpu_features(CpuFeatures *features) {
if (use_64_bits && have_avx && have_f16c && have_rdrand) {
int info2[4];
cpuid(info2, 7);
int32_t info3[4];
cpuid(info3, 7, 1);
constexpr uint32_t avx2 = 1U << 5;
constexpr uint32_t avx512f = 1U << 16;
constexpr uint32_t avx512dq = 1U << 17;
Expand All @@ -126,6 +128,9 @@ extern "C" WEAK int halide_get_cpu_features(CpuFeatures *features) {
constexpr uint32_t avx512_cannonlake = avx512_skylake | avx512ifma; // Assume ifma => vbmi
if ((info2[1] & avx2) == avx2) {
halide_set_available_cpu_feature(features, halide_target_feature_avx2);
if ((info3[0] & avxvnni) == avxvnni) {
halide_set_available_cpu_feature(features, halide_target_feature_avxvnni);
}
}
if ((info2[1] & avx512) == avx512) {
halide_set_available_cpu_feature(features, halide_target_feature_avx512);
Expand All @@ -138,13 +143,8 @@ extern "C" WEAK int halide_get_cpu_features(CpuFeatures *features) {
if ((info2[1] & avx512_cannonlake) == avx512_cannonlake) {
halide_set_available_cpu_feature(features, halide_target_feature_avx512_cannonlake);

int32_t info3[4];
cpuid(info3, 7, 1);
if ((info3[0] & avxvnni) == avxvnni) {
halide_set_available_cpu_feature(features, halide_target_feature_avxvnni);
if ((info3[0] & avx512bf16) == avx512bf16) {
halide_set_available_cpu_feature(features, halide_target_feature_avx512_sapphirerapids);
}
if ((info3[0] & avx512bf16) == avx512bf16) {
halide_set_available_cpu_feature(features, halide_target_feature_avx512_sapphirerapids);
}
}
}
Expand Down
1 change: 1 addition & 0 deletions test/correctness/simd_op_check.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ class SimdOpCheckTest {
Target::ARMv89a,
Target::AVX,
Target::AVX2,
Target::AVXVNNI,
Target::AVX512,
Target::AVX512_Cannonlake,
Target::AVX512_KNL,
Expand Down
1 change: 1 addition & 0 deletions test/correctness/simd_op_check_x86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,7 @@ int main(int argc, char **argv) {
// real reason to test avx without it.
Target("x86-64-linux-sse41-avx-f16c-fma"),
Target("x86-64-linux-sse41-avx-f16c-fma-avx2"),
Target("x86-64-linux-sse41-avx-f16c-fma-avx2-avxvnni"),
// See above: don't test avx512 without extra features, the test
// isn't yet set up to test it properly.
// Target("x86-64-linux-sse41-avx-avx2-avx512"),
Expand Down
Loading