diff options
Diffstat (limited to 'tests/validation/reference')
-rw-r--r-- | tests/validation/reference/PoolingLayer.cpp | 37 |
1 files changed, 14 insertions, 23 deletions
diff --git a/tests/validation/reference/PoolingLayer.cpp b/tests/validation/reference/PoolingLayer.cpp index 071c20ed56..f3f456b26e 100644 --- a/tests/validation/reference/PoolingLayer.cpp +++ b/tests/validation/reference/PoolingLayer.cpp @@ -25,6 +25,7 @@ #include "Permute.h" #include "arm_compute/core/Types.h" +#include "arm_compute/core/utils/misc/ShapeCalculator.h" #include "tests/validation/FixedPoint.h" #include "tests/validation/Helpers.h" @@ -36,24 +37,7 @@ namespace validation { namespace reference { -namespace -{ -TensorShape calculate_output_shape(TensorShape shape, const PoolingLayerInfo &info) -{ - TensorShape dst_shape = shape; - const int pool_size_x = info.is_global_pooling() ? shape.x() : info.pool_size().width; - const int pool_size_y = info.is_global_pooling() ? shape.y() : info.pool_size().height; - const std::pair<unsigned int, unsigned int> scaled_dims = arm_compute::scaled_dimensions(shape.x(), - shape.y(), - pool_size_x, - pool_size_y, - info.pad_stride_info()); - dst_shape.set(0, scaled_dims.first); - dst_shape.set(1, scaled_dims.second); - - return dst_shape; -} -} // namespace +using namespace arm_compute::misc::shape_calculator; template <typename T> SimpleTensor<T> pooling_layer_nchw(const SimpleTensor<T> &src, SimpleTensor<T> &dst, const PoolingLayerInfo &info) @@ -187,7 +171,10 @@ SimpleTensor<T> pooling_layer(const SimpleTensor<T> &src, const PoolingLayerInfo bool exclude_padding = info.exclude_padding(); // Create reference - SimpleTensor<T> dst{ calculate_output_shape(src.shape(), info), src.data_type(), 1, src.fixed_point_position() }; + TensorInfo src_info(src.shape(), 1, src.data_type(), src.fixed_point_position()); + src_info.set_data_layout(src.data_layout()); + + SimpleTensor<T> dst{ compute_pool_shape(src_info, info), src.data_type(), 1, src.fixed_point_position() }; const auto w_dst = static_cast<int>(dst.shape()[0]); const auto h_dst = static_cast<int>(dst.shape()[1]); @@ -300,18 +287,22 @@ SimpleTensor<uint8_t> pooling_layer<uint8_t>(const SimpleTensor<uint8_t> &src, c template <typename T, typename std::enable_if<is_floating_point<T>::value, int>::type> SimpleTensor<T> pooling_layer(const SimpleTensor<T> &src, const PoolingLayerInfo &info) { + TensorInfo src_info(src.shape(), 1, src.data_type(), src.fixed_point_position()); + src_info.set_data_layout(src.data_layout()); + + SimpleTensor<T> dst{ compute_pool_shape(src_info, info), src.data_type(), 1, src.fixed_point_position() }; + if(src.data_layout() == DataLayout::NHWC) { SimpleTensor<T> src_nchw = reference::permute<T>(src, PermutationVector(1U, 2U, 0U)); - SimpleTensor<T> dst{ calculate_output_shape(src_nchw.shape(), info), src_nchw.data_type(), 1, src_nchw.fixed_point_position() }; + SimpleTensor<T> dst_nchw = reference::permute<T>(dst, PermutationVector(1U, 2U, 0U)); - pooling_layer_nchw<T>(src_nchw, dst, info); + pooling_layer_nchw<T>(src_nchw, dst_nchw, info); - return reference::permute<T>(dst, PermutationVector(2U, 0U, 1U)); + return reference::permute<T>(dst_nchw, PermutationVector(2U, 0U, 1U)); } else { - SimpleTensor<T> dst{ calculate_output_shape(src.shape(), info), src.data_type(), 1, src.fixed_point_position() }; return pooling_layer_nchw<T>(src, dst, info); } } |