diff options
Diffstat (limited to 'tests/validation/CPP/PoolingLayer.cpp')
-rw-r--r-- | tests/validation/CPP/PoolingLayer.cpp | 13 |
1 files changed, 9 insertions, 4 deletions
diff --git a/tests/validation/CPP/PoolingLayer.cpp b/tests/validation/CPP/PoolingLayer.cpp index 90a48e0c44..1a7dd4cbb7 100644 --- a/tests/validation/CPP/PoolingLayer.cpp +++ b/tests/validation/CPP/PoolingLayer.cpp @@ -40,10 +40,11 @@ namespace TensorShape calculate_output_shape(TensorShape shape, PoolingLayerInfo info) { TensorShape dst_shape = shape; + const int pool_size = info.is_global_pooling() ? shape.x() : info.pool_size(); const std::pair<unsigned int, unsigned int> scaled_dims = arm_compute::scaled_dimensions(shape.x(), shape.y(), - info.pool_size(), - info.pool_size(), + pool_size, + pool_size, info.pad_stride_info()); dst_shape.set(0, scaled_dims.first); dst_shape.set(1, scaled_dims.second); @@ -55,7 +56,9 @@ TensorShape calculate_output_shape(TensorShape shape, PoolingLayerInfo info) template <typename T, typename std::enable_if<is_floating_point<T>::value, int>::type> SimpleTensor<T> pooling_layer(const SimpleTensor<T> &src, PoolingLayerInfo info) { - const int pool_size = info.pool_size(); + ARM_COMPUTE_ERROR_ON(info.is_global_pooling() && (src.shape().x() != src.shape().y())); + + const int pool_size = info.is_global_pooling() ? src.shape().x() : info.pool_size(); PoolingType type = info.pool_type(); int pool_stride_x = info.pad_stride_info().stride().first; int pool_stride_y = info.pad_stride_info().stride().second; @@ -164,7 +167,9 @@ SimpleTensor<T> pooling_layer(const SimpleTensor<T> &src, PoolingLayerInfo info) template <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type> SimpleTensor<T> pooling_layer(const SimpleTensor<T> &src, PoolingLayerInfo info) { - const int pool_size = info.pool_size(); + ARM_COMPUTE_ERROR_ON(info.is_global_pooling() && (src.shape().x() != src.shape().y())); + + const int pool_size = info.is_global_pooling() ? src.shape().x() : info.pool_size(); PoolingType type = info.pool_type(); int pool_stride_x = info.pad_stride_info().stride().first; int pool_stride_y = info.pad_stride_info().stride().second; |