diff options
Diffstat (limited to 'tests/validation/fixtures/PoolingLayerFixture.h')
-rw-r--r-- | tests/validation/fixtures/PoolingLayerFixture.h | 42 |
1 files changed, 30 insertions, 12 deletions
diff --git a/tests/validation/fixtures/PoolingLayerFixture.h b/tests/validation/fixtures/PoolingLayerFixture.h index ec186564b7..7f2d7ac225 100644 --- a/tests/validation/fixtures/PoolingLayerFixture.h +++ b/tests/validation/fixtures/PoolingLayerFixture.h @@ -34,7 +34,6 @@ #include "tests/framework/Asserts.h" #include "tests/framework/Fixture.h" #include "tests/validation/reference/PoolingLayer.h" - #include <random> namespace arm_compute @@ -48,7 +47,7 @@ class PoolingLayerValidationGenericFixture : public framework::Fixture { public: template <typename...> - void setup(TensorShape shape, PoolingLayerInfo pool_info, DataType data_type, DataLayout data_layout) + void setup(TensorShape shape, PoolingLayerInfo pool_info, DataType data_type, DataLayout data_layout, bool indices = false) { std::mt19937 gen(library->seed()); std::uniform_int_distribution<> offset_dis(0, 20); @@ -59,8 +58,8 @@ public: const QuantizationInfo output_qinfo(scale, scale_out); _pool_info = pool_info; - _target = compute_target(shape, pool_info, data_type, data_layout, input_qinfo, output_qinfo); - _reference = compute_reference(shape, pool_info, data_type, input_qinfo, output_qinfo); + _target = compute_target(shape, pool_info, data_type, data_layout, input_qinfo, output_qinfo, indices); + _reference = compute_reference(shape, pool_info, data_type, input_qinfo, output_qinfo, indices); } protected: @@ -79,7 +78,9 @@ protected: } TensorType compute_target(TensorShape shape, PoolingLayerInfo info, - DataType data_type, DataLayout data_layout, QuantizationInfo input_qinfo, QuantizationInfo output_qinfo) + DataType data_type, DataLayout data_layout, + QuantizationInfo input_qinfo, QuantizationInfo output_qinfo, + bool indices) { // Change shape in case of NHWC. if(data_layout == DataLayout::NHWC) @@ -91,20 +92,24 @@ protected: TensorType src = create_tensor<TensorType>(shape, data_type, 1, input_qinfo, data_layout); const TensorShape dst_shape = misc::shape_calculator::compute_pool_shape(*(src.info()), info); TensorType dst = create_tensor<TensorType>(dst_shape, data_type, 1, output_qinfo, data_layout); + _target_indices = create_tensor<TensorType>(dst_shape, DataType::U32, 1); // Create and configure function FunctionType pool_layer; - pool_layer.configure(&src, &dst, info); + pool_layer.configure(&src, &dst, info, (indices) ? &_target_indices : nullptr); ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(_target_indices.info()->is_resizable(), framework::LogLevel::ERRORS); // Allocate tensors src.allocator()->allocate(); dst.allocator()->allocate(); + _target_indices.allocator()->allocate(); ARM_COMPUTE_EXPECT(!src.info()->is_resizable(), framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(!dst.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(!_target_indices.info()->is_resizable(), framework::LogLevel::ERRORS); // Fill tensors fill(AccessorType(src)); @@ -115,20 +120,33 @@ protected: return dst; } - SimpleTensor<T> compute_reference(const TensorShape &shape, PoolingLayerInfo info, DataType data_type, QuantizationInfo input_qinfo, QuantizationInfo output_qinfo) + SimpleTensor<T> compute_reference(const TensorShape &shape, PoolingLayerInfo info, DataType data_type, + QuantizationInfo input_qinfo, QuantizationInfo output_qinfo, bool indices) { // Create reference SimpleTensor<T> src{ shape, data_type, 1, input_qinfo }; - // Fill reference fill(src); - return reference::pooling_layer<T>(src, info, output_qinfo); + return reference::pooling_layer<T>(src, info, output_qinfo, indices ? &_ref_indices : nullptr); } - TensorType _target{}; - SimpleTensor<T> _reference{}; - PoolingLayerInfo _pool_info{}; + TensorType _target{}; + SimpleTensor<T> _reference{}; + PoolingLayerInfo _pool_info{}; + TensorType _target_indices{}; + SimpleTensor<uint32_t> _ref_indices{}; +}; +template <typename TensorType, typename AccessorType, typename FunctionType, typename T> +class PoolingLayerIndicesValidationFixture : public PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T> +{ +public: + template <typename...> + void setup(TensorShape shape, PoolingType pool_type, Size2D pool_size, PadStrideInfo pad_stride_info, bool exclude_padding, DataType data_type, DataLayout data_layout) + { + PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, PoolingLayerInfo(pool_type, pool_size, data_layout, pad_stride_info, exclude_padding), + data_type, data_layout, true); + } }; template <typename TensorType, typename AccessorType, typename FunctionType, typename T> |