diff options
Diffstat (limited to 'arm_compute/core/NEON/kernels/detail/NEActivationFunctionDetail.h')
-rw-r--r-- | arm_compute/core/NEON/kernels/detail/NEActivationFunctionDetail.h | 91 |
1 files changed, 91 insertions, 0 deletions
diff --git a/arm_compute/core/NEON/kernels/detail/NEActivationFunctionDetail.h b/arm_compute/core/NEON/kernels/detail/NEActivationFunctionDetail.h index 71d5a9eef7..9344235d09 100644 --- a/arm_compute/core/NEON/kernels/detail/NEActivationFunctionDetail.h +++ b/arm_compute/core/NEON/kernels/detail/NEActivationFunctionDetail.h @@ -54,6 +54,97 @@ struct dummy ARM_COMPUTE_UNUSED(vval); } }; +/** Linear activation object */ +template <typename T, int S> +struct linear +{ + /** NEON vector type. */ + using ExactType = typename wrapper::traits::neon_vector<T, S>::type; + /** NEON vector tag type. */ + using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type; + + /** Construct a Linear activation object. + * + * @param[in] act_info Activation layer information. + */ + explicit linear(ActivationLayerInfo act_info) + : valpha(wrapper::vdup_n(static_cast<T>(act_info.a()), ExactTagType{})), + vbeta(wrapper::vdup_n(static_cast<T>(act_info.b()), ExactTagType{})) + { + } + + /** Run activation function. + * + * @param[in] vval Vector of values. + */ + void operator()(ExactType &vval) + { + vval = wrapper::vmla(vval, valpha, vbeta); + } + + /** Vector of alphas. */ + const ExactType valpha; + /** Vector of betas. */ + const ExactType vbeta; +}; +/** Square activation object */ +template <typename T, int S> +struct square +{ + /** NEON vector type. */ + using ExactType = typename wrapper::traits::neon_vector<T, S>::type; + /** NEON vector tag type. */ + using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type; + + /** Construct a Square activation object. + * + * @param[in] act_info Activation layer information. + */ + explicit square(ActivationLayerInfo act_info) + { + ARM_COMPUTE_UNUSED(act_info); + } + + /** Run activation function. + * + * @param[in] vval Vector of values. + */ + void operator()(ExactType &vval) + { + vval = wrapper::vmul(vval, vval); + } +}; +/** Logistic activation object */ +template <typename T, int S> +struct logistic +{ + /** NEON vector type. */ + using ExactType = typename wrapper::traits::neon_vector<T, S>::type; + /** NEON vector tag type. */ + using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type; + + /** Construct a Logistic activation object. + * + * @param[in] act_info Activation layer information. + */ + explicit logistic(ActivationLayerInfo act_info) + : vone(wrapper::vdup_n(static_cast<T>(1.f), ExactTagType{})) + { + ARM_COMPUTE_UNUSED(act_info); + } + + /** Run activation function. + * + * @param[in] vval Vector of values. + */ + void operator()(ExactType &vval) + { + vval = wrapper::vinv(wrapper::vadd(vone, wrapper::vexpq(wrapper::vnegq(vval)))); + } + + /** Vector of ones. */ + const ExactType vone; +}; /** RELU activation object */ template <typename T, int S> struct relu |