From ccc65d44a53eaa61c718cbc4d826c811e2ccebda Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Tue, 27 Jun 2017 17:39:11 +0100 Subject: COMPMID-427: Port NEActivationLayer in 16bit fixed point. Change-Id: Iebd61807f7b597c6bd990673bc7655c68ee16f4b Reviewed-on: http://mpd-gerrit.cambridge.arm.com/79085 Reviewed-by: Moritz Pflanzer Tested-by: Kaizen Reviewed-by: Gian Marco Iodice --- arm_compute/core/NEON/NEFixedPoint.h | 32 +++++++----- arm_compute/core/NEON/NEFixedPoint.inl | 61 ++++++++++++---------- .../core/NEON/kernels/NEActivationLayerKernel.h | 8 ++- 3 files changed, 60 insertions(+), 41 deletions(-) (limited to 'arm_compute/core/NEON') diff --git a/arm_compute/core/NEON/NEFixedPoint.h b/arm_compute/core/NEON/NEFixedPoint.h index e3eb5d4638..e30509cd0a 100644 --- a/arm_compute/core/NEON/NEFixedPoint.h +++ b/arm_compute/core/NEON/NEFixedPoint.h @@ -176,6 +176,14 @@ void vst1q_qs8(qint8_t *addr, qint8x16_t b); */ void vst1q_qs16(qint16_t *addr, qint16x8_t b); +/** Store two 16 bit fixed point vector to memory (8x2 elements) +* +* @param[in] addr Memory address where the 16 bit fixed point vectors should be stored +* @param[in] b 16 bit fixed point vectors to store +* +*/ +void vst2q_qs16(qint16_t *addr, qint16x8x2_t b); + /** 16 bit fixed point vector saturating narrow (8 elements) * * @param[in] a 16 bit fixed point vector to convert @@ -1122,7 +1130,7 @@ qint16x8_t vqinvsqrtq_qs16(qint16x8_t a, int fixed_point_position); * * @return The calculated Hyperbolic Tangent. */ -qint8x8_t vtanh_qs8(qint8x8_t a, int fixed_point_position); +qint8x8_t vqtanh_qs8(qint8x8_t a, int fixed_point_position); /** Calculate hyperbolic tangent for fixed point 16 bit (4 elements) * @@ -1131,7 +1139,7 @@ qint8x8_t vtanh_qs8(qint8x8_t a, int fixed_point_position); * * @return The calculated Hyperbolic Tangent. */ -qint16x4_t vtanh_qs16(qint16x4_t a, int fixed_point_position); +qint16x4_t vqtanh_qs16(qint16x4_t a, int fixed_point_position); /** Calculate hyperbolic tangent for fixed point 8bit (16 elements) * @@ -1140,7 +1148,16 @@ qint16x4_t vtanh_qs16(qint16x4_t a, int fixed_point_position); * * @return The calculated Hyperbolic Tangent. */ -qint8x16_t vtanhq_qs8(qint8x16_t a, int fixed_point_position); +qint8x16_t vqtanhq_qs8(qint8x16_t a, int fixed_point_position); + +/** Calculate hyperbolic tangent for fixed point 16bit (8 elements) + * + * @param[in] a 16 bit fixed point input vector + * @param[in] fixed_point_position Fixed point position that expresses the number of bits for the fractional part of the number + * + * @return The calculated Hyperbolic Tangent. + */ +qint16x8_t vqtanhq_qs16(qint16x8_t a, int fixed_point_position); /** Calculate saturating n power for fixed point 8bit (16 elements). * @@ -1162,15 +1179,6 @@ qint8x8_t vqpowq_qs8(qint8x8_t a, qint8x16_t b, int fixed_point_position); * @return The lane-by-lane maximum -> float32x4x2 */ float32x4x2_t vmax2q_f32(float32x4x2_t a, float32x4x2_t b); - -/** Calculate hyperbolic tangent for fixed point 8bit (8 elements) - * - * @param[in] a 16 bit fixed point input vector - * @param[in] fixed_point_position Fixed point position that expresses the number of bits for the fractional part of the number - * - * @return The calculated Hyperbolic Tangent. - */ -qint16x8_t vtanhq_qs16(qint16x8_t a, int fixed_point_position); } #include "arm_compute/core/NEON/NEFixedPoint.inl" #endif /* __ARM_COMPUTE_NEFIXEDPOINT_H__ */ diff --git a/arm_compute/core/NEON/NEFixedPoint.inl b/arm_compute/core/NEON/NEFixedPoint.inl index 92af82cf71..b241dd5069 100644 --- a/arm_compute/core/NEON/NEFixedPoint.inl +++ b/arm_compute/core/NEON/NEFixedPoint.inl @@ -200,6 +200,11 @@ inline void vst1q_qs16(qint16_t *addr, qint16x8_t b) vst1q_s16(addr, b); } +inline void vst2q_qs16(qint16_t *addr, qint16x8x2_t b) +{ + vst2q_s16(addr, b); +} + inline qint8x8_t vqmovn_qs16(qint16x8_t a) { return vqmovn_s16(a); @@ -1641,15 +1646,15 @@ inline qint8x8_t vqinvsqrt_qs8(qint8x8_t a, int fixed_point_position) const qint8x8_t const_three = vdup_n_s8(3 << fixed_point_position); // Find shift value. Number must be in (0.5, 2) range. - qint8x8_t shift_value = vneg_s8(vqsub_s8(vdup_n_s8(8), vqadd_s8(vclz_s8(a), vdup_n_s8(fixed_point_position)))); + qint8x8_t shift_value = vqneg_s8(vqsub_s8(vdup_n_s8(8), vqadd_s8(vclz_s8(a), vdup_n_s8(fixed_point_position)))); // Add one when the shift value is negative in order to get the correct result when we shift right with 1 qint8x8_t temp = vqsub_s8(vdup_n_s8(8), vqadd_s8(vclz_s8(a), vdup_n_s8(fixed_point_position))); uint8x8_t temp_ltz = vclt_s8(temp, vdup_n_qs8(0)); temp = vbsl_s8(temp_ltz, vqadd_s8(temp, vdup_n_s8(1)), temp); - qint8x8_t shift_value2 = vneg_s8(vshr_n_s8(temp, 1)); + qint8x8_t shift_value2 = vqneg_s8(vshr_n_s8(temp, 1)); - temp = vshl_s8(a, shift_value); + temp = vqshl_s8(a, shift_value); // Initial guess qint8x8_t x = temp; @@ -1660,7 +1665,7 @@ inline qint8x8_t vqinvsqrt_qs8(qint8x8_t a, int fixed_point_position) x = vshr_n_s8(vqmul_qs8(x, vqsub_s8(const_three, vqmul_qs8(temp, vqmul_qs8(x, x, fixed_point_position), fixed_point_position)), fixed_point_position), 1); x = vshr_n_s8(vqmul_qs8(x, vqsub_s8(const_three, vqmul_qs8(temp, vqmul_qs8(x, x, fixed_point_position), fixed_point_position)), fixed_point_position), 1); - return vshl_s8(x, shift_value2); + return vqshl_s8(x, shift_value2); } inline qint16x4_t vqinvsqrt_qs16(qint16x4_t a, int fixed_point_position) @@ -1668,15 +1673,15 @@ inline qint16x4_t vqinvsqrt_qs16(qint16x4_t a, int fixed_point_position) const qint16x4_t const_three = vdup_n_s16(3 << fixed_point_position); // Find shift value. Number must be in (0.5, 2) range. - qint16x4_t shift_value = vneg_s16(vqsub_s16(vdup_n_s16(16), vqadd_s16(vclz_s16(a), vdup_n_s16(fixed_point_position)))); + qint16x4_t shift_value = vqneg_s16(vqsub_s16(vdup_n_s16(16), vqadd_s16(vclz_s16(a), vdup_n_s16(fixed_point_position)))); // Add one when the shift value is negative in order to get the correct result when we shift right with 1 qint16x4_t temp = vqsub_s16(vdup_n_s16(16), vqadd_s16(vclz_s16(a), vdup_n_s16(fixed_point_position))); uint16x4_t temp_ltz = vclt_s16(temp, vdup_n_qs16(0)); temp = vbsl_s16(temp_ltz, vqadd_s16(temp, vdup_n_s16(1)), temp); - qint16x4_t shift_value2 = vneg_s16(vshr_n_s16(temp, 1)); + qint16x4_t shift_value2 = vqneg_s16(vshr_n_s16(temp, 1)); - temp = vshl_s16(a, shift_value); + temp = vqshl_s16(a, shift_value); // Initial guess qint16x4_t x = temp; @@ -1753,15 +1758,15 @@ inline qint8x16_t vqinvsqrtq_qs8(qint8x16_t a, int fixed_point_position) const qint8x16_t const_three = vdupq_n_s8(3 << fixed_point_position); // Find shift value. Number must be in (0.5, 2) range. - qint8x16_t shift_value = vnegq_s8(vqsubq_s8(vdupq_n_s8(8), vqaddq_s8(vclzq_s8(a), vdupq_n_s8(fixed_point_position)))); + qint8x16_t shift_value = vqnegq_s8(vqsubq_s8(vdupq_n_s8(8), vqaddq_s8(vclzq_s8(a), vdupq_n_s8(fixed_point_position)))); // Add one when the shift value is negative in order to get the correct result when we shift right with 1 qint8x16_t temp = vqsubq_s8(vdupq_n_s8(8), vqaddq_s8(vclzq_s8(a), vdupq_n_s8(fixed_point_position))); uint8x16_t temp_ltz = vcltq_s8(temp, vdupq_n_qs8(0)); temp = vbslq_s8(temp_ltz, vqaddq_s8(temp, vdupq_n_s8(1)), temp); - qint8x16_t shift_value2 = vnegq_s8(vshrq_n_s8(temp, 1)); + qint8x16_t shift_value2 = vqnegq_s8(vshrq_n_s8(temp, 1)); - temp = vshlq_s8(a, shift_value); + temp = vqshlq_s8(a, shift_value); // Initial guess qint8x16_t x = temp; @@ -1780,13 +1785,13 @@ inline qint16x8_t vqinvsqrtq_qs16(qint16x8_t a, int fixed_point_position) const qint16x8_t const_three = vdupq_n_s16(3 << fixed_point_position); // Find shift value. Number must be in (0.5, 2) range. - qint16x8_t shift_value = vnegq_s16(vqsubq_s16(vdupq_n_s16(16), vqaddq_s16(vclzq_s16(a), vdupq_n_s16(fixed_point_position)))); + qint16x8_t shift_value = vqnegq_s16(vqsubq_s16(vdupq_n_s16(16), vqaddq_s16(vclzq_s16(a), vdupq_n_s16(fixed_point_position)))); // Add one when the shift value is negative in order to get the correct result when we shift right with 1 qint16x8_t temp = vqsubq_s16(vdupq_n_s16(16), vqaddq_s16(vclzq_s16(a), vdupq_n_s16(fixed_point_position))); uint16x8_t temp_ltz = vcltq_s16(temp, vdupq_n_qs16(0)); temp = vbslq_s16(temp_ltz, vqaddq_s16(temp, vdupq_n_s16(1)), temp); - qint16x8_t shift_value2 = vnegq_s16(vshrq_n_s16(temp, 1)); + qint16x8_t shift_value2 = vqnegq_s16(vshrq_n_s16(temp, 1)); temp = vqshlq_s16(a, shift_value); @@ -1804,7 +1809,7 @@ inline qint16x8_t vqinvsqrtq_qs16(qint16x8_t a, int fixed_point_position) return vqshlq_s16(x, shift_value2); } -inline qint8x8_t vtanh_qs8(qint8x8_t a, int fixed_point_position) +inline qint8x8_t vqtanh_qs8(qint8x8_t a, int fixed_point_position) { const qint8x8_t const_one = vdup_n_s8(1 << fixed_point_position); const qint8x8_t const_two = vdup_n_s8(2 << fixed_point_position); @@ -1817,7 +1822,7 @@ inline qint8x8_t vtanh_qs8(qint8x8_t a, int fixed_point_position) return tanh; } -inline qint16x4_t vtanh_qs16(qint16x4_t a, int fixed_point_position) +inline qint16x4_t vqtanh_qs16(qint16x4_t a, int fixed_point_position) { const qint16x4_t const_one = vdup_n_s16(1 << fixed_point_position); const qint16x4_t const_two = vdup_n_s16(2 << fixed_point_position); @@ -1830,7 +1835,7 @@ inline qint16x4_t vtanh_qs16(qint16x4_t a, int fixed_point_position) return tanh; } -inline qint8x16_t vtanhq_qs8(qint8x16_t a, int fixed_point_position) +inline qint8x16_t vqtanhq_qs8(qint8x16_t a, int fixed_point_position) { const qint8x16_t const_one = vdupq_n_s8(1 << fixed_point_position); const qint8x16_t const_two = vdupq_n_s8(2 << fixed_point_position); @@ -1843,6 +1848,19 @@ inline qint8x16_t vtanhq_qs8(qint8x16_t a, int fixed_point_position) return tanh; } +inline qint16x8_t vqtanhq_qs16(qint16x8_t a, int fixed_point_position) +{ + const qint16x8_t const_one = vdupq_n_s16(1 << fixed_point_position); + const qint16x8_t const_two = vdupq_n_s16(2 << fixed_point_position); + + qint16x8_t exp2x = vqexpq_qs16(vqmulq_qs16(const_two, a, fixed_point_position), fixed_point_position); + qint16x8_t num = vqsubq_qs16(exp2x, const_one); + qint16x8_t den = vqaddq_qs16(exp2x, const_one); + qint16x8_t tanh = vqmulq_qs16(num, vqrecipq_qs16(den, fixed_point_position), fixed_point_position); + + return tanh; +} + inline qint8x16_t vqpowq_qs8(qint8x16_t a, qint8x16_t b, int fixed_point_position) { return vqexpq_qs8(vqmulq_qs8(b, vlogq_qs8(a, fixed_point_position), fixed_point_position), fixed_point_position); @@ -1859,17 +1877,4 @@ inline float32x4x2_t vmax2q_f32(float32x4x2_t a, float32x4x2_t b) }; return res; } - -inline qint16x8_t vtanhq_qs16(qint16x8_t a, int fixed_point_position) -{ - const qint16x8_t const_one = vdupq_n_s16(1 << fixed_point_position); - const qint16x8_t const_two = vdupq_n_s16(2 << fixed_point_position); - - qint16x8_t exp2x = vqexpq_qs16(vqmulq_qs16(const_two, a, fixed_point_position), fixed_point_position); - qint16x8_t num = vqsubq_qs16(exp2x, const_one); - qint16x8_t den = vqaddq_qs16(exp2x, const_one); - qint16x8_t tanh = vqmulq_qs16(num, vqrecipq_qs16(den, fixed_point_position), fixed_point_position); - - return tanh; -} } diff --git a/arm_compute/core/NEON/kernels/NEActivationLayerKernel.h b/arm_compute/core/NEON/kernels/NEActivationLayerKernel.h index 539bca587a..e995f1e5e0 100644 --- a/arm_compute/core/NEON/kernels/NEActivationLayerKernel.h +++ b/arm_compute/core/NEON/kernels/NEActivationLayerKernel.h @@ -50,7 +50,7 @@ public: * @note If the output tensor is a nullptr, the activation function will be performed in-place * * @param[in, out] input Source tensor. In case of @p output tensor = nullptr, this tensor will store the result - * of the activation function. Data types supported: QS8/F32. + * of the activation function. Data types supported: QS8/QS16/F32. * @param[out] output Destination tensor. Data type supported: same as @p input * @param[in] activation_info Activation layer information. */ @@ -78,6 +78,12 @@ private: */ template typename std::enable_if::value, void>::type activation(const Window &window); + /** Function to apply an activation function on a tensor. + * + * @param[in] window Region on which to execute the kernel + */ + template + typename std::enable_if::value, void>::type activation(const Window &window); private: ITensor *_input; -- cgit v1.2.1