diff options
Diffstat (limited to 'src/core/NEON')
-rw-r--r-- | src/core/NEON/NEAsymm.h | 8 | ||||
-rw-r--r-- | src/core/NEON/NEAsymm.inl | 60 |
2 files changed, 65 insertions, 3 deletions
diff --git a/src/core/NEON/NEAsymm.h b/src/core/NEON/NEAsymm.h index 9b92a865d0..5b8d2be04b 100644 --- a/src/core/NEON/NEAsymm.h +++ b/src/core/NEON/NEAsymm.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2020 Arm Limited. + * Copyright (c) 2017-2020, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -715,6 +715,12 @@ inline uint16x8x2_t vquantize_qasymm16(const float32x4x4_t &qv, const UniformQua const uint16x8_t pb = vcombine_u16(vqmovun_s32(rf.val[2]), vqmovun_s32(rf.val[3])); return { pa, pb }; } + +template <RoundingPolicy round_policy = RoundingPolicy::TO_ZERO> +qasymm8x16_signed_t vmlaq_qasymm8(qasymm8x16_signed_t vd, float32x4_t vs, float32x4_t vo); + +template <RoundingPolicy round_policy = RoundingPolicy::TO_ZERO> +qasymm8x16_signed_t vmlaq_qasymm8_signed(qasymm8x16_signed_t vd, float32x4_t vs, float32x4_t vo); } // namespace arm_compute #include "src/core/NEON/NEAsymm.inl" #endif // ARM_COMPUTE_NEASYMM_H diff --git a/src/core/NEON/NEAsymm.inl b/src/core/NEON/NEAsymm.inl index 6ee1a336b8..ca2aea1e18 100644 --- a/src/core/NEON/NEAsymm.inl +++ b/src/core/NEON/NEAsymm.inl @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2020 Arm Limited. + * Copyright (c) 2017-2020, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -21,8 +21,12 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ + +#include "arm_compute/core/Rounding.h" + namespace arm_compute { +template <RoundingPolicy round_policy> inline qasymm8x16_t vmlaq_qasymm8(qasymm8x16_t vd, float32x4_t vs, float32x4_t vo) { // Convert uint8 vectors to uint16 vectors @@ -46,16 +50,43 @@ inline qasymm8x16_t vmlaq_qasymm8(qasymm8x16_t vd, float32x4_t vs, float32x4_t v C_f32x4 = vmlaq_f32(vo, C_f32x4, vs); D_f32x4 = vmlaq_f32(vo, D_f32x4, vs); // Convert float32 vectors to uint32 vectors +#if __aarch64__ + if(round_policy == RoundingPolicy::TO_NEAREST_EVEN) + { + A_u32x4 = vcvtnq_u32_f32(A_f32x4); + B_u32x4 = vcvtnq_u32_f32(B_f32x4); + C_u32x4 = vcvtnq_u32_f32(C_f32x4); + D_u32x4 = vcvtnq_u32_f32(D_f32x4); + } + else if(round_policy == RoundingPolicy::TO_NEAREST_UP) + { + A_u32x4 = vcvtaq_u32_f32(A_f32x4); + B_u32x4 = vcvtaq_u32_f32(B_f32x4); + C_u32x4 = vcvtaq_u32_f32(C_f32x4); + D_u32x4 = vcvtaq_u32_f32(D_f32x4); + } + else + { + A_u32x4 = vcvtq_u32_f32(A_f32x4); + B_u32x4 = vcvtq_u32_f32(B_f32x4); + C_u32x4 = vcvtq_u32_f32(C_f32x4); + D_u32x4 = vcvtq_u32_f32(D_f32x4); + } +#else // #if __aarch64__ + // rounding mode only supported in aarch64 A_u32x4 = vcvtq_u32_f32(A_f32x4); B_u32x4 = vcvtq_u32_f32(B_f32x4); C_u32x4 = vcvtq_u32_f32(C_f32x4); D_u32x4 = vcvtq_u32_f32(D_f32x4); +#endif // #if __aarch64__ // Convert uint32 vectors to uint16 vectors (with saturation) vd_low_u16x8 = vcombine_u16(vqmovn_u32(A_u32x4), vqmovn_u32(B_u32x4)); vd_high_u16x8 = vcombine_u16(vqmovn_u32(C_u32x4), vqmovn_u32(D_u32x4)); // convert uint16 vectors to uint8 vectors (with saturation) return vcombine_u8(vqmovn_u16(vd_low_u16x8), vqmovn_u16(vd_high_u16x8)); } + +template <RoundingPolicy round_policy> inline qasymm8x16_signed_t vmlaq_qasymm8_signed(qasymm8x16_signed_t vd, float32x4_t vs, float32x4_t vo) { // Convert uint8 vectors to int16 vectors @@ -78,11 +109,36 @@ inline qasymm8x16_signed_t vmlaq_qasymm8_signed(qasymm8x16_signed_t vd, float32x B_f32x4 = vmlaq_f32(vo, B_f32x4, vs); C_f32x4 = vmlaq_f32(vo, C_f32x4, vs); D_f32x4 = vmlaq_f32(vo, D_f32x4, vs); - // Convert float32 vectors to int32 vectors +#if __aarch64__ + if(round_policy == RoundingPolicy::TO_NEAREST_EVEN) + { + A_s32x4 = vcvtnq_s32_f32(A_f32x4); + B_s32x4 = vcvtnq_s32_f32(B_f32x4); + C_s32x4 = vcvtnq_s32_f32(C_f32x4); + D_s32x4 = vcvtnq_s32_f32(D_f32x4); + } + else if(round_policy == RoundingPolicy::TO_NEAREST_UP) + { + A_s32x4 = vcvtaq_s32_f32(A_f32x4); + B_s32x4 = vcvtaq_s32_f32(B_f32x4); + C_s32x4 = vcvtaq_s32_f32(C_f32x4); + D_s32x4 = vcvtaq_s32_f32(D_f32x4); + } + else + { + A_s32x4 = vcvtq_s32_f32(A_f32x4); + B_s32x4 = vcvtq_s32_f32(B_f32x4); + C_s32x4 = vcvtq_s32_f32(C_f32x4); + D_s32x4 = vcvtq_s32_f32(D_f32x4); + } +#else // #if __aarch64__ + // rounding mode only supported in aarch64 A_s32x4 = vcvtq_s32_f32(A_f32x4); B_s32x4 = vcvtq_s32_f32(B_f32x4); C_s32x4 = vcvtq_s32_f32(C_f32x4); D_s32x4 = vcvtq_s32_f32(D_f32x4); +#endif // #if __aarch64__ + // Convert int32 vectors to int16 vectors (with saturation) vd_low_s16x8 = vcombine_s16(vqmovn_s32(A_s32x4), vqmovn_s32(B_s32x4)); vd_high_s16x8 = vcombine_s16(vqmovn_s32(C_s32x4), vqmovn_s32(D_s32x4)); |