diff --git a/src/layer/arm/neon_mathfun_fp16s.h b/src/layer/arm/neon_mathfun_fp16s.h index 2ad7e214470..4738052180c 100644 --- a/src/layer/arm/neon_mathfun_fp16s.h +++ b/src/layer/arm/neon_mathfun_fp16s.h @@ -197,13 +197,7 @@ static inline float16x4_t exp_ps_f16(float16x4_t x) #endif /* perform a floorf */ - tmp = vcvt_f16_s16(vcvt_s16_f16(fx)); - - /* if greater, substract 1 */ - uint16x4_t mask = vcgt_f16(tmp, fx); - mask = vand_u16(mask, (uint16x4_t)(one)); - - fx = vsub_f16(tmp, (float16x4_t)(mask)); + fx = vrndm_f16(fx); #if defined(_MSC_VER) && !defined(__clang__) tmp = vmul_f16(fx, vcvt_f16_f32(vdupq_n_f32(c_cephes_exp_C1))); @@ -255,13 +249,7 @@ static inline float16x8_t exp_ps_f16(float16x8_t x) #endif /* perform a floorf */ - tmp = vcvtq_f16_s16(vcvtq_s16_f16(fx)); - - /* if greater, substract 1 */ - uint16x8_t mask = vcgtq_f16(tmp, fx); - mask = vandq_u16(mask, vreinterpretq_u16_f16(one)); - - fx = vsubq_f16(tmp, vreinterpretq_f16_u16(mask)); + fx = vrndmq_f16(fx); #if defined(_MSC_VER) && !defined(__clang__) float16x4_t _c_cephes_exp_C1 = vcvt_f16_f32(vdupq_n_f32(c_cephes_exp_C1));