aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/core/NEON/kernels/NEPoolingLayerKernel.cpp28
1 files changed, 25 insertions, 3 deletions
diff --git a/src/core/NEON/kernels/NEPoolingLayerKernel.cpp b/src/core/NEON/kernels/NEPoolingLayerKernel.cpp
index 0f0b9eed5a..b46843badd 100644
--- a/src/core/NEON/kernels/NEPoolingLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEPoolingLayerKernel.cpp
@@ -2283,9 +2283,10 @@ void NEPoolingLayerKernel::poolingMxN_q8_nchw(const Window &window_input, const
template <typename T>
void NEPoolingLayerKernel::poolingMxN_q8_nhwc(const Window &window_input, const Window &window, PoolingType pooling_type, bool exclude_padding)
{
- const int window_start_x = window.x().start();
- const int window_end_x = window.x().end();
- const int window_step_x = 16;
+ const int window_start_x = window.x().start();
+ const int window_end_x = window.x().end();
+ const int window_step_x = 16;
+ const int window_half_step_x = window_step_x / 2;
Window window_out = window;
window_out.set(Window::DimX, Window::Dimension(0, 1, 1));
@@ -2422,6 +2423,27 @@ void NEPoolingLayerKernel::poolingMxN_q8_nhwc(const Window &window_input, const
}
}
+ if(pooling_type == PoolingType::MAX)
+ {
+ for(; x_off <= (window_end_x - window_half_step_x); x_off += window_half_step_x)
+ {
+ q8x8_t vres = wrapper::vdup_n(std::numeric_limits<T>::min(), wrapper::traits::vector_64_tag{});
+ for(int y = pool_start_y; y < pool_end_y; ++y)
+ {
+ for(int x = pool_start_x; x < pool_end_x; ++x)
+ {
+ const q8x8_t data = wrapper::vload(reinterpret_cast<const T *>(input.ptr() + (x - pool_pad_left) * static_cast<int>(_input->info()->strides_in_bytes().y()) + (y - pool_pad_top) * static_cast<int>
+ (_input->info()->strides_in_bytes().z())) + x_off);
+ vres = wrapper::vmax(vres, data);
+ }
+ }
+
+ // Store result
+ wrapper::vstore(reinterpret_cast<T *>(output.ptr()) + x_off,
+ (input_qinfo != output_qinfo) ? vrequantize_pooling<q8x8_t>(vres, requant_qinfo) : vres);
+ }
+ }
+
// Left-overs loop
for(; x_off < window_end_x; ++x_off)
{