From 24c0355d327e090b51fca2a58b1a6d5e296a5163 Mon Sep 17 00:00:00 2001 From: Fabian Schuetze Date: Sat, 5 Jul 2025 17:12:21 +0200 Subject: [PATCH 1/5] vnni codegen on with avx2 extension --- src/CodeGen_X86.cpp | 17 ++++++++++------- src/Target.cpp | 15 ++++++++------- src/runtime/x86_avx2.ll | 31 +++++++++++++++++++++++++++++++ src/runtime/x86_cpu_features.cpp | 14 +++++++------- 4 files changed, 56 insertions(+), 21 deletions(-) diff --git a/src/CodeGen_X86.cpp b/src/CodeGen_X86.cpp index 08d52587b57f..6cdcb80c7b18 100644 --- a/src/CodeGen_X86.cpp +++ b/src/CodeGen_X86.cpp @@ -277,16 +277,16 @@ const x86Intrinsic intrinsic_defs[] = { {"dpbf16psx4", Float(32, 4), "dot_product", {Float(32, 4), BFloat(16, 8), BFloat(16, 8)}, Target::AVX512_Zen4}, {"dpbusdx16", Int(32, 16), "dot_product", {Int(32, 16), UInt(8, 64), Int(8, 64)}, Target::AVX512_Zen4}, - {"dpbusdx8", Int(32, 8), "dot_product", {Int(32, 8), UInt(8, 32), Int(8, 32)}, Target::AVX512_Zen4}, - {"dpbusdx4", Int(32, 4), "dot_product", {Int(32, 4), UInt(8, 16), Int(8, 16)}, Target::AVX512_Zen4}, + {"dpbusdx8", Int(32, 8), "dot_product", {Int(32, 8), UInt(8, 32), Int(8, 32)}, Target::AVXVNNI}, + {"dpbusdx4", Int(32, 4), "dot_product", {Int(32, 4), UInt(8, 16), Int(8, 16)}, Target::AVXVNNI}, {"dpwssdx16", Int(32, 16), "dot_product", {Int(32, 16), Int(16, 32), Int(16, 32)}, Target::AVX512_Zen4}, {"dpwssdx8", Int(32, 8), "dot_product", {Int(32, 8), Int(16, 16), Int(16, 16)}, Target::AVX512_Zen4}, {"dpwssdx4", Int(32, 4), "dot_product", {Int(32, 4), Int(16, 8), Int(16, 8)}, Target::AVX512_Zen4}, {"dpbusdsx16", Int(32, 16), "saturating_dot_product", {Int(32, 16), UInt(8, 64), Int(8, 64)}, Target::AVX512_Zen4}, - {"dpbusdsx8", Int(32, 8), "saturating_dot_product", {Int(32, 8), UInt(8, 32), Int(8, 32)}, Target::AVX512_Zen4}, - {"dpbusdsx4", Int(32, 4), "saturating_dot_product", {Int(32, 4), UInt(8, 16), Int(8, 16)}, Target::AVX512_Zen4}, + {"dpbusdsx8", Int(32, 8), "saturating_dot_product", {Int(32, 8), UInt(8, 32), Int(8, 32)}, Target::AVXVNNI}, + {"dpbusdsx4", Int(32, 4), "saturating_dot_product", {Int(32, 4), UInt(8, 16), Int(8, 16)}, Target::AVXVNNI}, {"dpwssdsx16", Int(32, 16), "saturating_dot_product", {Int(32, 16), Int(16, 32), Int(16, 32)}, Target::AVX512_Zen4}, {"dpwssdsx8", Int(32, 8), "saturating_dot_product", {Int(32, 8), Int(16, 16), Int(16, 16)}, Target::AVX512_Zen4}, @@ -1063,6 +1063,9 @@ string CodeGen_X86::mattrs() const { if (target.has_feature(Target::F16C)) { attrs.emplace_back("+f16c"); } + if (target.has_feature(Target::AVXVNNI)) { + attrs.emplace_back("+avxvnni"); + } if (target.has_feature(Target::AVX512) || target.has_feature(Target::AVX512_KNL) || target.has_feature(Target::AVX512_Skylake) || @@ -1089,9 +1092,9 @@ string CodeGen_X86::mattrs() const { attrs.emplace_back("+avx512bitalg"); attrs.emplace_back("+avx512vbmi2"); } - if (target.has_feature(Target::AVXVNNI)) { - attrs.emplace_back("+avxvnni"); - } + //if (target.has_feature(Target::AVXVNNI)) { + //attrs.emplace_back("+avxvnni"); + //} if (target.has_feature(Target::AVX512_SapphireRapids)) { attrs.emplace_back("+amx-int8"); attrs.emplace_back("+amx-bf16"); diff --git a/src/Target.cpp b/src/Target.cpp index c5d47bcdf43b..6be97e7bd32f 100644 --- a/src/Target.cpp +++ b/src/Target.cpp @@ -400,6 +400,12 @@ Target calculate_host_target() { const uint32_t avx512_cannonlake = avx512_skylake | avx512ifma; // Assume ifma => vbmi if ((info2[1] & avx2) == avx2) { initial_features.push_back(Target::AVX2); + // avxvnni (note, not avx512vnni) result in eax + const uint32_t avxvnni = 1U << 4; + // TODO: port to family/model -based detection. + if ((info3[0] & avxvnni) == avxvnni) { + initial_features.push_back(Target::AVXVNNI); + } } if ((info2[1] & avx512) == avx512) { initial_features.push_back(Target::AVX512); @@ -415,14 +421,9 @@ Target calculate_host_target() { if ((info2[1] & avx512_cannonlake) == avx512_cannonlake) { initial_features.push_back(Target::AVX512_Cannonlake); - const uint32_t avxvnni = 1U << 4; // avxvnni (note, not avx512vnni) result in eax const uint32_t avx512bf16 = 1U << 5; // bf16 result in eax, with cpuid(eax=7, ecx=1) - // TODO: port to family/model -based detection. - if ((info3[0] & avxvnni) == avxvnni) { - initial_features.push_back(Target::AVXVNNI); - if ((info3[0] & avx512bf16) == avx512bf16) { - initial_features.push_back(Target::AVX512_SapphireRapids); - } + if ((info3[0] & avx512bf16) == avx512bf16) { + initial_features.push_back(Target::AVX512_SapphireRapids); } } } diff --git a/src/runtime/x86_avx2.ll b/src/runtime/x86_avx2.ll index 3407c03c7029..afcbcf8c1933 100644 --- a/src/runtime/x86_avx2.ll +++ b/src/runtime/x86_avx2.ll @@ -76,3 +76,34 @@ define weak_odr <8 x i32> @hadd_pmadd_i16_avx2(<16 x i16> %a) nounwind alwaysinl declare <8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16>, <16 x i16>) nounwind readnone +define weak_odr <8 x i32> @dpbusdx8(<8 x i32> %init, <32 x i8> %a, <32 x i8> %b) nounwind alwaysinline { + %1 = bitcast <32 x i8> %a to <8 x i32> + %2 = bitcast <32 x i8> %b to <8 x i32> + %3 = tail call <8 x i32> @llvm.x86.avx512.vpdpbusd.256(<8 x i32> %init, <8 x i32> %1, <8 x i32> %2) + ret <8 x i32> %3 +} +declare <8 x i32> @llvm.x86.avx512.vpdpbusd.256(<8 x i32>, <8 x i32>, <8 x i32>) + +define weak_odr <4 x i32> @dpbusdx4(<4 x i32> %init, <16 x i8> %a, <16 x i8> %b) nounwind alwaysinline { + %1 = bitcast <16 x i8> %a to <4 x i32> + %2 = bitcast <16 x i8> %b to <4 x i32> + %3 = tail call <4 x i32> @llvm.x86.avx512.vpdpbusd.128(<4 x i32> %init, <4 x i32> %1, <4 x i32> %2) + ret <4 x i32> %3 +} +declare <4 x i32> @llvm.x86.avx512.vpdpbusd.128(<4 x i32>, <4 x i32>, <4 x i32>) + +define weak_odr <8 x i32> @dpbusdsx8(<8 x i32> %init, <32 x i8> %a, <32 x i8> %b) nounwind alwaysinline { + %1 = bitcast <32 x i8> %a to <8 x i32> + %2 = bitcast <32 x i8> %b to <8 x i32> + %3 = tail call <8 x i32> @llvm.x86.avx512.vpdpbusds.256(<8 x i32> %init, <8 x i32> %1, <8 x i32> %2) + ret <8 x i32> %3 +} +declare <8 x i32> @llvm.x86.avx512.vpdpbusds.256(<8 x i32>, <8 x i32>, <8 x i32>) + +define weak_odr <4 x i32> @dpbusdsx4(<4 x i32> %init, <16 x i8> %a, <16 x i8> %b) nounwind alwaysinline { + %1 = bitcast <16 x i8> %a to <4 x i32> + %2 = bitcast <16 x i8> %b to <4 x i32> + %3 = tail call <4 x i32> @llvm.x86.avx512.vpdpbusds.128(<4 x i32> %init, <4 x i32> %1, <4 x i32> %2) + ret <4 x i32> %3 +} +declare <4 x i32> @llvm.x86.avx512.vpdpbusds.128(<4 x i32>, <4 x i32>, <4 x i32>) diff --git a/src/runtime/x86_cpu_features.cpp b/src/runtime/x86_cpu_features.cpp index 8e63c2495394..0cb021046d07 100644 --- a/src/runtime/x86_cpu_features.cpp +++ b/src/runtime/x86_cpu_features.cpp @@ -109,6 +109,8 @@ extern "C" WEAK int halide_get_cpu_features(CpuFeatures *features) { if (use_64_bits && have_avx && have_f16c && have_rdrand) { int info2[4]; cpuid(info2, 7); + int32_t info3[4]; + cpuid(info3, 7, 1); constexpr uint32_t avx2 = 1U << 5; constexpr uint32_t avx512f = 1U << 16; constexpr uint32_t avx512dq = 1U << 17; @@ -126,6 +128,9 @@ extern "C" WEAK int halide_get_cpu_features(CpuFeatures *features) { constexpr uint32_t avx512_cannonlake = avx512_skylake | avx512ifma; // Assume ifma => vbmi if ((info2[1] & avx2) == avx2) { halide_set_available_cpu_feature(features, halide_target_feature_avx2); + if ((info3[0] & avxvnni) == avxvnni) { + halide_set_available_cpu_feature(features, halide_target_feature_avxvnni); + } } if ((info2[1] & avx512) == avx512) { halide_set_available_cpu_feature(features, halide_target_feature_avx512); @@ -138,13 +143,8 @@ extern "C" WEAK int halide_get_cpu_features(CpuFeatures *features) { if ((info2[1] & avx512_cannonlake) == avx512_cannonlake) { halide_set_available_cpu_feature(features, halide_target_feature_avx512_cannonlake); - int32_t info3[4]; - cpuid(info3, 7, 1); - if ((info3[0] & avxvnni) == avxvnni) { - halide_set_available_cpu_feature(features, halide_target_feature_avxvnni); - if ((info3[0] & avx512bf16) == avx512bf16) { - halide_set_available_cpu_feature(features, halide_target_feature_avx512_sapphirerapids); - } + if ((info3[0] & avx512bf16) == avx512bf16) { + halide_set_available_cpu_feature(features, halide_target_feature_avx512_sapphirerapids); } } } From 4167458eea7874a07a37dbe6805f797ee362a489 Mon Sep 17 00:00:00 2001 From: Fabian Schuetze Date: Sat, 5 Jul 2025 18:45:18 +0200 Subject: [PATCH 2/5] correctness x86 test pass --- src/CodeGen_X86.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/CodeGen_X86.cpp b/src/CodeGen_X86.cpp index 6cdcb80c7b18..6f23071c5f70 100644 --- a/src/CodeGen_X86.cpp +++ b/src/CodeGen_X86.cpp @@ -46,6 +46,7 @@ Target complete_x86_target(Target t) { t.set_feature(Target::AVXVNNI); } if (t.has_feature(Target::AVX512_Zen4)) { + t.set_feature(Target::AVXVNNI); t.set_feature(Target::AVX512_Cannonlake); } if (t.has_feature(Target::AVX512_Cannonlake)) { @@ -1092,9 +1093,6 @@ string CodeGen_X86::mattrs() const { attrs.emplace_back("+avx512bitalg"); attrs.emplace_back("+avx512vbmi2"); } - //if (target.has_feature(Target::AVXVNNI)) { - //attrs.emplace_back("+avxvnni"); - //} if (target.has_feature(Target::AVX512_SapphireRapids)) { attrs.emplace_back("+amx-int8"); attrs.emplace_back("+amx-bf16"); From 1593a55a9f445a55d3c56dff4162aafdb3397b9c Mon Sep 17 00:00:00 2001 From: Fabian Schuetze Date: Wed, 9 Jul 2025 17:38:16 +0200 Subject: [PATCH 3/5] add test --- src/CodeGen_X86.cpp | 8 +++---- src/runtime/x86_avx2.ll | 32 ++++++++++++++++++++++++++ test/correctness/simd_op_check_x86.cpp | 1 + 3 files changed, 37 insertions(+), 4 deletions(-) diff --git a/src/CodeGen_X86.cpp b/src/CodeGen_X86.cpp index 6f23071c5f70..3b6ec95246aa 100644 --- a/src/CodeGen_X86.cpp +++ b/src/CodeGen_X86.cpp @@ -282,16 +282,16 @@ const x86Intrinsic intrinsic_defs[] = { {"dpbusdx4", Int(32, 4), "dot_product", {Int(32, 4), UInt(8, 16), Int(8, 16)}, Target::AVXVNNI}, {"dpwssdx16", Int(32, 16), "dot_product", {Int(32, 16), Int(16, 32), Int(16, 32)}, Target::AVX512_Zen4}, - {"dpwssdx8", Int(32, 8), "dot_product", {Int(32, 8), Int(16, 16), Int(16, 16)}, Target::AVX512_Zen4}, - {"dpwssdx4", Int(32, 4), "dot_product", {Int(32, 4), Int(16, 8), Int(16, 8)}, Target::AVX512_Zen4}, + {"dpwssdx8", Int(32, 8), "dot_product", {Int(32, 8), Int(16, 16), Int(16, 16)}, Target::AVXVNNI}, + {"dpwssdx4", Int(32, 4), "dot_product", {Int(32, 4), Int(16, 8), Int(16, 8)}, Target::AVXVNNI}, {"dpbusdsx16", Int(32, 16), "saturating_dot_product", {Int(32, 16), UInt(8, 64), Int(8, 64)}, Target::AVX512_Zen4}, {"dpbusdsx8", Int(32, 8), "saturating_dot_product", {Int(32, 8), UInt(8, 32), Int(8, 32)}, Target::AVXVNNI}, {"dpbusdsx4", Int(32, 4), "saturating_dot_product", {Int(32, 4), UInt(8, 16), Int(8, 16)}, Target::AVXVNNI}, {"dpwssdsx16", Int(32, 16), "saturating_dot_product", {Int(32, 16), Int(16, 32), Int(16, 32)}, Target::AVX512_Zen4}, - {"dpwssdsx8", Int(32, 8), "saturating_dot_product", {Int(32, 8), Int(16, 16), Int(16, 16)}, Target::AVX512_Zen4}, - {"dpwssdsx4", Int(32, 4), "saturating_dot_product", {Int(32, 4), Int(16, 8), Int(16, 8)}, Target::AVX512_Zen4}, + {"dpwssdsx8", Int(32, 8), "saturating_dot_product", {Int(32, 8), Int(16, 16), Int(16, 16)}, Target::AVXVNNI}, + {"dpwssdsx4", Int(32, 4), "saturating_dot_product", {Int(32, 4), Int(16, 8), Int(16, 8)}, Target::AVXVNNI}, {"tileloadd64_i8", Int(8, 1024), "tile_load", {Int(16), Int(16), Handle(), Int(64), Int(64)}, Target::AVX512_SapphireRapids, x86Intrinsic::AccessesMemory}, {"tileloadd64_i8", UInt(8, 1024), "tile_load", {Int(16), Int(16), Handle(), Int(64), Int(64)}, Target::AVX512_SapphireRapids, x86Intrinsic::AccessesMemory}, diff --git a/src/runtime/x86_avx2.ll b/src/runtime/x86_avx2.ll index afcbcf8c1933..54801c973981 100644 --- a/src/runtime/x86_avx2.ll +++ b/src/runtime/x86_avx2.ll @@ -107,3 +107,35 @@ define weak_odr <4 x i32> @dpbusdsx4(<4 x i32> %init, <16 x i8> %a, <16 x i8> % ret <4 x i32> %3 } declare <4 x i32> @llvm.x86.avx512.vpdpbusds.128(<4 x i32>, <4 x i32>, <4 x i32>) + +define weak_odr <8 x i32> @dpwssdx8(<8 x i32> %init, <16 x i16> %a, <16 x i16> %b) nounwind alwaysinline { + %1 = bitcast <16 x i16> %a to <8 x i32> + %2 = bitcast <16 x i16> %b to <8 x i32> + %3 = tail call <8 x i32> @llvm.x86.avx512.vpdpwssd.256(<8 x i32> %init, <8 x i32> %1, <8 x i32> %2) + ret <8 x i32> %3 +} +declare <8 x i32> @llvm.x86.avx512.vpdpwssd.256(<8 x i32>, <8 x i32>, <8 x i32>) + +define weak_odr <4 x i32> @dpwssdx4(<4 x i32> %init, <8 x i16> %a, <8 x i16> %b) nounwind alwaysinline { + %1 = bitcast <8 x i16> %a to <4 x i32> + %2 = bitcast <8 x i16> %b to <4 x i32> + %3 = tail call <4 x i32> @llvm.x86.avx512.vpdpwssd.128(<4 x i32> %init, <4 x i32> %1, <4 x i32> %2) + ret <4 x i32> %3 +} +declare <4 x i32> @llvm.x86.avx512.vpdpwssd.128(<4 x i32>, <4 x i32>, <4 x i32>) + +define weak_odr <8 x i32> @dpwssdsx8(<8 x i32> %init, <16 x i16> %a, <16 x i16> %b) nounwind alwaysinline { + %1 = bitcast <16 x i16> %a to <8 x i32> + %2 = bitcast <16 x i16> %b to <8 x i32> + %3 = tail call <8 x i32> @llvm.x86.avx512.vpdpwssds.256(<8 x i32> %init, <8 x i32> %1, <8 x i32> %2) + ret <8 x i32> %3 +} +declare <8 x i32> @llvm.x86.avx512.vpdpwssds.256(<8 x i32>, <8 x i32>, <8 x i32>) + +define weak_odr <4 x i32> @dpwssdsx4(<4 x i32> %init, <8 x i16> %a, <8 x i16> %b) nounwind alwaysinline { + %1 = bitcast <8 x i16> %a to <4 x i32> + %2 = bitcast <8 x i16> %b to <4 x i32> + %3 = tail call <4 x i32> @llvm.x86.avx512.vpdpwssds.128(<4 x i32> %init, <4 x i32> %1, <4 x i32> %2) + ret <4 x i32> %3 +} +declare <4 x i32> @llvm.x86.avx512.vpdpwssds.128(<4 x i32>, <4 x i32>, <4 x i32>) diff --git a/test/correctness/simd_op_check_x86.cpp b/test/correctness/simd_op_check_x86.cpp index 0b2b3a8455fa..108760548385 100644 --- a/test/correctness/simd_op_check_x86.cpp +++ b/test/correctness/simd_op_check_x86.cpp @@ -709,6 +709,7 @@ int main(int argc, char **argv) { // real reason to test avx without it. Target("x86-64-linux-sse41-avx-f16c-fma"), Target("x86-64-linux-sse41-avx-f16c-fma-avx2"), + Target("x86-64-linux-sse41-avx-f16c-fma-avx2-avxvnni"), // See above: don't test avx512 without extra features, the test // isn't yet set up to test it properly. // Target("x86-64-linux-sse41-avx-avx2-avx512"), From 57ff867e3a94d0e2978dd00677f659c5d146a363 Mon Sep 17 00:00:00 2001 From: Fabian Schuetze Date: Thu, 24 Jul 2025 08:35:01 +0200 Subject: [PATCH 4/5] check feasible ops --- test/correctness/simd_op_check.h | 1 + 1 file changed, 1 insertion(+) diff --git a/test/correctness/simd_op_check.h b/test/correctness/simd_op_check.h index 25b641800987..154f0d5df8c4 100644 --- a/test/correctness/simd_op_check.h +++ b/test/correctness/simd_op_check.h @@ -132,6 +132,7 @@ class SimdOpCheckTest { Target::ARMv89a, Target::AVX, Target::AVX2, + Target::AVXVNNI, Target::AVX512, Target::AVX512_Cannonlake, Target::AVX512_KNL, From 97ebc77e91e89122ebbed2dc3fc098d831f58c34 Mon Sep 17 00:00:00 2001 From: Fabian Schuetze Date: Mon, 28 Jul 2025 08:00:39 +0200 Subject: [PATCH 5/5] remove vnni for zen4 --- src/CodeGen_X86.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/CodeGen_X86.cpp b/src/CodeGen_X86.cpp index 3b6ec95246aa..f111a7d2607b 100644 --- a/src/CodeGen_X86.cpp +++ b/src/CodeGen_X86.cpp @@ -46,7 +46,6 @@ Target complete_x86_target(Target t) { t.set_feature(Target::AVXVNNI); } if (t.has_feature(Target::AVX512_Zen4)) { - t.set_feature(Target::AVXVNNI); t.set_feature(Target::AVX512_Cannonlake); } if (t.has_feature(Target::AVX512_Cannonlake)) {