aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/ROIAlignLayerFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/ROIAlignLayerFixture.h')
-rw-r--r--tests/validation/fixtures/ROIAlignLayerFixture.h22
1 files changed, 10 insertions, 12 deletions
diff --git a/tests/validation/fixtures/ROIAlignLayerFixture.h b/tests/validation/fixtures/ROIAlignLayerFixture.h
index b9b85d3073..e4470c99a0 100644
--- a/tests/validation/fixtures/ROIAlignLayerFixture.h
+++ b/tests/validation/fixtures/ROIAlignLayerFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2019 ARM Limited.
+ * Copyright (c) 2018-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -41,12 +41,10 @@ namespace test
{
namespace validation
{
-template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T, typename TRois>
class ROIAlignLayerGenericFixture : public framework::Fixture
{
public:
- using TRois = typename std::conditional<std::is_same<typename std::decay<T>::type, uint8_t>::value, uint16_t, T>::type;
-
template <typename...>
void setup(TensorShape input_shape, const ROIPoolingLayerInfo pool_info, TensorShape rois_shape, DataType data_type, DataLayout data_layout, QuantizationInfo qinfo, QuantizationInfo output_qinfo)
{
@@ -187,28 +185,28 @@ protected:
DataType _rois_data_type{};
};
-template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
-class ROIAlignLayerFixture : public ROIAlignLayerGenericFixture<TensorType, AccessorType, FunctionType, T>
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T, typename TRois>
+class ROIAlignLayerFixture : public ROIAlignLayerGenericFixture<TensorType, AccessorType, FunctionType, T, TRois>
{
public:
template <typename...>
void setup(TensorShape input_shape, const ROIPoolingLayerInfo pool_info, TensorShape rois_shape, DataType data_type, DataLayout data_layout)
{
- ROIAlignLayerGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, pool_info, rois_shape, data_type, data_layout,
- QuantizationInfo(), QuantizationInfo());
+ ROIAlignLayerGenericFixture<TensorType, AccessorType, FunctionType, T, TRois>::setup(input_shape, pool_info, rois_shape, data_type, data_layout,
+ QuantizationInfo(), QuantizationInfo());
}
};
-template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
-class ROIAlignLayerQuantizedFixture : public ROIAlignLayerGenericFixture<TensorType, AccessorType, FunctionType, T>
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T, typename TRois>
+class ROIAlignLayerQuantizedFixture : public ROIAlignLayerGenericFixture<TensorType, AccessorType, FunctionType, T, TRois>
{
public:
template <typename...>
void setup(TensorShape input_shape, const ROIPoolingLayerInfo pool_info, TensorShape rois_shape, DataType data_type,
DataLayout data_layout, QuantizationInfo qinfo, QuantizationInfo output_qinfo)
{
- ROIAlignLayerGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, pool_info, rois_shape,
- data_type, data_layout, qinfo, output_qinfo);
+ ROIAlignLayerGenericFixture<TensorType, AccessorType, FunctionType, T, TRois>::setup(input_shape, pool_info, rois_shape,
+ data_type, data_layout, qinfo, output_qinfo);
}
};
} // namespace validation