aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/cpu/kernels/activation/generic/neon/impl.h7
-rw-r--r--tests/datasets/ActivationFunctionsDataset.h2
2 files changed, 7 insertions, 2 deletions
diff --git a/src/cpu/kernels/activation/generic/neon/impl.h b/src/cpu/kernels/activation/generic/neon/impl.h
index 35abcb5408..4d4aa8d212 100644
--- a/src/cpu/kernels/activation/generic/neon/impl.h
+++ b/src/cpu/kernels/activation/generic/neon/impl.h
@@ -72,14 +72,15 @@ void fp_neon_activation_impl(const ITensor *src, ITensor *dst, const ActivationL
// In case of aarh64, we call vsqrt directly, so we don't use delta.
#ifndef __aarch64__
const auto delta = wrapper::vdup_n(static_cast<T>(P.delta), ExactTagType {});
+#else /* #ifndef __aarch64__ */
+ const auto const_inv_2 = wrapper::vdup_n(static_cast<T>(0.5f), ExactTagType {});
+ const auto const_inv_sqrt_2 = wrapper::vdup_n(static_cast<T>(0.70710678118f), ExactTagType{});
#endif /* __aarch64__ */
const auto const_1 = wrapper::vdup_n(static_cast<T>(1.f), ExactTagType {});
const auto const_0 = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
const auto const_6 = wrapper::vdup_n(static_cast<T>(6.f), ExactTagType{});
const auto const_3 = wrapper::vdup_n(static_cast<T>(3.f), ExactTagType{});
- const auto const_inv_2 = wrapper::vdup_n(static_cast<T>(0.5f), ExactTagType{});
const auto const_inv_6 = wrapper::vdup_n(static_cast<T>(0.166666667f), ExactTagType{});
- const auto const_inv_sqrt_2 = wrapper::vdup_n(static_cast<T>(0.70710678118f), ExactTagType{});
constexpr float soft_relu_thresh = 12.f;
const auto vsoft_relu_thresh = wrapper::vdup_n(static_cast<T>(soft_relu_thresh), ExactTagType{});
const auto va = wrapper::vdup_n(static_cast<T>(act_info.a()), ExactTagType{});
@@ -148,9 +149,11 @@ void fp_neon_activation_impl(const ITensor *src, ITensor *dst, const ActivationL
case ActivationLayerInfo::ActivationFunction::HARD_SWISH:
tmp = wrapper::vmul(vin, wrapper::vmul(const_inv_6, wrapper::vmin(const_6, wrapper::vmax(const_0, wrapper::vadd(vin, const_3)))));
break;
+#ifdef __aarch64__
case ActivationLayerInfo::ActivationFunction::GELU:
tmp = wrapper::vmul(vin, wrapper::vmul(const_inv_2, wrapper::vadd(const_1, wrapper::verf(wrapper::vmul(vin, const_inv_sqrt_2)))));
break;
+#endif /* __aarch64__ */
default:
ARM_COMPUTE_ERROR("Unsupported activation function");
}
diff --git a/tests/datasets/ActivationFunctionsDataset.h b/tests/datasets/ActivationFunctionsDataset.h
index 85febd4df4..9b0d775376 100644
--- a/tests/datasets/ActivationFunctionsDataset.h
+++ b/tests/datasets/ActivationFunctionsDataset.h
@@ -54,7 +54,9 @@ public:
ActivationLayerInfo::ActivationFunction::SQUARE,
ActivationLayerInfo::ActivationFunction::TANH,
ActivationLayerInfo::ActivationFunction::IDENTITY,
+#ifdef __aarch64__
ActivationLayerInfo::ActivationFunction::GELU,
+#endif /* __aarch64__ */
})
{
}