diff options
Diffstat (limited to 'tests/validation/fixtures/ReductionOperationFixture.h')
-rw-r--r-- | tests/validation/fixtures/ReductionOperationFixture.h | 52 |
1 files changed, 42 insertions, 10 deletions
diff --git a/tests/validation/fixtures/ReductionOperationFixture.h b/tests/validation/fixtures/ReductionOperationFixture.h index 0dee7eb707..9079b47cbb 100644 --- a/tests/validation/fixtures/ReductionOperationFixture.h +++ b/tests/validation/fixtures/ReductionOperationFixture.h @@ -45,26 +45,36 @@ class ReductionOperationValidationFixture : public framework::Fixture { public: template <typename...> - void setup(TensorShape shape, DataType data_type, unsigned int axis, ReductionOperation op) + void setup(TensorShape shape, DataType data_type, unsigned int axis, ReductionOperation op, QuantizationInfo quantization_info) { const TensorShape output_shape = get_output_shape(shape, axis); - _target = compute_target(shape, output_shape, data_type, axis, op); - _reference = compute_reference(shape, output_shape, data_type, axis, op); + _target = compute_target(shape, output_shape, data_type, axis, op, quantization_info); + _reference = compute_reference(shape, output_shape, data_type, axis, op, quantization_info); } protected: template <typename U> void fill(U &&tensor) { - std::uniform_real_distribution<> distribution(-1.0f, 1.0f); - library->fill(tensor, distribution, 0); + if(!is_data_type_quantized(tensor.data_type())) + { + std::uniform_real_distribution<> distribution(-1.0f, 1.0f); + library->fill(tensor, distribution, 0); + } + else + { + std::pair<int, int> bounds = get_quantized_bounds(tensor.quantization_info(), -1.0f, 1.0f); + std::uniform_int_distribution<uint8_t> distribution(bounds.first, bounds.second); + + library->fill(tensor, distribution, 0); + } } - TensorType compute_target(const TensorShape &src_shape, const TensorShape &dst_shape, DataType data_type, unsigned int axis, ReductionOperation op) + TensorType compute_target(const TensorShape &src_shape, const TensorShape &dst_shape, DataType data_type, unsigned int axis, ReductionOperation op, QuantizationInfo quantization_info) { // Create tensors - TensorType src = create_tensor<TensorType>(src_shape, data_type); - TensorType dst = create_tensor<TensorType>(dst_shape, data_type); + TensorType src = create_tensor<TensorType>(src_shape, data_type, 1, quantization_info); + TensorType dst = create_tensor<TensorType>(dst_shape, data_type, 1, quantization_info); // Create and configure function FunctionType reduction_func; @@ -89,10 +99,10 @@ protected: return dst; } - SimpleTensor<T> compute_reference(const TensorShape &src_shape, const TensorShape &dst_shape, DataType data_type, unsigned int axis, ReductionOperation op) + SimpleTensor<T> compute_reference(const TensorShape &src_shape, const TensorShape &dst_shape, DataType data_type, unsigned int axis, ReductionOperation op, QuantizationInfo quantization_info) { // Create reference - SimpleTensor<T> src{ src_shape, data_type }; + SimpleTensor<T> src{ src_shape, data_type, 1, quantization_info }; // Fill reference fill(src); @@ -111,6 +121,28 @@ private: return output_shape; } }; + +template <typename TensorType, typename AccessorType, typename FunctionType, typename T> +class ReductionOperationQuantizedFixture : public ReductionOperationValidationFixture<TensorType, AccessorType, FunctionType, T> +{ +public: + template <typename...> + void setup(TensorShape shape, DataType data_type, unsigned int axis, ReductionOperation op, QuantizationInfo quantization_info = QuantizationInfo()) + { + ReductionOperationValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, axis, op, quantization_info); + } +}; + +template <typename TensorType, typename AccessorType, typename FunctionType, typename T> +class ReductionOperationFixture : public ReductionOperationValidationFixture<TensorType, AccessorType, FunctionType, T> +{ +public: + template <typename...> + void setup(TensorShape shape, DataType data_type, unsigned int axis, ReductionOperation op) + { + ReductionOperationValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, axis, op, QuantizationInfo()); + } +}; } // namespace validation } // namespace test } // namespace arm_compute |