From 20cfa45faefbf56f62c8b1aa95dfd0b4f52e5641 Mon Sep 17 00:00:00 2001 From: Pablo Marquez Tello Date: Mon, 20 Mar 2023 16:29:21 +0000 Subject: Round to nearest with ties to away from zero in Relu * This patch adds support for rounding modes in vmlaq_qasymm8_signed which is used to compute Relu for quantized types * Partially resolves MLCE-1018 Change-Id: I2a267b84745430e1ffe92b8bc79828a39332db18 Signed-off-by: Pablo Marquez Tello Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9354 Comments-Addressed: Arm Jenkins Reviewed-by: Gunes Bayir Reviewed-by: Viet-Hoa Do Tested-by: Arm Jenkins Benchmark: Arm Jenkins --- src/core/NEON/NEAsymm.h | 8 ++++++- src/core/NEON/NEAsymm.inl | 60 +++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 65 insertions(+), 3 deletions(-) (limited to 'src/core/NEON') diff --git a/src/core/NEON/NEAsymm.h b/src/core/NEON/NEAsymm.h index 9b92a865d0..5b8d2be04b 100644 --- a/src/core/NEON/NEAsymm.h +++ b/src/core/NEON/NEAsymm.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2020 Arm Limited. + * Copyright (c) 2017-2020, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -715,6 +715,12 @@ inline uint16x8x2_t vquantize_qasymm16(const float32x4x4_t &qv, const UniformQua const uint16x8_t pb = vcombine_u16(vqmovun_s32(rf.val[2]), vqmovun_s32(rf.val[3])); return { pa, pb }; } + +template +qasymm8x16_signed_t vmlaq_qasymm8(qasymm8x16_signed_t vd, float32x4_t vs, float32x4_t vo); + +template +qasymm8x16_signed_t vmlaq_qasymm8_signed(qasymm8x16_signed_t vd, float32x4_t vs, float32x4_t vo); } // namespace arm_compute #include "src/core/NEON/NEAsymm.inl" #endif // ARM_COMPUTE_NEASYMM_H diff --git a/src/core/NEON/NEAsymm.inl b/src/core/NEON/NEAsymm.inl index 6ee1a336b8..ca2aea1e18 100644 --- a/src/core/NEON/NEAsymm.inl +++ b/src/core/NEON/NEAsymm.inl @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2020 Arm Limited. + * Copyright (c) 2017-2020, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -21,8 +21,12 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ + +#include "arm_compute/core/Rounding.h" + namespace arm_compute { +template inline qasymm8x16_t vmlaq_qasymm8(qasymm8x16_t vd, float32x4_t vs, float32x4_t vo) { // Convert uint8 vectors to uint16 vectors @@ -46,16 +50,43 @@ inline qasymm8x16_t vmlaq_qasymm8(qasymm8x16_t vd, float32x4_t vs, float32x4_t v C_f32x4 = vmlaq_f32(vo, C_f32x4, vs); D_f32x4 = vmlaq_f32(vo, D_f32x4, vs); // Convert float32 vectors to uint32 vectors +#if __aarch64__ + if(round_policy == RoundingPolicy::TO_NEAREST_EVEN) + { + A_u32x4 = vcvtnq_u32_f32(A_f32x4); + B_u32x4 = vcvtnq_u32_f32(B_f32x4); + C_u32x4 = vcvtnq_u32_f32(C_f32x4); + D_u32x4 = vcvtnq_u32_f32(D_f32x4); + } + else if(round_policy == RoundingPolicy::TO_NEAREST_UP) + { + A_u32x4 = vcvtaq_u32_f32(A_f32x4); + B_u32x4 = vcvtaq_u32_f32(B_f32x4); + C_u32x4 = vcvtaq_u32_f32(C_f32x4); + D_u32x4 = vcvtaq_u32_f32(D_f32x4); + } + else + { + A_u32x4 = vcvtq_u32_f32(A_f32x4); + B_u32x4 = vcvtq_u32_f32(B_f32x4); + C_u32x4 = vcvtq_u32_f32(C_f32x4); + D_u32x4 = vcvtq_u32_f32(D_f32x4); + } +#else // #if __aarch64__ + // rounding mode only supported in aarch64 A_u32x4 = vcvtq_u32_f32(A_f32x4); B_u32x4 = vcvtq_u32_f32(B_f32x4); C_u32x4 = vcvtq_u32_f32(C_f32x4); D_u32x4 = vcvtq_u32_f32(D_f32x4); +#endif // #if __aarch64__ // Convert uint32 vectors to uint16 vectors (with saturation) vd_low_u16x8 = vcombine_u16(vqmovn_u32(A_u32x4), vqmovn_u32(B_u32x4)); vd_high_u16x8 = vcombine_u16(vqmovn_u32(C_u32x4), vqmovn_u32(D_u32x4)); // convert uint16 vectors to uint8 vectors (with saturation) return vcombine_u8(vqmovn_u16(vd_low_u16x8), vqmovn_u16(vd_high_u16x8)); } + +template inline qasymm8x16_signed_t vmlaq_qasymm8_signed(qasymm8x16_signed_t vd, float32x4_t vs, float32x4_t vo) { // Convert uint8 vectors to int16 vectors @@ -78,11 +109,36 @@ inline qasymm8x16_signed_t vmlaq_qasymm8_signed(qasymm8x16_signed_t vd, float32x B_f32x4 = vmlaq_f32(vo, B_f32x4, vs); C_f32x4 = vmlaq_f32(vo, C_f32x4, vs); D_f32x4 = vmlaq_f32(vo, D_f32x4, vs); - // Convert float32 vectors to int32 vectors +#if __aarch64__ + if(round_policy == RoundingPolicy::TO_NEAREST_EVEN) + { + A_s32x4 = vcvtnq_s32_f32(A_f32x4); + B_s32x4 = vcvtnq_s32_f32(B_f32x4); + C_s32x4 = vcvtnq_s32_f32(C_f32x4); + D_s32x4 = vcvtnq_s32_f32(D_f32x4); + } + else if(round_policy == RoundingPolicy::TO_NEAREST_UP) + { + A_s32x4 = vcvtaq_s32_f32(A_f32x4); + B_s32x4 = vcvtaq_s32_f32(B_f32x4); + C_s32x4 = vcvtaq_s32_f32(C_f32x4); + D_s32x4 = vcvtaq_s32_f32(D_f32x4); + } + else + { + A_s32x4 = vcvtq_s32_f32(A_f32x4); + B_s32x4 = vcvtq_s32_f32(B_f32x4); + C_s32x4 = vcvtq_s32_f32(C_f32x4); + D_s32x4 = vcvtq_s32_f32(D_f32x4); + } +#else // #if __aarch64__ + // rounding mode only supported in aarch64 A_s32x4 = vcvtq_s32_f32(A_f32x4); B_s32x4 = vcvtq_s32_f32(B_f32x4); C_s32x4 = vcvtq_s32_f32(C_f32x4); D_s32x4 = vcvtq_s32_f32(D_f32x4); +#endif // #if __aarch64__ + // Convert int32 vectors to int16 vectors (with saturation) vd_low_s16x8 = vcombine_s16(vqmovn_s32(A_s32x4), vqmovn_s32(B_s32x4)); vd_high_s16x8 = vcombine_s16(vqmovn_s32(C_s32x4), vqmovn_s32(D_s32x4)); -- cgit v1.2.1