diff options
Diffstat (limited to 'tests/validation/fixtures/PoolingLayerFixture.h')
-rw-r--r-- | tests/validation/fixtures/PoolingLayerFixture.h | 16 |
1 files changed, 5 insertions, 11 deletions
diff --git a/tests/validation/fixtures/PoolingLayerFixture.h b/tests/validation/fixtures/PoolingLayerFixture.h index ec4e9f80dd..59c920868b 100644 --- a/tests/validation/fixtures/PoolingLayerFixture.h +++ b/tests/validation/fixtures/PoolingLayerFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2021, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -45,7 +45,6 @@ template <typename TensorType, typename AccessorType, typename FunctionType, typ class PoolingLayerValidationGenericFixture : public framework::Fixture { public: - template <typename...> void setup(TensorShape shape, PoolingLayerInfo pool_info, DataType data_type, DataLayout data_layout, bool indices = false, QuantizationInfo input_qinfo = QuantizationInfo(), QuantizationInfo output_qinfo = QuantizationInfo(), bool mixed_layout = false) { @@ -161,10 +160,10 @@ template <typename TensorType, typename AccessorType, typename FunctionType, typ class PoolingLayerIndicesValidationFixture : public PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T> { public: - template <typename...> - void setup(TensorShape shape, PoolingType pool_type, Size2D pool_size, PadStrideInfo pad_stride_info, bool exclude_padding, DataType data_type, DataLayout data_layout) + void setup(TensorShape shape, PoolingType pool_type, Size2D pool_size, PadStrideInfo pad_stride_info, bool exclude_padding, DataType data_type, DataLayout data_layout, bool use_kernel_indices) { - PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, PoolingLayerInfo(pool_type, pool_size, data_layout, pad_stride_info, exclude_padding), + PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, PoolingLayerInfo(pool_type, pool_size, data_layout, pad_stride_info, exclude_padding, false, + true, use_kernel_indices), data_type, data_layout, true); } }; @@ -173,7 +172,6 @@ template <typename TensorType, typename AccessorType, typename FunctionType, typ class PoolingLayerValidationFixture : public PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T> { public: - template <typename...> void setup(TensorShape shape, PoolingType pool_type, Size2D pool_size, PadStrideInfo pad_stride_info, bool exclude_padding, DataType data_type, DataLayout data_layout) { PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, PoolingLayerInfo(pool_type, pool_size, data_layout, pad_stride_info, exclude_padding), @@ -185,7 +183,6 @@ template <typename TensorType, typename AccessorType, typename FunctionType, typ class PoolingLayerValidationMixedPrecisionFixture : public PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T> { public: - template <typename...> void setup(TensorShape shape, PoolingType pool_type, Size2D pool_size, PadStrideInfo pad_stride_info, bool exclude_padding, DataType data_type, DataLayout data_layout, bool fp_mixed_precision = false) { PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, PoolingLayerInfo(pool_type, pool_size, data_layout, pad_stride_info, exclude_padding, fp_mixed_precision), @@ -197,7 +194,6 @@ template <typename TensorType, typename AccessorType, typename FunctionType, typ class PoolingLayerValidationQuantizedFixture : public PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T> { public: - template <typename...> void setup(TensorShape shape, PoolingType pool_type, Size2D pool_size, PadStrideInfo pad_stride_info, bool exclude_padding, DataType data_type, DataLayout data_layout = DataLayout::NCHW, QuantizationInfo input_qinfo = QuantizationInfo(), QuantizationInfo output_qinfo = QuantizationInfo()) { @@ -210,10 +206,9 @@ template <typename TensorType, typename AccessorType, typename FunctionType, typ class SpecialPoolingLayerValidationFixture : public PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T> { public: - template <typename...> void setup(TensorShape src_shape, PoolingLayerInfo pool_info, DataType data_type) { - PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(src_shape, pool_info, data_type, DataLayout::NCHW); + PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(src_shape, pool_info, data_type, pool_info.data_layout); } }; @@ -221,7 +216,6 @@ template <typename TensorType, typename AccessorType, typename FunctionType, typ class GlobalPoolingLayerValidationFixture : public PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T> { public: - template <typename...> void setup(TensorShape shape, PoolingType pool_type, DataType data_type, DataLayout data_layout = DataLayout::NCHW) { PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, PoolingLayerInfo(pool_type, data_layout), data_type, data_layout); |