Skip to content

Commit

Permalink
+add SSE4.1 optimizations of class SynetDeconvolution16bNhwcGemm.
Browse files Browse the repository at this point in the history
  • Loading branch information
ermig1979 committed Sep 10, 2024
1 parent efdefd3 commit 645dfd7
Show file tree
Hide file tree
Showing 8 changed files with 600 additions and 22 deletions.
2 changes: 1 addition & 1 deletion docs/2024.html
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ <h4>Algorithms</h4>
<h5>New features</h5>
<ul>
<li>Base implementation of class SynetDeconvolution16bGemm.</li>
<li>Base implementation of class SynetDeconvolution16bNhwcGemm.</li>
<li>Base implementation, SSE4.1 optimizations of class SynetDeconvolution16bNhwcGemm.</li>
</ul>

<h4>Test framework</h4>
Expand Down
2 changes: 2 additions & 0 deletions prj/vs2022/Sse41.vcxproj
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@
<ClCompile Include="..\..\src\Simd\SimdSse41SynetConvolution8iNhwcDirect.cpp" />
<ClCompile Include="..\..\src\Simd\SimdSse41SynetConvolution8iNhwcDirect1x1.cpp" />
<ClCompile Include="..\..\src\Simd\SimdSse41SynetConvolution8iNhwcDirectAny.cpp" />
<ClCompile Include="..\..\src\Simd\SimdSse41SynetDeconvolution16b.cpp" />
<ClCompile Include="..\..\src\Simd\SimdSse41SynetDeconvolution16bNhwcGemm.cpp" />
<ClCompile Include="..\..\src\Simd\SimdSse41SynetDeconvolution32f.cpp" />
<ClCompile Include="..\..\src\Simd\SimdSse41SynetGridSample.cpp" />
<ClCompile Include="..\..\src\Simd\SimdSse41SynetGridSample2d32fBlZ.cpp" />
Expand Down
6 changes: 6 additions & 0 deletions prj/vs2022/Sse41.vcxproj.filters
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,12 @@
<ClCompile Include="..\..\src\Simd\SimdSse41SynetConvolution16bNchwGemm.cpp">
<Filter>Sse41</Filter>
</ClCompile>
<ClCompile Include="..\..\src\Simd\SimdSse41SynetDeconvolution16b.cpp">
<Filter>Sse41</Filter>
</ClCompile>
<ClCompile Include="..\..\src\Simd\SimdSse41SynetDeconvolution16bNhwcGemm.cpp">
<Filter>Sse41</Filter>
</ClCompile>
</ItemGroup>
<ItemGroup>
<Filter Include="Sse41">
Expand Down
8 changes: 4 additions & 4 deletions src/Simd/SimdBaseSynetDeconvolution16b.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,14 +210,14 @@ namespace Simd
assert(p.group == 1);
GemmNN(_M, _N, _K, src16b, _ldS, wgt, _ldW, buf32f, _ldD);
if (!_is1x1)
ImgToRow(buf32f, dst32f);
RowToImg(buf32f, dst32f);
}
else
{
for (size_t g = 0; g < p.group; ++g)
GemmNN(_M, _N, _K, wgt + _grW * g, _ldW, src16b + _grS * g, _ldS, buf32f + _grD * g, _ldD);
if (!_is1x1)
ImgToCol(buf32f, dst32f);
ColToImg(buf32f, dst32f);
}
ConvolutionBiasAndActivation(_bias.data, p.dstC, p.dstH * p.dstW, p.activation, _params.data, p.trans, dst32f);
if (_dst16b)
Expand All @@ -227,7 +227,7 @@ namespace Simd
}
}

void SynetDeconvolution16bGemm::ImgToCol(const float* src, float* dst)
void SynetDeconvolution16bGemm::ColToImg(const float* src, float* dst)
{
const DeconvParam& p = _param;
assert(!p.trans);
Expand Down Expand Up @@ -261,7 +261,7 @@ namespace Simd
}
}

void SynetDeconvolution16bGemm::ImgToRow(const float* src, float* dst)
void SynetDeconvolution16bGemm::RowToImg(const float* src, float* dst)
{
const DeconvParam& p = _param;
assert(p.trans && p.group == 1);
Expand Down
53 changes: 42 additions & 11 deletions src/Simd/SimdBaseSynetDeconvolution16bNhwcGemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ namespace Simd
SynetDeconvolution16bNhwcGemm::SynetDeconvolution16bNhwcGemm(const DeconvParam& p)
: SynetDeconvolution16b(p)
, _convert(0)
, _gemm(0)
, _toImg(0)
, _biasAct(0)
{
assert(p.trans && p.group == 1);
}
Expand Down Expand Up @@ -70,6 +73,7 @@ namespace Simd
a.bufM = p.dstH * AlignHi(p.dstW, a.F);
a.macroK = Simd::RestrictRange(AlignLo(L1 / a.microN / 2, a.microK), a.microK, a.bufK);
a.macroH = Simd::RestrictRange(L2 / a.macroK / p.dstW / 2, size_t(1), p.dstH);
a.macroM = a.macroH * p.dstW;
a.macroN = Simd::RestrictRange(AlignLoAny(L3 / a.macroK / 2, a.microN), a.microN, a.bufN);
_stepS = p.srcH * p.srcW * p.srcC * _elemS;
_stepD = p.dstH * p.dstW * p.dstC * _elemD;
Expand Down Expand Up @@ -112,24 +116,51 @@ namespace Simd
uint16_t* bufS = _src16b && a.bufK == a.K ? NULL : Allocate<uint16_t>(buf, a.bufK * a.bufM);
float* bufB = _is1x1 ? NULL : Allocate<float>(buf, a.bufN * a.bufM);
float* bufD = _dst16b ? Allocate<float>(buf, p.dstH * p.dstW * p.dstC) : NULL;
const uint16_t* wgt = _weight.data;
for (size_t b = 0; b < p.batch; ++b)
{
const uint16_t* src16b = _src16b ? (uint16_t*)src : bufS;
float* dst32f = _dst16b ? bufD : (float*)dst;
float* buf32f = _is1x1 ? dst32f : bufB;
if (!_src16b || a.bufK != a.K)
_convert(src, p, a, 0, p.srcH, bufS);
//GemmNN(_M, _N, _K, src16b, _ldS, wgt, _ldW, buf32f, _ldD);
//if (!_is1x1)
// ImgToRow(buf32f, dst32f);
_biasAct(dst32f, p, a, p.dstC, p.dstH, _bias.data, _params.data, dst);
ForwardCommon(src, bufS, bufB, bufD, dst);
src += _stepS;
dst += _stepD;
}
}

bool SynetDeconvolution16bNhwcGemm::Preferable(const ConvParam& p)
void SynetDeconvolution16bNhwcGemm::ForwardCommon(const uint8_t* src, uint16_t* bufS, float* bufB, float* bufD, uint8_t* dst)
{
const DeconvParam& p = _param;
const AlgParam& a = _alg;
const uint16_t* src16b = _src16b ? (uint16_t*)src : bufS;
float* dst32f = _dst16b ? bufD : (float*)dst;
float* buf32f = _is1x1 ? dst32f : bufB;
if (!_src16b || a.bufK != a.K)
_convert(src, p, a, 0, p.srcH, bufS);
GemmCommon(src16b, buf32f);
if (!_is1x1)
_toImg(buf32f, p, a, p.dstC, 0, p.dstH, dst32f);
_biasAct(dst32f, p, a, p.dstC, p.dstH, _bias.data, _params.data, dst);
}

void SynetDeconvolution16bNhwcGemm::GemmCommon(const uint16_t* src, float* dst)
{
const AlgParam& a = _alg;
for (size_t man = 0; man < a.N; man += a.macroN)
{
size_t macroN = Simd::Min(a.N, man + a.macroN) - man;
const uint16_t* wgt = _weight.data + man * a.bufK;
for (size_t mak = 0; mak < a.K; mak += a.macroK)
{
size_t macroK = Simd::Min(a.bufK, mak + a.macroK) - mak;
for (size_t mam = 0; mam < a.M; mam += a.macroM)
{
size_t macroM = Simd::Min(a.bufM, mam + a.macroM) - mam;
_gemm(src + mam * a.bufK + mak, _param, a, macroM, macroN, macroK, mak == 0 ? 1 : 0, wgt, dst + macroM * a.bufN);
}
wgt += macroK * a.F;
}
dst += macroN;
}
}

bool SynetDeconvolution16bNhwcGemm::Preferable(const DeconvParam& p)
{
return false;
}
Expand Down
42 changes: 42 additions & 0 deletions src/Simd/SimdSse41SynetDeconvolution16b.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Simd Library (http://ermig1979.github.io/Simd).
*
* Copyright (c) 2011-2024 Yermalayeu Ihar.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#include "Simd/SimdSynetDeconvolution16b.h"

namespace Simd
{
#if defined(SIMD_SSE41_ENABLE) && defined(SIMD_SYNET_ENABLE)
namespace Sse41
{
void* SynetDeconvolution16bInit(size_t batch, const SimdConvolutionParameters* conv, SimdSynetCompatibilityType compatibility)
{
DeconvParam param(batch, conv, compatibility);
if (!param.Valid(SimdTensorData32f, SimdTensorData16b))
return NULL;
if (SynetDeconvolution16bNhwcGemm::Preferable(param))
return new Sse41::SynetDeconvolution16bNhwcGemm(param);
return new Base::SynetDeconvolution16bGemm(param);
}
}
#endif
}
Loading

0 comments on commit 645dfd7

Please sign in to comment.