From 6057373db5853d2b0199a1a7e9ec683f6289625e Mon Sep 17 00:00:00 2001 From: Yermalayeu Ihar Date: Wed, 28 Jun 2023 23:19:06 +0300 Subject: [PATCH] +add AVX-512VNNI optimizations of functions DescrIntCosineDistancesMxNp, DescrIntCosineDistancesMxNa. --- .github/workflows/cmake.yml | 2 +- docs/2023.html | 4 +- prj/vs2019/Avx512vnni.vcxproj | 6 + prj/vs2019/Avx512vnni.vcxproj.filters | 18 +++ prj/vs2022/Avx512vnni.vcxproj | 6 + prj/vs2022/Avx512vnni.vcxproj.filters | 18 +++ src/Simd/SimdAvx512vnniDescrInt.cpp | 59 ++++++++ src/Simd/SimdAvx512vnniDescrIntCdu.cpp | 191 +++++++++++++++++++++++++ src/Simd/SimdDescrInt.h | 19 +++ src/Simd/SimdLib.cpp | 2 +- src/Test/TestDescrInt.cpp | 10 ++ 11 files changed, 331 insertions(+), 4 deletions(-) create mode 100644 src/Simd/SimdAvx512vnniDescrInt.cpp create mode 100644 src/Simd/SimdAvx512vnniDescrIntCdu.cpp diff --git a/.github/workflows/cmake.yml b/.github/workflows/cmake.yml index cf06c53e07..74bcfa6b62 100644 --- a/.github/workflows/cmake.yml +++ b/.github/workflows/cmake.yml @@ -21,7 +21,7 @@ jobs: run: lscpu - name: Configure CMake - run: cmake ./prj/cmake -B ${{github.workspace}}/build -DCMAKE_BUILD_TYPE=${{matrix.build_type}} -DSIMD_TEST_FLAGS="-mavx2" + run: cmake ./prj/cmake -B ${{github.workspace}}/build -DCMAKE_BUILD_TYPE=${{matrix.build_type}} -DSIMD_AVX512VNNI=ON -DSIMD_TEST_FLAGS="-mavx2" - name: Build run: cmake --build ${{github.workspace}}/build --config ${{matrix.build_type}} --parallel$(nproc) diff --git a/docs/2023.html b/docs/2023.html index 9563fe0c7b..b8708c65ef 100644 --- a/docs/2023.html +++ b/docs/2023.html @@ -44,8 +44,8 @@
New features
  • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2, AVX-512BW optimizations of function DescrIntDecode32f.
  • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2, AVX-512BW optimizations of function DescrIntDecode16f.
  • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2, AVX-512BW optimizations of function DescrIntCosineDistance.
  • -
  • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2, AVX-512BW optimizations of function DescrIntCosineDistancesMxNp.
  • -
  • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2, AVX-512BW optimizations of function DescrIntCosineDistancesMxNa.
  • +
  • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2, AVX-512BW, AVX-512VNNI optimizations of function DescrIntCosineDistancesMxNp.
  • +
  • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2, AVX-512BW, AVX-512VNNI optimizations of function DescrIntCosineDistancesMxNa.
  • Base implementation, SSE4.1, AVX2, AVX-512BW optimizations of function SynetNormalizeLayerForwardV3.
  • Improving
    diff --git a/prj/vs2019/Avx512vnni.vcxproj b/prj/vs2019/Avx512vnni.vcxproj index 9a9b9be900..7df71eec09 100644 --- a/prj/vs2019/Avx512vnni.vcxproj +++ b/prj/vs2019/Avx512vnni.vcxproj @@ -18,7 +18,10 @@ + + + @@ -32,6 +35,7 @@ + @@ -43,6 +47,8 @@ + + diff --git a/prj/vs2019/Avx512vnni.vcxproj.filters b/prj/vs2019/Avx512vnni.vcxproj.filters index d4c2d4a11a..e8424fccb0 100644 --- a/prj/vs2019/Avx512vnni.vcxproj.filters +++ b/prj/vs2019/Avx512vnni.vcxproj.filters @@ -99,6 +99,18 @@ Inc + + Inc + + + Inc + + + Inc + + + Inc + @@ -125,5 +137,11 @@ Avx512vnni + + Avx512vnni + + + Avx512vnni + \ No newline at end of file diff --git a/prj/vs2022/Avx512vnni.vcxproj b/prj/vs2022/Avx512vnni.vcxproj index 9a9b9be900..7df71eec09 100644 --- a/prj/vs2022/Avx512vnni.vcxproj +++ b/prj/vs2022/Avx512vnni.vcxproj @@ -18,7 +18,10 @@ + + + @@ -32,6 +35,7 @@ + @@ -43,6 +47,8 @@ + + diff --git a/prj/vs2022/Avx512vnni.vcxproj.filters b/prj/vs2022/Avx512vnni.vcxproj.filters index d4c2d4a11a..e8424fccb0 100644 --- a/prj/vs2022/Avx512vnni.vcxproj.filters +++ b/prj/vs2022/Avx512vnni.vcxproj.filters @@ -99,6 +99,18 @@ Inc + + Inc + + + Inc + + + Inc + + + Inc + @@ -125,5 +137,11 @@ Avx512vnni + + Avx512vnni + + + Avx512vnni + \ No newline at end of file diff --git a/src/Simd/SimdAvx512vnniDescrInt.cpp b/src/Simd/SimdAvx512vnniDescrInt.cpp new file mode 100644 index 0000000000..f3fd1897d7 --- /dev/null +++ b/src/Simd/SimdAvx512vnniDescrInt.cpp @@ -0,0 +1,59 @@ +/* +* Simd Library (http://ermig1979.github.io/Simd). +* +* Copyright (c) 2011-2023 Yermalayeu Ihar. +* +* Permission is hereby granted, free of charge, to any person obtaining a copy +* of this software and associated documentation files (the "Software"), to deal +* in the Software without restriction, including without limitation the rights +* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +* copies of the Software, and to permit persons to whom the Software is +* furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +* SOFTWARE. +*/ +#include "Simd/SimdMemory.h" +#include "Simd/SimdStore.h" +#include "Simd/SimdExtract.h" +#include "Simd/SimdArray.h" +#include "Simd/SimdUnpack.h" +#include "Simd/SimdDescrInt.h" +#include "Simd/SimdDescrIntCommon.h" +#include "Simd/SimdCpu.h" + +namespace Simd +{ +#ifdef SIMD_AVX512VNNI_ENABLE + namespace Avx512vnni + { + DescrInt::DescrInt(size_t size, size_t depth) + : Avx512bw::DescrInt(size, depth) + { + if (_depth != 8) + { + _macroCosineDistancesUnpack = GetMacroCosineDistancesUnpack(_depth); + _microMu = 12; + _microNu = 32; + } + } + + //------------------------------------------------------------------------------------------------- + + void* DescrIntInit(size_t size, size_t depth) + { + if (!Base::DescrInt::Valid(size, depth)) + return NULL; + return new Avx512vnni::DescrInt(size, depth); + } + } +#endif +} diff --git a/src/Simd/SimdAvx512vnniDescrIntCdu.cpp b/src/Simd/SimdAvx512vnniDescrIntCdu.cpp new file mode 100644 index 0000000000..76aad377a4 --- /dev/null +++ b/src/Simd/SimdAvx512vnniDescrIntCdu.cpp @@ -0,0 +1,191 @@ +/* +* Simd Library (http://ermig1979.github.io/Simd). +* +* Copyright (c) 2011-2023 Yermalayeu Ihar. +* +* Permission is hereby granted, free of charge, to any person obtaining a copy +* of this software and associated documentation files (the "Software"), to deal +* in the Software without restriction, including without limitation the rights +* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +* copies of the Software, and to permit persons to whom the Software is +* furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +* SOFTWARE. +*/ +#include "Simd/SimdMemory.h" +#include "Simd/SimdStore.h" +#include "Simd/SimdExtract.h" +#include "Simd/SimdArray.h" +#include "Simd/SimdUnpack.h" +#include "Simd/SimdDescrInt.h" +#include "Simd/SimdDescrIntCommon.h" +#include "Simd/SimdCpu.h" +#include "Simd/SimdSynet.h" + +namespace Simd +{ +#ifdef SIMD_AVX512VNNI_ENABLE + namespace Avx512vnni + { + template void Correlation8_2xM(size_t N, size_t K, const uint8_t* ad0, const uint8_t* bd, const float* an, const float* bn, size_t bnStride, float* distances, size_t stride) + { + __m512i ab00, ab01, ab10, ab11, ab20, ab21, ab30, ab31, ab40, ab41, ab50, ab51, ab60, ab61, ab70, ab71, ab80, ab81, ab90, ab91, abA0, abA1, abB0, abB1, a0, b0, b1; + const uint8_t* ad1 = ad0 + 1 * K; + const uint8_t* ad2 = ad0 + 2 * K; + const uint8_t* ad3 = ad0 + 3 * K; + const uint8_t* ad4 = ad0 + 4 * K; + const uint8_t* ad5 = ad0 + 5 * K; + if (N > F) + { + if (M > 0x0) ab00 = _mm512_setzero_si512(), ab01 = _mm512_setzero_si512(); + if (M > 0x1) ab10 = _mm512_setzero_si512(), ab11 = _mm512_setzero_si512(); + if (M > 0x2) ab20 = _mm512_setzero_si512(), ab21 = _mm512_setzero_si512(); + if (M > 0x3) ab30 = _mm512_setzero_si512(), ab31 = _mm512_setzero_si512(); + if (M > 0x4) ab40 = _mm512_setzero_si512(), ab41 = _mm512_setzero_si512(); + if (M > 0x5) ab50 = _mm512_setzero_si512(), ab51 = _mm512_setzero_si512(); + if (M > 0x6) ab60 = _mm512_setzero_si512(), ab61 = _mm512_setzero_si512(); + if (M > 0x7) ab70 = _mm512_setzero_si512(), ab71 = _mm512_setzero_si512(); + if (M > 0x8) ab80 = _mm512_setzero_si512(), ab81 = _mm512_setzero_si512(); + if (M > 0x9) ab90 = _mm512_setzero_si512(), ab91 = _mm512_setzero_si512(); + if (M > 0xA) abA0 = _mm512_setzero_si512(), abA1 = _mm512_setzero_si512(); + if (M > 0xB) abB0 = _mm512_setzero_si512(), abB1 = _mm512_setzero_si512(); + for (size_t k0 = 0, k6 = K * 6; k0 < K; k0 += 4, k6 += 4) + { + b0 = _mm512_loadu_si512((__m512i*)bd + 0); + b1 = _mm512_loadu_si512((__m512i*)bd + 1); + if (M > 0x0) a0 = Set4(ad0 + k0), Madd4(ab00, a0, b0), Madd4(ab01, a0, b1); + if (M > 0x1) a0 = Set4(ad1 + k0), Madd4(ab10, a0, b0), Madd4(ab11, a0, b1); + if (M > 0x2) a0 = Set4(ad2 + k0), Madd4(ab20, a0, b0), Madd4(ab21, a0, b1); + if (M > 0x3) a0 = Set4(ad3 + k0), Madd4(ab30, a0, b0), Madd4(ab31, a0, b1); + if (M > 0x4) a0 = Set4(ad4 + k0), Madd4(ab40, a0, b0), Madd4(ab41, a0, b1); + if (M > 0x5) a0 = Set4(ad5 + k0), Madd4(ab50, a0, b0), Madd4(ab51, a0, b1); + if (M > 0x6) a0 = Set4(ad0 + k6), Madd4(ab60, a0, b0), Madd4(ab61, a0, b1); + if (M > 0x7) a0 = Set4(ad1 + k6), Madd4(ab70, a0, b0), Madd4(ab71, a0, b1); + if (M > 0x8) a0 = Set4(ad2 + k6), Madd4(ab80, a0, b0), Madd4(ab81, a0, b1); + if (M > 0x9) a0 = Set4(ad3 + k6), Madd4(ab90, a0, b0), Madd4(ab91, a0, b1); + if (M > 0xA) a0 = Set4(ad4 + k6), Madd4(abA0, a0, b0), Madd4(abA1, a0, b1); + if (M > 0xB) a0 = Set4(ad5 + k6), Madd4(abB0, a0, b0), Madd4(abB1, a0, b1); + bd += DA; + } + __mmask16 tail = TailMask16(N - F); + if (M > 0x0) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab00, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, ab01, distances + F, tail), an += 4, distances += stride; + if (M > 0x1) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab10, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, ab11, distances + F, tail), an += 4, distances += stride; + if (M > 0x2) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab20, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, ab21, distances + F, tail), an += 4, distances += stride; + if (M > 0x3) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab30, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, ab31, distances + F, tail), an += 4, distances += stride; + if (M > 0x4) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab40, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, ab41, distances + F, tail), an += 4, distances += stride; + if (M > 0x5) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab50, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, ab51, distances + F, tail), an += 4, distances += stride; + if (M > 0x6) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab60, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, ab61, distances + F, tail), an += 4, distances += stride; + if (M > 0x7) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab70, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, ab71, distances + F, tail), an += 4, distances += stride; + if (M > 0x8) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab80, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, ab81, distances + F, tail), an += 4, distances += stride; + if (M > 0x9) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab90, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, ab91, distances + F, tail), an += 4, distances += stride; + if (M > 0xA) DecodeCosineDistances1xF(an, bn + 0, bnStride, abA0, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, abA1, distances + F, tail), an += 4, distances += stride; + if (M > 0xB) DecodeCosineDistances1xF(an, bn + 0, bnStride, abB0, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, abB1, distances + F, tail), an += 4, distances += stride; + } + else + { + if (M > 0x0) ab00 = _mm512_setzero_si512(); + if (M > 0x1) ab10 = _mm512_setzero_si512(); + if (M > 0x2) ab20 = _mm512_setzero_si512(); + if (M > 0x3) ab30 = _mm512_setzero_si512(); + if (M > 0x4) ab40 = _mm512_setzero_si512(); + if (M > 0x5) ab50 = _mm512_setzero_si512(); + if (M > 0x6) ab60 = _mm512_setzero_si512(); + if (M > 0x7) ab70 = _mm512_setzero_si512(); + if (M > 0x8) ab80 = _mm512_setzero_si512(); + if (M > 0x9) ab90 = _mm512_setzero_si512(); + if (M > 0xA) abA0 = _mm512_setzero_si512(); + if (M > 0xB) abB0 = _mm512_setzero_si512(); + for (size_t k0 = 0, k6 = K * 6; k0 < K; k0 += 4, k6 += 4) + { + b0 = _mm512_loadu_si512((__m512i*)bd + 0); + if (M > 0x0) a0 = Set4(ad0 + k0), Madd4(ab00, a0, b0); + if (M > 0x1) a0 = Set4(ad1 + k0), Madd4(ab10, a0, b0); + if (M > 0x2) a0 = Set4(ad2 + k0), Madd4(ab20, a0, b0); + if (M > 0x3) a0 = Set4(ad3 + k0), Madd4(ab30, a0, b0); + if (M > 0x4) a0 = Set4(ad4 + k0), Madd4(ab40, a0, b0); + if (M > 0x5) a0 = Set4(ad5 + k0), Madd4(ab50, a0, b0); + if (M > 0x6) a0 = Set4(ad0 + k6), Madd4(ab60, a0, b0); + if (M > 0x7) a0 = Set4(ad1 + k6), Madd4(ab70, a0, b0); + if (M > 0x8) a0 = Set4(ad2 + k6), Madd4(ab80, a0, b0); + if (M > 0x9) a0 = Set4(ad3 + k6), Madd4(ab90, a0, b0); + if (M > 0xA) a0 = Set4(ad4 + k6), Madd4(abA0, a0, b0); + if (M > 0xB) a0 = Set4(ad5 + k6), Madd4(abB0, a0, b0); + bd += DA; + } + __mmask16 tail = TailMask16(N); + if (M > 0x0) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab00, distances + 0, tail), an += 4, distances += stride; + if (M > 0x1) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab10, distances + 0, tail), an += 4, distances += stride; + if (M > 0x2) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab20, distances + 0, tail), an += 4, distances += stride; + if (M > 0x3) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab30, distances + 0, tail), an += 4, distances += stride; + if (M > 0x4) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab40, distances + 0, tail), an += 4, distances += stride; + if (M > 0x5) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab50, distances + 0, tail), an += 4, distances += stride; + if (M > 0x6) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab60, distances + 0, tail), an += 4, distances += stride; + if (M > 0x7) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab70, distances + 0, tail), an += 4, distances += stride; + if (M > 0x8) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab80, distances + 0, tail), an += 4, distances += stride; + if (M > 0x9) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab90, distances + 0, tail), an += 4, distances += stride; + if (M > 0xA) DecodeCosineDistances1xF(an, bn + 0, bnStride, abA0, distances + 0, tail), an += 4, distances += stride; + if (M > 0xB) DecodeCosineDistances1xF(an, bn + 0, bnStride, abB0, distances + 0, tail), an += 4, distances += stride; + } + } + + typedef void(*Correlation8_2xM_Ptr)(size_t N, size_t K, const uint8_t* ad0, const uint8_t* bd, const float* an, const float* bn, size_t bnStride, float* distances, size_t stride); + + SIMD_INLINE Correlation8_2xM_Ptr GetCorrelation8_2xM(size_t M) + { + switch (M) + { + case 0x0: return NULL; + case 0x1: return Correlation8_2xM<0x1>; + case 0x2: return Correlation8_2xM<0x2>; + case 0x3: return Correlation8_2xM<0x3>; + case 0x4: return Correlation8_2xM<0x4>; + case 0x5: return Correlation8_2xM<0x5>; + case 0x6: return Correlation8_2xM<0x6>; + case 0x7: return Correlation8_2xM<0x7>; + case 0x8: return Correlation8_2xM<0x8>; + case 0x9: return Correlation8_2xM<0x9>; + case 0xA: return Correlation8_2xM<0xA>; + case 0xB: return Correlation8_2xM<0xB>; + case 0xC: return Correlation8_2xM<0xC>; + } + assert(0); + return NULL; + } + + void MacroCorrelation8(size_t M, size_t N, size_t K, const uint8_t* ad, const float* an, const uint8_t* bd, const float* bn, float* distances, size_t stride) + { + size_t M12 = AlignLoAny(M, 12); + Correlation8_2xM_Ptr correlation_2x12 = GetCorrelation8_2xM(12); + Correlation8_2xM_Ptr correlation_2xT = GetCorrelation8_2xM(M - M12); + for (size_t j = 0; j < N; j += DF) + { + size_t dN = Simd::Min(DF, N - j); + size_t i = 0; + for (; i < M12; i += 12) + correlation_2x12(dN, K, ad + i * K, bd, an + i * 4, bn, N, distances + i * stride, stride); + if (i < M) + correlation_2xT(dN, K, ad + i * K, bd, an + i * 4, bn, N, distances + i * stride, stride); + bd += K * DF; + bn += DF; + distances += DF; + } + } + + //------------------------------------------------------------------------------------------------- + + Sse41::DescrInt::MacroCosineDistancesUnpackPtr GetMacroCosineDistancesUnpack(size_t depth) + { + return depth == 8 ? NULL : MacroCorrelation8; + } + } +#endif +} diff --git a/src/Simd/SimdDescrInt.h b/src/Simd/SimdDescrInt.h index 40444bf96b..1065b874fd 100644 --- a/src/Simd/SimdDescrInt.h +++ b/src/Simd/SimdDescrInt.h @@ -187,5 +187,24 @@ namespace Simd void* DescrIntInit(size_t size, size_t depth); } #endif + +#ifdef SIMD_AVX512VNNI_ENABLE + namespace Avx512vnni + { + class DescrInt : public Avx512bw::DescrInt + { + public: + DescrInt(size_t size, size_t depth); + }; + + //------------------------------------------------------------------------------------------------- + + Sse41::DescrInt::MacroCosineDistancesUnpackPtr GetMacroCosineDistancesUnpack(size_t depth); + + //------------------------------------------------------------------------------------------------- + + void* DescrIntInit(size_t size, size_t depth); + } +#endif } #endif//__SimdDescrInt_h__ diff --git a/src/Simd/SimdLib.cpp b/src/Simd/SimdLib.cpp index ad5de4e830..d7d0dab8a1 100644 --- a/src/Simd/SimdLib.cpp +++ b/src/Simd/SimdLib.cpp @@ -1922,7 +1922,7 @@ SIMD_API void* SimdDescrIntInit(size_t size, size_t depth) { SIMD_EMPTY(); typedef void* (*SimdDescrIntInitPtr) (size_t size, size_t depth); - const static SimdDescrIntInitPtr simdDescrIntInit = SIMD_FUNC3(DescrIntInit, SIMD_AVX512BW_FUNC, SIMD_AVX2_FUNC, SIMD_SSE41_FUNC);// , SIMD_NEON_FUNC); + const static SimdDescrIntInitPtr simdDescrIntInit = SIMD_FUNC4(DescrIntInit, SIMD_AVX512VNNI_FUNC, SIMD_AVX512BW_FUNC, SIMD_AVX2_FUNC, SIMD_SSE41_FUNC);// , SIMD_NEON_FUNC); return simdDescrIntInit(size, depth); } diff --git a/src/Test/TestDescrInt.cpp b/src/Test/TestDescrInt.cpp index 5b65e282ab..311a335845 100644 --- a/src/Test/TestDescrInt.cpp +++ b/src/Test/TestDescrInt.cpp @@ -610,6 +610,11 @@ namespace Test if (Simd::Avx512bw::Enable) result = result && DescrIntCosineDistancesMxNaAutoTest(FUNC_DI(Simd::Avx512bw::DescrIntInit), FUNC_DI(SimdDescrIntInit)); #endif + +#ifdef SIMD_AVX512VNNI_ENABLE + if (Simd::Avx512vnni::Enable) + result = result && DescrIntCosineDistancesMxNaAutoTest(FUNC_DI(Simd::Avx512vnni::DescrIntInit), FUNC_DI(SimdDescrIntInit)); +#endif //#if defined(SIMD_NEON_ENABLE) // if (Simd::Neon::Enable) @@ -684,6 +689,11 @@ namespace Test result = result && DescrIntCosineDistancesMxNpAutoTest(FUNC_DI(Simd::Avx512bw::DescrIntInit), FUNC_DI(SimdDescrIntInit)); #endif +#ifdef SIMD_AVX512VNNI_ENABLE + if (Simd::Avx512vnni::Enable) + result = result && DescrIntCosineDistancesMxNpAutoTest(FUNC_DI(Simd::Avx512vnni::DescrIntInit), FUNC_DI(SimdDescrIntInit)); +#endif + //#if defined(SIMD_NEON_ENABLE) // if (Simd::Neon::Enable) // result = result && DescrIntCosineDistancesMxNpAutoTest(FUNC_DI(Simd::Neon::DescrIntInit), FUNC_DI(SimdDescrIntInit));