Skip to content

Commit

Permalink
*fix bug: case of unaligned input size for BF16.
Browse files Browse the repository at this point in the history
  • Loading branch information
ermig1979 committed May 31, 2024
1 parent 95ff4fc commit 1a6db0d
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 15 deletions.
39 changes: 37 additions & 2 deletions src/Simd/SimdAmxBf16SynetMergedConvolution16b.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "Simd/SimdBFloat16.h"
#include "Simd/SimdAmxBf16.h"
#include "Simd/SimdCpu.h"
#include "Simd/SimdCopy.h"

namespace Simd
{
Expand Down Expand Up @@ -82,6 +83,30 @@ namespace Simd
}
}

static void ReorderBf16(const uint8_t* src8, const ConvParam& p, const AlgParam& a, size_t yBeg, size_t yEnd, uint16_t* dst)
{
const uint16_t* src = (uint16_t*)src8;
size_t bufH = a.bufH[0], mask = bufH - 1;
size_t srcC = AlignHi(p.srcC, a.miK);
size_t srcCDF = Simd::AlignLo(p.srcC, DF);
__mmask32 tailC = TailMask32(p.srcC - srcCDF);
for (size_t y = yBeg; y < yEnd; ++y)
{
const uint16_t* ps = src + y * p.srcW * p.srcC;
uint16_t* pd = dst + (y & mask) * p.srcW * srcC;
for (size_t x = 0; x < p.srcW; ++x)
{
size_t c = 0;
for (; c < srcCDF; c += DF)
Avx512bw::Copy(ps + c, pd + c);
if (tailC)
Avx512bw::Copy(ps + c, pd + c, tailC);
ps += p.srcC;
pd += srcC;
}
}
}

//-------------------------------------------------------------------------------------------------

SynetMergedConvolution16bCdc::SynetMergedConvolution16bCdc(const MergConvParam& p)
Expand All @@ -90,7 +115,12 @@ namespace Simd
if (p.conv[2].dstC > HF)
{
SetSize(Avx512bw::F, Avx512bw::DF);
_convert = ConvertFp32ToBf16;
if (!_src16b)
_convert = ConvertFp32ToBf16;
else if (!Aligned(p.conv[0].srcC, Avx512bw::DF))
_convert = ReorderBf16;
else
_convert = NULL;
if (_param.conv[0].Is1x1())
SetInput(_param.conv[0], _input);
else
Expand All @@ -108,7 +138,12 @@ namespace Simd
if (p.conv[1].dstC > HF)
{
SetSize(Avx512bw::F, Avx512bw::DF);
_convert = ConvertFp32ToBf16;
if (!_src16b)
_convert = ConvertFp32ToBf16;
else if (!Aligned(p.conv[0].srcC, Avx512bw::DF))
_convert = ReorderBf16;
else
_convert = NULL;
if (_param.conv[0].Is1x1())
SetInput(_param.conv[0], _input);
else
Expand Down
36 changes: 35 additions & 1 deletion src/Simd/SimdAvx2SynetMergedConvolution16b.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "Simd/SimdUpdate.h"
#include "Simd/SimdAvx2.h"
#include "Simd/SimdCpu.h"
#include "Simd/SimdCopy.h"

namespace Simd
{
Expand Down Expand Up @@ -91,14 +92,43 @@ namespace Simd
}
}

static void ReorderBf16(const uint8_t* src8, const ConvParam& p, const AlgParam& a, size_t yBeg, size_t yEnd, uint16_t* dst)
{
const uint16_t* src = (uint16_t*)src8;
size_t srcC = AlignHi(p.srcC, a.miK);
size_t bufH = a.bufH[0], mask = bufH - 1;
size_t srcCDF = Simd::AlignLo(p.srcC, DF);
for (size_t y = yBeg; y < yEnd; ++y)
{
const uint16_t* ps = src + y * p.srcW * p.srcC;
uint16_t* pd = dst + (y & mask) * p.srcW * srcC;
for (size_t x = 0; x < p.srcW; ++x)
{
size_t c = 0;
for (; c < srcCDF; c += DF)
Copy(ps + c, pd + c);
for (; c < p.srcC; ++c)
pd[c] = ps[c];
for (; c < srcC; ++c)
pd[c] = 0;
ps += p.srcC;
pd += srcC;
}
}
}

//-----------------------------------------------------------------------------------------

SynetMergedConvolution16bCdc::SynetMergedConvolution16bCdc(const MergConvParam& p)
: Sse41::SynetMergedConvolution16bCdc(p)
{
SetSize(F, 2);
if(!_src16b)
if (!_src16b)
_convert = ConvertFp32ToBf16;
else if (!Aligned(p.conv[0].srcC, 2))
_convert = ReorderBf16;
else
_convert = NULL;
SetInput(_param.conv[0], _input);
SetDepthwise(_param.conv[1], _depthwise);
SetOutput(_param.conv[2], _output);
Expand All @@ -112,6 +142,10 @@ namespace Simd
SetSize(F, 2);
if (!_src16b)
_convert = ConvertFp32ToBf16;
else if (!Aligned(p.conv[0].srcC, 2))
_convert = ReorderBf16;
else
_convert = NULL;
SetInput(_param.conv[0], _input);
SetDepthwise(_param.conv[1], _depthwise);
}
Expand Down
34 changes: 33 additions & 1 deletion src/Simd/SimdAvx512bwSynetMergedConvolution16b.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
#include "Simd/SimdUpdate.h"
#include "Simd/SimdAvx512bw.h"
#include "Simd/SimdCpu.h"
#include "Simd/SimdCopy.h"
#include "Simd/SimdSet.h"

namespace Simd
{
Expand Down Expand Up @@ -82,14 +84,40 @@ namespace Simd
}
}

static void ReorderBf16(const uint8_t* src8, const ConvParam& p, const AlgParam& a, size_t yBeg, size_t yEnd, uint16_t* dst)
{
const uint16_t* src = (uint16_t*)src8;
size_t srcC = AlignHi(p.srcC, a.miK);
size_t bufH = a.bufH[0], mask = bufH - 1;
size_t srcCDF = Simd::AlignLo(p.srcC, DF);
__mmask32 srcMask = TailMask32(p.srcC - srcCDF), gapMask = TailMask32(srcC - p.srcC);
for (size_t y = yBeg; y < yEnd; ++y)
{
const uint16_t* ps = src + y * p.srcW * p.srcC;
uint16_t* pd = dst + (y & mask) * p.srcW * srcC;
for (size_t x = 0; x < p.srcW; ++x)
{
Copy(ps, srcCDF, srcMask, pd);
if (gapMask)
SetZero(pd + p.srcC, gapMask);
ps += p.srcC;
pd += srcC;
}
}
}

//-----------------------------------------------------------------------------------------

SynetMergedConvolution16bCdc::SynetMergedConvolution16bCdc(const MergConvParam& p)
: Avx2::SynetMergedConvolution16bCdc(p)
{
SetSize(F, 2);
if(!_src16b)
if (!_src16b)
_convert = ConvertFp32ToBf16;
else if (!Aligned(p.conv[0].srcC, 2))
_convert = ReorderBf16;
else
_convert = NULL;
SetInput(_param.conv[0], _input);
SetDepthwise(_param.conv[1], _depthwise);
SetOutput(_param.conv[2], _output);
Expand All @@ -103,6 +131,10 @@ namespace Simd
SetSize(F, 2);
if (!_src16b)
_convert = ConvertFp32ToBf16;
else if (!Aligned(p.conv[0].srcC, 2))
_convert = ReorderBf16;
else
_convert = NULL;
SetInput(_param.conv[0], _input);
SetDepthwise(_param.conv[1], _depthwise);
}
Expand Down
12 changes: 6 additions & 6 deletions src/Simd/SimdBaseSynetMergedConvolution16b.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -636,9 +636,9 @@ namespace Simd
size_t yEnd2 = Simd::RestrictRange(yBeg2 + a.yStep[2], a.yStart[2], c1.dstH);
size_t yEnd1 = Simd::RestrictRange(yBeg1 + a.yStep[1], a.yStart[1], c1.srcH);
size_t yEnd0 = Simd::RestrictRange(yBeg0 + a.yStep[0], a.yStart[0], c0.srcH);
if (!_src16b)
if (_convert)
_convert(src, c0, a, yBeg0, yEnd0, buf0);
const uint16_t* src16b = _src16b ? (uint16_t*)src : buf0;
const uint16_t* src16b = _convert ? buf0 : (uint16_t*)src;
_input(src16b, c0, a, maC, yBeg1, yEnd1, _weightI.data + c * a.dw[0],
_bias[0].data + c, _params[0].data + c * a.dp[0], buf1);
_depthwise((uint8_t*)buf1, c1, a, maC, yBeg2, yEnd2, _weightD.data + c * a.dw[1],
Expand Down Expand Up @@ -700,7 +700,7 @@ namespace Simd
a.yStart[0] = Simd::Min(a.yStart[1], c0.srcH);
a.bufH[0] = Pow2Hi(Simd::Max(a.yStep[1], a.yStart[0]));

_sizeB[0] = _src16b ? 0 : a.bufH[0] * p.conv[0].srcW * AlignHi(p.conv[0].srcC, a.miK);
_sizeB[0] = _src16b && Aligned(c0.srcC, a.miK) ? 0 : a.bufH[0] * p.conv[0].srcW * AlignHi(p.conv[0].srcC, a.miK);
_sizeB[1] = a.bufH[1] * p.conv[1].srcW * a.maC;
_sizeB[2] = a.bufH[2] * p.conv[1].dstW * a.maC;
if (_sizeB[0] * 2 + _sizeB[1] * 4 + _sizeB[2] * 2 <= L2)
Expand Down Expand Up @@ -746,9 +746,9 @@ namespace Simd
size_t yEnd2 = Simd::RestrictRange(yBeg2 + a.yStep[2], a.yStart[2], c1.dstH);
size_t yEnd1 = Simd::RestrictRange(yBeg1 + a.yStep[1], a.yStart[1], c1.srcH);
size_t yEnd0 = Simd::RestrictRange(yBeg0 + a.yStep[0], a.yStart[0], c0.srcH);
if(!_src16b)
if(_convert)
_convert(src, c0, a, yBeg0, yEnd0, buf0);
const uint16_t* src16b = _src16b ? (uint16_t*)src : buf0;
const uint16_t* src16b = _convert ? buf0 : (uint16_t*)src;
_input(src16b, c0, a, maC, yBeg1, yEnd1, _weightI.data + c * a.dw[0],
_bias[0].data + c, _params[0].data + c * a.dp[0], buf1);
_depthwise((uint8_t*)buf1, c1, a, maC, yBeg2, yEnd2, _weightD.data + c * a.dw[1],
Expand Down Expand Up @@ -802,7 +802,7 @@ namespace Simd
a.yStart[0] = Simd::Min(a.yStart[1], c0.srcH);
a.bufH[0] = Pow2Hi(Simd::Max(a.yStep[1], a.yStart[0]));

_sizeB[0] = _src16b ? 0 : a.bufH[0] * p.conv[0].srcW * AlignHi(p.conv[0].srcC, a.miK);
_sizeB[0] = _src16b && Aligned(c0.srcC, a.miK) ? 0 : a.bufH[0] * p.conv[0].srcW * AlignHi(p.conv[0].srcC, a.miK);
_sizeB[1] = a.bufH[1] * p.conv[1].srcW * a.maC;
if (_sizeB[0] * 2 + _sizeB[1] * 4 <= L2)
break;
Expand Down
36 changes: 35 additions & 1 deletion src/Simd/SimdSse41SynetMergedConvolution16b.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "Simd/SimdUpdate.h"
#include "Simd/SimdSse41.h"
#include "Simd/SimdCpu.h"
#include "Simd/SimdCopy.h"

namespace Simd
{
Expand Down Expand Up @@ -84,14 +85,43 @@ namespace Simd
}
}

static void ReorderBf16(const uint8_t* src8, const ConvParam& p, const AlgParam& a, size_t yBeg, size_t yEnd, uint16_t* dst)
{
const uint16_t* src = (uint16_t*)src8;
size_t srcC = AlignHi(p.srcC, a.miK);
size_t bufH = a.bufH[0], mask = bufH - 1;
size_t srcCDF = Simd::AlignLo(p.srcC, DF);
for (size_t y = yBeg; y < yEnd; ++y)
{
const uint16_t* ps = src + y * p.srcW * p.srcC;
uint16_t* pd = dst + (y & mask) * p.srcW * srcC;
for (size_t x = 0; x < p.srcW; ++x)
{
size_t c = 0;
for (; c < srcCDF; c += DF)
Copy(ps + c, pd + c);
for (; c < p.srcC; ++c)
pd[c] = ps[c];
for (; c < srcC; ++c)
pd[c] = 0;
ps += p.srcC;
pd += srcC;
}
}
}

//-----------------------------------------------------------------------------------------

SynetMergedConvolution16bCdc::SynetMergedConvolution16bCdc(const MergConvParam& p)
: Base::SynetMergedConvolution16bCdc(p)
{
SetSize(F, 2);
if(!_src16b)
if (!_src16b)
_convert = ConvertFp32ToBf16;
else if (!Aligned(p.conv[0].srcC, 2))
_convert = ReorderBf16;
else
_convert = NULL;
SetInput(_param.conv[0], _input);
SetDepthwise(_param.conv[1], _depthwise);
SetOutput(_param.conv[2], _output);
Expand All @@ -105,6 +135,10 @@ namespace Simd
SetSize(F, 2);
if (!_src16b)
_convert = ConvertFp32ToBf16;
else if (!Aligned(p.conv[0].srcC, 2))
_convert = ReorderBf16;
else
_convert = NULL;
SetInput(_param.conv[0], _input);
SetDepthwise(_param.conv[1], _depthwise);
}
Expand Down
9 changes: 5 additions & 4 deletions src/Test/TestSynetMergedConvolution16b.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ namespace Test
#if defined(NDEBUG)
#if 1
result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 555, 40, 23), Cnv(a1, 1, 1, 256), Cnv(a0, 3, 1), f32, b16, c), f1, f2);
//result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 555, 40, 23), Cnv(a1, 1, 1, 256), Cnv(a0, 3, 1), b16, b16, c), f1, f2);
result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 555, 40, 23), Cnv(a1, 1, 1, 256), Cnv(a0, 3, 1), b16, b16, c), f1, f2);
result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 555, 40, 23), Cnv(a0, 3, 2), Cnv(a1, 1, 1, 1555), f32, f32, c), f1, f2);
result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 1024, 8, 6), Cnv(a0, 1, 1, 1548), Cnv(a1, 3, 1), f32, f32, c), f1, f2);
result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 256, 10, 6), Cnv(a0, 1, 1, 64), Cnv(a1, 3, 2), Cnv(a2, 1, 1, 256), f32, f32, c), f1, f2);
Expand All @@ -275,11 +275,12 @@ namespace Test
result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 256, 10, 6), Cnv(a0, 1, 1, 64), Cnv(a1, 3, 2), Cnv(a2, 1, 1, 256), b16, b16, c), f1, f2);
#endif
#else
result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 555, 40, 23), Cnv(a0, 3, 2), Cnv(a1, 1, 1, 1555), f32, f32, c), f1, f2);
result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 1024, 8, 6), Cnv(a0, 1, 1, 1548), Cnv(a1, 3, 1), f32, f32, c), f1, f2);
result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 555, 40, 23), Cnv(a1, 1, 1, 256), Cnv(a0, 3, 1), b16, b16, c), f1, f2);
//result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 555, 40, 23), Cnv(a0, 3, 2), Cnv(a1, 1, 1, 1555), f32, f32, c), f1, f2);
//result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 1024, 8, 6), Cnv(a0, 1, 1, 1548), Cnv(a1, 3, 1), f32, f32, c), f1, f2);
//result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 256, 10, 6), Cnv(a0, 1, 1, 64), Cnv(a1, 3, 2), Cnv(a2, 1, 1, 256), f32, f32, c), f1, f2);
//result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 64, 40, 23), Cnv(a0, 3, 2), Cnv(a1, 1, 1, 128), b16, b16, c), f1, f2);
result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 1024, 8, 6), Cnv(a0, 1, 1, 1548), Cnv(a1, 3, 1), b16, b16, c), f1, f2);
//result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 1024, 8, 6), Cnv(a0, 1, 1, 1548), Cnv(a1, 3, 1), b16, b16, c), f1, f2);
//result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 256, 10, 6), Cnv(a0, 1, 1, 64), Cnv(a1, 3, 2), Cnv(a2, 1, 1, 256), b16, b16, c), f1, f2);
#endif
return result;
Expand Down

0 comments on commit 1a6db0d

Please sign in to comment.