diff options
Diffstat (limited to 'tests/validation/fixtures/ROIAlignLayerFixture.h')
-rw-r--r-- | tests/validation/fixtures/ROIAlignLayerFixture.h | 22 |
1 files changed, 10 insertions, 12 deletions
diff --git a/tests/validation/fixtures/ROIAlignLayerFixture.h b/tests/validation/fixtures/ROIAlignLayerFixture.h index b9b85d3073..e4470c99a0 100644 --- a/tests/validation/fixtures/ROIAlignLayerFixture.h +++ b/tests/validation/fixtures/ROIAlignLayerFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2019 ARM Limited. + * Copyright (c) 2018-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -41,12 +41,10 @@ namespace test { namespace validation { -template <typename TensorType, typename AccessorType, typename FunctionType, typename T> +template <typename TensorType, typename AccessorType, typename FunctionType, typename T, typename TRois> class ROIAlignLayerGenericFixture : public framework::Fixture { public: - using TRois = typename std::conditional<std::is_same<typename std::decay<T>::type, uint8_t>::value, uint16_t, T>::type; - template <typename...> void setup(TensorShape input_shape, const ROIPoolingLayerInfo pool_info, TensorShape rois_shape, DataType data_type, DataLayout data_layout, QuantizationInfo qinfo, QuantizationInfo output_qinfo) { @@ -187,28 +185,28 @@ protected: DataType _rois_data_type{}; }; -template <typename TensorType, typename AccessorType, typename FunctionType, typename T> -class ROIAlignLayerFixture : public ROIAlignLayerGenericFixture<TensorType, AccessorType, FunctionType, T> +template <typename TensorType, typename AccessorType, typename FunctionType, typename T, typename TRois> +class ROIAlignLayerFixture : public ROIAlignLayerGenericFixture<TensorType, AccessorType, FunctionType, T, TRois> { public: template <typename...> void setup(TensorShape input_shape, const ROIPoolingLayerInfo pool_info, TensorShape rois_shape, DataType data_type, DataLayout data_layout) { - ROIAlignLayerGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, pool_info, rois_shape, data_type, data_layout, - QuantizationInfo(), QuantizationInfo()); + ROIAlignLayerGenericFixture<TensorType, AccessorType, FunctionType, T, TRois>::setup(input_shape, pool_info, rois_shape, data_type, data_layout, + QuantizationInfo(), QuantizationInfo()); } }; -template <typename TensorType, typename AccessorType, typename FunctionType, typename T> -class ROIAlignLayerQuantizedFixture : public ROIAlignLayerGenericFixture<TensorType, AccessorType, FunctionType, T> +template <typename TensorType, typename AccessorType, typename FunctionType, typename T, typename TRois> +class ROIAlignLayerQuantizedFixture : public ROIAlignLayerGenericFixture<TensorType, AccessorType, FunctionType, T, TRois> { public: template <typename...> void setup(TensorShape input_shape, const ROIPoolingLayerInfo pool_info, TensorShape rois_shape, DataType data_type, DataLayout data_layout, QuantizationInfo qinfo, QuantizationInfo output_qinfo) { - ROIAlignLayerGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, pool_info, rois_shape, - data_type, data_layout, qinfo, output_qinfo); + ROIAlignLayerGenericFixture<TensorType, AccessorType, FunctionType, T, TRois>::setup(input_shape, pool_info, rois_shape, + data_type, data_layout, qinfo, output_qinfo); } }; } // namespace validation |