diff options
author | Michalis Spyrou <michalis.spyrou@arm.com> | 2018-03-01 16:03:50 +0000 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:50:48 +0000 |
commit | 57dac8400d56a4b68975d5563a9540c96d49fe5f (patch) | |
tree | cf474c6690c02115e623d5e7d867be665050f87e /tests/validation/reference/PoolingLayer.cpp | |
parent | 0ef7e670a22248806458d7327db9e8b8c4db4ce6 (diff) | |
download | ComputeLibrary-57dac8400d56a4b68975d5563a9540c96d49fe5f.tar.gz |
COMPMID-806 Add NHWC data format support format for NEON pooling
Change-Id: I7ab174c72f3d56134fcec259a137739061fd12e9
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/123065
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Tested-by: Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests/validation/reference/PoolingLayer.cpp')
-rw-r--r-- | tests/validation/reference/PoolingLayer.cpp | 35 |
1 files changed, 26 insertions, 9 deletions
diff --git a/tests/validation/reference/PoolingLayer.cpp b/tests/validation/reference/PoolingLayer.cpp index c14ab98c28..071c20ed56 100644 --- a/tests/validation/reference/PoolingLayer.cpp +++ b/tests/validation/reference/PoolingLayer.cpp @@ -23,6 +23,7 @@ */ #include "PoolingLayer.h" +#include "Permute.h" #include "arm_compute/core/Types.h" #include "tests/validation/FixedPoint.h" #include "tests/validation/Helpers.h" @@ -54,8 +55,8 @@ TensorShape calculate_output_shape(TensorShape shape, const PoolingLayerInfo &in } } // namespace -template <typename T, typename std::enable_if<is_floating_point<T>::value, int>::type> -SimpleTensor<T> pooling_layer(const SimpleTensor<T> &src, const PoolingLayerInfo &info) +template <typename T> +SimpleTensor<T> pooling_layer_nchw(const SimpleTensor<T> &src, SimpleTensor<T> &dst, const PoolingLayerInfo &info) { ARM_COMPUTE_ERROR_ON(info.is_global_pooling() && (src.shape().x() != src.shape().y())); @@ -74,9 +75,6 @@ SimpleTensor<T> pooling_layer(const SimpleTensor<T> &src, const PoolingLayerInfo const auto h_src = static_cast<int>(src.shape()[1]); const int upper_dims = src.shape().total_size() / (w_src * h_src); - // Create reference - SimpleTensor<T> dst{ calculate_output_shape(src.shape(), info), src.data_type(), 1, src.fixed_point_position() }; - const auto w_dst = static_cast<int>(dst.shape()[0]); const auto h_dst = static_cast<int>(dst.shape()[1]); @@ -173,6 +171,10 @@ SimpleTensor<T> pooling_layer(const SimpleTensor<T> &src, const PoolingLayerInfo { ARM_COMPUTE_ERROR_ON(info.is_global_pooling() && (src.shape().x() != src.shape().y())); + const auto w_src = static_cast<int>(src.shape()[0]); + const auto h_src = static_cast<int>(src.shape()[1]); + const int upper_dims = src.shape().total_size() / (w_src * h_src); + const int pool_size_x = info.is_global_pooling() ? src.shape().x() : info.pool_size().width; const int pool_size_y = info.is_global_pooling() ? src.shape().y() : info.pool_size().height; PoolingType type = info.pool_type(); @@ -184,10 +186,6 @@ SimpleTensor<T> pooling_layer(const SimpleTensor<T> &src, const PoolingLayerInfo int pad_bottom = info.pad_stride_info().pad_bottom(); bool exclude_padding = info.exclude_padding(); - const auto w_src = static_cast<int>(src.shape()[0]); - const auto h_src = static_cast<int>(src.shape()[1]); - const int upper_dims = src.shape().total_size() / (w_src * h_src); - // Create reference SimpleTensor<T> dst{ calculate_output_shape(src.shape(), info), src.data_type(), 1, src.fixed_point_position() }; @@ -299,6 +297,25 @@ SimpleTensor<uint8_t> pooling_layer<uint8_t>(const SimpleTensor<uint8_t> &src, c return dst; } +template <typename T, typename std::enable_if<is_floating_point<T>::value, int>::type> +SimpleTensor<T> pooling_layer(const SimpleTensor<T> &src, const PoolingLayerInfo &info) +{ + if(src.data_layout() == DataLayout::NHWC) + { + SimpleTensor<T> src_nchw = reference::permute<T>(src, PermutationVector(1U, 2U, 0U)); + SimpleTensor<T> dst{ calculate_output_shape(src_nchw.shape(), info), src_nchw.data_type(), 1, src_nchw.fixed_point_position() }; + + pooling_layer_nchw<T>(src_nchw, dst, info); + + return reference::permute<T>(dst, PermutationVector(2U, 0U, 1U)); + } + else + { + SimpleTensor<T> dst{ calculate_output_shape(src.shape(), info), src.data_type(), 1, src.fixed_point_position() }; + return pooling_layer_nchw<T>(src, dst, info); + } +} + template SimpleTensor<float> pooling_layer(const SimpleTensor<float> &src, const PoolingLayerInfo &info); template SimpleTensor<half> pooling_layer(const SimpleTensor<half> &src, const PoolingLayerInfo &info); template SimpleTensor<qint8_t> pooling_layer(const SimpleTensor<qint8_t> &src, const PoolingLayerInfo &info); |