diff options
Diffstat (limited to 'src/cpu/kernels')
-rw-r--r-- | src/cpu/kernels/activation/generic/neon/impl.h | 7 |
1 files changed, 5 insertions, 2 deletions
diff --git a/src/cpu/kernels/activation/generic/neon/impl.h b/src/cpu/kernels/activation/generic/neon/impl.h index 35abcb5408..4d4aa8d212 100644 --- a/src/cpu/kernels/activation/generic/neon/impl.h +++ b/src/cpu/kernels/activation/generic/neon/impl.h @@ -72,14 +72,15 @@ void fp_neon_activation_impl(const ITensor *src, ITensor *dst, const ActivationL // In case of aarh64, we call vsqrt directly, so we don't use delta. #ifndef __aarch64__ const auto delta = wrapper::vdup_n(static_cast<T>(P.delta), ExactTagType {}); +#else /* #ifndef __aarch64__ */ + const auto const_inv_2 = wrapper::vdup_n(static_cast<T>(0.5f), ExactTagType {}); + const auto const_inv_sqrt_2 = wrapper::vdup_n(static_cast<T>(0.70710678118f), ExactTagType{}); #endif /* __aarch64__ */ const auto const_1 = wrapper::vdup_n(static_cast<T>(1.f), ExactTagType {}); const auto const_0 = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{}); const auto const_6 = wrapper::vdup_n(static_cast<T>(6.f), ExactTagType{}); const auto const_3 = wrapper::vdup_n(static_cast<T>(3.f), ExactTagType{}); - const auto const_inv_2 = wrapper::vdup_n(static_cast<T>(0.5f), ExactTagType{}); const auto const_inv_6 = wrapper::vdup_n(static_cast<T>(0.166666667f), ExactTagType{}); - const auto const_inv_sqrt_2 = wrapper::vdup_n(static_cast<T>(0.70710678118f), ExactTagType{}); constexpr float soft_relu_thresh = 12.f; const auto vsoft_relu_thresh = wrapper::vdup_n(static_cast<T>(soft_relu_thresh), ExactTagType{}); const auto va = wrapper::vdup_n(static_cast<T>(act_info.a()), ExactTagType{}); @@ -148,9 +149,11 @@ void fp_neon_activation_impl(const ITensor *src, ITensor *dst, const ActivationL case ActivationLayerInfo::ActivationFunction::HARD_SWISH: tmp = wrapper::vmul(vin, wrapper::vmul(const_inv_6, wrapper::vmin(const_6, wrapper::vmax(const_0, wrapper::vadd(vin, const_3))))); break; +#ifdef __aarch64__ case ActivationLayerInfo::ActivationFunction::GELU: tmp = wrapper::vmul(vin, wrapper::vmul(const_inv_2, wrapper::vadd(const_1, wrapper::verf(wrapper::vmul(vin, const_inv_sqrt_2))))); break; +#endif /* __aarch64__ */ default: ARM_COMPUTE_ERROR("Unsupported activation function"); } |