diff --git a/src/Simd/SimdAvx2SynetConvolution16bNhwcDirect.cpp b/src/Simd/SimdAvx2SynetConvolution16bNhwcDirect.cpp index 9e482659f3..f1b269ce69 100644 --- a/src/Simd/SimdAvx2SynetConvolution16bNhwcDirect.cpp +++ b/src/Simd/SimdAvx2SynetConvolution16bNhwcDirect.cpp @@ -174,6 +174,132 @@ namespace Simd //----------------------------------------------------------------------------------------- + static void Convert16bNhwcDirect2(const uint8_t* src8, const ConvParam& p, const AlgParam& a, size_t dyBeg, size_t dyEnd, uint16_t* dst) + { + assert(a.microC == DF * 2); + const float* src = (float*)src8; + size_t srcCDF = Simd::AlignLo(p.srcC, DF), srcCQF = Simd::AlignLo(p.srcC, QF), tailC0 = srcCDF - srcCQF, tailC1 = p.srcC - srcCDF; + size_t syPad = p.kernelY - 1 - p.padY, syBeg, syEnd = (dyEnd == p.dstH ? p.srcH : dyEnd + syPad); + size_t cD = a.batch * a.srcH * a.srcW, sD = a.microC; + if (dyBeg == 0) + { + for (size_t s = 0, n = p.padY * a.srcW; s < n; ++s) + for (size_t c = 0; c < a.srcC; c += a.microC) + Avx2::SetZero2(dst + c * cD + s * sD); + dst += p.padY * a.srcW * sD; + syBeg = 0; + } + else + { + syBeg = dyBeg + syPad; + src += syBeg * p.srcW * p.srcC; + dst += (dyBeg + p.kernelY - 1) * a.srcW * sD; + } + for (size_t sy = syBeg; sy < syEnd; ++sy) + { + if (p.padX) + { + for (size_t s = 0; s < p.padX; ++s) + for (size_t c = 0; c < a.srcC; c += a.microC) + Avx2::SetZero2(dst + c * cD + s * sD); + dst += p.padX * sD; + } + for (size_t sx = 0; sx < p.srcW; ++sx) + { + size_t sc = 0; + for (; sc < srcCQF; sc += QF) + { + Avx2::Float32ToBFloat16(src + sc + 0 * F, dst + sc * cD + 0 * F); + Avx2::Float32ToBFloat16(src + sc + 2 * F, dst + sc * cD + 2 * F); + } + if (tailC0) + Avx2::Float32ToBFloat16Tail(src + sc + 0 * F, tailC0, dst + sc * cD + 0 * F); + if (tailC1) + Avx2::Float32ToBFloat16Tail(src + sc + 2 * F, tailC1, dst + sc * cD + 2 * F); + src += p.srcC; + dst += sD; + } + if (p.padW) + { + for (size_t s = 0; s < p.padW; ++s) + for (size_t c = 0; c < a.srcC; c += a.microC) + Avx2::SetZero2(dst + c * cD + s * sD); + dst += p.padW * sD; + } + } + if (dyEnd == p.dstH) + { + for (size_t s = 0, n = p.padH * a.srcW; s < n; ++s) + for (size_t c = 0; c < a.srcC; c += a.microC) + Avx2::SetZero2(dst + c * cD + s * sD); + dst += p.padH * a.srcW * sD; + } + } + + static void Reorder16bNhwcDirect2(const uint8_t* src8, const ConvParam& p, const AlgParam& a, size_t dyBeg, size_t dyEnd, uint16_t* dst) + { + assert(a.microC == DF * 2); + const uint16_t* src = (uint16_t*)src8; + size_t srcCDF = Simd::AlignLo(p.srcC, DF), srcCQF = Simd::AlignLo(p.srcC, QF), tailC0 = srcCDF - srcCQF, tailC1 = p.srcC - srcCDF; + size_t syPad = p.kernelY - 1 - p.padY, syBeg, syEnd = (dyEnd == p.dstH ? p.srcH : dyEnd + syPad); + size_t cD = a.batch * a.srcH * a.srcW, sD = a.microC; + if (dyBeg == 0) + { + for (size_t s = 0, n = p.padY * a.srcW; s < n; ++s) + for (size_t c = 0; c < a.srcC; c += a.microC) + Avx2::SetZero2(dst + c * cD + s * sD); + dst += p.padY * a.srcW * sD; + syBeg = 0; + } + else + { + syBeg = dyBeg + syPad; + src += syBeg * p.srcW * p.srcC; + dst += (dyBeg + p.kernelY - 1) * a.srcW * sD; + } + for (size_t sy = syBeg; sy < syEnd; ++sy) + { + if (p.padX) + { + for (size_t s = 0; s < p.padX; ++s) + for (size_t c = 0; c < a.srcC; c += a.microC) + Avx2::SetZero2(dst + c * cD + s * sD); + dst += p.padX * sD; + } + for (size_t sx = 0; sx < p.srcW; ++sx) + { + size_t sc = 0; + for (; sc < srcCQF; sc += QF) + { + Avx2::Copy(src + sc + 0 * F, dst + sc * cD + 0 * F); + Avx2::Copy(src + sc + 2 * F, dst + sc * cD + 2 * F); + } + if (tailC0) + Avx2::Copy(src + sc + 0 * F, tailC0, dst + sc * cD + 0 * F); + if (tailC1) + Avx2::Copy(src + sc + 2 * F, tailC1, dst + sc * cD + 2 * F); + src += p.srcC; + dst += sD; + } + if (p.padW) + { + for (size_t s = 0; s < p.padW; ++s) + for (size_t c = 0; c < a.srcC; c += a.microC) + Avx2::SetZero2(dst + c * cD + s * sD); + dst += p.padW * sD; + } + } + if (dyEnd == p.dstH) + { + for (size_t s = 0, n = p.padH * a.srcW; s < n; ++s) + for (size_t c = 0; c < a.srcC; c += a.microC) + Avx2::SetZero2(dst + c * cD + s * sD); + dst += p.padH * a.srcW * sD; + } + } + + //----------------------------------------------------------------------------------------- + template void Convolution16bNhwcDirect_2xM(const uint16_t* src0, const ConvParam& p, const AlgParam& a, size_t srcC, size_t dstC, int zero, const uint16_t* weight0, float* dst) { __m256 d00, d01, d10, d11, d20, d21, d30, d31, d40, d41, s0, w00, w01, w10, w11, m = _mm256_castsi256_ps(Bf16::MASK); @@ -422,14 +548,24 @@ namespace Simd //----------------------------------------------------------------------------------------- +//#define SIMD_CONV16B_COMPATIBLE_WITH_AMX + SynetConvolution16bNhwcDirect::SynetConvolution16bNhwcDirect(const ConvParam & p) : Sse41::SynetConvolution16bNhwcDirect(p) { +#ifdef SIMD_CONV16B_COMPATIBLE_WITH_AMX + SetAlgParam(F, F * 2, 5, F * 4, Base::AlgCacheL1(), Base::AlgCacheL2(), Base::AlgCacheL3()); + if (_src16b) + _preprocess = Reorder16bNhwcDirect2; + else + _preprocess = Convert16bNhwcDirect2; +#else SetAlgParam(F, F * 2, 5, F * 2, Base::AlgCacheL1(), Base::AlgCacheL2(), Base::AlgCacheL3()); if (_src16b) _preprocess = Reorder16bNhwcDirect; else _preprocess = Convert16bNhwcDirect; +#endif _convolution = Convolution16bNhwcDirect_2; switch (p.activation) { diff --git a/src/Simd/SimdSet.h b/src/Simd/SimdSet.h index 6a719b4337..ec416ce2ac 100644 --- a/src/Simd/SimdSet.h +++ b/src/Simd/SimdSet.h @@ -131,6 +131,12 @@ namespace Simd { _mm256_storeu_si256((__m256i*)dst, _mm256_setzero_si256()); } + + SIMD_INLINE void SetZero2(uint16_t* dst) + { + _mm256_storeu_si256((__m256i*)dst + 0, _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*)dst + 1, _mm256_setzero_si256()); + } } #endif diff --git a/src/Test/TestSynetConvolution16b.cpp b/src/Test/TestSynetConvolution16b.cpp index 52a5cd567a..050a799b36 100644 --- a/src/Test/TestSynetConvolution16b.cpp +++ b/src/Test/TestSynetConvolution16b.cpp @@ -129,7 +129,7 @@ namespace Test buf8u1.Extend({ ::SimdSynetConvolution16bExternalBufferSize(context1) }); buf8u2.Extend({ ::SimdSynetConvolution16bExternalBufferSize(context2) }); Fill(buf8u1, uint8_t(1)); - Fill(buf8u2, uint8_t(2)); + Fill(buf8u2, uint8_t(-1)); ::SimdSynetConvolution16bSetParams(context1, weight.Data(), bias.Data(), params.Data()); ::SimdSynetConvolution16bSetParams(context2, weight.Data(), bias.Data(), params.Data()); @@ -272,17 +272,13 @@ namespace Test result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 55, 15, 15, 55, _1, _1, _1, _0, _0, 1, aPr, tF, b16, b16), c, f1, f2); #endif #if 1 - result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 608, 8, 8, 608, _1, _1, _1, _0, _0, 1, aPr, tT, b16, b16), c, f1, f2); - result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 608, 8, 8, 608, _1, _1, _1, _0, _0, 1, aPr, tT, f32, b16), c, f1, f2); + result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 48, 192, 256, 48, _3, _1, _1, _1, _1, 1, aRe, tT, f32, b16), c, f1, f2); + //result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 608, 8, 8, 608, _1, _1, _1, _0, _0, 1, aPr, tT, b16, b16), c, f1, f2); + //result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 608, 8, 8, 608, _1, _1, _1, _0, _0, 1, aPr, tT, f32, b16), c, f1, f2); #endif #else - result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 2156, 4, 4, 4, _1, _1, _1, _0, _0, 1, aId, tF, b16, b16), c, f1, f2); - //result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 64, 88, 88, 128, _3, _1, _2, _1, _1, 1, aRe, tT, f32, b16), c, f1, f2); - //result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 64, 88, 88, 128, _3, _1, _2, _1, _1, 1, aRe, tT, f32, b16), c, f1, f2); - //result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 64, 88, 88, 64, _3, _1, _1, _1, _1, 1, aSw, tT, b16, f32), c, f1, f2); - //result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 256, 44, 44, 256, _1, _1, _1, _0, _0, 1, aRe, tT, b16, f32), c, f1, f2); - + result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 48, 192, 256, 48, _3, _1, _1, _1, _1, 1, aRe, tT, f32, b16), c, f1, f2); #endif return result;