diff options
author | Georgios Pinitas <georgios.pinitas@arm.com> | 2017-11-13 12:58:41 +0000 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:35:24 +0000 |
commit | 4c2dd54d6983275530ef20f9dbb4ce6080c7307b (patch) | |
tree | ddef97fa862d8a08faa42e2c624b029018591d13 /tests/validation/CPP | |
parent | fa330439b88af04f96c23b75e36a5a7813b89711 (diff) | |
download | ComputeLibrary-4c2dd54d6983275530ef20f9dbb4ce6080c7307b.tar.gz |
COMPMID-671: Add global pooling layer support.
Change-Id: Iead7497cc03e1e7bde440d2965a7bf54cbfa88bf
Reviewed-on: http://mpd-gerrit.cambridge.arm.com/95579
Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com>
Reviewed-by: Joel Liang <joel.liang@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Diffstat (limited to 'tests/validation/CPP')
-rw-r--r-- | tests/validation/CPP/PoolingLayer.cpp | 13 |
1 files changed, 9 insertions, 4 deletions
diff --git a/tests/validation/CPP/PoolingLayer.cpp b/tests/validation/CPP/PoolingLayer.cpp index 90a48e0c44..1a7dd4cbb7 100644 --- a/tests/validation/CPP/PoolingLayer.cpp +++ b/tests/validation/CPP/PoolingLayer.cpp @@ -40,10 +40,11 @@ namespace TensorShape calculate_output_shape(TensorShape shape, PoolingLayerInfo info) { TensorShape dst_shape = shape; + const int pool_size = info.is_global_pooling() ? shape.x() : info.pool_size(); const std::pair<unsigned int, unsigned int> scaled_dims = arm_compute::scaled_dimensions(shape.x(), shape.y(), - info.pool_size(), - info.pool_size(), + pool_size, + pool_size, info.pad_stride_info()); dst_shape.set(0, scaled_dims.first); dst_shape.set(1, scaled_dims.second); @@ -55,7 +56,9 @@ TensorShape calculate_output_shape(TensorShape shape, PoolingLayerInfo info) template <typename T, typename std::enable_if<is_floating_point<T>::value, int>::type> SimpleTensor<T> pooling_layer(const SimpleTensor<T> &src, PoolingLayerInfo info) { - const int pool_size = info.pool_size(); + ARM_COMPUTE_ERROR_ON(info.is_global_pooling() && (src.shape().x() != src.shape().y())); + + const int pool_size = info.is_global_pooling() ? src.shape().x() : info.pool_size(); PoolingType type = info.pool_type(); int pool_stride_x = info.pad_stride_info().stride().first; int pool_stride_y = info.pad_stride_info().stride().second; @@ -164,7 +167,9 @@ SimpleTensor<T> pooling_layer(const SimpleTensor<T> &src, PoolingLayerInfo info) template <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type> SimpleTensor<T> pooling_layer(const SimpleTensor<T> &src, PoolingLayerInfo info) { - const int pool_size = info.pool_size(); + ARM_COMPUTE_ERROR_ON(info.is_global_pooling() && (src.shape().x() != src.shape().y())); + + const int pool_size = info.is_global_pooling() ? src.shape().x() : info.pool_size(); PoolingType type = info.pool_type(); int pool_stride_x = info.pad_stride_info().stride().first; int pool_stride_y = info.pad_stride_info().stride().second; |