diff options
-rw-r--r-- | src/core/cpu/kernels/activation/NEON/fp16.cpp | 41 | ||||
-rw-r--r-- | src/core/cpu/kernels/activation/NEON/fp32.cpp | 15 |
2 files changed, 31 insertions, 25 deletions
diff --git a/src/core/cpu/kernels/activation/NEON/fp16.cpp b/src/core/cpu/kernels/activation/NEON/fp16.cpp index 27ae2830cc..bd459e9e77 100644 --- a/src/core/cpu/kernels/activation/NEON/fp16.cpp +++ b/src/core/cpu/kernels/activation/NEON/fp16.cpp @@ -46,7 +46,7 @@ inline float16x8_t mask_float_vector(const float16x8_t &in, const uint16x8_t &ma auto int_in = vreinterpretq_u16_f16(in); return vreinterpretq_f16_u16(wrapper::vand(int_in, mask)); } -#endif /* __arch64__ */ +#endif /* __aarch64__ */ } // namespace void fp16_neon_activation(const ITensor *src, ITensor *dst, const ActivationLayerInfo &act_info, const Window &window) @@ -69,20 +69,23 @@ void fp16_neon_activation(const ITensor *src, ITensor *dst, const ActivationLaye // to prevent NAN values caused by zeros in inputs to SQRT. // In case of aarh64, we call vsqrt directly, so we don't use delta. #ifndef __aarch64__ - const auto delta = wrapper::vdup_n(static_cast<float16_t>((1e-7), ExactTagType {}); -#endif /* __aarch64 */ - - const auto const_1 = wrapper::vdup_n(static_cast<float16_t>(1.f), ExactTagType {}); - const auto const_0 = wrapper::vdup_n(static_cast<float16_t>(0.f), ExactTagType{}); - const auto const_6 = wrapper::vdup_n(static_cast<float16_t>(6.f), ExactTagType{}); - const auto const_3 = wrapper::vdup_n(static_cast<float16_t>(3.f), ExactTagType{}); - const auto const_inv_6 = wrapper::vdup_n(static_cast<float16_t>(0.166666667f), ExactTagType{}); - - const auto va = wrapper::vdup_n(static_cast<float16_t>(act_info.a()), ExactTagType{}); - const auto vb = wrapper::vdup_n(static_cast<float16_t>(act_info.b()), ExactTagType{}); - const auto a = static_cast<float16_t>(act_info.a()); - const auto b = static_cast<float16_t>(act_info.b()); - execute_window_loop(win_collapsed, [&](const Coordinates &) + const auto delta = wrapper::vdup_n(static_cast<float16_t>((1e-7), ExactTagType {})); +#endif /* __aarch64__ */ + + const auto const_1 = wrapper::vdup_n(static_cast<float16_t>(1.f), ExactTagType{}); + const auto const_0 = wrapper::vdup_n(static_cast<float16_t>(0.f), ExactTagType{}); + const auto const_6 = wrapper::vdup_n(static_cast<float16_t>(6.f), ExactTagType{}); + const auto const_3 = wrapper::vdup_n(static_cast<float16_t>(3.f), ExactTagType{}); + const auto const_inv_6 = wrapper::vdup_n(static_cast<float16_t>(0.166666667f), ExactTagType{}); + + constexpr float soft_relu_thresh = 12.f; + const auto vsoft_relu_thresh = wrapper::vdup_n(static_cast<float16_t>(soft_relu_thresh), ExactTagType{}); + + const auto va = wrapper::vdup_n(static_cast<float16_t>(act_info.a()), ExactTagType{}); + const auto vb = wrapper::vdup_n(static_cast<float16_t>(act_info.b()), ExactTagType{}); + const auto a = static_cast<float16_t>(act_info.a()); + const auto b = static_cast<float16_t>(act_info.b()); + execute_window_loop(win_collapsed, [&](const Coordinates &) { const auto input_ptr = reinterpret_cast<const float16_t *>(input.ptr()); const auto output_ptr = reinterpret_cast<float16_t *>(output.ptr()); @@ -118,7 +121,7 @@ void fp16_neon_activation(const ITensor *src, ITensor *dst, const ActivationLaye tmp = wrapper::vbsl(wrapper::vcgt(vin, const_0), vin, wrapper::vmul(va, vin)); break; case ActivationLayerInfo::ActivationFunction::SOFT_RELU: - tmp = wrapper::vlog(wrapper::vadd(const_1, wrapper::vexpq(vin))); + tmp = wrapper::vbsl(wrapper::vcgt(vin, vsoft_relu_thresh), vin, wrapper::vlog(wrapper::vadd(const_1, wrapper::vexpq(vin)))); break; case ActivationLayerInfo::ActivationFunction::ELU: tmp = wrapper::vbsl(wrapper::vcge(vin, const_0), vin, wrapper::vmul(va, wrapper::vsub(wrapper::vexpq(vin), const_1))); @@ -126,13 +129,13 @@ void fp16_neon_activation(const ITensor *src, ITensor *dst, const ActivationLaye case ActivationLayerInfo::ActivationFunction::SQRT: #ifdef __aarch64__ tmp = wrapper::vsqrt(vin); -#else /* aarch64 */ +#else /* __aarch64__ */ { const auto bitmask = wrapper::vceq(vin, wrapper::vdup_n(0, ExactTagType{})); tmp = wrapper::vinv(wrapper::vinvsqrt(wrapper::vadd(vin, mask_float_vector(delta, bitmask)))); tmp = mask_float_vector(tmp, wrapper::vnot(bitmask)); } -#endif /* aarch64 */ +#endif /* __aarch64__ */ break; case ActivationLayerInfo::ActivationFunction::SQUARE: tmp = wrapper::vmul(vin, vin); @@ -181,7 +184,7 @@ void fp16_neon_activation(const ITensor *src, ITensor *dst, const ActivationLaye tmp = (in > 0) ? in : a * in; break; case ActivationLayerInfo::ActivationFunction::SOFT_RELU: - tmp = std::log(static_cast<float16_t>(1) + std::exp(in)); + tmp = (in > soft_relu_thresh) ? in : std::log(static_cast<float16_t>(1) + std::exp(in)); break; case ActivationLayerInfo::ActivationFunction::ELU: tmp = (in >= 0) ? in : a * (std::exp(in) - 1); diff --git a/src/core/cpu/kernels/activation/NEON/fp32.cpp b/src/core/cpu/kernels/activation/NEON/fp32.cpp index 0687646be7..c76035b5d2 100644 --- a/src/core/cpu/kernels/activation/NEON/fp32.cpp +++ b/src/core/cpu/kernels/activation/NEON/fp32.cpp @@ -44,7 +44,7 @@ inline float32x4_t mask_float_vector(const float32x4_t &in, const uint32x4_t &ma auto int_in = vreinterpretq_u32_f32(in); return vreinterpretq_f32_u32(wrapper::vand(int_in, mask)); } -#endif /* __arch64__ */ +#endif /* __aarch64__ */ } // namespace void fp32_neon_activation(const ITensor *src, ITensor *dst, const ActivationLayerInfo &act_info, const Window &window) @@ -68,13 +68,16 @@ void fp32_neon_activation(const ITensor *src, ITensor *dst, const ActivationLaye // In case of aarh64, we call vsqrt directly, so we don't use delta. #ifndef __aarch64__ const auto delta = wrapper::vdup_n(static_cast<float>(1e-24), ExactTagType {}); -#endif /* __aarch64 */ +#endif /* __aarch64__ */ const auto const_1 = wrapper::vdup_n(static_cast<float>(1.f), ExactTagType {}); const auto const_0 = wrapper::vdup_n(static_cast<float>(0.f), ExactTagType{}); const auto const_6 = wrapper::vdup_n(static_cast<float>(6.f), ExactTagType{}); const auto const_3 = wrapper::vdup_n(static_cast<float>(3.f), ExactTagType{}); const auto const_inv_6 = wrapper::vdup_n(static_cast<float>(0.166666667f), ExactTagType{}); + constexpr float soft_relu_thresh = 12.f; + const auto vsoft_relu_thresh = wrapper::vdup_n(static_cast<float>(soft_relu_thresh), ExactTagType{}); + const auto va = wrapper::vdup_n(static_cast<float>(act_info.a()), ExactTagType{}); const auto vb = wrapper::vdup_n(static_cast<float>(act_info.b()), ExactTagType{}); const auto a = static_cast<float>(act_info.a()); @@ -115,7 +118,7 @@ void fp32_neon_activation(const ITensor *src, ITensor *dst, const ActivationLaye tmp = wrapper::vbsl(wrapper::vcgt(vin, const_0), vin, wrapper::vmul(va, vin)); break; case ActivationLayerInfo::ActivationFunction::SOFT_RELU: - tmp = wrapper::vlog(wrapper::vadd(const_1, wrapper::vexpq(vin))); + tmp = wrapper::vbsl(wrapper::vcgt(vin, vsoft_relu_thresh), vin, wrapper::vlog(wrapper::vadd(const_1, wrapper::vexpq(vin)))); break; case ActivationLayerInfo::ActivationFunction::ELU: tmp = wrapper::vbsl(wrapper::vcge(vin, const_0), vin, wrapper::vmul(va, wrapper::vsub(wrapper::vexpq(vin), const_1))); @@ -123,13 +126,13 @@ void fp32_neon_activation(const ITensor *src, ITensor *dst, const ActivationLaye case ActivationLayerInfo::ActivationFunction::SQRT: #ifdef __aarch64__ tmp = wrapper::vsqrt(vin); -#else /* aarch64 */ +#else /* __aarch64__ */ { const auto bitmask = wrapper::vceq(vin, wrapper::vdup_n(0.f, ExactTagType{})); tmp = wrapper::vinv(wrapper::vinvsqrt(wrapper::vadd(vin, mask_float_vector(delta, bitmask)))); tmp = mask_float_vector(tmp, wrapper::vnot(bitmask)); } -#endif /* aarch64 */ +#endif /* __aarch64__ */ break; case ActivationLayerInfo::ActivationFunction::SQUARE: tmp = wrapper::vmul(vin, vin); @@ -178,7 +181,7 @@ void fp32_neon_activation(const ITensor *src, ITensor *dst, const ActivationLaye tmp = (in > 0) ? in : a * in; break; case ActivationLayerInfo::ActivationFunction::SOFT_RELU: - tmp = std::log(static_cast<float>(1) + std::exp(in)); + tmp = (in > soft_relu_thresh) ? in : std::log(static_cast<float>(1) + std::exp(in)); break; case ActivationLayerInfo::ActivationFunction::ELU: tmp = (in >= 0) ? in : a * (std::exp(in) - 1); |