From 6fe04b58f6b0e4641221c68f9fd80513f5e06019 Mon Sep 17 00:00:00 2001 From: Yermalayeu Ihar Date: Tue, 23 Jul 2024 18:16:14 +0300 Subject: [PATCH] +add Sse41::Convolution16bNchwGemm_2. --- prj/vs2019/Sse41.vcxproj | 1 + prj/vs2019/Sse41.vcxproj.filters | 3 + prj/vs2022/Sse41.vcxproj | 1 + prj/vs2022/Sse41.vcxproj.filters | 3 + .../SimdSse41SynetConvolution16bNchwGemm.cpp | 408 ++++++++++++++++++ src/Simd/SimdSynetConvolution16b.h | 10 +- src/Simd/SimdSynetConvolution16bCommon.h | 68 +++ 7 files changed, 493 insertions(+), 1 deletion(-) create mode 100644 src/Simd/SimdSse41SynetConvolution16bNchwGemm.cpp diff --git a/prj/vs2019/Sse41.vcxproj b/prj/vs2019/Sse41.vcxproj index dec3c27583..20e2354d74 100644 --- a/prj/vs2019/Sse41.vcxproj +++ b/prj/vs2019/Sse41.vcxproj @@ -93,6 +93,7 @@ + diff --git a/prj/vs2019/Sse41.vcxproj.filters b/prj/vs2019/Sse41.vcxproj.filters index b849fdf6e4..20e847bdde 100644 --- a/prj/vs2019/Sse41.vcxproj.filters +++ b/prj/vs2019/Sse41.vcxproj.filters @@ -415,6 +415,9 @@ Sse41 + + Sse41 + diff --git a/prj/vs2022/Sse41.vcxproj b/prj/vs2022/Sse41.vcxproj index dec3c27583..20e2354d74 100644 --- a/prj/vs2022/Sse41.vcxproj +++ b/prj/vs2022/Sse41.vcxproj @@ -93,6 +93,7 @@ + diff --git a/prj/vs2022/Sse41.vcxproj.filters b/prj/vs2022/Sse41.vcxproj.filters index b849fdf6e4..20e847bdde 100644 --- a/prj/vs2022/Sse41.vcxproj.filters +++ b/prj/vs2022/Sse41.vcxproj.filters @@ -415,6 +415,9 @@ Sse41 + + Sse41 + diff --git a/src/Simd/SimdSse41SynetConvolution16bNchwGemm.cpp b/src/Simd/SimdSse41SynetConvolution16bNchwGemm.cpp new file mode 100644 index 0000000000..6b91494368 --- /dev/null +++ b/src/Simd/SimdSse41SynetConvolution16bNchwGemm.cpp @@ -0,0 +1,408 @@ +/* +* Simd Library (http://ermig1979.github.io/Simd). +* +* Copyright (c) 2011-2024 Yermalayeu Ihar. +* +* Permission is hereby granted, free of charge, to any person obtaining a copy +* of this software and associated documentation files (the "Software"), to deal +* in the Software without restriction, including without limitation the rights +* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +* copies of the Software, and to permit persons to whom the Software is +* furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +* SOFTWARE. +*/ +#include "Simd/SimdSynetConvolution16b.h" +#include "Simd/SimdSynetConvolution16bCommon.h" +#include "Simd/SimdBFloat16.h" +#include "Simd/SimdSynet.h" +#include "Simd/SimdSse41.h" +#include "Simd/SimdCpu.h" + +namespace Simd +{ +#if defined(SIMD_SSE41_ENABLE) && defined(SIMD_SYNET_ENABLE) + namespace Sse41 + { + typedef Base::SynetConvolution16bNchwGemm::AlgParam AlgParam; + typedef Base::SynetConvolution16bNchwGemm::ConvolutionPtr Convolution; + + //----------------------------------------------------------------------------------------- + + //static void Convert16bNhwcGemm(const uint8_t* src8, const ConvParam& p, const AlgParam& a, size_t yBeg, size_t yEnd, uint16_t* dst) + //{ + // const float* src = (float*)src8; + // size_t srcC8 = Simd::AlignLo(p.srcC, 8); + // size_t srcC4 = Simd::AlignLo(p.srcC, 4); + // size_t gap = a.bufK - a.K; + // for (size_t dy = yBeg, dr = 0; dy < yEnd; ++dy) + // { + // for (size_t dx = 0; dx < p.dstW; ++dx, ++dr) + // { + // uint16_t* row = dst + dr * a.bufK; + // for (size_t ky = 0, k = 0; ky < p.kernelY; ky++) + // { + // size_t sy = dy * p.strideY + ky * p.dilationY - p.padY; + // if (sy < p.srcH) + // { + // for (size_t kx = 0; kx < p.kernelX; kx++) + // { + // size_t sx = dx * p.strideX + kx * p.dilationX - p.padX; + // if (sx < p.srcW) + // { + // const float* ps = src + (sy * p.srcW + sx) * p.srcC; + // size_t sc = 0; + // for (; sc < srcC8; sc += 8) + // { + // __m128i d0 = Sse41::Float32ToBFloat16(_mm_loadu_ps(ps + sc + 0)); + // __m128i d1 = Sse41::Float32ToBFloat16(_mm_loadu_ps(ps + sc + 4)); + // _mm_storeu_si128((__m128i*)(row + sc), _mm_packus_epi32(d0, d1)); + // } + // for (; sc < srcC4; sc += 4) + // { + // __m128i d0 = Sse41::Float32ToBFloat16(_mm_loadu_ps(ps + sc + 0)); + // _mm_storel_epi64((__m128i*)(row + sc), _mm_packus_epi32(d0, Sse41::K_ZERO)); + // } + // for (; sc < p.srcC; ++sc) + // row[sc] = Base::Float32ToBFloat16(ps[sc]); + // row += p.srcC; + // } + // else + // { + // memset(row, 0, p.srcC * 2); + // row += p.srcC; + // } + // } + // } + // else + // { + // memset(row, 0, p.kernelX * p.srcC * 2); + // row += p.kernelX * p.srcC; + // } + // } + // for (size_t g = 0; g < gap; ++g) + // *(row++) = 0; + // } + // } + //} + + //static void Reorder16bNhwcGemm(const uint8_t* src8, const ConvParam& p, const AlgParam& a, size_t yBeg, size_t yEnd, uint16_t* dst) + //{ + // const uint16_t* src = (uint16_t*)src8; + // size_t gap = a.bufK - a.K; + // for (size_t dy = yBeg, dr = 0; dy < yEnd; ++dy) + // { + // for (size_t dx = 0; dx < p.dstW; ++dx, ++dr) + // { + // uint16_t* row = dst + dr * a.bufK; + // for (size_t ky = 0, k = 0; ky < p.kernelY; ky++) + // { + // size_t sy = dy * p.strideY + ky * p.dilationY - p.padY; + // if (sy < p.srcH) + // { + // for (size_t kx = 0; kx < p.kernelX; kx++) + // { + // size_t sx = dx * p.strideX + kx * p.dilationX - p.padX; + // if (sx < p.srcW) + // { + // const uint16_t* ps = src + (sy * p.srcW + sx) * p.srcC; + // memcpy(row, ps, p.srcC * 2); + // row += p.srcC; + // } + // else + // { + // memset(row, 0, p.srcC * 2); + // row += p.srcC; + // } + // } + // } + // else + // { + // memset(row, 0, p.kernelX * p.srcC * 2); + // row += p.kernelX * p.srcC; + // } + // } + // for (size_t g = 0; g < gap; ++g) + // *(row++) = 0; + // } + // } + //} + + //----------------------------------------------------------------------------------------- + + template void Convolution16bNchwGemm_2xM(const uint16_t* weight0, const ConvParam& p, const AlgParam& a, + size_t K, size_t dstS, int zero, const uint16_t* src0, const float* bias, const float* params, float* buf, uint8_t* dst) + { + __m128 d00, d01, d10, d11, d20, d21, d30, d31, d40, d41, w0, s00, s01, s10, s11, m = _mm_castsi128_ps(Bf16::MASK); + size_t dB = a.N, dD = a.N * a.elem; + const uint16_t* src1 = src0 + K * F; + const uint16_t* weight1 = weight0 + 1 * K; + const uint16_t* weight2 = weight0 + 2 * K; + const uint16_t* weight3 = weight0 + 3 * K; + const uint16_t* weight4 = weight0 + 4 * K; + if (dstS > F) + { + if (zero) + { + if (M > 0) d00 = _mm_setzero_ps(), d01 = _mm_setzero_ps(); + if (M > 1) d10 = _mm_setzero_ps(), d11 = _mm_setzero_ps(); + if (M > 2) d20 = _mm_setzero_ps(), d21 = _mm_setzero_ps(); + if (M > 3) d30 = _mm_setzero_ps(), d31 = _mm_setzero_ps(); + if (M > 4) d40 = _mm_setzero_ps(), d41 = _mm_setzero_ps(); + } + else + { + if (M > 0) d00 = _mm_loadu_ps(buf + 0 * dB + 0), d01 = _mm_loadu_ps(buf + 0 * dB + F); + if (M > 1) d10 = _mm_loadu_ps(buf + 1 * dB + 0), d11 = _mm_loadu_ps(buf + 1 * dB + F); + if (M > 2) d20 = _mm_loadu_ps(buf + 2 * dB + 0), d21 = _mm_loadu_ps(buf + 2 * dB + F); + if (M > 3) d30 = _mm_loadu_ps(buf + 3 * dB + 0), d31 = _mm_loadu_ps(buf + 3 * dB + F); + if (M > 4) d40 = _mm_loadu_ps(buf + 4 * dB + 0), d41 = _mm_loadu_ps(buf + 4 * dB + F); + } + for (size_t k = 0; k < K; k += 2) + { + s01 = _mm_loadu_ps((float*)src0); + s00 = _mm_castsi128_ps(_mm_slli_epi32(_mm_castps_si128(s01), Base::Bf16::SHIFT)); + s01 = _mm_and_ps(s01, m); + s11 = _mm_loadu_ps((float*)src1); + s10 = _mm_castsi128_ps(_mm_slli_epi32(_mm_castps_si128(s11), Base::Bf16::SHIFT)); + s11 = _mm_and_ps(s11, m); + if (M > 0) + { + w0 = _mm_and_ps(_mm_set1_ps(*(float*)(weight0 + k - 1)), m); + d00 = _mm_add_ps(_mm_mul_ps(w0, s00), d00); + d01 = _mm_add_ps(_mm_mul_ps(w0, s10), d01); + w0 = _mm_and_ps(_mm_set1_ps(*(float*)(weight0 + k - 0)), m); + d00 = _mm_add_ps(_mm_mul_ps(w0, s01), d00); + d01 = _mm_add_ps(_mm_mul_ps(w0, s11), d01); + } + if (M > 1) + { + w0 = _mm_and_ps(_mm_set1_ps(*(float*)(weight1 + k - 1)), m); + d10 = _mm_add_ps(_mm_mul_ps(w0, s00), d10); + d11 = _mm_add_ps(_mm_mul_ps(w0, s10), d11); + w0 = _mm_and_ps(_mm_set1_ps(*(float*)(weight1 + k - 0)), m); + d10 = _mm_add_ps(_mm_mul_ps(w0, s01), d10); + d11 = _mm_add_ps(_mm_mul_ps(w0, s11), d11); + } + if (M > 2) + { + w0 = _mm_and_ps(_mm_set1_ps(*(float*)(weight2 + k - 1)), m); + d20 = _mm_add_ps(_mm_mul_ps(w0, s00), d20); + d21 = _mm_add_ps(_mm_mul_ps(w0, s10), d21); + w0 = _mm_and_ps(_mm_set1_ps(*(float*)(weight2 + k - 0)), m); + d20 = _mm_add_ps(_mm_mul_ps(w0, s01), d20); + d21 = _mm_add_ps(_mm_mul_ps(w0, s11), d21); + } + if (M > 3) + { + w0 = _mm_and_ps(_mm_set1_ps(*(float*)(weight3 + k - 1)), m); + d30 = _mm_add_ps(_mm_mul_ps(w0, s00), d30); + d31 = _mm_add_ps(_mm_mul_ps(w0, s10), d31); + w0 = _mm_and_ps(_mm_set1_ps(*(float*)(weight3 + k - 0)), m); + d30 = _mm_add_ps(_mm_mul_ps(w0, s01), d30); + d31 = _mm_add_ps(_mm_mul_ps(w0, s11), d31); + } + if (M > 4) + { + w0 = _mm_and_ps(_mm_set1_ps(*(float*)(weight4 + k - 1)), m); + d40 = _mm_add_ps(_mm_mul_ps(w0, s00), d40); + d41 = _mm_add_ps(_mm_mul_ps(w0, s10), d41); + w0 = _mm_and_ps(_mm_set1_ps(*(float*)(weight4 + k - 0)), m); + d40 = _mm_add_ps(_mm_mul_ps(w0, s01), d40); + d41 = _mm_add_ps(_mm_mul_ps(w0, s11), d41); + } + src0 += DF; + src1 += DF; + } + if (dstS == DF) + { + if (M > 0) Save2(dst, buf, d00, d01, bias, params, 0), dst += dD, buf += dB; + if (M > 1) Save2(dst, buf, d10, d11, bias, params, 1), dst += dD, buf += dB; + if (M > 2) Save2(dst, buf, d20, d21, bias, params, 2), dst += dD, buf += dB; + if (M > 3) Save2(dst, buf, d30, d31, bias, params, 3), dst += dD, buf += dB; + if (M > 4) Save2(dst, buf, d40, d41, bias, params, 4), dst += dD, buf += dB; + } + else + { + dstS -= F; + if (M > 0) Save2(dst, buf, d00, d01, bias, params, 0, dstS), dst += dD, buf += dB; + if (M > 1) Save2(dst, buf, d10, d11, bias, params, 1, dstS), dst += dD, buf += dB; + if (M > 2) Save2(dst, buf, d20, d21, bias, params, 2, dstS), dst += dD, buf += dB; + if (M > 3) Save2(dst, buf, d30, d31, bias, params, 3, dstS), dst += dD, buf += dB; + if (M > 4) Save2(dst, buf, d40, d41, bias, params, 4, dstS), dst += dD, buf += dB; + } + } + else + { + if (zero) + { + if (M > 0) d00 = _mm_setzero_ps(); + if (M > 1) d10 = _mm_setzero_ps(); + if (M > 2) d20 = _mm_setzero_ps(); + if (M > 3) d30 = _mm_setzero_ps(); + if (M > 4) d40 = _mm_setzero_ps(); + } + else + { + if (M > 0) d00 = _mm_loadu_ps(buf + 0 * dB + 0); + if (M > 1) d10 = _mm_loadu_ps(buf + 1 * dB + 0); + if (M > 2) d20 = _mm_loadu_ps(buf + 2 * dB + 0); + if (M > 3) d30 = _mm_loadu_ps(buf + 3 * dB + 0); + if (M > 4) d40 = _mm_loadu_ps(buf + 4 * dB + 0); + } + for (size_t k = 0; k < K; k += 2) + { + s01 = _mm_loadu_ps((float*)src0); + s00 = _mm_castsi128_ps(_mm_slli_epi32(_mm_castps_si128(s01), Base::Bf16::SHIFT)); + s01 = _mm_and_ps(s01, m); + if (M > 0) + { + w0 = _mm_and_ps(_mm_set1_ps(*(float*)(weight0 + k - 1)), m); + d00 = _mm_add_ps(_mm_mul_ps(w0, s00), d00); + w0 = _mm_and_ps(_mm_set1_ps(*(float*)(weight0 + k - 0)), m); + d00 = _mm_add_ps(_mm_mul_ps(w0, s01), d00); + } + if (M > 1) + { + w0 = _mm_and_ps(_mm_set1_ps(*(float*)(weight1 + k - 1)), m); + d10 = _mm_add_ps(_mm_mul_ps(w0, s00), d10); + w0 = _mm_and_ps(_mm_set1_ps(*(float*)(weight1 + k - 0)), m); + d10 = _mm_add_ps(_mm_mul_ps(w0, s01), d10); + } + if (M > 2) + { + w0 = _mm_and_ps(_mm_set1_ps(*(float*)(weight2 + k - 1)), m); + d20 = _mm_add_ps(_mm_mul_ps(w0, s00), d20); + w0 = _mm_and_ps(_mm_set1_ps(*(float*)(weight2 + k - 0)), m); + d20 = _mm_add_ps(_mm_mul_ps(w0, s01), d20); + } + if (M > 3) + { + w0 = _mm_and_ps(_mm_set1_ps(*(float*)(weight3 + k - 1)), m); + d30 = _mm_add_ps(_mm_mul_ps(w0, s00), d30); + w0 = _mm_and_ps(_mm_set1_ps(*(float*)(weight3 + k - 0)), m); + d30 = _mm_add_ps(_mm_mul_ps(w0, s01), d30); + } + if (M > 4) + { + w0 = _mm_and_ps(_mm_set1_ps(*(float*)(weight4 + k - 1)), m); + d40 = _mm_add_ps(_mm_mul_ps(w0, s00), d40); + w0 = _mm_and_ps(_mm_set1_ps(*(float*)(weight4 + k - 0)), m); + d40 = _mm_add_ps(_mm_mul_ps(w0, s01), d40); + } + src0 += DF; + } + if (dstS == F) + { + if (M > 0) Save1(dst, buf, d00, bias, params, 0), dst += dD, buf += dB; + if (M > 1) Save1(dst, buf, d10, bias, params, 1), dst += dD, buf += dB; + if (M > 2) Save1(dst, buf, d20, bias, params, 2), dst += dD, buf += dB; + if (M > 3) Save1(dst, buf, d30, bias, params, 3), dst += dD, buf += dB; + if (M > 4) Save1(dst, buf, d40, bias, params, 4), dst += dD, buf += dB; + } + else + { + if (M > 0) Save1(dst, buf, d00, bias, params, 0, dstS), dst += dD, buf += dB; + if (M > 1) Save1(dst, buf, d10, bias, params, 1, dstS), dst += dD, buf += dB; + if (M > 2) Save1(dst, buf, d20, bias, params, 2, dstS), dst += dD, buf += dB; + if (M > 3) Save1(dst, buf, d30, bias, params, 3, dstS), dst += dD, buf += dB; + if (M > 4) Save1(dst, buf, d40, bias, params, 4, dstS), dst += dD, buf += dB; + } + } + } + + typedef void(*Convolution16bNchwGemm_2xM_Ptr)(const uint16_t* src0, const ConvParam& p, const AlgParam& a, + size_t srcC, size_t dstC, int zero, const uint16_t* weight0, const float* bias, const float* params, float* buf, uint8_t* dst); + + template Convolution16bNchwGemm_2xM_Ptr GetConvolution16bNchwGemm_2xM(size_t M) + { + switch (M) + { + case 0: return NULL; + case 1: return Convolution16bNchwGemm_2xM; + case 2: return Convolution16bNchwGemm_2xM; + case 3: return Convolution16bNchwGemm_2xM; + case 4: return Convolution16bNchwGemm_2xM; + case 5: return Convolution16bNchwGemm_2xM; + } + assert(0); + return NULL; + } + + template void Convolution16bNchwGemm_2(const uint16_t* weight, const ConvParam& p, const AlgParam& a, + size_t dstC, size_t dstH, size_t K, int zero, const uint16_t* src, const float* bias, const float* params, float* buf, uint8_t* dst) + { + size_t dstS = dstH * p.dstW, n1 = dstC, n = 5; + size_t nn = AlignLoAny(n1, n), m = n1 - nn; + size_t dB = a.N, dD = a.N * a.elem, dW = K, dp = type == ::SimdConvolutionActivationPrelu ? 1 : 0; + Convolution16bNchwGemm_2xM_Ptr convolution_2xN = GetConvolution16bNchwGemm_2xM(n); + Convolution16bNchwGemm_2xM_Ptr convolution_2xM = GetConvolution16bNchwGemm_2xM(m); + + for (size_t ds = 0; ds < dstS; ds += DF) + { + size_t dS = Simd::Min(DF, dstS - ds); + const uint16_t* w = weight; + float* b = buf + ds; + uint8_t* d = dst + ds * a.elem; + size_t i = 0; + for (; i < nn; i += n, w += n * dW, b += n * dB, d += n * dD) + convolution_2xN(w, p, a, K, dS, zero, src, bias + i, params + i * dp, b, d); + for (; i < n1; i += m, w += m * dW, b += m * dB, d += m * dD) + convolution_2xM(w, p, a, K, dS, zero, src, bias + i, params + i * dp, b, d); + src += K * DF; + } + } + + //----------------------------------------------------------------------------------------- + + template SIMD_INLINE void Set(const ConvParam& p, const AlgParam & a, Convolution* convolutions) + { + convolutions[0] = Convolution16bNchwGemm_2; + if(p.dstT == SimdTensorData16b) + convolutions[1] = Convolution16bNchwGemm_2; + else + convolutions[1] = Convolution16bNchwGemm_2; + } + + SynetConvolution16bNchwGemm::SynetConvolution16bNchwGemm(const ConvParam & p) + : Base::SynetConvolution16bNchwGemm(p) + { + SetAlgParam(F, F * 2, 5, 2, Base::AlgCacheL1(), Base::AlgCacheL2(), Base::AlgCacheL3()); + if (_src16b) + { + AlgParam& a = _alg; + //if (_is1x1 && a.K == a.bufK) + // _convert = NULL; + //else + // _convert = Reorder16bNhwcGemm; + } + //else + // _convert = Convert16bNhwcGemm; + switch (p.activation) + { + case SimdConvolutionActivationIdentity: Set(p, _alg, _convolutions); break; + case SimdConvolutionActivationRelu: Set(p, _alg, _convolutions); break; + case SimdConvolutionActivationLeakyRelu: Set(p, _alg, _convolutions); break; + case SimdConvolutionActivationRestrictRange: Set(p, _alg, _convolutions); break; + case SimdConvolutionActivationPrelu: Set(p, _alg, _convolutions); break; + case SimdConvolutionActivationElu: Set(p, _alg, _convolutions); break; + case SimdConvolutionActivationHswish: Set(p, _alg, _convolutions); break; + case SimdConvolutionActivationMish: Set(p, _alg, _convolutions); break; + case SimdConvolutionActivationHardSigmoid: Set(p, _alg, _convolutions); break; + case SimdConvolutionActivationSwish: Set(p, _alg, _convolutions); break; + case SimdConvolutionActivationGelu: Set(p, _alg, _convolutions); break; + default: assert(0); + } + } + } +#endif +} diff --git a/src/Simd/SimdSynetConvolution16b.h b/src/Simd/SimdSynetConvolution16b.h index 2bf17ccd2f..651339b7df 100644 --- a/src/Simd/SimdSynetConvolution16b.h +++ b/src/Simd/SimdSynetConvolution16b.h @@ -224,7 +224,7 @@ namespace Simd typedef void(*ConvertPtr)(const uint8_t* src, const ConvParam& p, const AlgParam& a, size_t yBeg, size_t yEnd, size_t cBeg, size_t cEnd, uint16_t* dst); typedef void(*ConvolutionPtr)(const uint16_t* weight, const ConvParam& p, const AlgParam& a, size_t dstC, size_t dstH, - size_t srcC, int zero, const uint16_t* src, const float* bias, const float* params, float* sum, uint8_t* dst); + size_t K, int zero, const uint16_t* src, const float* bias, const float* params, float* sum, uint8_t* dst); protected: void SetAlgParam(size_t F, size_t microD, size_t microN, size_t microK, size_t L1, size_t L2, size_t L3); @@ -260,6 +260,14 @@ namespace Simd virtual String Ext() const { return "Sse41"; } }; + class SynetConvolution16bNchwGemm : public Base::SynetConvolution16bNchwGemm + { + public: + SynetConvolution16bNchwGemm(const ConvParam& p); + + virtual String Ext() const { return "Sse41"; } + }; + //------------------------------------------------------------------------------------------------- void* SynetConvolution16bInit(size_t batch, const SimdConvolutionParameters* conv, SimdSynetCompatibilityType compatibility); diff --git a/src/Simd/SimdSynetConvolution16bCommon.h b/src/Simd/SimdSynetConvolution16bCommon.h index b9b75d9e36..253c41763d 100644 --- a/src/Simd/SimdSynetConvolution16bCommon.h +++ b/src/Simd/SimdSynetConvolution16bCommon.h @@ -50,6 +50,9 @@ namespace Simd template static SIMD_INLINE void Postprocess(const float* src, const float* bias, const float* params, size_t offset, uint8_t* dst); template static SIMD_INLINE void Postprocess(const float* src, const float* bias, const float* params, size_t offset, uint8_t* dst, size_t tail); + + template static SIMD_INLINE void Save(uint8_t* ptr, float* buf, __m128 value, const float* bias, const float* params, size_t offset); + template static SIMD_INLINE void Save(uint8_t* ptr, float* buf, __m128 value, const float* bias, const float* params, size_t offset, size_t tail); }; template <> struct Term16b @@ -98,6 +101,21 @@ namespace Simd for (size_t i = 0; i < tail; ++i) ((uint16_t*)dst)[offset + i] = tmp[i]; } + + template static SIMD_INLINE void Save(uint8_t* ptr, float* buf, __m128 value, const float* bias, const float* params, size_t offset) + { + __m128 f32 = Activate(_mm_add_ps(value, _mm_set1_ps(bias[offset])), params, offset); + _mm_storel_epi64((__m128i*)(ptr + index * DF), _mm_packus_epi32(Float32ToBFloat16(f32), K_ZERO)); + } + + template static SIMD_INLINE void Save(uint8_t* ptr, float* buf, __m128 value, const float* bias, const float* params, size_t offset, size_t tail) + { + __m128 f32 = Activate(_mm_add_ps(value, _mm_set1_ps(bias[offset])), params, offset); + uint16_t tmp[F]; + _mm_storel_epi64((__m128i*)tmp, _mm_packus_epi32(Float32ToBFloat16(f32), K_ZERO)); + for (size_t i = 0; i < tail; ++i) + ((uint16_t*)ptr)[i + index * F] = tmp[i]; + } }; template <> struct Term16b @@ -142,6 +160,19 @@ namespace Simd for (size_t i = 0; i < tail; ++i) ((float*)dst)[offset + i] = tmp[i]; } + + template static SIMD_INLINE void Save(uint8_t* ptr, float* buf, __m128 value, const float* bias, const float* params, size_t offset) + { + _mm_storeu_ps((float*)ptr + index * F, Activate(_mm_add_ps(value, _mm_set1_ps(bias[offset])), params, offset)); + } + + template static SIMD_INLINE void Save(uint8_t* ptr, float* buf, __m128 value, const float* bias, const float* params, size_t offset, size_t tail) + { + float tmp[F]; + _mm_storeu_ps(tmp, Activate(_mm_add_ps(value, _mm_set1_ps(bias[offset])), params, offset)); + for (size_t i = 0; i < tail; ++i) + ((float*)ptr)[i + index * F] = tmp[i]; + } }; template <> struct Term16b @@ -179,6 +210,19 @@ namespace Simd template static SIMD_INLINE void Postprocess(const float* src, const float* bias, const float* params, size_t offset, uint8_t* dst, size_t tail) { } + + template static SIMD_INLINE void Save(uint8_t* ptr, float* buf, __m128 value, const float* bias, const float* params, size_t offset) + { + _mm_storeu_ps(buf + index * F, value); + } + + template static SIMD_INLINE void Save(uint8_t* ptr, float* buf, __m128 value, const float* bias, const float* params, size_t offset, size_t tail) + { + float tmp[F]; + _mm_storeu_ps(tmp, value); + for (size_t i = 0; i < tail; ++i) + buf[i + index * F] = tmp[i]; + } }; //------------------------------------------------------------------------------------------------- @@ -260,6 +304,30 @@ namespace Simd Term16b::template Save(ptr, buf, val0, bias, NULL); Term16b::template Save(ptr, buf, val1, bias, NULL, tail); } + + //------------------------------------------------------------------------------------------------- + + template SIMD_INLINE void Save1(uint8_t* ptr, float* buf, __m128 val0, const float* bias, const float* params, size_t offset) + { + Term16b::template Save(ptr, buf, val0, bias, params, offset); + } + + template SIMD_INLINE void Save1(uint8_t* ptr, float* buf, __m128 val0, const float* bias, const float* params, size_t offset, size_t tail) + { + Term16b::template Save(ptr, buf, val0, bias, params, tail, offset); + } + + template SIMD_INLINE void Save2(uint8_t* ptr, float* buf, __m128 val0, __m128 val1, const float* bias, const float* params, size_t offset) + { + Term16b::template Save(ptr, buf, val0, bias, params, offset); + Term16b::template Save(ptr, buf, val1, bias, params, offset); + } + + template SIMD_INLINE void Save2(uint8_t* ptr, float* buf, __m128 val0, __m128 val1, const float* bias, const float* params, size_t offset, size_t tail) + { + Term16b::template Save(ptr, buf, val0, bias, params, offset); + Term16b::template Save(ptr, buf, val1, bias, params, offset, tail); + } } #endif