Skip to content

Commit

Permalink
+add Sse41::DescrInt::CosineDistancesUnpack (part 1).
Browse files Browse the repository at this point in the history
  • Loading branch information
ermig1979 committed Jun 26, 2023
1 parent 3393d41 commit 307db98
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 11 deletions.
7 changes: 0 additions & 7 deletions src/Simd/SimdBaseDescrInt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -543,8 +543,6 @@ namespace Simd
{
_encSize = 16 + DivHi(size * depth, 8);
_range = float((1 << _depth) - 1);
_microMd = 1;
_microNd = 1;
_minMax32f = MinMax32f;
_minMax16f = MinMax16f;
switch (depth)
Expand All @@ -556,7 +554,6 @@ namespace Simd
_decode32f = Decode32f4;
_decode16f = Decode16f4;
_cosineDistance = Base::CosineDistance<4>;
_macroCosineDistancesDirect = Base::MacroCosineDistancesDirect<4>;
break;
}
case 5:
Expand All @@ -566,7 +563,6 @@ namespace Simd
_decode32f = Decode32f5;
_decode16f = Decode16f5;
_cosineDistance = Base::CosineDistance<5>;
_macroCosineDistancesDirect = Base::MacroCosineDistancesDirect<5>;
break;
}
case 6:
Expand All @@ -576,7 +572,6 @@ namespace Simd
_decode32f = Decode32f6;
_decode16f = Decode16f6;
_cosineDistance = Base::CosineDistance<6>;
_macroCosineDistancesDirect = Base::MacroCosineDistancesDirect<6>;
break;
}
case 7:
Expand All @@ -586,7 +581,6 @@ namespace Simd
_decode32f = Decode32f7;
_decode16f = Decode16f7;
_cosineDistance = Base::CosineDistance<7>;
_macroCosineDistancesDirect = Base::MacroCosineDistancesDirect<7>;
break;
}
case 8:
Expand All @@ -596,7 +590,6 @@ namespace Simd
_decode32f = Decode32f8;
_decode16f = Decode16f8;
_cosineDistance = Base::CosineDistance<8>;
_macroCosineDistancesDirect = Base::MacroCosineDistancesDirect<8>;
break;
}
default:
Expand Down
17 changes: 14 additions & 3 deletions src/Simd/SimdDescrInt.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ namespace Simd
typedef void (*Decode32fPtr)(const uint8_t * src, float scale, float shift, size_t size, float* dst);
typedef void (*Decode16fPtr)(const uint8_t* src, float scale, float shift, size_t size, uint16_t* dst);
typedef void (*CosineDistancePtr)(const uint8_t* a, const uint8_t* b, size_t size, float* distance);
typedef void (*MacroCosineDistancesDirectPtr)(size_t M, size_t N, const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride);

MinMax32fPtr _minMax32f;
MinMax16fPtr _minMax16f;
Expand All @@ -70,8 +69,7 @@ namespace Simd
Decode32fPtr _decode32f;
Decode16fPtr _decode16f;
CosineDistancePtr _cosineDistance;
MacroCosineDistancesDirectPtr _macroCosineDistancesDirect;
size_t _size, _depth, _encSize, _microMd, _microNd;
size_t _size, _depth, _encSize;
float _range;
};

Expand All @@ -94,6 +92,19 @@ namespace Simd
protected:
void CosineDistancesDirect(size_t M, size_t N, const uint8_t* const* A, const uint8_t* const* B, float* distances) const;

typedef void (*MacroCosineDistancesDirectPtr)(size_t M, size_t N, const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride);
MacroCosineDistancesDirectPtr _macroCosineDistancesDirect;
size_t _microMd, _microNd;

void CosineDistancesUnpack(size_t M, size_t N, const uint8_t* const* A, const uint8_t* const* B, float* distances) const;

typedef void (*UnpackNormPtr)(size_t count, const uint8_t* const* src, float* dst, size_t stride);
typedef void (*UnpackDataPtr)(size_t count, const uint8_t* const* src, size_t size, uint8_t* dst, size_t stride);
typedef void (*MacroCosineDistancesUnpackPtr)(size_t M, size_t N, const uint8_t* ad, const float * an, const uint8_t* bd, const float* bn, size_t size, float* distances, size_t stride);
UnpackNormPtr _unpackNormA, _unpackNormB;
UnpackDataPtr _unpackDataA, _unpackDataB;
MacroCosineDistancesUnpackPtr _macroCosineDistancesUnpack;
size_t _microMu, _microNu, _unpSize;
};

//-------------------------------------------------------------------------------------------------
Expand Down
71 changes: 70 additions & 1 deletion src/Simd/SimdSse41DescrInt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1405,13 +1405,56 @@ namespace Simd

//-------------------------------------------------------------------------------------------------

static void UnpackNormA(size_t count, const uint8_t* const* src, float* dst, size_t stride)
{
for (size_t i = 0; i < count; ++i)
_mm_storeu_si128((__m128i*)dst + i, _mm_loadu_si128((__m128i*)src[i]));
}

//-------------------------------------------------------------------------------------------------


static void UnpackNormB(size_t count, const uint8_t* const* src, float* dst, size_t stride)
{
size_t count4 = AlignLo(count, 4), i = 0;
for (; i < count4; i += 4, src += 4, dst += 4)
{
__m128 s0 = _mm_loadu_ps((float*)src[0]);
__m128 s1 = _mm_loadu_ps((float*)src[1]);
__m128 s2 = _mm_loadu_ps((float*)src[2]);
__m128 s3 = _mm_loadu_ps((float*)src[3]);
__m128 s00 = _mm_unpacklo_ps(s0, s2);
__m128 s01 = _mm_unpacklo_ps(s1, s3);
__m128 s10 = _mm_unpackhi_ps(s0, s2);
__m128 s11 = _mm_unpackhi_ps(s1, s3);
_mm_storeu_ps(dst + 0 * stride, _mm_unpacklo_ps(s00, s01));
_mm_storeu_ps(dst + 1 * stride, _mm_unpackhi_ps(s00, s01));
_mm_storeu_ps(dst + 2 * stride, _mm_unpacklo_ps(s10, s11));
_mm_storeu_ps(dst + 3 * stride, _mm_unpackhi_ps(s10, s11));
}
for (; i < count; i++, src++, dst++)
{
dst[0 * stride] = ((float*)src)[0];
dst[1 * stride] = ((float*)src)[1];
dst[2 * stride] = ((float*)src)[2];
dst[3 * stride] = ((float*)src)[3];
}
}

//-------------------------------------------------------------------------------------------------

DescrInt::DescrInt(size_t size, size_t depth)
: Base::DescrInt(size, depth)
{
_minMax32f = MinMax32f;
_minMax16f = MinMax16f;
_unpackNormA = UnpackNormA;
_unpackNormB = UnpackNormB;
_microMd = 2;
_microNd = 4;
_unpSize = _size * (_depth == 8 ? 2 : 1);
_microMu = 5;
_microNu = 8;
switch (depth)
{
case 4:
Expand Down Expand Up @@ -1471,7 +1514,10 @@ namespace Simd

void DescrInt::CosineDistancesMxNa(size_t M, size_t N, const uint8_t* const* A, const uint8_t* const* B, float* distances) const
{
CosineDistancesDirect(M, N, A, B, distances);
if(_unpSize * _microNu > Base::AlgCacheL1() || N * 2 < _microNu || 1)
CosineDistancesDirect(M, N, A, B, distances);
else
CosineDistancesUnpack(M, N, A, B, distances);
}

void DescrInt::CosineDistancesMxNp(size_t M, size_t N, const uint8_t* A, const uint8_t* B, float* distances) const
Expand Down Expand Up @@ -1501,6 +1547,29 @@ namespace Simd
}
}

void DescrInt::CosineDistancesUnpack(size_t M, size_t N, const uint8_t* const* A, const uint8_t* const* B, float* distances) const
{
size_t macroM = AlignLoAny(Base::AlgCacheL2() / _unpSize, _microMu);
size_t macroN = AlignLoAny(Base::AlgCacheL3() / _unpSize, _microNu);
Array8u dA(Min(macroM, M) * _unpSize);
Array8u dB(Min(macroN, N) * _unpSize);
Array32f nA(Min(macroM, M) * 4);
Array32f nB(AlignHi(Min(macroN, N), _microNu) * 4);
for (size_t i = 0; i < M; i += macroM)
{
size_t dM = Simd::Min(M, i + macroM) - i;
_unpackNormA(dM, A + i, nA.data, 1);
//_unpackDataA(dM, A + i, _size, dA.data, 1);
for (size_t j = 0; j < N; j += macroN)
{
size_t dN = Simd::Min(N, j + macroN) - j;
_unpackNormB(dN, B + j, nB.data, dN);
//_unpackDataB(dN, B + j, _size, dB.data, _microNu);
//_macroCosineDistancesUnpack(dM, dN, dA.data, nA.data, dB.data, nB.data, _size, distances + i * N + j, N);
}
}
}

//-------------------------------------------------------------------------------------------------

void* DescrIntInit(size_t size, size_t depth)
Expand Down

0 comments on commit 307db98

Please sign in to comment.