diff options
Diffstat (limited to 'arm_compute/core/NEON/kernels/NEPoolingLayerKernel.h')
-rw-r--r-- | arm_compute/core/NEON/kernels/NEPoolingLayerKernel.h | 19 |
1 files changed, 13 insertions, 6 deletions
diff --git a/arm_compute/core/NEON/kernels/NEPoolingLayerKernel.h b/arm_compute/core/NEON/kernels/NEPoolingLayerKernel.h index 15d63c746b..2be25080cd 100644 --- a/arm_compute/core/NEON/kernels/NEPoolingLayerKernel.h +++ b/arm_compute/core/NEON/kernels/NEPoolingLayerKernel.h @@ -91,12 +91,6 @@ private: * @param[in] window_input Input region on which to execute the kernel. * @param[in] window Output region on which to execute the kernel. */ - void pooling2_f32_nchw_maxpool_indices(const Window &window_input, const Window &window); - /** Function to perform 2x2 pooling and compute the pooling indices. The indices can be used for max unpool. - * - * @param[in] window_input Input region on which to execute the kernel. - * @param[in] window Output region on which to execute the kernel. - */ void pooling2_f32_nhwc_maxpool_indices(const Window &window_input, const Window &window); /** Function to perform MxN pooling for 32-bit floating point values. * @@ -138,6 +132,19 @@ private: * @param[in] exclude_padding Flag to specify exclusion of padding from the operation. */ void pooling2_f16_nchw(const Window &window_input, const Window &window, PoolingType pooling_type, bool exclude_padding = false); + /** Function to perform 2x2 pooling and compute the pooling indices for FP32/FP16. The indices can be used for max unpool. + * + * @param[in] window_input Input region on which to execute the kernel. + * @param[in] window Output region on which to execute the kernel. + */ + template <typename T> + void pooling2_nchw_maxpool_indices(const Window &window_input, const Window &window); + /** Function to perform 2x2 pooling and compute the pooling indices. The indices can be used for max unpool. + * + * @param[in] window_input Input region on which to execute the kernel. + * @param[in] window Output region on which to execute the kernel. + */ + void pooling2_f16_nhwc_maxpool_indices(const Window &window_input, const Window &window); /** Function to perform 3x3 pooling. * * @param[in] window_input Input region on which to execute the kernel. |