Skip to content

Commit

Permalink
+add functions Base::InnerProduct16bGemmNN_ConvertBn, Base::InnerProd…
Browse files Browse the repository at this point in the history
…uct16bGemmNN_ConvertBt.
  • Loading branch information
ermig1979 committed Jun 11, 2024
1 parent f165525 commit 2263558
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 37 deletions.
20 changes: 11 additions & 9 deletions src/Simd/SimdBaseSynetInnerProduct16b.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,35 +87,37 @@ namespace Simd
void SynetInnerProduct16bRef::GemmAndBias(const uint16_t* A, const uint16_t* B, float* C)
{
const InnerProductParam16b& p = _param;
Array32f Af(p.K);
for (size_t i = 0; i < p.M; ++i)
{
float* pC = C + i * p.N;
for (size_t k = 0; k < p.K; ++k)
Af[k] = BFloat16ToFloat32(A[k]);
if (p.transB)
{
for (size_t j = 0; j < p.N; ++j)
{
const uint16_t* pA = A + i * p.K;
const uint16_t* pB = B + j * p.K;
pC[j] = 0;
C[j] = 0;
for (size_t k = 0; k < p.K; ++k)
pC[j] += BFloat16ToFloat32(pA[k]) * BFloat16ToFloat32(pB[k]);
C[j] += Af[k] * BFloat16ToFloat32(pB[k]);
}
}
else
{
for (size_t j = 0; j < p.N; ++j)
pC[j] = 0.0;
C[j] = 0.0;
for (size_t k = 0; k < p.K; ++k)
{
const uint16_t* pB = B + k * p.N;
float a = BFloat16ToFloat32(A[i * p.K + k]);
for (size_t j = 0; j < p.N; ++j)
pC[j] += a * BFloat16ToFloat32(pB[j]);
C[j] += Af[k] * BFloat16ToFloat32(pB[j]);
}
}
for (size_t j = 0; j < p.N; ++j)
pC[j] += _bias[j];
}
C[j] += _bias[j];
A += p.K;
C += p.N;
}
}

//-------------------------------------------------------------------------------------------------
Expand Down
84 changes: 62 additions & 22 deletions src/Simd/SimdBaseSynetInnerProduct16bGemmNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,61 @@ namespace Simd
#if defined(SIMD_SYNET_ENABLE)
namespace Base
{
typedef Base::SynetInnerProduct16bGemmNN::AlgParam AlgParam;

//-----------------------------------------------------------------------------------------

static void InnerProduct16bGemmNN_ConvertBn(const uint8_t* src8, const InnerProductParam16b& p, const AlgParam& a, size_t size, size_t K, uint16_t* dst)
{
const float* src = (float*)src8;
size_t N = DivHi(p.N, a.F);
for (size_t n = 0; n < N; n++)
{
for (size_t k = 0; k < a.aK; k += 2)
{
const float* ps = src + k * p.N + n * a.F;
for (size_t f = 0; f < a.F; ++f)
{
for (size_t i = 0; i < 2; ++i)
{
if (n * a.F + f < p.N && k + i < p.K)
*(dst++) = Float32ToBFloat16(ps[i * p.N + f]);
else
*(dst++) = 0;
}
}
}
}
}

static void InnerProduct16bGemmNN_ConvertBt(const uint8_t* src8, const InnerProductParam16b& p, const AlgParam& a, size_t size, size_t K, uint16_t* dst)
{
const float* src = (float*)src8;
size_t N = DivHi(p.N, a.F);
for (size_t n = 0; n < N; n++)
{
for (size_t k = 0; k < a.aK; k += 2)
{
const float* ps = src + n * a.F * p.K + k;
for (size_t f = 0; f < a.F; ++f)
{
for (size_t i = 0; i < 2; ++i)
{
if (n * a.F + f < p.N && k + i < p.K)
*(dst++) = Float32ToBFloat16(ps[f * p.K + i]);
else
*(dst++) = 0;
}
}
}
}
}

//-----------------------------------------------------------------------------------------

bool SynetInnerProduct16bGemmNN::Preferable(const InnerProductParam16b& p)
{
return p.constB == SimdTrue && p.transB == SimdFalse;
return p.constB == SimdTrue || p.typeB == SimdTensorData32f;
}

SynetInnerProduct16bGemmNN::SynetInnerProduct16bGemmNN(const InnerProductParam16b& p)
Expand All @@ -44,6 +96,13 @@ namespace Simd
, _gemm(0)
, _post(0)
{
if (p.typeB == SimdTensorData32f || p.constB)
{
if (p.transB)
_prepB = InnerProduct16bGemmNN_ConvertBt;
else
_prepB = InnerProduct16bGemmNN_ConvertBn;
}
}

String SynetInnerProduct16bGemmNN::Desc() const
Expand Down Expand Up @@ -86,26 +145,7 @@ namespace Simd
{
assert(weight);
_weight.Resize(a.aK * a.aN, true);
size_t N = DivHi(p.N, _alg.F);
uint16_t* dst = _weight.data;
for (size_t n = 0; n < N; n++)
{
for (size_t k = 0; k < a.aK; k += 2)
{
const float* src = weight + k * p.N + n * _alg.F;
for (size_t f = 0; f < _alg.F; ++f)
{
for (size_t i = 0; i < 2; ++i)
{
if (n * _alg.F + f < p.N && k + i < p.K)
*(dst++) = Float32ToBFloat16(src[i * p.N]);
else
*(dst++) = 0;
}
src++;
}
}
}
_prepB((uint8_t*)weight, p, a, p.N, p.K, _weight.data);
}
_bias.Resize(a.aN, true);
if (p.bias && bias)
Expand Down Expand Up @@ -134,7 +174,7 @@ namespace Simd
size_t offsC = _sizeC ? 0 : i * a.cN + j;
if (j == 0 && k == 0 && _prepA)
_prepA(A + i * p.K * a.eA, p, a, macroM, p.K, bufA + offsA);
if (i == 0 && _prepB)
if (i == 0 && _prepB && !p.constB)
_prepB(B + (k * p.N + j) * a.eB, p, a, macroN, macroK, bufB + offsB);
_gemm(bufA + offsA + k, p, a, macroM, macroN, macroK, (int)k, bufB + offsB, bufC + offsC);
if (k + macroK == p.K && _post)
Expand Down
54 changes: 49 additions & 5 deletions src/Simd/SimdSse41SynetInnerProduct16bGemmNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,18 @@ namespace Simd

//-----------------------------------------------------------------------------------------

static void InnerProduct16bGemmNN_ConvertA(const uint8_t* src8, const InnerProductParam16b& p, const AlgParam& a, size_t size, size_t K, uint16_t* dst)
static void InnerProduct16bGemmNN_ConvertA(const uint8_t* src8, const InnerProductParam16b& p, const AlgParam& a, size_t M, size_t K, uint16_t* dst)
{
const float* src = (float*)src8;
if (p.K == a.aK)
{
Float32ToBFloat16(src, K * size, dst);
Float32ToBFloat16(src, K * M, dst);
}
else
{
size_t KDF = Simd::AlignLo(p.K, DF);
size_t KF = Simd::AlignLo(p.K, F);
for (size_t i = 0; i < size; ++i)
for (size_t i = 0; i < M; ++i)
{
size_t k = 0;
for (; k < KDF; k += DF)
Expand All @@ -77,11 +77,11 @@ namespace Simd
}
}

static void InnerProduct16bGemmNN_ReorderA(const uint8_t* src8, const InnerProductParam16b& p, const AlgParam& a, size_t size, size_t K, uint16_t* dst)
static void InnerProduct16bGemmNN_ReorderA(const uint8_t* src8, const InnerProductParam16b& p, const AlgParam& a, size_t M, size_t K, uint16_t* dst)
{
const uint16_t* src = (uint16_t*)src8;
size_t KDF = Simd::AlignLo(p.K, DF);
for (size_t i = 0; i < size; ++i)
for (size_t i = 0; i < M; ++i)
{
size_t k = 0;
for (; k < KDF; k += DF)
Expand All @@ -97,6 +97,50 @@ namespace Simd

//-----------------------------------------------------------------------------------------

SIMD_INLINE void ConvertBn(const float* src, int stride, uint16_t* dst)
{
__m128i d0 = _mm_srli_epi32(_mm_add_epi32(_mm_castps_si128(_mm_loadu_ps(src + 0 * stride)), Bf16::ROUND), Base::Bf16::SHIFT);
__m128i d1 = _mm_or_si128(_mm_add_epi32(_mm_castps_si128(_mm_loadu_ps(src + 1 * stride)), Bf16::ROUND), Bf16::MASK);
_mm_storeu_si128((__m128i*)dst, _mm_or_si128(d0, d1));
}

//static void InnerProduct16bGemmNN_ConvertBn(const uint8_t* src8, const InnerProductParam16b& p, const AlgParam& a, size_t N, size_t K, uint16_t* dst)
//{
// const float* src = (float*)src8;
// if (p.K == a.aK)
// {
// Float32ToBFloat16(src, K * size, dst);
// }
// else
// {
// size_t KDF = Simd::AlignLo(p.K, DF);
// size_t KF = Simd::AlignLo(p.K, F);
// for (size_t i = 0; i < size; ++i)
// {
// size_t k = 0;
// for (; k < KDF; k += DF)
// {
// __m128i d0 = Float32ToBFloat16(_mm_loadu_ps(src + k + 0));
// __m128i d1 = Float32ToBFloat16(_mm_loadu_ps(src + k + F));
// _mm_storeu_si128((__m128i*)(dst + k), _mm_packus_epi32(d0, d1));
// }
// for (; k < KF; k += F)
// {
// __m128i d0 = Float32ToBFloat16(_mm_loadu_ps(src + k));
// _mm_storel_epi64((__m128i*)(dst + k), _mm_packus_epi32(d0, K_ZERO));
// }
// for (; k < p.K; ++k)
// dst[k] = Base::Float32ToBFloat16(src[k]);
// for (; k < a.aK; ++k)
// dst[k] = 0;
// src += p.K;
// dst += a.aK;
// }
// }
//}

//-----------------------------------------------------------------------------------------

template<int M> void InnerProduct16bGemmNN_2xM(const uint16_t* A0, const InnerProductParam16b& p, const AlgParam& a,
size_t N, size_t K, int update, const uint16_t* B0, float* C)
{
Expand Down
1 change: 1 addition & 0 deletions src/Simd/SimdSynetInnerProduct16bCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#define __SimdSynetInnerProduct16bCommon_h__

#include "Simd/SimdStore.h"
#include "Simd/SimdBFloat16.h"

namespace Simd
{
Expand Down
3 changes: 2 additions & 1 deletion src/Test/TestSynetInnerProduct16b.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ namespace Test
::SimdSynetInnerProduct32fForward(context3, Af.Data(), C3f.Data());
::SimdRelease(context3);

result = result && Compare(C1f, C3f, 0.038, true, 64, DifferenceBoth, " Compare to SynetInnerProduct32f.");//0.129
result = result && Compare(C1f, C3f, 0.039, true, 64, DifferenceBoth, " Compare to SynetInnerProduct32f.");
}

return result;
Expand Down Expand Up @@ -154,6 +154,7 @@ namespace Test
result = result && SynetInnerProduct16bForwardAutoTest(eps, Param(128, 128, 128, f32, f32, b16, t, f, t), f1, f2);
result = result && SynetInnerProduct16bForwardAutoTest(eps, Param(128, 128, 128, b16, b16, f32, t, t, f), f1, f2);
result = result && SynetInnerProduct16bForwardAutoTest(eps, Param(128, 128, 128, b16, b16, b16, f, t, f), f1, f2);
result = result && SynetInnerProduct16bForwardAutoTest(eps, Param(128, 128, 128, b16, b16, b16, t, f, t), f1, f2);
#endif
#else
result = result && SynetInnerProduct16bForwardAutoTest(eps, Param(128, 128, 128, b16, b16, b16, f, t, f), f1, f2);
Expand Down

0 comments on commit 2263558

Please sign in to comment.