diff options
author | Isabella Gottardi <isabella.gottardi@arm.com> | 2018-01-26 12:32:45 +0000 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:46:07 +0000 |
commit | 6e464c37b5335e362ac3f988cc4b0beed5205ff4 (patch) | |
tree | 24f758be47f1dd439a4ab3f8f3631a1ed7ef4566 /arm_compute/core/Types.h | |
parent | 3364c4ca658d44f449a8d3d6e9eee31d90254f15 (diff) | |
download | ComputeLibrary-6e464c37b5335e362ac3f988cc4b0beed5205ff4.tar.gz |
COMPMID-828 - Add support for non square pool size - Part1
Change-Id: Ib8100e7c659c49694c746fa3f36ce20f44f6929f
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/117804
Reviewed-by: Michele DiGiorgio <michele.digiorgio@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Tested-by: Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'arm_compute/core/Types.h')
-rw-r--r-- | arm_compute/core/Types.h | 25 |
1 files changed, 21 insertions, 4 deletions
diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h index aa415acebe..72be5cba2b 100644 --- a/arm_compute/core/Types.h +++ b/arm_compute/core/Types.h @@ -27,6 +27,7 @@ #include "arm_compute/core/Coordinates.h" #include "arm_compute/core/QAsymm8.h" #include "arm_compute/core/Rounding.h" +#include "arm_compute/core/Size2D.h" #include "arm_compute/core/Strides.h" #include "arm_compute/core/TensorShape.h" #include "support/Half.h" @@ -578,7 +579,7 @@ class PoolingLayerInfo public: /** Default Constructor */ PoolingLayerInfo() - : _pool_type(PoolingType::MAX), _pool_size(0), _pad_stride_info(PadStrideInfo()), _exclude_padding(false), _is_global_pooling(false) + : _pool_type(PoolingType::MAX), _pool_size(Size2D()), _pad_stride_info(PadStrideInfo()), _exclude_padding(false), _is_global_pooling(false) { } /** Default Constructor @@ -594,6 +595,22 @@ public: unsigned int pool_size, PadStrideInfo pad_stride_info = PadStrideInfo(), bool exclude_padding = false) + : _pool_type(pool_type), _pool_size(Size2D(pool_size, pool_size)), _pad_stride_info(pad_stride_info), _exclude_padding(exclude_padding), _is_global_pooling(false) + { + } + /** Default Constructor + * + * @param[in] pool_type Pooling type @ref PoolingType. + * @param[in] pool_size Pooling size, in elements, across x and y. + * @param[in] pad_stride_info (Optional) Padding and stride information @ref PadStrideInfo + * @param[in] exclude_padding (Optional) Strategy when accounting padding in calculations. + * True will exclude padding while false will not (Used in AVG/L2 pooling to determine the pooling area). + * Defaults to false; + */ + explicit PoolingLayerInfo(PoolingType pool_type, + Size2D pool_size, + PadStrideInfo pad_stride_info = PadStrideInfo(), + bool exclude_padding = false) : _pool_type(pool_type), _pool_size(pool_size), _pad_stride_info(pad_stride_info), _exclude_padding(exclude_padding), _is_global_pooling(false) { } @@ -604,14 +621,14 @@ public: * @param[in] pool_type Pooling type @ref PoolingType. */ explicit PoolingLayerInfo(PoolingType pool_type) - : _pool_type(pool_type), _pool_size(0), _pad_stride_info(PadStrideInfo(1, 1, 0, 0)), _exclude_padding(false), _is_global_pooling(true) + : _pool_type(pool_type), _pool_size(Size2D()), _pad_stride_info(PadStrideInfo(1, 1, 0, 0)), _exclude_padding(false), _is_global_pooling(true) { } PoolingType pool_type() const { return _pool_type; } - unsigned int pool_size() const + const Size2D &pool_size() const { return _pool_size; } @@ -630,7 +647,7 @@ public: private: PoolingType _pool_type; - unsigned int _pool_size; + Size2D _pool_size; PadStrideInfo _pad_stride_info; bool _exclude_padding; bool _is_global_pooling; |