diff options
author | Michalis Spyrou <michalis.spyrou@arm.com> | 2019-01-03 11:10:25 +0000 |
---|---|---|
committer | Michalis Spyrou <michalis.spyrou@arm.com> | 2019-01-10 16:24:26 +0000 |
commit | aea14c63e2efeda9d5f7492099389d439c65204f (patch) | |
tree | 176a6181bbf00e4df078d5da0a17dd44f248958e /tests/validation/fixtures/ArgMinMaxFixture.h | |
parent | c10bc0b5db5169a6ccea02a1aaefe34f082709e5 (diff) | |
download | ComputeLibrary-aea14c63e2efeda9d5f7492099389d439c65204f.tar.gz |
COMPMID-1764 NEON: Implement ArgMax/ArgMin
Change-Id: Ibe23aa90b36ffd8553d1d1c35fada5d300fab829
Reviewed-on: https://review.mlplatform.org/475
Reviewed-by: Isabella Gottardi <isabella.gottardi@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Giuseppe Rossini <giuseppe.rossini@arm.com>
Diffstat (limited to 'tests/validation/fixtures/ArgMinMaxFixture.h')
-rw-r--r-- | tests/validation/fixtures/ArgMinMaxFixture.h | 60 |
1 files changed, 46 insertions, 14 deletions
diff --git a/tests/validation/fixtures/ArgMinMaxFixture.h b/tests/validation/fixtures/ArgMinMaxFixture.h index 5f5f85c104..e263b25bf2 100644 --- a/tests/validation/fixtures/ArgMinMaxFixture.h +++ b/tests/validation/fixtures/ArgMinMaxFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -42,28 +42,38 @@ namespace test namespace validation { template <typename TensorType, typename AccessorType, typename FunctionType, typename T> -class ArgMinMaxValidationFixture : public framework::Fixture +class ArgMinMaxValidationBaseFixture : public framework::Fixture { public: template <typename...> - void setup(TensorShape shape, DataType data_type, int axis, ReductionOperation op) + void setup(TensorShape shape, DataType data_type, int axis, ReductionOperation op, QuantizationInfo q_info) { - _target = compute_target(shape, data_type, axis, op); - _reference = compute_reference(shape, data_type, axis, op); + _target = compute_target(shape, data_type, axis, op, q_info); + _reference = compute_reference(shape, data_type, axis, op, q_info); } protected: template <typename U> void fill(U &&tensor) { - std::uniform_real_distribution<> distribution(-1.0f, 1.0f); - library->fill(tensor, distribution, 0); + if(!is_data_type_quantized(tensor.data_type())) + { + std::uniform_real_distribution<> distribution(-1.0f, 1.0f); + library->fill(tensor, distribution, 0); + } + else + { + std::pair<int, int> bounds = get_quantized_bounds(tensor.quantization_info(), -1.0f, 1.0f); + std::uniform_int_distribution<uint8_t> distribution(bounds.first, bounds.second); + + library->fill(tensor, distribution, 0); + } } - TensorType compute_target(TensorShape &src_shape, DataType data_type, int axis, ReductionOperation op) + TensorType compute_target(TensorShape &src_shape, DataType data_type, int axis, ReductionOperation op, QuantizationInfo q_info) { // Create tensors - TensorType src = create_tensor<TensorType>(src_shape, data_type, 1); + TensorType src = create_tensor<TensorType>(src_shape, data_type, 1, q_info); TensorType dst; // Create and configure function @@ -89,21 +99,43 @@ protected: return dst; } - SimpleTensor<T> compute_reference(TensorShape &src_shape, DataType data_type, int axis, ReductionOperation op) + SimpleTensor<uint32_t> compute_reference(TensorShape &src_shape, DataType data_type, int axis, ReductionOperation op, QuantizationInfo q_info) { // Create reference - SimpleTensor<T> src{ src_shape, data_type, 1 }; + SimpleTensor<T> src{ src_shape, data_type, 1, q_info }; // Fill reference fill(src); TensorShape output_shape = src_shape; output_shape.set(axis, 1); - return reference::reduction_operation<T>(src, output_shape, axis, op); + return reference::reduction_operation<T, uint32_t>(src, output_shape, axis, op); } - TensorType _target{}; - SimpleTensor<T> _reference{}; + TensorType _target{}; + SimpleTensor<uint32_t> _reference{}; +}; + +template <typename TensorType, typename AccessorType, typename FunctionType, typename T> +class ArgMinMaxValidationQuantizedFixture : public ArgMinMaxValidationBaseFixture<TensorType, AccessorType, FunctionType, T> +{ +public: + template <typename...> + void setup(const TensorShape &shape, DataType data_type, int axis, ReductionOperation op, QuantizationInfo quantization_info) + { + ArgMinMaxValidationBaseFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, axis, op, quantization_info); + } +}; + +template <typename TensorType, typename AccessorType, typename FunctionType, typename T> +class ArgMinMaxValidationFixture : public ArgMinMaxValidationBaseFixture<TensorType, AccessorType, FunctionType, T> +{ +public: + template <typename...> + void setup(const TensorShape &shape, DataType data_type, int axis, ReductionOperation op) + { + ArgMinMaxValidationBaseFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, axis, op, QuantizationInfo()); + } }; } // namespace validation } // namespace test |