Skip to content

Commit

Permalink
opt++
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Sep 3, 2024
1 parent e538796 commit 464da66
Show file tree
Hide file tree
Showing 2 changed files with 430 additions and 4 deletions.
186 changes: 183 additions & 3 deletions src/layer/arm/rmsnorm_arm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ RMSNorm_arm::RMSNorm_arm()
#if __ARM_NEON
support_packing = true;
#if NCNN_ARM82
// support_fp16_storage = cpu_support_arm_asimdhp();
support_fp16_storage = cpu_support_arm_asimdhp();
#endif
#endif // __ARM_NEON

#if NCNN_BF16
// support_bf16_storage = true;
support_bf16_storage = true;
#endif
}

Expand Down Expand Up @@ -72,7 +72,7 @@ static void rmsnorm(float* ptr, const float* gamma_ptr, float eps, int elemcount
float32x4_t _eps = vdupq_n_f32(eps);

#if __aarch64__
_sqsum = vdivq_f32(_sqsum, vdupq_n_f32(elemcount));
_sqsum = vdivq_f32(_sqsum, _elemcount);
_sqsum = vaddq_f32(_sqsum, _eps);
#else
float32x4_t _inv_elemcount = vrecpeq_f32(_elemcount);
Expand Down Expand Up @@ -232,8 +232,188 @@ int RMSNorm_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
}

#if NCNN_BF16
static void rmsnorm_bf16s(unsigned short* ptr, const float* gamma_ptr, float eps, int elemcount, int elempack)
{
const int size = elemcount * elempack;

#if __ARM_NEON
float32x4_t _sqsum = vdupq_n_f32(0.f);
#endif // __ARM_NEON
float sqsum = 0.f;
{
const unsigned short* ptr0 = ptr;

int i = 0;
#if __ARM_NEON
for (; i + 3 < size; i += 4)
{
float32x4_t _p = bfloat2float(vld1_u16(ptr0));
_sqsum = vmlaq_f32(_sqsum, _p, _p);
ptr0 += 4;
}
#endif // __ARM_NEON
for (; i < size; i++)
{
float v = bfloat16_to_float32(ptr0[0]);
sqsum += v * v;
ptr0++;
}
}

#if __ARM_NEON
float32x4_t _a;
if (elempack == 4)
{
float32x4_t _elemcount = vdupq_n_f32(elemcount);
float32x4_t _eps = vdupq_n_f32(eps);

#if __aarch64__
_sqsum = vdivq_f32(_sqsum, _elemcount);
_sqsum = vaddq_f32(_sqsum, _eps);
#else
float32x4_t _inv_elemcount = vrecpeq_f32(_elemcount);
_inv_elemcount = vmulq_f32(vrecpsq_f32(_elemcount, _inv_elemcount), _inv_elemcount);
_inv_elemcount = vmulq_f32(vrecpsq_f32(_elemcount, _inv_elemcount), _inv_elemcount);
_sqsum = vmlaq_f32(_eps, _sqsum, _inv_elemcount);
#endif

_a = vrsqrteq_f32(_sqsum);
_a = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_sqsum, _a), _a), _a);
_a = vmulq_f32(vrsqrtsq_f32(vmulq_f32(_sqsum, _a), _a), _a);
}
#endif // __ARM_NEON

float a;
if (elempack == 1)
{
#if __aarch64__
sqsum += vaddvq_f32(_sqsum);
#else
float32x2_t _s2 = vadd_f32(vget_low_f32(_sqsum), vget_high_f32(_sqsum));
_s2 = vpadd_f32(_s2, _s2);
sqsum += vget_lane_f32(_s2, 0);
#endif

a = 1.f / sqrtf(sqsum / elemcount + eps);
#if __ARM_NEON
_a = vdupq_n_f32(a);
#endif // __ARM_NEON
}

if (gamma_ptr)
{
int i = 0;
#if __ARM_NEON
if (elempack == 4)
{
for (; i + 3 < size; i += 4)
{
float32x4_t _p = bfloat2float(vld1_u16(ptr));
float32x4_t _gamma = vdupq_n_f32(gamma_ptr[0]);
_p = vmulq_f32(_p, _a);
_p = vmulq_f32(_p, _gamma);
vst1_u16(ptr, float2bfloat(_p));
ptr += 4;
gamma_ptr += 1;
}
}

if (elempack == 1)
{
for (; i + 3 < size; i += 4)
{
float32x4_t _p = bfloat2float(vld1_u16(ptr));
float32x4_t _gamma = vld1q_f32(gamma_ptr);
_p = vmulq_f32(_p, _a);
_p = vmulq_f32(_p, _gamma);
vst1_u16(ptr, float2bfloat(_p));
ptr += 4;
gamma_ptr += 4;
}
}
#endif // __ARM_NEON
for (; i < size; i++)
{
float v = bfloat16_to_float32(ptr[0]);
ptr[0] = float32_to_bfloat16((v * a) * gamma_ptr[0]);
ptr++;
gamma_ptr++;
}
}
else
{
int i = 0;
#if __ARM_NEON
for (; i + 3 < size; i += 4)
{
float32x4_t _p = bfloat2float(vld1_u16(ptr));
_p = vmulq_f32(_p, _a);
vst1_u16(ptr, float2bfloat(_p));
ptr += 4;
}
#endif // __ARM_NEON
for (; i < size; i++)
{
float v = bfloat16_to_float32(ptr[0]);
ptr[0] = float32_to_bfloat16(v * a);
ptr++;
}
}
}

int RMSNorm_arm::forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const
{
const int dims = bottom_top_blob.dims;
const int w = bottom_top_blob.w;
const int h = bottom_top_blob.h;
const int channels = bottom_top_blob.c;
const int elempack = bottom_top_blob.elempack;

if (dims == 1)
{
// assert affine_size == w

unsigned short* ptr = bottom_top_blob;
rmsnorm_bf16s(ptr, gamma_data, eps, w * elempack, 1);
}

if (dims == 2)
{
// assert affine_size == w

#pragma omp parallel for num_threads(opt.num_threads)
for (int i = 0; i < h; i++)
{
unsigned short* ptr = bottom_top_blob.row<unsigned short>(i);
rmsnorm_bf16s(ptr, gamma_data, eps, w, elempack);
}
}

if (dims == 3)
{
if (affine_size == w)
{
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = 0; q < channels; q++)
{
for (int i = 0; i < h; i++)
{
unsigned short* ptr = bottom_top_blob.channel(q).row<unsigned short>(i);
rmsnorm_bf16s(ptr, gamma_data, eps, w, elempack);
}
}
}
else // if (affine_size == w * h)
{
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = 0; q < channels; q++)
{
unsigned short* ptr = bottom_top_blob.channel(q);
rmsnorm_bf16s(ptr, gamma_data, eps, w * h, elempack);
}
}
}

return 0;
}
#endif // NCNN_BF16
Expand Down
Loading

0 comments on commit 464da66

Please sign in to comment.