From 283fc606dbd9058f636b91350a1c47b97aba1a87 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Fri, 9 Nov 2018 10:46:43 +0000 Subject: COMPMID-1451: Reduces accuracy issue in NEPoolingLayer for QASYMM8 NHWC Adds 0.5f after scaling AVG pooling to be able to round to nearest as vcvtq_u32_f32 rounds towards zero. Change-Id: I22ce78f9e628cf4184a317edabce47211ab09456 --- src/core/NEON/kernels/NEPoolingLayerKernel.cpp | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) (limited to 'src/core/NEON/kernels/NEPoolingLayerKernel.cpp') diff --git a/src/core/NEON/kernels/NEPoolingLayerKernel.cpp b/src/core/NEON/kernels/NEPoolingLayerKernel.cpp index f5d5281884..310560b48a 100644 --- a/src/core/NEON/kernels/NEPoolingLayerKernel.cpp +++ b/src/core/NEON/kernels/NEPoolingLayerKernel.cpp @@ -1951,6 +1951,8 @@ void NEPoolingLayerKernel::poolingMxN_qasymm8_nhwc(const Window &window_input, c const int upper_bound_w = _input->info()->dimension(1) + (exclude_padding ? 0 : pool_pad_right); const int upper_bound_h = _input->info()->dimension(2) + (exclude_padding ? 0 : pool_pad_bottom); + const float32x4_t half_scale_v = vdupq_n_f32(0.5f); + execute_window_loop(window, [&](const Coordinates & id) { const int idx_width = id.y() * pool_stride_x; @@ -1991,11 +1993,11 @@ void NEPoolingLayerKernel::poolingMxN_qasymm8_nhwc(const Window &window_input, c vres4 = vaddq_u32(vres4, vmovl_u16(vget_high_u16(data2_u16))); } } - // Divide by scale - vres1 = vcvtq_u32_f32(vmulq_f32(vcvtq_f32_u32(vres1), scale_v)); - vres2 = vcvtq_u32_f32(vmulq_f32(vcvtq_f32_u32(vres2), scale_v)); - vres3 = vcvtq_u32_f32(vmulq_f32(vcvtq_f32_u32(vres3), scale_v)); - vres4 = vcvtq_u32_f32(vmulq_f32(vcvtq_f32_u32(vres4), scale_v)); + // Divide by scale and add 0.5f to round to nearest instead of rounding towards zero + vres1 = vcvtq_u32_f32(vmlaq_f32(half_scale_v, vcvtq_f32_u32(vres1), scale_v)); + vres2 = vcvtq_u32_f32(vmlaq_f32(half_scale_v, vcvtq_f32_u32(vres2), scale_v)); + vres3 = vcvtq_u32_f32(vmlaq_f32(half_scale_v, vcvtq_f32_u32(vres3), scale_v)); + vres4 = vcvtq_u32_f32(vmlaq_f32(half_scale_v, vcvtq_f32_u32(vres4), scale_v)); uint8x8_t res1 = vmovn_u16(vcombine_u16(vmovn_u32(vres1), vmovn_u32(vres2))); uint8x8_t res2 = vmovn_u16(vcombine_u16(vmovn_u32(vres3), vmovn_u32(vres4))); -- cgit v1.2.1