Skip to content

Commit

Permalink
*improve AVX-512BW optimizations of function ConvolutionDirectNhwcCon…
Browse files Browse the repository at this point in the history
…volutionBiasActivationDepthwise.
  • Loading branch information
ermig1979 committed Oct 1, 2024
1 parent 76fcac5 commit 4660c1c
Show file tree
Hide file tree
Showing 2 changed files with 204 additions and 17 deletions.
9 changes: 9 additions & 0 deletions docs/2024.html
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ <h1>Simd Library Release Notes (2024).</h1>
<a href="2013.html">2013</a>
</center>

<a href="#HOME">Home</a>
<hr/>
<h3 id="R143">November X, 2024 (version X.X.143)</h3>
<h4>Algorithms</h4>
<h5>Improving</h5>
<ul>
<li>AVX-512BW optimizations of function ConvolutionDirectNhwcConvolutionBiasActivationDepthwise.</li>
</ul>

<a href="#HOME">Home</a>
<hr/>
<h3 id="R142">October 1, 2024 (version 6.1.142)</h3>
Expand Down
212 changes: 195 additions & 17 deletions src/Simd/SimdAvx512bwSynetConvolution32fDirectNhwc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -665,15 +665,190 @@ namespace Simd

template<::SimdConvolutionActivationType type> void ConvolutionDirectNhwcConvolutionBiasActivationDepthwise(const float * src, const ConvParam & p, const float * weight, const float * bias, const float * params, float * dst)
{
size_t srcW = p.srcW, strideX = p.strideX, dilationX = p.dilationX, kernelX = p.kernelY;
size_t dstC = p.dstC, dstCF = AlignLo(p.dstC, F), dstC2F = AlignLo(p.dstC, 2 * F), dstC4F = AlignLo(p.dstC, 4 * F);
size_t dstW2 = AlignLo(p.dstW, 2);
__m512 d00, d01, d02, d03, d10, d11, d12, d13, w0;
size_t dstW2 = AlignLo(p.dstW, 2), dstW4 = AlignLo(p.dstW, 4);
__m512 d00, d01, d02, d03, d10, d11, d12, d13, d20, d21, d22, d23, d30, d31, d32, d33, w0;
for (size_t dy = 0; dy < p.dstH; ++dy)
{
size_t dx = 0;
for (; dx < dstW4; dx += 4)
{
float* dst0 = dst + 0 * p.dstC, * dst1 = dst + 1 * p.dstC, * dst2 = dst + 2 * p.dstC, * dst3 = dst + 3 * p.dstC;
size_t sx0 = dx * p.strideX - p.padX;
size_t dc = 0;
for (; dc < dstC4F; dc += 4 * F)
{
if (bias)
{
d00 = _mm512_loadu_ps(bias + dc + 0 * F);
d01 = _mm512_loadu_ps(bias + dc + 1 * F);
d02 = _mm512_loadu_ps(bias + dc + 2 * F);
d03 = _mm512_loadu_ps(bias + dc + 3 * F);
}
else
{
d00 = _mm512_setzero_ps();
d01 = _mm512_setzero_ps();
d02 = _mm512_setzero_ps();
d03 = _mm512_setzero_ps();
}
d10 = d00; d11 = d01; d12 = d02; d13 = d03;
d20 = d00; d21 = d01; d22 = d02; d23 = d03;
d30 = d00; d31 = d01; d32 = d02; d33 = d03;
for (size_t ky = 0; ky < p.kernelY; ++ky)
{
size_t sy = dy * p.strideY + ky * p.dilationY - p.padY;
const float* psy = src + sy * p.srcW * dstC + dc;
const float* pwy = weight + ky * p.kernelX * dstC + dc;
if (sy < p.srcH)
{
for (size_t kx = 0; kx < kernelX; ++kx)
{
size_t sx = sx0 + kx * dilationX;
const float* pw = pwy + kx * dstC;
__mmask16 mask0 = sx + 0 * strideX < srcW ? 0xFFFF : 0x0000;
__mmask16 mask1 = sx + 1 * strideX < srcW ? 0xFFFF : 0x0000;
__mmask16 mask2 = sx + 2 * strideX < srcW ? 0xFFFF : 0x0000;
__mmask16 mask3 = sx + 3 * strideX < srcW ? 0xFFFF : 0x0000;
const float* ps0 = psy + sx * dstC, * ps1 = ps0 + 1 * dstC, * ps2 = ps0 + 2 * dstC, * ps3 = ps0 + 3 * dstC;

w0 = _mm512_loadu_ps(pw + 0 * F);
d00 = _mm512_mask3_fmadd_ps(_mm512_maskz_loadu_ps(mask0, ps0 + 0 * F), w0, d00, mask0);
d10 = _mm512_mask3_fmadd_ps(_mm512_maskz_loadu_ps(mask1, ps1 + 0 * F), w0, d10, mask1);
d20 = _mm512_mask3_fmadd_ps(_mm512_maskz_loadu_ps(mask2, ps2 + 0 * F), w0, d20, mask2);
d30 = _mm512_mask3_fmadd_ps(_mm512_maskz_loadu_ps(mask3, ps3 + 0 * F), w0, d30, mask3);
w0 = _mm512_loadu_ps(pw + 1 * F);
d01 = _mm512_mask3_fmadd_ps(_mm512_maskz_loadu_ps(mask0, ps0 + 1 * F), w0, d01, mask0);
d11 = _mm512_mask3_fmadd_ps(_mm512_maskz_loadu_ps(mask1, ps1 + 1 * F), w0, d11, mask1);
d21 = _mm512_mask3_fmadd_ps(_mm512_maskz_loadu_ps(mask2, ps2 + 1 * F), w0, d21, mask2);
d31 = _mm512_mask3_fmadd_ps(_mm512_maskz_loadu_ps(mask3, ps3 + 1 * F), w0, d31, mask3);
w0 = _mm512_loadu_ps(pw + 2 * F);
d02 = _mm512_mask3_fmadd_ps(_mm512_maskz_loadu_ps(mask0, ps0 + 2 * F), w0, d02, mask0);
d12 = _mm512_mask3_fmadd_ps(_mm512_maskz_loadu_ps(mask1, ps1 + 2 * F), w0, d12, mask1);
d22 = _mm512_mask3_fmadd_ps(_mm512_maskz_loadu_ps(mask2, ps2 + 2 * F), w0, d22, mask2);
d32 = _mm512_mask3_fmadd_ps(_mm512_maskz_loadu_ps(mask3, ps3 + 2 * F), w0, d32, mask3);
w0 = _mm512_loadu_ps(pw + 3 * F);
d03 = _mm512_mask3_fmadd_ps(_mm512_maskz_loadu_ps(mask0, ps0 + 3 * F), w0, d03, mask0);
d13 = _mm512_mask3_fmadd_ps(_mm512_maskz_loadu_ps(mask1, ps1 + 3 * F), w0, d13, mask1);
d23 = _mm512_mask3_fmadd_ps(_mm512_maskz_loadu_ps(mask2, ps2 + 3 * F), w0, d23, mask2);
d33 = _mm512_mask3_fmadd_ps(_mm512_maskz_loadu_ps(mask3, ps3 + 3 * F), w0, d33, mask3);
}
}
}
_mm512_storeu_ps(dst0 + dc + 0 * F, Activate<type>(d00, params, dc + 0 * F));
_mm512_storeu_ps(dst0 + dc + 1 * F, Activate<type>(d01, params, dc + 1 * F));
_mm512_storeu_ps(dst0 + dc + 2 * F, Activate<type>(d02, params, dc + 2 * F));
_mm512_storeu_ps(dst0 + dc + 3 * F, Activate<type>(d03, params, dc + 3 * F));
_mm512_storeu_ps(dst1 + dc + 0 * F, Activate<type>(d10, params, dc + 0 * F));
_mm512_storeu_ps(dst1 + dc + 1 * F, Activate<type>(d11, params, dc + 1 * F));
_mm512_storeu_ps(dst1 + dc + 2 * F, Activate<type>(d12, params, dc + 2 * F));
_mm512_storeu_ps(dst1 + dc + 3 * F, Activate<type>(d13, params, dc + 3 * F));
_mm512_storeu_ps(dst2 + dc + 0 * F, Activate<type>(d20, params, dc + 0 * F));
_mm512_storeu_ps(dst2 + dc + 1 * F, Activate<type>(d21, params, dc + 1 * F));
_mm512_storeu_ps(dst2 + dc + 2 * F, Activate<type>(d22, params, dc + 2 * F));
_mm512_storeu_ps(dst2 + dc + 3 * F, Activate<type>(d23, params, dc + 3 * F));
_mm512_storeu_ps(dst3 + dc + 0 * F, Activate<type>(d30, params, dc + 0 * F));
_mm512_storeu_ps(dst3 + dc + 1 * F, Activate<type>(d31, params, dc + 1 * F));
_mm512_storeu_ps(dst3 + dc + 2 * F, Activate<type>(d32, params, dc + 2 * F));
_mm512_storeu_ps(dst3 + dc + 3 * F, Activate<type>(d33, params, dc + 3 * F));
}
for (; dc < dstC2F; dc += 2 * F)
{
if (bias)
{
d00 = _mm512_loadu_ps(bias + dc + 0 * F);
d01 = _mm512_loadu_ps(bias + dc + 1 * F);
}
else
{
d00 = _mm512_setzero_ps();
d01 = _mm512_setzero_ps();
}
d10 = d00; d11 = d01;
d20 = d00; d21 = d01;
d30 = d00; d31 = d01;

for (size_t ky = 0; ky < p.kernelY; ++ky)
{
size_t sy = dy * p.strideY + ky * p.dilationY - p.padY;
const float* psy = src + sy * p.srcW * dstC + dc;
const float* pwy = weight + ky * p.kernelX * dstC + dc;
if (sy < p.srcH)
{
for (size_t kx = 0; kx < kernelX; ++kx)
{
size_t sx = sx0 + kx * dilationX;
const float* pw = pwy + kx * dstC;
__mmask16 mask0 = sx + 0 * strideX < srcW ? 0xFFFF : 0x0000;
__mmask16 mask1 = sx + 1 * strideX < srcW ? 0xFFFF : 0x0000;
__mmask16 mask2 = sx + 2 * strideX < srcW ? 0xFFFF : 0x0000;
__mmask16 mask3 = sx + 3 * strideX < srcW ? 0xFFFF : 0x0000;
const float* ps0 = psy + sx * dstC, * ps1 = ps0 + 1 * dstC, * ps2 = ps0 + 2 * dstC, * ps3 = ps0 + 3 * dstC;

w0 = _mm512_loadu_ps(pw + 0 * F);
d00 = _mm512_mask3_fmadd_ps(_mm512_maskz_loadu_ps(mask0, ps0 + 0 * F), w0, d00, mask0);
d10 = _mm512_mask3_fmadd_ps(_mm512_maskz_loadu_ps(mask1, ps1 + 0 * F), w0, d10, mask1);
d20 = _mm512_mask3_fmadd_ps(_mm512_maskz_loadu_ps(mask2, ps2 + 0 * F), w0, d20, mask2);
d30 = _mm512_mask3_fmadd_ps(_mm512_maskz_loadu_ps(mask3, ps3 + 0 * F), w0, d30, mask3);
w0 = _mm512_loadu_ps(pw + 1 * F);
d01 = _mm512_mask3_fmadd_ps(_mm512_maskz_loadu_ps(mask0, ps0 + 1 * F), w0, d01, mask0);
d11 = _mm512_mask3_fmadd_ps(_mm512_maskz_loadu_ps(mask1, ps1 + 1 * F), w0, d11, mask1);
d21 = _mm512_mask3_fmadd_ps(_mm512_maskz_loadu_ps(mask2, ps2 + 1 * F), w0, d21, mask2);
d31 = _mm512_mask3_fmadd_ps(_mm512_maskz_loadu_ps(mask3, ps3 + 1 * F), w0, d31, mask3);
}
}
}
_mm512_storeu_ps(dst0 + dc + 0 * F, Activate<type>(d00, params, dc + 0 * F));
_mm512_storeu_ps(dst0 + dc + 1 * F, Activate<type>(d01, params, dc + 1 * F));
_mm512_storeu_ps(dst1 + dc + 0 * F, Activate<type>(d10, params, dc + 0 * F));
_mm512_storeu_ps(dst1 + dc + 1 * F, Activate<type>(d11, params, dc + 1 * F));
_mm512_storeu_ps(dst2 + dc + 0 * F, Activate<type>(d20, params, dc + 0 * F));
_mm512_storeu_ps(dst2 + dc + 1 * F, Activate<type>(d21, params, dc + 1 * F));
_mm512_storeu_ps(dst3 + dc + 0 * F, Activate<type>(d30, params, dc + 0 * F));
_mm512_storeu_ps(dst3 + dc + 1 * F, Activate<type>(d31, params, dc + 1 * F));
}
for (; dc < dstC; dc += F)
{
__mmask16 tailC = dc < dstCF ? __mmask16(-1) : TailMask16(dstCF - dc);
d00 = bias ? _mm512_maskz_loadu_ps(tailC, bias + dc) : _mm512_setzero_ps();
d10 = d00; d20 = d00; d30 = d00;
for (size_t ky = 0; ky < p.kernelY; ++ky)
{
size_t sy = dy * p.strideY + ky * p.dilationY - p.padY;
const float* psy = src + sy * p.srcW * dstC + dc;
const float* pwy = weight + ky * p.kernelX * dstC + dc;
if (sy < p.srcH)
{
for (size_t kx = 0; kx < kernelX; ++kx)
{
size_t sx = sx0 + kx * dilationX;
const float* pw = pwy + kx * dstC;
__mmask16 mask0 = sx + 0 * strideX < srcW ? tailC : 0x0000;
__mmask16 mask1 = sx + 1 * strideX < srcW ? tailC : 0x0000;
__mmask16 mask2 = sx + 2 * strideX < srcW ? tailC : 0x0000;
__mmask16 mask3 = sx + 3 * strideX < srcW ? tailC : 0x0000;
const float* ps0 = psy + sx * dstC, * ps1 = ps0 + 1 * dstC, * ps2 = ps0 + 2 * dstC, * ps3 = ps0 + 3 * dstC;

w0 = _mm512_loadu_ps(pw + 0 * F);
d00 = _mm512_mask3_fmadd_ps(_mm512_maskz_loadu_ps(mask0, ps0 + 0 * F), w0, d00, mask0);
d10 = _mm512_mask3_fmadd_ps(_mm512_maskz_loadu_ps(mask1, ps1 + 0 * F), w0, d10, mask1);
d20 = _mm512_mask3_fmadd_ps(_mm512_maskz_loadu_ps(mask2, ps2 + 0 * F), w0, d20, mask2);
d30 = _mm512_mask3_fmadd_ps(_mm512_maskz_loadu_ps(mask3, ps3 + 0 * F), w0, d30, mask3);
}
}
}
_mm512_mask_storeu_ps(dst0 + dc, tailC, Activate<type>(d00, params, dc, tailC));
_mm512_mask_storeu_ps(dst1 + dc, tailC, Activate<type>(d10, params, dc, tailC));
_mm512_mask_storeu_ps(dst2 + dc, tailC, Activate<type>(d20, params, dc, tailC));
_mm512_mask_storeu_ps(dst3 + dc, tailC, Activate<type>(d30, params, dc, tailC));
}
dst += 4 * p.dstC;
}
for (; dx < dstW2; dx += 2)
{
float* dst0 = dst + 0 * p.dstC, *dst1 = dst + 1 * p.dstC;
size_t sx0 = dx * p.strideX - p.padX;
size_t dc = 0;
for (; dc < dstC4F; dc += 4 * F)
{
Expand All @@ -696,14 +871,15 @@ namespace Simd
{
size_t sy = dy * p.strideY + ky * p.dilationY - p.padY;
const float* psy = src + sy * p.srcW * dstC + dc;
const float* pwy = weight + ky * p.kernelX * dstC + dc;
if (sy < p.srcH)
{
for (size_t kx = 0; kx < p.kernelX; ++kx)
for (size_t kx = 0; kx < kernelX; ++kx)
{
size_t sx = dx * p.strideX + kx * p.dilationX - p.padX;
const float* pw = weight + (ky * p.kernelX + kx) * dstC + dc;
__mmask16 mask0 = sx + 0 * p.strideX < p.srcW ? 0xFFFF : 0x0000;
__mmask16 mask1 = sx + 1 * p.strideX < p.srcW ? 0xFFFF : 0x0000;
size_t sx = sx0 + kx * dilationX;
const float* pw = pwy + kx * dstC;
__mmask16 mask0 = sx + 0 * strideX < srcW ? 0xFFFF : 0x0000;
__mmask16 mask1 = sx + 1 * strideX < srcW ? 0xFFFF : 0x0000;
const float* ps0 = psy + sx * dstC, *ps1 = ps0 + dstC;

w0 = _mm512_loadu_ps(pw + 0 * F);
Expand Down Expand Up @@ -747,14 +923,15 @@ namespace Simd
{
size_t sy = dy * p.strideY + ky * p.dilationY - p.padY;
const float* psy = src + sy * p.srcW * dstC + dc;
const float* pwy = weight + ky * p.kernelX * dstC + dc;
if (sy < p.srcH)
{
for (size_t kx = 0; kx < p.kernelX; ++kx)
for (size_t kx = 0; kx < kernelX; ++kx)
{
size_t sx = dx * p.strideX + kx * p.dilationX - p.padX;
const float* pw = weight + (ky * p.kernelX + kx) * dstC + dc;
__mmask16 mask0 = sx + 0 * p.strideX < p.srcW ? 0xFFFF : 0x0000;
__mmask16 mask1 = sx + 1 * p.strideX < p.srcW ? 0xFFFF : 0x0000;
size_t sx = sx0 + kx * dilationX;
const float* pw = pwy + kx * dstC;
__mmask16 mask0 = sx + 0 * strideX < srcW ? 0xFFFF : 0x0000;
__mmask16 mask1 = sx + 1 * strideX < srcW ? 0xFFFF : 0x0000;
const float* ps0 = psy + sx * dstC, * ps1 = ps0 + dstC;

w0 = _mm512_loadu_ps(pw + 0 * F);
Expand All @@ -780,14 +957,15 @@ namespace Simd
{
size_t sy = dy * p.strideY + ky * p.dilationY - p.padY;
const float* psy = src + sy * p.srcW * dstC + dc;
const float* pwy = weight + ky * p.kernelX * dstC + dc;
if (sy < p.srcH)
{
for (size_t kx = 0; kx < p.kernelX; ++kx)
for (size_t kx = 0; kx < kernelX; ++kx)
{
size_t sx = dx * p.strideX + kx * p.dilationX - p.padX;
const float* pw = weight + (ky * p.kernelX + kx) * dstC + dc;
__mmask16 mask0 = sx + 0 * p.strideX < p.srcW ? tailC : 0x0000;
__mmask16 mask1 = sx + 1 * p.strideX < p.srcW ? tailC : 0x0000;
size_t sx = sx0 + kx * dilationX;
const float* pw = pwy + kx * dstC;
__mmask16 mask0 = sx + 0 * strideX < srcW ? tailC : 0x0000;
__mmask16 mask1 = sx + 1 * strideX < srcW ? tailC : 0x0000;
const float* ps0 = psy + sx * dstC, * ps1 = ps0 + dstC;

w0 = _mm512_maskz_loadu_ps(tailC, pw);
Expand Down

0 comments on commit 4660c1c

Please sign in to comment.