Skip to content

Commit

Permalink
*fix bugs in Base implementation of class SynetInnerProduct16bRef.
Browse files Browse the repository at this point in the history
  • Loading branch information
ermig1979 committed Jun 5, 2024
1 parent 9fbca53 commit 2521253
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 9 deletions.
5 changes: 3 additions & 2 deletions src/Simd/SimdBaseSynetInnerProduct16b.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ namespace Simd
{
_sizeA = p.typeA == SimdTensorData32f ? p.M * p.K : 0;
_sizeB = (p.typeB == SimdTensorData32f && !p.constB) ? p.K * p.N : 0;
_sizeC = p.typeB == SimdTensorData16b ? p.M * p.N : 0;
_sizeC = p.typeC == SimdTensorData16b ? p.M * p.N : 0;
}

String SynetInnerProduct16bRef::Desc() const
Expand Down Expand Up @@ -101,7 +101,8 @@ namespace Simd
{
const uint16_t* pA = A + i * p.K;
const uint16_t* pB = B + j * p.K;
for (size_t k = 0; k < p.K; ++k)
pC[j] = 0;
for (size_t k = 0; k < p.K; ++k)
pC[j] += BFloat16ToFloat32(pA[k]) * BFloat16ToFloat32(pB[k]);
}
}
Expand Down
5 changes: 3 additions & 2 deletions src/Simd/SimdSynetInnerProduct16b.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,9 @@ namespace Simd
// SynetInnerProduct16bGemm(const InnerProductParam16b& p);
// virtual String Ext() const { return "Base"; }
// virtual String Desc() const;
// virtual void SetParams(const float* weight, SimdBool* internal, const float* bias);
// virtual void Forward(const uint8_t* A, const uint8_t* B, const uint8_t* bias, uint8_t* C);
// virtual size_t ExternalBufferSize() const;
// virtual void SetParams(const float* weight, const float* bias);
// virtual void Forward(const uint8_t* A, const uint8_t* B, uint8_t* buf, uint8_t* C);

//protected:
// //typedef void(*GemmPtr)(size_t M, size_t N, size_t K, const float* alpha, const float* A, size_t lda, const float* B, size_t ldb, const float* beta, float* C, size_t ldc);
Expand Down
13 changes: 8 additions & 5 deletions src/Test/TestSynetInnerProduct16b.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,14 @@ namespace Test
}
result = result && Compare(C1f, C2f, eps, true, 64, DifferenceBoth);

if(0)
if(1)
{
void* context3 = SimdSynetInnerProduct32fInit(p.M, p.K, p.N, p.transB, SimdConvolutionActivationIdentity);
::SimdSynetInnerProduct32fSetParams(context3, Bf.Data(), NULL, bias.Data(), NULL);
::SimdSynetInnerProduct32fSetParams(context3, Bf.Data(), NULL, p.bias ? bias.Data() : NULL, NULL);
::SimdSynetInnerProduct32fForward(context3, Af.Data(), C3f.Data());
::SimdRelease(context3);

result = result && Compare(C1f, C3f, 0.03, true, 64, DifferenceBoth, " Compare to SynetInnerProduct32f.");//0.129
result = result && Compare(C1f, C3f, 0.033, true, 64, DifferenceBoth, " Compare to SynetInnerProduct32f.");//0.129
}

return result;
Expand All @@ -141,10 +141,13 @@ namespace Test

#if defined(NDEBUG)
#if 1
result = result && SynetInnerProduct16bForwardAutoTest(eps, Param(128, 128, 128, f32, f32, f32, f, t, f), f1, f2);
result = result && SynetInnerProduct16bForwardAutoTest(eps, Param(128, 128, 128, f32, f32, b16, f, f, t), f1, f2);
result = result && SynetInnerProduct16bForwardAutoTest(eps, Param(128, 128, 128, b16, b16, f32, f, t, f), f1, f2);
result = result && SynetInnerProduct16bForwardAutoTest(eps, Param(128, 128, 128, f32, f32, b16, t, f, t), f1, f2);
result = result && SynetInnerProduct16bForwardAutoTest(eps, Param(128, 128, 128, b16, b16, f32, t, t, f), f1, f2);
#endif
#else
result = result && SynetInnerProduct16bForwardAutoTest(eps, Param(128, 128, 128, f32, f32, f32, f, t, f), f1, f2);
result = result && SynetInnerProduct16bForwardAutoTest(eps, Param(128, 128, 128, f32, f32, b16, f, f, f), f1, f2);
#endif

return result;
Expand Down

0 comments on commit 2521253

Please sign in to comment.