From b9531540dadce8331a703c32456f3c9defdfefa9 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Thu, 5 Nov 2020 20:06:49 +0000 Subject: COMPMID-3850: NEPooling regression for NHWC Expand left-over loop to handle multiples of 8 for quantized data type during MaxPooling. Signed-off-by: Georgios Pinitas Change-Id: I1304d174c45d2c98247470ac8b4bb6752bbc03a6 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4339 Reviewed-by: Michele Di Giorgio Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins --- src/core/NEON/kernels/NEPoolingLayerKernel.cpp | 28 +++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) (limited to 'src/core/NEON/kernels/NEPoolingLayerKernel.cpp') diff --git a/src/core/NEON/kernels/NEPoolingLayerKernel.cpp b/src/core/NEON/kernels/NEPoolingLayerKernel.cpp index 0f0b9eed5a..b46843badd 100644 --- a/src/core/NEON/kernels/NEPoolingLayerKernel.cpp +++ b/src/core/NEON/kernels/NEPoolingLayerKernel.cpp @@ -2283,9 +2283,10 @@ void NEPoolingLayerKernel::poolingMxN_q8_nchw(const Window &window_input, const template void NEPoolingLayerKernel::poolingMxN_q8_nhwc(const Window &window_input, const Window &window, PoolingType pooling_type, bool exclude_padding) { - const int window_start_x = window.x().start(); - const int window_end_x = window.x().end(); - const int window_step_x = 16; + const int window_start_x = window.x().start(); + const int window_end_x = window.x().end(); + const int window_step_x = 16; + const int window_half_step_x = window_step_x / 2; Window window_out = window; window_out.set(Window::DimX, Window::Dimension(0, 1, 1)); @@ -2422,6 +2423,27 @@ void NEPoolingLayerKernel::poolingMxN_q8_nhwc(const Window &window_input, const } } + if(pooling_type == PoolingType::MAX) + { + for(; x_off <= (window_end_x - window_half_step_x); x_off += window_half_step_x) + { + q8x8_t vres = wrapper::vdup_n(std::numeric_limits::min(), wrapper::traits::vector_64_tag{}); + for(int y = pool_start_y; y < pool_end_y; ++y) + { + for(int x = pool_start_x; x < pool_end_x; ++x) + { + const q8x8_t data = wrapper::vload(reinterpret_cast(input.ptr() + (x - pool_pad_left) * static_cast(_input->info()->strides_in_bytes().y()) + (y - pool_pad_top) * static_cast + (_input->info()->strides_in_bytes().z())) + x_off); + vres = wrapper::vmax(vres, data); + } + } + + // Store result + wrapper::vstore(reinterpret_cast(output.ptr()) + x_off, + (input_qinfo != output_qinfo) ? vrequantize_pooling(vres, requant_qinfo) : vres); + } + } + // Left-overs loop for(; x_off < window_end_x; ++x_off) { -- cgit v1.2.1