From bcf8a968da4b26926df8bb770df16d82146bcb54 Mon Sep 17 00:00:00 2001 From: Michalis Spyrou Date: Fri, 12 Oct 2018 10:51:31 +0100 Subject: COMPMID-1580 Implement ReduceMean in NEON Change-Id: Id974efad304c2513b8824a6561ad45ee60b9e7fb Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/153763 Reviewed-by: Giuseppe Rossini Reviewed-by: Isabella Gottardi Tested-by: bsgcomp --- .../fixtures/ReductionOperationFixture.h | 52 +++++++++++++++++----- 1 file changed, 42 insertions(+), 10 deletions(-) (limited to 'tests/validation/fixtures/ReductionOperationFixture.h') diff --git a/tests/validation/fixtures/ReductionOperationFixture.h b/tests/validation/fixtures/ReductionOperationFixture.h index 0dee7eb707..9079b47cbb 100644 --- a/tests/validation/fixtures/ReductionOperationFixture.h +++ b/tests/validation/fixtures/ReductionOperationFixture.h @@ -45,26 +45,36 @@ class ReductionOperationValidationFixture : public framework::Fixture { public: template - void setup(TensorShape shape, DataType data_type, unsigned int axis, ReductionOperation op) + void setup(TensorShape shape, DataType data_type, unsigned int axis, ReductionOperation op, QuantizationInfo quantization_info) { const TensorShape output_shape = get_output_shape(shape, axis); - _target = compute_target(shape, output_shape, data_type, axis, op); - _reference = compute_reference(shape, output_shape, data_type, axis, op); + _target = compute_target(shape, output_shape, data_type, axis, op, quantization_info); + _reference = compute_reference(shape, output_shape, data_type, axis, op, quantization_info); } protected: template void fill(U &&tensor) { - std::uniform_real_distribution<> distribution(-1.0f, 1.0f); - library->fill(tensor, distribution, 0); + if(!is_data_type_quantized(tensor.data_type())) + { + std::uniform_real_distribution<> distribution(-1.0f, 1.0f); + library->fill(tensor, distribution, 0); + } + else + { + std::pair bounds = get_quantized_bounds(tensor.quantization_info(), -1.0f, 1.0f); + std::uniform_int_distribution distribution(bounds.first, bounds.second); + + library->fill(tensor, distribution, 0); + } } - TensorType compute_target(const TensorShape &src_shape, const TensorShape &dst_shape, DataType data_type, unsigned int axis, ReductionOperation op) + TensorType compute_target(const TensorShape &src_shape, const TensorShape &dst_shape, DataType data_type, unsigned int axis, ReductionOperation op, QuantizationInfo quantization_info) { // Create tensors - TensorType src = create_tensor(src_shape, data_type); - TensorType dst = create_tensor(dst_shape, data_type); + TensorType src = create_tensor(src_shape, data_type, 1, quantization_info); + TensorType dst = create_tensor(dst_shape, data_type, 1, quantization_info); // Create and configure function FunctionType reduction_func; @@ -89,10 +99,10 @@ protected: return dst; } - SimpleTensor compute_reference(const TensorShape &src_shape, const TensorShape &dst_shape, DataType data_type, unsigned int axis, ReductionOperation op) + SimpleTensor compute_reference(const TensorShape &src_shape, const TensorShape &dst_shape, DataType data_type, unsigned int axis, ReductionOperation op, QuantizationInfo quantization_info) { // Create reference - SimpleTensor src{ src_shape, data_type }; + SimpleTensor src{ src_shape, data_type, 1, quantization_info }; // Fill reference fill(src); @@ -111,6 +121,28 @@ private: return output_shape; } }; + +template +class ReductionOperationQuantizedFixture : public ReductionOperationValidationFixture +{ +public: + template + void setup(TensorShape shape, DataType data_type, unsigned int axis, ReductionOperation op, QuantizationInfo quantization_info = QuantizationInfo()) + { + ReductionOperationValidationFixture::setup(shape, data_type, axis, op, quantization_info); + } +}; + +template +class ReductionOperationFixture : public ReductionOperationValidationFixture +{ +public: + template + void setup(TensorShape shape, DataType data_type, unsigned int axis, ReductionOperation op) + { + ReductionOperationValidationFixture::setup(shape, data_type, axis, op, QuantizationInfo()); + } +}; } // namespace validation } // namespace test } // namespace arm_compute -- cgit v1.2.1