Skip to content

Commit

Permalink
*improve AVX-512BW optimizations of function ConvolutionDirectNhwcCon…
Browse files Browse the repository at this point in the history
…volutionBiasActivationDepthwise.
  • Loading branch information
ermig1979 committed Oct 1, 2024
1 parent b6711aa commit b047877
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 66 deletions.
6 changes: 4 additions & 2 deletions docs/2024.html
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ <h5>New features</h5>
<li>AMX-BF16 (AVX-512VBMI) optimizations of function DeinterleaveBgr.</li>
<li>AMX-BF16 (AVX-512VBMI) optimizations of function DeinterleaveBgra.</li>
</ul>

<h5>Improving</h5>
<ul>
<li>AVX-512BW optimizations of function ConvolutionDirectNhwcConvolutionBiasActivationDepthwise.</li>
</ul>
<h5>Removing</h5>
<ul>
<li>Base implementation, SSE4.1, AVX2, AVX-512BW, AMX-BF16 optimizations of class SynetConvolution32fBf16NhwcGemm.</li>
Expand All @@ -58,7 +61,6 @@ <h5>Removing</h5>
<li>Parameter 'compatibility' from function SynetMergedConvolution32fInit.</li>
</ul>


<h4>Test framework</h4>
<h5>New features</h5>
<ul>
Expand Down
53 changes: 22 additions & 31 deletions src/Simd/SimdAvx512bwSynetConvolution32fDirectNhwc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,6 @@ namespace Simd
for (size_t dy = 0; dy < p.dstH; ++dy)
{
size_t dx = 0;
#if 0
for (; dx < dstW2; dx += 2)
{
float* dst0 = dst + 0 * p.dstC, *dst1 = dst + 1 * p.dstC;
Expand All @@ -684,25 +683,19 @@ namespace Simd
d01 = _mm512_loadu_ps(bias + dc + 1 * F);
d02 = _mm512_loadu_ps(bias + dc + 2 * F);
d03 = _mm512_loadu_ps(bias + dc + 3 * F);
d10 = _mm512_loadu_ps(bias + dc + 0 * F);
d11 = _mm512_loadu_ps(bias + dc + 1 * F);
d12 = _mm512_loadu_ps(bias + dc + 2 * F);
d13 = _mm512_loadu_ps(bias + dc + 3 * F);
}
else
{
d00 = _mm512_setzero_ps();
d01 = _mm512_setzero_ps();
d02 = _mm512_setzero_ps();
d03 = _mm512_setzero_ps();
d10 = _mm512_setzero_ps();
d11 = _mm512_setzero_ps();
d12 = _mm512_setzero_ps();
d13 = _mm512_setzero_ps();
}
d10 = d00; d11 = d01; d12 = d02; d13 = d03;
for (size_t ky = 0; ky < p.kernelY; ++ky)
{
size_t sy = dy * p.strideY + ky * p.dilationY - p.padY;
const float* psy = src + sy * p.srcW * dstC + dc;
if (sy < p.srcH)
{
for (size_t kx = 0; kx < p.kernelX; ++kx)
Expand All @@ -711,20 +704,20 @@ namespace Simd
const float* pw = weight + (ky * p.kernelX + kx) * dstC + dc;
__mmask16 mask0 = sx + 0 * p.strideX < p.srcW ? 0xFFFF : 0x0000;
__mmask16 mask1 = sx + 1 * p.strideX < p.srcW ? 0xFFFF : 0x0000;
const float* ps0 = src + (sy * p.srcW + sx + 0) * dstC + dc, *ps1 = ps0 + dstC;
const float* ps0 = psy + sx * dstC, *ps1 = ps0 + dstC;

w0 = _mm512_loadu_ps(pw + 0 * F);
d00 = _mm512_mask_fmadd_ps(d00, mask0, _mm512_maskz_loadu_ps(mask0, ps0 + 0 * F), w0);
d10 = _mm512_mask_fmadd_ps(d10, mask1, _mm512_maskz_loadu_ps(mask1, ps1 + 0 * F), w0);
d00 = _mm512_mask3_fmadd_ps(_mm512_maskz_loadu_ps(mask0, ps0 + 0 * F), w0, d00, mask0);
d10 = _mm512_mask3_fmadd_ps(_mm512_maskz_loadu_ps(mask1, ps1 + 0 * F), w0, d10, mask1);
w0 = _mm512_loadu_ps(pw + 1 * F);
d01 = _mm512_mask_fmadd_ps(d01, mask0, _mm512_maskz_loadu_ps(mask0, ps0 + 1 * F), w0);
d11 = _mm512_mask_fmadd_ps(d11, mask1, _mm512_maskz_loadu_ps(mask1, ps1 + 1 * F), w0);
d01 = _mm512_mask3_fmadd_ps(_mm512_maskz_loadu_ps(mask0, ps0 + 1 * F), w0, d01, mask0);
d11 = _mm512_mask3_fmadd_ps(_mm512_maskz_loadu_ps(mask1, ps1 + 1 * F), w0, d11, mask1);
w0 = _mm512_loadu_ps(pw + 2 * F);
d02 = _mm512_mask_fmadd_ps(d02, mask0, _mm512_maskz_loadu_ps(mask0, ps0 + 2 * F), w0);
d12 = _mm512_mask_fmadd_ps(d12, mask1, _mm512_maskz_loadu_ps(mask1, ps1 + 2 * F), w0);
d02 = _mm512_mask3_fmadd_ps(_mm512_maskz_loadu_ps(mask0, ps0 + 2 * F), w0, d02, mask0);
d12 = _mm512_mask3_fmadd_ps(_mm512_maskz_loadu_ps(mask1, ps1 + 2 * F), w0, d12, mask1);
w0 = _mm512_loadu_ps(pw + 3 * F);
d03 = _mm512_mask_fmadd_ps(d03, mask0, _mm512_maskz_loadu_ps(mask0, ps0 + 0 * F), w0);
d13 = _mm512_mask_fmadd_ps(d13, mask1, _mm512_maskz_loadu_ps(mask1, ps1 + 0 * F), w0);
d03 = _mm512_mask3_fmadd_ps(_mm512_maskz_loadu_ps(mask0, ps0 + 3 * F), w0, d03, mask0);
d13 = _mm512_mask3_fmadd_ps(_mm512_maskz_loadu_ps(mask1, ps1 + 3 * F), w0, d13, mask1);
}
}
}
Expand All @@ -743,19 +736,17 @@ namespace Simd
{
d00 = _mm512_loadu_ps(bias + dc + 0 * F);
d01 = _mm512_loadu_ps(bias + dc + 1 * F);
d10 = _mm512_loadu_ps(bias + dc + 0 * F);
d11 = _mm512_loadu_ps(bias + dc + 1 * F);
}
else
{
d00 = _mm512_setzero_ps();
d01 = _mm512_setzero_ps();
d10 = _mm512_setzero_ps();
d11 = _mm512_setzero_ps();
}
d10 = d00; d11 = d01;
for (size_t ky = 0; ky < p.kernelY; ++ky)
{
size_t sy = dy * p.strideY + ky * p.dilationY - p.padY;
const float* psy = src + sy * p.srcW * dstC + dc;
if (sy < p.srcH)
{
for (size_t kx = 0; kx < p.kernelX; ++kx)
Expand All @@ -764,14 +755,14 @@ namespace Simd
const float* pw = weight + (ky * p.kernelX + kx) * dstC + dc;
__mmask16 mask0 = sx + 0 * p.strideX < p.srcW ? 0xFFFF : 0x0000;
__mmask16 mask1 = sx + 1 * p.strideX < p.srcW ? 0xFFFF : 0x0000;
const float* ps0 = src + (sy * p.srcW + sx + 0) * dstC + dc, * ps1 = ps0 + dstC;
const float* ps0 = psy + sx * dstC, * ps1 = ps0 + dstC;

w0 = _mm512_loadu_ps(pw + 0 * F);
d00 = _mm512_mask_fmadd_ps(d00, mask0, _mm512_maskz_loadu_ps(mask0, ps0 + 0 * F), w0);
d10 = _mm512_mask_fmadd_ps(d10, mask1, _mm512_maskz_loadu_ps(mask1, ps1 + 0 * F), w0);
d00 = _mm512_mask3_fmadd_ps(_mm512_maskz_loadu_ps(mask0, ps0 + 0 * F), w0, d00, mask0);
d10 = _mm512_mask3_fmadd_ps(_mm512_maskz_loadu_ps(mask1, ps1 + 0 * F), w0, d10, mask1);
w0 = _mm512_loadu_ps(pw + 1 * F);
d01 = _mm512_mask_fmadd_ps(d01, mask0, _mm512_maskz_loadu_ps(mask0, ps0 + 1 * F), w0);
d11 = _mm512_mask_fmadd_ps(d11, mask1, _mm512_maskz_loadu_ps(mask1, ps1 + 1 * F), w0);
d01 = _mm512_mask3_fmadd_ps(_mm512_maskz_loadu_ps(mask0, ps0 + 1 * F), w0, d01, mask0);
d11 = _mm512_mask3_fmadd_ps(_mm512_maskz_loadu_ps(mask1, ps1 + 1 * F), w0, d11, mask1);
}
}
}
Expand All @@ -788,6 +779,7 @@ namespace Simd
for (size_t ky = 0; ky < p.kernelY; ++ky)
{
size_t sy = dy * p.strideY + ky * p.dilationY - p.padY;
const float* psy = src + sy * p.srcW * dstC + dc;
if (sy < p.srcH)
{
for (size_t kx = 0; kx < p.kernelX; ++kx)
Expand All @@ -796,11 +788,11 @@ namespace Simd
const float* pw = weight + (ky * p.kernelX + kx) * dstC + dc;
__mmask16 mask0 = sx + 0 * p.strideX < p.srcW ? tailC : 0x0000;
__mmask16 mask1 = sx + 1 * p.strideX < p.srcW ? tailC : 0x0000;
const float* ps0 = src + (sy * p.srcW + sx + 0) * dstC + dc, * ps1 = ps0 + dstC;
const float* ps0 = psy + sx * dstC, * ps1 = ps0 + dstC;

w0 = _mm512_maskz_loadu_ps(tailC, pw);
d00 = _mm512_mask_fmadd_ps(d00, mask0, _mm512_maskz_loadu_ps(mask0, ps0 + 0 * F), w0);
d10 = _mm512_mask_fmadd_ps(d10, mask1, _mm512_maskz_loadu_ps(mask1, ps1 + 0 * F), w0);
d00 = _mm512_mask3_fmadd_ps(_mm512_maskz_loadu_ps(mask0, ps0 + 0 * F), w0, d00, mask0);
d10 = _mm512_mask3_fmadd_ps(_mm512_maskz_loadu_ps(mask1, ps1 + 0 * F), w0, d10, mask1);
}
}
}
Expand All @@ -809,7 +801,6 @@ namespace Simd
}
dst += 2 * p.dstC;
}
#endif
for (; dx < p.dstW; ++dx)
{
size_t dc = 0;
Expand Down
33 changes: 0 additions & 33 deletions src/Test/TestSynetConvolution32f.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,44 +276,11 @@ namespace Test
return result;
}

static float Sum(const float* src, size_t size, __mmask16 mask)
{
__m512 sum = _mm512_setzero_ps();
for (size_t i = 0; i < size; i += 16)
sum = _mm512_mask_add_ps(sum, mask, sum, _mm512_maskz_load_ps(mask, src + i));
float dst[16] = { 0 };
_mm512_storeu_ps(dst, sum);
return dst[0];
}

void ZeroLoadTest()
{
size_t n = 1024 * 64 * 16;
std::vector<float> src(n, 0);
float dst[16] = { 0 };
for (size_t j = 0; j < 100; j++)
{
{
TEST_PERFORMANCE_TEST("zero_mask_load_add")
dst[0] += Sum(src.data(), n, 0);
}
{
TEST_PERFORMANCE_TEST("full_mask_load")
dst[0] += Sum(src.data(), n, -1);
}
}
std::cout << dst[0];
std::cout << Test::PerformanceMeasurerStorage::s_storage.ConsoleReport(false, true);
exit(0);
}

bool SynetConvolution32fForwardAutoTest()
{
const float EPS = 0.001f;
bool result = true;

//ZeroLoadTest();

if(TestBase())
result = result && SynetConvolution32fForwardAutoTest(2 * EPS, FUNC_C(Simd::Base::SynetConvolution32fInit), FUNC_C(SimdSynetConvolution32fInit));

Expand Down

0 comments on commit b047877

Please sign in to comment.