aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/CPP
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2017-11-13 12:58:41 +0000
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:35:24 +0000
commit4c2dd54d6983275530ef20f9dbb4ce6080c7307b (patch)
treeddef97fa862d8a08faa42e2c624b029018591d13 /tests/validation/CPP
parentfa330439b88af04f96c23b75e36a5a7813b89711 (diff)
downloadComputeLibrary-4c2dd54d6983275530ef20f9dbb4ce6080c7307b.tar.gz
COMPMID-671: Add global pooling layer support.
Change-Id: Iead7497cc03e1e7bde440d2965a7bf54cbfa88bf Reviewed-on: http://mpd-gerrit.cambridge.arm.com/95579 Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com> Reviewed-by: Joel Liang <joel.liang@arm.com> Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Diffstat (limited to 'tests/validation/CPP')
-rw-r--r--tests/validation/CPP/PoolingLayer.cpp13
1 files changed, 9 insertions, 4 deletions
diff --git a/tests/validation/CPP/PoolingLayer.cpp b/tests/validation/CPP/PoolingLayer.cpp
index 90a48e0c44..1a7dd4cbb7 100644
--- a/tests/validation/CPP/PoolingLayer.cpp
+++ b/tests/validation/CPP/PoolingLayer.cpp
@@ -40,10 +40,11 @@ namespace
TensorShape calculate_output_shape(TensorShape shape, PoolingLayerInfo info)
{
TensorShape dst_shape = shape;
+ const int pool_size = info.is_global_pooling() ? shape.x() : info.pool_size();
const std::pair<unsigned int, unsigned int> scaled_dims = arm_compute::scaled_dimensions(shape.x(),
shape.y(),
- info.pool_size(),
- info.pool_size(),
+ pool_size,
+ pool_size,
info.pad_stride_info());
dst_shape.set(0, scaled_dims.first);
dst_shape.set(1, scaled_dims.second);
@@ -55,7 +56,9 @@ TensorShape calculate_output_shape(TensorShape shape, PoolingLayerInfo info)
template <typename T, typename std::enable_if<is_floating_point<T>::value, int>::type>
SimpleTensor<T> pooling_layer(const SimpleTensor<T> &src, PoolingLayerInfo info)
{
- const int pool_size = info.pool_size();
+ ARM_COMPUTE_ERROR_ON(info.is_global_pooling() && (src.shape().x() != src.shape().y()));
+
+ const int pool_size = info.is_global_pooling() ? src.shape().x() : info.pool_size();
PoolingType type = info.pool_type();
int pool_stride_x = info.pad_stride_info().stride().first;
int pool_stride_y = info.pad_stride_info().stride().second;
@@ -164,7 +167,9 @@ SimpleTensor<T> pooling_layer(const SimpleTensor<T> &src, PoolingLayerInfo info)
template <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type>
SimpleTensor<T> pooling_layer(const SimpleTensor<T> &src, PoolingLayerInfo info)
{
- const int pool_size = info.pool_size();
+ ARM_COMPUTE_ERROR_ON(info.is_global_pooling() && (src.shape().x() != src.shape().y()));
+
+ const int pool_size = info.is_global_pooling() ? src.shape().x() : info.pool_size();
PoolingType type = info.pool_type();
int pool_stride_x = info.pad_stride_info().stride().first;
int pool_stride_y = info.pad_stride_info().stride().second;