diff options
Diffstat (limited to 'tests/validation/fixtures/ArgMinMaxFixture.h')
-rw-r--r-- | tests/validation/fixtures/ArgMinMaxFixture.h | 63 |
1 files changed, 32 insertions, 31 deletions
diff --git a/tests/validation/fixtures/ArgMinMaxFixture.h b/tests/validation/fixtures/ArgMinMaxFixture.h index f140c9846b..7a823568a8 100644 --- a/tests/validation/fixtures/ArgMinMaxFixture.h +++ b/tests/validation/fixtures/ArgMinMaxFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2021 Arm Limited. + * Copyright (c) 2018-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -42,15 +42,14 @@ namespace test { namespace validation { -template <typename TensorType, typename AccessorType, typename FunctionType, typename T> +template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2> class ArgMinMaxValidationBaseFixture : public framework::Fixture { public: - template <typename...> - void setup(TensorShape shape, DataType data_type, int axis, ReductionOperation op, QuantizationInfo q_info) + void setup(TensorShape shape, DataType input_type, DataType output_type, int axis, ReductionOperation op, QuantizationInfo q_info) { - _target = compute_target(shape, data_type, axis, op, q_info); - _reference = compute_reference(shape, data_type, axis, op, q_info); + _target = compute_target(shape, input_type, output_type, axis, op, q_info); + _reference = compute_reference(shape, input_type, output_type, axis, op, q_info); } protected: @@ -80,7 +79,7 @@ protected: case DataType::QASYMM8: { std::pair<int, int> bounds = get_quantized_bounds(tensor.quantization_info(), -1.0f, 1.0f); - std::uniform_int_distribution<uint8_t> distribution(bounds.first, bounds.second); + std::uniform_int_distribution<uint32_t> distribution(bounds.first, bounds.second); library->fill(tensor, distribution, 0); break; @@ -88,7 +87,7 @@ protected: case DataType::QASYMM8_SIGNED: { std::pair<int, int> bounds = get_quantized_qasymm8_signed_bounds(tensor.quantization_info(), -1.0f, 1.0f); - std::uniform_int_distribution<int8_t> distribution(bounds.first, bounds.second); + std::uniform_int_distribution<int32_t> distribution(bounds.first, bounds.second); library->fill(tensor, distribution, 0); break; @@ -98,25 +97,25 @@ protected: } } - TensorType compute_target(TensorShape &src_shape, DataType data_type, int axis, ReductionOperation op, QuantizationInfo q_info) + TensorType compute_target(TensorShape &src_shape, DataType input_type, DataType output_type, int axis, ReductionOperation op, QuantizationInfo q_info) { // Create tensors - TensorType src = create_tensor<TensorType>(src_shape, data_type, 1, q_info); - TensorType dst; + TensorType src = create_tensor<TensorType>(src_shape, input_type, 1, q_info); + TensorType dst = create_tensor<TensorType>(compute_output_shape(src_shape, axis), output_type, 1, q_info); // Create and configure function FunctionType arg_min_max_layer; arg_min_max_layer.configure(&src, axis, &dst, op); - ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_ASSERT(src.info()->is_resizable()); + ARM_COMPUTE_ASSERT(dst.info()->is_resizable()); // Allocate tensors src.allocator()->allocate(); dst.allocator()->allocate(); - ARM_COMPUTE_EXPECT(!src.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(!dst.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_ASSERT(!src.info()->is_resizable()); + ARM_COMPUTE_ASSERT(!dst.info()->is_resizable()); // Fill tensors fill(AccessorType(src)); @@ -127,41 +126,43 @@ protected: return dst; } - SimpleTensor<int32_t> compute_reference(TensorShape &src_shape, DataType data_type, int axis, ReductionOperation op, QuantizationInfo q_info) + TensorShape compute_output_shape(const TensorShape &src_shape, int axis) + { + return arm_compute::misc::shape_calculator::compute_reduced_shape(src_shape, axis, false); + } + + SimpleTensor<T2> compute_reference(TensorShape &src_shape, DataType input_type, DataType output_type, int axis, ReductionOperation op, QuantizationInfo q_info) { // Create reference - SimpleTensor<T> src{ src_shape, data_type, 1, q_info }; + SimpleTensor<T1> src{ src_shape, input_type, 1, q_info }; // Fill reference fill(src); - TensorShape output_shape = arm_compute::misc::shape_calculator::compute_reduced_shape(src_shape, axis, false); - return reference::reduction_operation<T, int32_t>(src, output_shape, axis, op); + return reference::reduction_operation<T1, T2>(src, compute_output_shape(src_shape, axis), axis, op, output_type); } - TensorType _target{}; - SimpleTensor<int32_t> _reference{}; + TensorType _target{}; + SimpleTensor<T2> _reference{}; }; -template <typename TensorType, typename AccessorType, typename FunctionType, typename T> -class ArgMinMaxValidationQuantizedFixture : public ArgMinMaxValidationBaseFixture<TensorType, AccessorType, FunctionType, T> +template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2> +class ArgMinMaxValidationQuantizedFixture : public ArgMinMaxValidationBaseFixture<TensorType, AccessorType, FunctionType, T1, T2> { public: - template <typename...> - void setup(const TensorShape &shape, DataType data_type, int axis, ReductionOperation op, QuantizationInfo quantization_info) + void setup(const TensorShape &shape, DataType input_type, DataType output_type, int axis, ReductionOperation op, QuantizationInfo quantization_info) { - ArgMinMaxValidationBaseFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, axis, op, quantization_info); + ArgMinMaxValidationBaseFixture<TensorType, AccessorType, FunctionType, T1, T2>::setup(shape, input_type, output_type, axis, op, quantization_info); } }; -template <typename TensorType, typename AccessorType, typename FunctionType, typename T> -class ArgMinMaxValidationFixture : public ArgMinMaxValidationBaseFixture<TensorType, AccessorType, FunctionType, T> +template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2> +class ArgMinMaxValidationFixture : public ArgMinMaxValidationBaseFixture<TensorType, AccessorType, FunctionType, T1, T2> { public: - template <typename...> - void setup(const TensorShape &shape, DataType data_type, int axis, ReductionOperation op) + void setup(const TensorShape &shape, DataType input_type, DataType output_type, int axis, ReductionOperation op) { - ArgMinMaxValidationBaseFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, axis, op, QuantizationInfo()); + ArgMinMaxValidationBaseFixture<TensorType, AccessorType, FunctionType, T1, T2>::setup(shape, input_type, output_type, axis, op, QuantizationInfo()); } }; } // namespace validation |