Skip to content

Commit

Permalink
+add macro SIMD_CONV16B_COMPATIBLE_WITH_AMX.
Browse files Browse the repository at this point in the history
  • Loading branch information
ermig1979 committed Oct 3, 2024
1 parent bda5b58 commit 9699a32
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 9 deletions.
136 changes: 136 additions & 0 deletions src/Simd/SimdAvx2SynetConvolution16bNhwcDirect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,132 @@ namespace Simd

//-----------------------------------------------------------------------------------------

static void Convert16bNhwcDirect2(const uint8_t* src8, const ConvParam& p, const AlgParam& a, size_t dyBeg, size_t dyEnd, uint16_t* dst)
{
assert(a.microC == DF * 2);
const float* src = (float*)src8;
size_t srcCDF = Simd::AlignLo(p.srcC, DF), srcCQF = Simd::AlignLo(p.srcC, QF), tailC0 = srcCDF - srcCQF, tailC1 = p.srcC - srcCDF;
size_t syPad = p.kernelY - 1 - p.padY, syBeg, syEnd = (dyEnd == p.dstH ? p.srcH : dyEnd + syPad);
size_t cD = a.batch * a.srcH * a.srcW, sD = a.microC;
if (dyBeg == 0)
{
for (size_t s = 0, n = p.padY * a.srcW; s < n; ++s)
for (size_t c = 0; c < a.srcC; c += a.microC)
Avx2::SetZero2(dst + c * cD + s * sD);
dst += p.padY * a.srcW * sD;
syBeg = 0;
}
else
{
syBeg = dyBeg + syPad;
src += syBeg * p.srcW * p.srcC;
dst += (dyBeg + p.kernelY - 1) * a.srcW * sD;
}
for (size_t sy = syBeg; sy < syEnd; ++sy)
{
if (p.padX)
{
for (size_t s = 0; s < p.padX; ++s)
for (size_t c = 0; c < a.srcC; c += a.microC)
Avx2::SetZero2(dst + c * cD + s * sD);
dst += p.padX * sD;
}
for (size_t sx = 0; sx < p.srcW; ++sx)
{
size_t sc = 0;
for (; sc < srcCQF; sc += QF)
{
Avx2::Float32ToBFloat16(src + sc + 0 * F, dst + sc * cD + 0 * F);
Avx2::Float32ToBFloat16(src + sc + 2 * F, dst + sc * cD + 2 * F);
}
if (tailC0)
Avx2::Float32ToBFloat16Tail(src + sc + 0 * F, tailC0, dst + sc * cD + 0 * F);
if (tailC1)
Avx2::Float32ToBFloat16Tail(src + sc + 2 * F, tailC1, dst + sc * cD + 2 * F);
src += p.srcC;
dst += sD;
}
if (p.padW)
{
for (size_t s = 0; s < p.padW; ++s)
for (size_t c = 0; c < a.srcC; c += a.microC)
Avx2::SetZero2(dst + c * cD + s * sD);
dst += p.padW * sD;
}
}
if (dyEnd == p.dstH)
{
for (size_t s = 0, n = p.padH * a.srcW; s < n; ++s)
for (size_t c = 0; c < a.srcC; c += a.microC)
Avx2::SetZero2(dst + c * cD + s * sD);
dst += p.padH * a.srcW * sD;
}
}

static void Reorder16bNhwcDirect2(const uint8_t* src8, const ConvParam& p, const AlgParam& a, size_t dyBeg, size_t dyEnd, uint16_t* dst)
{
assert(a.microC == DF * 2);
const uint16_t* src = (uint16_t*)src8;
size_t srcCDF = Simd::AlignLo(p.srcC, DF), srcCQF = Simd::AlignLo(p.srcC, QF), tailC0 = srcCDF - srcCQF, tailC1 = p.srcC - srcCDF;
size_t syPad = p.kernelY - 1 - p.padY, syBeg, syEnd = (dyEnd == p.dstH ? p.srcH : dyEnd + syPad);
size_t cD = a.batch * a.srcH * a.srcW, sD = a.microC;
if (dyBeg == 0)
{
for (size_t s = 0, n = p.padY * a.srcW; s < n; ++s)
for (size_t c = 0; c < a.srcC; c += a.microC)
Avx2::SetZero2(dst + c * cD + s * sD);
dst += p.padY * a.srcW * sD;
syBeg = 0;
}
else
{
syBeg = dyBeg + syPad;
src += syBeg * p.srcW * p.srcC;
dst += (dyBeg + p.kernelY - 1) * a.srcW * sD;
}
for (size_t sy = syBeg; sy < syEnd; ++sy)
{
if (p.padX)
{
for (size_t s = 0; s < p.padX; ++s)
for (size_t c = 0; c < a.srcC; c += a.microC)
Avx2::SetZero2(dst + c * cD + s * sD);
dst += p.padX * sD;
}
for (size_t sx = 0; sx < p.srcW; ++sx)
{
size_t sc = 0;
for (; sc < srcCQF; sc += QF)
{
Avx2::Copy(src + sc + 0 * F, dst + sc * cD + 0 * F);
Avx2::Copy(src + sc + 2 * F, dst + sc * cD + 2 * F);
}
if (tailC0)
Avx2::Copy(src + sc + 0 * F, tailC0, dst + sc * cD + 0 * F);
if (tailC1)
Avx2::Copy(src + sc + 2 * F, tailC1, dst + sc * cD + 2 * F);
src += p.srcC;
dst += sD;
}
if (p.padW)
{
for (size_t s = 0; s < p.padW; ++s)
for (size_t c = 0; c < a.srcC; c += a.microC)
Avx2::SetZero2(dst + c * cD + s * sD);
dst += p.padW * sD;
}
}
if (dyEnd == p.dstH)
{
for (size_t s = 0, n = p.padH * a.srcW; s < n; ++s)
for (size_t c = 0; c < a.srcC; c += a.microC)
Avx2::SetZero2(dst + c * cD + s * sD);
dst += p.padH * a.srcW * sD;
}
}

//-----------------------------------------------------------------------------------------

template<int M> void Convolution16bNhwcDirect_2xM(const uint16_t* src0, const ConvParam& p, const AlgParam& a, size_t srcC, size_t dstC, int zero, const uint16_t* weight0, float* dst)
{
__m256 d00, d01, d10, d11, d20, d21, d30, d31, d40, d41, s0, w00, w01, w10, w11, m = _mm256_castsi256_ps(Bf16::MASK);
Expand Down Expand Up @@ -422,14 +548,24 @@ namespace Simd

//-----------------------------------------------------------------------------------------

//#define SIMD_CONV16B_COMPATIBLE_WITH_AMX

SynetConvolution16bNhwcDirect::SynetConvolution16bNhwcDirect(const ConvParam & p)
: Sse41::SynetConvolution16bNhwcDirect(p)
{
#ifdef SIMD_CONV16B_COMPATIBLE_WITH_AMX
SetAlgParam(F, F * 2, 5, F * 4, Base::AlgCacheL1(), Base::AlgCacheL2(), Base::AlgCacheL3());
if (_src16b)
_preprocess = Reorder16bNhwcDirect2;
else
_preprocess = Convert16bNhwcDirect2;
#else
SetAlgParam(F, F * 2, 5, F * 2, Base::AlgCacheL1(), Base::AlgCacheL2(), Base::AlgCacheL3());
if (_src16b)
_preprocess = Reorder16bNhwcDirect;
else
_preprocess = Convert16bNhwcDirect;
#endif
_convolution = Convolution16bNhwcDirect_2;
switch (p.activation)
{
Expand Down
6 changes: 6 additions & 0 deletions src/Simd/SimdSet.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,12 @@ namespace Simd
{
_mm256_storeu_si256((__m256i*)dst, _mm256_setzero_si256());
}

SIMD_INLINE void SetZero2(uint16_t* dst)
{
_mm256_storeu_si256((__m256i*)dst + 0, _mm256_setzero_si256());
_mm256_storeu_si256((__m256i*)dst + 1, _mm256_setzero_si256());
}
}
#endif

Expand Down
14 changes: 5 additions & 9 deletions src/Test/TestSynetConvolution16b.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ namespace Test
buf8u1.Extend({ ::SimdSynetConvolution16bExternalBufferSize(context1) });
buf8u2.Extend({ ::SimdSynetConvolution16bExternalBufferSize(context2) });
Fill(buf8u1, uint8_t(1));
Fill(buf8u2, uint8_t(2));
Fill(buf8u2, uint8_t(-1));

::SimdSynetConvolution16bSetParams(context1, weight.Data(), bias.Data(), params.Data());
::SimdSynetConvolution16bSetParams(context2, weight.Data(), bias.Data(), params.Data());
Expand Down Expand Up @@ -272,17 +272,13 @@ namespace Test
result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 55, 15, 15, 55, _1, _1, _1, _0, _0, 1, aPr, tF, b16, b16), c, f1, f2);
#endif
#if 1
result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 608, 8, 8, 608, _1, _1, _1, _0, _0, 1, aPr, tT, b16, b16), c, f1, f2);
result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 608, 8, 8, 608, _1, _1, _1, _0, _0, 1, aPr, tT, f32, b16), c, f1, f2);
result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 48, 192, 256, 48, _3, _1, _1, _1, _1, 1, aRe, tT, f32, b16), c, f1, f2);
//result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 608, 8, 8, 608, _1, _1, _1, _0, _0, 1, aPr, tT, b16, b16), c, f1, f2);
//result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 608, 8, 8, 608, _1, _1, _1, _0, _0, 1, aPr, tT, f32, b16), c, f1, f2);
#endif

#else
result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 2156, 4, 4, 4, _1, _1, _1, _0, _0, 1, aId, tF, b16, b16), c, f1, f2);
//result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 64, 88, 88, 128, _3, _1, _2, _1, _1, 1, aRe, tT, f32, b16), c, f1, f2);
//result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 64, 88, 88, 128, _3, _1, _2, _1, _1, 1, aRe, tT, f32, b16), c, f1, f2);
//result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 64, 88, 88, 64, _3, _1, _1, _1, _1, 1, aSw, tT, b16, f32), c, f1, f2);
//result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 256, 44, 44, 256, _1, _1, _1, _0, _0, 1, aRe, tT, b16, f32), c, f1, f2);

result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 48, 192, 256, 48, _3, _1, _1, _1, _1, 1, aRe, tT, f32, b16), c, f1, f2);
#endif

return result;
Expand Down

0 comments on commit 9699a32

Please sign in to comment.