aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/CPP/PoolingLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/CPP/PoolingLayer.cpp')
-rw-r--r--tests/validation/CPP/PoolingLayer.cpp13
1 files changed, 9 insertions, 4 deletions
diff --git a/tests/validation/CPP/PoolingLayer.cpp b/tests/validation/CPP/PoolingLayer.cpp
index 90a48e0c44..1a7dd4cbb7 100644
--- a/tests/validation/CPP/PoolingLayer.cpp
+++ b/tests/validation/CPP/PoolingLayer.cpp
@@ -40,10 +40,11 @@ namespace
TensorShape calculate_output_shape(TensorShape shape, PoolingLayerInfo info)
{
TensorShape dst_shape = shape;
+ const int pool_size = info.is_global_pooling() ? shape.x() : info.pool_size();
const std::pair<unsigned int, unsigned int> scaled_dims = arm_compute::scaled_dimensions(shape.x(),
shape.y(),
- info.pool_size(),
- info.pool_size(),
+ pool_size,
+ pool_size,
info.pad_stride_info());
dst_shape.set(0, scaled_dims.first);
dst_shape.set(1, scaled_dims.second);
@@ -55,7 +56,9 @@ TensorShape calculate_output_shape(TensorShape shape, PoolingLayerInfo info)
template <typename T, typename std::enable_if<is_floating_point<T>::value, int>::type>
SimpleTensor<T> pooling_layer(const SimpleTensor<T> &src, PoolingLayerInfo info)
{
- const int pool_size = info.pool_size();
+ ARM_COMPUTE_ERROR_ON(info.is_global_pooling() && (src.shape().x() != src.shape().y()));
+
+ const int pool_size = info.is_global_pooling() ? src.shape().x() : info.pool_size();
PoolingType type = info.pool_type();
int pool_stride_x = info.pad_stride_info().stride().first;
int pool_stride_y = info.pad_stride_info().stride().second;
@@ -164,7 +167,9 @@ SimpleTensor<T> pooling_layer(const SimpleTensor<T> &src, PoolingLayerInfo info)
template <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type>
SimpleTensor<T> pooling_layer(const SimpleTensor<T> &src, PoolingLayerInfo info)
{
- const int pool_size = info.pool_size();
+ ARM_COMPUTE_ERROR_ON(info.is_global_pooling() && (src.shape().x() != src.shape().y()));
+
+ const int pool_size = info.is_global_pooling() ? src.shape().x() : info.pool_size();
PoolingType type = info.pool_type();
int pool_stride_x = info.pad_stride_info().stride().first;
int pool_stride_y = info.pad_stride_info().stride().second;