Skip to content

Commit

Permalink
+add Sse41::DescrInt::CosineDistancesUnpack (part 2, depth=8).
Browse files Browse the repository at this point in the history
  • Loading branch information
ermig1979 committed Jun 27, 2023
1 parent 307db98 commit 0649860
Show file tree
Hide file tree
Showing 3 changed files with 235 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/Simd/SimdDescrInt.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ namespace Simd

typedef void (*UnpackNormPtr)(size_t count, const uint8_t* const* src, float* dst, size_t stride);
typedef void (*UnpackDataPtr)(size_t count, const uint8_t* const* src, size_t size, uint8_t* dst, size_t stride);
typedef void (*MacroCosineDistancesUnpackPtr)(size_t M, size_t N, const uint8_t* ad, const float * an, const uint8_t* bd, const float* bn, size_t size, float* distances, size_t stride);
typedef void (*MacroCosineDistancesUnpackPtr)(size_t M, size_t N, size_t K, const uint8_t* ad, const float * an, const uint8_t* bd, const float* bn, float* distances, size_t stride);
UnpackNormPtr _unpackNormA, _unpackNormB;
UnpackDataPtr _unpackDataA, _unpackDataB;
MacroCosineDistancesUnpackPtr _macroCosineDistancesUnpack;
Expand Down
24 changes: 24 additions & 0 deletions src/Simd/SimdDescrIntCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,30 @@ namespace Simd

_mm_storeu_ps(distances, _mm_min_ps(_mm_max_ps(_mm_sub_ps(_mm_set1_ps(1.0f), _mm_div_ps(ab, _mm_mul_ps(aNorm, bNorm))), _mm_setzero_ps()), _mm_set1_ps(2.0f)));
}

SIMD_INLINE void DecodeCosineDistances1x4(const float* a, const float *b, size_t stride, __m128i abSum, float* distances)
{
__m128 aScale = _mm_set1_ps(a[0]);
__m128 aShift = _mm_set1_ps(a[1]);
__m128 aMean = _mm_set1_ps(a[2]);
__m128 aNorm = _mm_set1_ps(a[3]);
__m128 bScale = _mm_loadu_ps(b + 0 * stride);
__m128 bShift = _mm_loadu_ps(b + 1 * stride);
__m128 bMean = _mm_loadu_ps(b + 2 * stride);
__m128 bNorm = _mm_loadu_ps(b + 3 * stride);
__m128 ab = _mm_mul_ps(_mm_cvtepi32_ps(abSum), _mm_mul_ps(aScale, bScale));
ab = _mm_add_ps(_mm_mul_ps(aMean, bShift), ab);
ab = _mm_add_ps(_mm_mul_ps(bMean, aShift), ab);
_mm_storeu_ps(distances, _mm_min_ps(_mm_max_ps(_mm_sub_ps(_mm_set1_ps(1.0f), _mm_div_ps(ab, _mm_mul_ps(aNorm, bNorm))), _mm_setzero_ps()), _mm_set1_ps(2.0f)));
}

SIMD_INLINE void DecodeCosineDistances1x4(const float* a, const float* b, size_t stride, __m128i abSum, float* distances, size_t N)
{
float d[4];
DecodeCosineDistances1x4(a, b, stride, abSum, d);
for (size_t i = 0; i < N; ++i)
distances[i] = d[i];
}
}
#endif

Expand Down
214 changes: 210 additions & 4 deletions src/Simd/SimdSse41DescrInt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1088,7 +1088,7 @@ namespace Simd
template<> void MicroCosineDistancesDirect1x4<4>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride)
{
size_t i = 0, size32 = AlignLo(size, 32), o = 16;
__m128i a0, a1, b0;
__m128i a0, b0;
__m128i ab00 = _mm_setzero_si128();
__m128i ab01 = _mm_setzero_si128();
__m128i ab02 = _mm_setzero_si128();
Expand Down Expand Up @@ -1443,6 +1443,209 @@ namespace Simd

//-------------------------------------------------------------------------------------------------

static void UnpackDataA8(size_t count, const uint8_t* const* src, size_t size, uint8_t* dst, size_t stride)
{
size_t size16 = AlignLo(size, 16);
for (size_t i = 0, j; i < count; i++)
{
const uint8_t* ps = src[i] + 16;
uint16_t* pd = (uint16_t*)dst + i * size;
for (j = 0; j < size16; j += 16, ps += 16, pd += 16)
{
__m128i s = _mm_loadu_si128((__m128i*)ps);
_mm_storeu_si128((__m128i*)pd + 0, UnpackU8<0>(s));
_mm_storeu_si128((__m128i*)pd + 1, UnpackU8<1>(s));
}
for (; j < size; j += 8, ps += 8, pd += 8)
{
__m128i s = _mm_loadl_epi64((__m128i*)ps);
_mm_storeu_si128((__m128i*)pd, UnpackU8<0>(s));
}
}
}

//-------------------------------------------------------------------------------------------------

SIMD_INLINE void UnpackDataB8x4(const uint8_t* const* src, size_t offset, uint8_t* dst)
{
__m128i a0 = UnpackU8<0>(_mm_loadl_epi64((__m128i*)(src[0] + offset)));
__m128i a1 = UnpackU8<0>(_mm_loadl_epi64((__m128i*)(src[1] + offset)));
__m128i a2 = UnpackU8<0>(_mm_loadl_epi64((__m128i*)(src[2] + offset)));
__m128i a3 = UnpackU8<0>(_mm_loadl_epi64((__m128i*)(src[3] + offset)));
__m128i b0 = _mm_unpacklo_epi32(a0, a2);
__m128i b1 = _mm_unpacklo_epi32(a1, a3);
__m128i b2 = _mm_unpackhi_epi32(a0, a2);
__m128i b3 = _mm_unpackhi_epi32(a1, a3);
_mm_storeu_si128((__m128i*)dst + 0, _mm_unpacklo_epi32(b0, b1));
_mm_storeu_si128((__m128i*)dst + 2, _mm_unpackhi_epi32(b0, b1));
_mm_storeu_si128((__m128i*)dst + 4, _mm_unpacklo_epi32(b2, b3));
_mm_storeu_si128((__m128i*)dst + 6, _mm_unpackhi_epi32(b2, b3));
}

static void UnpackDataB8(size_t count, const uint8_t* const* src, size_t size, uint8_t* dst, size_t stride)
{
size_t count8 = AlignLo(count, 8), i;
for (i = 0, size += 16; i < count8; i += 8, src += 8)
{
for (size_t j = 16; j < size; j += 8, dst += 8 * A)
{
UnpackDataB8x4(src + 0, j, dst + 0);
UnpackDataB8x4(src + 4, j, dst + A);
}
}
if (i < count)
{
const uint8_t* _src[8];
for (size_t j = 0; j < 8; i++, j++)
_src[j] = i < count ? *src++ : src[-1];
for (size_t j = 16; j < size; j += 8, dst += 8 * A)
{
UnpackDataB8x4(_src + 0, j, dst + 0);
UnpackDataB8x4(_src + 4, j, dst + A);
}
}
}

//-------------------------------------------------------------------------------------------------

SIMD_INLINE __m128i Set2(const int16_t* src)
{
return _mm_set1_epi32(*(int32_t*)src);
}

SIMD_INLINE void Madd2(__m128i& ab, __m128i a, __m128i b)
{
ab = _mm_add_epi32(ab, _mm_madd_epi16(a, b));
}

template<int M> void Correlation16_2xM(size_t N, size_t K, const int16_t* ad0, const int16_t* bd, const float *an, const float *bn, size_t bnStride, float* distances, size_t stride)
{
__m128i ab00, ab01, ab10, ab11, ab20, ab21, ab30, ab31, ab40, ab41, ab50, ab51, a0, b0, b1;
const int16_t* ad1 = ad0 + 1 * K;
const int16_t* ad2 = ad0 + 2 * K;
const int16_t* ad3 = ad0 + 3 * K;
const int16_t* ad4 = ad0 + 4 * K;
const int16_t* ad5 = ad0 + 5 * K;
if (N > 4)
{
if (M > 0) ab00 = _mm_setzero_si128(), ab01 = _mm_setzero_si128();
if (M > 1) ab10 = _mm_setzero_si128(), ab11 = _mm_setzero_si128();
if (M > 2) ab20 = _mm_setzero_si128(), ab21 = _mm_setzero_si128();
if (M > 3) ab30 = _mm_setzero_si128(), ab31 = _mm_setzero_si128();
if (M > 4) ab40 = _mm_setzero_si128(), ab41 = _mm_setzero_si128();
if (M > 5) ab50 = _mm_setzero_si128(), ab51 = _mm_setzero_si128();
for (size_t k = 0; k < K; k += 2)
{
b0 = _mm_loadu_si128((__m128i*)bd + 0);
b1 = _mm_loadu_si128((__m128i*)bd + 1);
if (M > 0) a0 = Set2(ad0 + k), Madd2(ab00, a0, b0), Madd2(ab01, a0, b1);
if (M > 1) a0 = Set2(ad1 + k), Madd2(ab10, a0, b0), Madd2(ab11, a0, b1);
if (M > 2) a0 = Set2(ad2 + k), Madd2(ab20, a0, b0), Madd2(ab21, a0, b1);
if (M > 3) a0 = Set2(ad3 + k), Madd2(ab30, a0, b0), Madd2(ab31, a0, b1);
if (M > 4) a0 = Set2(ad4 + k), Madd2(ab40, a0, b0), Madd2(ab41, a0, b1);
if (M > 5) a0 = Set2(ad5 + k), Madd2(ab50, a0, b0), Madd2(ab51, a0, b1);
bd += 16;
}
if (N == 8)
{
if (M > 0) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab00, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab01, distances + 4), an += 4, distances += stride;
if (M > 1) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab10, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab11, distances + 4), an += 4, distances += stride;
if (M > 2) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab20, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab21, distances + 4), an += 4, distances += stride;
if (M > 3) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab30, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab31, distances + 4), an += 4, distances += stride;
if (M > 4) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab40, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab41, distances + 4), an += 4, distances += stride;
if (M > 5) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab50, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab51, distances + 4), an += 4, distances += stride;
}
else
{
if (M > 0) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab00, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab01, distances + 4, N - 4), an += 4, distances += stride;
if (M > 1) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab10, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab11, distances + 4, N - 4), an += 4, distances += stride;
if (M > 2) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab20, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab21, distances + 4, N - 4), an += 4, distances += stride;
if (M > 3) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab30, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab31, distances + 4, N - 4), an += 4, distances += stride;
if (M > 4) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab40, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab41, distances + 4, N - 4), an += 4, distances += stride;
if (M > 5) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab50, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab51, distances + 4, N - 4), an += 4, distances += stride;
}
}
else
{
if (M > 0) ab00 = _mm_setzero_si128();
if (M > 1) ab10 = _mm_setzero_si128();
if (M > 2) ab20 = _mm_setzero_si128();
if (M > 3) ab30 = _mm_setzero_si128();
if (M > 4) ab40 = _mm_setzero_si128();
if (M > 5) ab50 = _mm_setzero_si128();
for (size_t k = 0; k < K; k += 2)
{
b0 = _mm_loadu_si128((__m128i*)bd + 0);
if (M > 0) a0 = Set2(ad0 + k), Madd2(ab00, a0, b0);
if (M > 1) a0 = Set2(ad1 + k), Madd2(ab10, a0, b0);
if (M > 2) a0 = Set2(ad2 + k), Madd2(ab20, a0, b0);
if (M > 3) a0 = Set2(ad3 + k), Madd2(ab30, a0, b0);
if (M > 4) a0 = Set2(ad4 + k), Madd2(ab40, a0, b0);
if (M > 5) a0 = Set2(ad5 + k), Madd2(ab50, a0, b0);
bd += 16;
}
if (N == 4)
{
if (M > 0) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab00, distances + 0), an += 4, distances += stride;
if (M > 1) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab10, distances + 0), an += 4, distances += stride;
if (M > 2) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab20, distances + 0), an += 4, distances += stride;
if (M > 3) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab30, distances + 0), an += 4, distances += stride;
if (M > 4) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab40, distances + 0), an += 4, distances += stride;
if (M > 5) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab50, distances + 0), an += 4, distances += stride;
}
else
{
if (M > 0) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab00, distances + 0, N), an += 4, distances += stride;
if (M > 1) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab10, distances + 0, N), an += 4, distances += stride;
if (M > 2) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab20, distances + 0, N), an += 4, distances += stride;
if (M > 3) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab30, distances + 0, N), an += 4, distances += stride;
if (M > 4) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab40, distances + 0, N), an += 4, distances += stride;
if (M > 5) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab50, distances + 0, N), an += 4, distances += stride;
}
}
}

typedef void(*Correlation16_2xM_Ptr)(size_t N, size_t K, const int16_t* ad0, const int16_t* bd, const float* an, const float* bn, size_t bnStride, float* distances, size_t stride);

SIMD_INLINE Correlation16_2xM_Ptr GetCorrelation16_2xM(size_t M)
{
switch (M)
{
case 0: return NULL;
case 1: return Correlation16_2xM<1>;
case 2: return Correlation16_2xM<2>;
case 3: return Correlation16_2xM<3>;
case 4: return Correlation16_2xM<4>;
case 5: return Correlation16_2xM<5>;
case 6: return Correlation16_2xM<6>;
}
assert(0);
return NULL;
}

void MacroCorrelation16(size_t M, size_t N, size_t K, const uint8_t* ad, const float* an, const uint8_t* bd, const float* bn, float* distances, size_t stride)
{
size_t M6 = AlignLoAny(M, 6);
Correlation16_2xM_Ptr correlation_2x6 = GetCorrelation16_2xM(6);
Correlation16_2xM_Ptr correlation_2xT = GetCorrelation16_2xM(M - M6);
const int16_t* a = (int16_t*)ad;
const int16_t* b = (int16_t*)bd;
for (size_t j = 0; j < N; j += 8)
{
size_t dN = Simd::Min<size_t>(8, N - j);
size_t i = 0;
for (; i < M6; i += 6)
correlation_2x6(dN, K, a + i * K, b, an + i * 4, bn, N, distances + i * stride, stride);
if(i < M)
correlation_2xT(dN, K, a + i * K, b, an + i * 4, bn, N, distances + i * stride, stride);
b += K * 8;
bn += 8;
distances += 8;
}
}

//-------------------------------------------------------------------------------------------------

DescrInt::DescrInt(size_t size, size_t depth)
: Base::DescrInt(size, depth)
{
Expand Down Expand Up @@ -1505,6 +1708,9 @@ namespace Simd
_decode16f = Decode16f8;
_cosineDistance = Sse41::CosineDistance<8>;
_macroCosineDistancesDirect = Sse41::MacroCosineDistancesDirect<8>;
_unpackDataA = UnpackDataA8;
_unpackDataB = UnpackDataB8;
_macroCosineDistancesUnpack = MacroCorrelation16;
break;
}
default:
Expand Down Expand Up @@ -1559,13 +1765,13 @@ namespace Simd
{
size_t dM = Simd::Min(M, i + macroM) - i;
_unpackNormA(dM, A + i, nA.data, 1);
//_unpackDataA(dM, A + i, _size, dA.data, 1);
_unpackDataA(dM, A + i, _size, dA.data, _unpSize);
for (size_t j = 0; j < N; j += macroN)
{
size_t dN = Simd::Min(N, j + macroN) - j;
_unpackNormB(dN, B + j, nB.data, dN);
//_unpackDataB(dN, B + j, _size, dB.data, _microNu);
//_macroCosineDistancesUnpack(dM, dN, dA.data, nA.data, dB.data, nB.data, _size, distances + i * N + j, N);
_unpackDataB(dN, B + j, _size, dB.data, 1);
_macroCosineDistancesUnpack(dM, dN, _size, dA.data, nA.data, dB.data, nB.data, distances + i * N + j, N);
}
}
}
Expand Down

0 comments on commit 0649860

Please sign in to comment.