From 4cb0bd488f70a07b222e1ed7008d888642dfec6f Mon Sep 17 00:00:00 2001 From: Pablo Marquez Tello Date: Thu, 27 Jul 2023 18:02:37 +0100 Subject: Improved testing for ArgMinMax * ArgMinMax output was fixed to S32, this patch makes the changes required to allow other output types like U64/S64 * Made changes to the ArgMinMax fixture and tests to allow specifying output data type. * Made changes to the reference reduction_operation to allow specifying the output type * Added tests case to output S64 for the CL backend. * Added missing test cases in the neon backend. * Partially resolves MLCE-1089 Change-Id: I6f1cbc7093669d12c2a3aff6974cf19d83b2ecda Signed-off-by: Pablo Marquez Tello Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/10003 Reviewed-by: Viet-Hoa Do Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins Benchmark: Arm Jenkins --- tests/validation/fixtures/ArgMinMaxFixture.h | 46 ++++++++++++---------- tests/validation/fixtures/ReduceMeanFixture.h | 6 +-- .../fixtures/ReductionOperationFixture.h | 2 +- 3 files changed, 29 insertions(+), 25 deletions(-) (limited to 'tests/validation/fixtures') diff --git a/tests/validation/fixtures/ArgMinMaxFixture.h b/tests/validation/fixtures/ArgMinMaxFixture.h index 9a600b8645..7a823568a8 100644 --- a/tests/validation/fixtures/ArgMinMaxFixture.h +++ b/tests/validation/fixtures/ArgMinMaxFixture.h @@ -42,14 +42,14 @@ namespace test { namespace validation { -template +template class ArgMinMaxValidationBaseFixture : public framework::Fixture { public: - void setup(TensorShape shape, DataType data_type, int axis, ReductionOperation op, QuantizationInfo q_info) + void setup(TensorShape shape, DataType input_type, DataType output_type, int axis, ReductionOperation op, QuantizationInfo q_info) { - _target = compute_target(shape, data_type, axis, op, q_info); - _reference = compute_reference(shape, data_type, axis, op, q_info); + _target = compute_target(shape, input_type, output_type, axis, op, q_info); + _reference = compute_reference(shape, input_type, output_type, axis, op, q_info); } protected: @@ -97,11 +97,11 @@ protected: } } - TensorType compute_target(TensorShape &src_shape, DataType data_type, int axis, ReductionOperation op, QuantizationInfo q_info) + TensorType compute_target(TensorShape &src_shape, DataType input_type, DataType output_type, int axis, ReductionOperation op, QuantizationInfo q_info) { // Create tensors - TensorType src = create_tensor(src_shape, data_type, 1, q_info); - TensorType dst; + TensorType src = create_tensor(src_shape, input_type, 1, q_info); + TensorType dst = create_tensor(compute_output_shape(src_shape, axis), output_type, 1, q_info); // Create and configure function FunctionType arg_min_max_layer; @@ -126,39 +126,43 @@ protected: return dst; } - SimpleTensor compute_reference(TensorShape &src_shape, DataType data_type, int axis, ReductionOperation op, QuantizationInfo q_info) + TensorShape compute_output_shape(const TensorShape &src_shape, int axis) + { + return arm_compute::misc::shape_calculator::compute_reduced_shape(src_shape, axis, false); + } + + SimpleTensor compute_reference(TensorShape &src_shape, DataType input_type, DataType output_type, int axis, ReductionOperation op, QuantizationInfo q_info) { // Create reference - SimpleTensor src{ src_shape, data_type, 1, q_info }; + SimpleTensor src{ src_shape, input_type, 1, q_info }; // Fill reference fill(src); - TensorShape output_shape = arm_compute::misc::shape_calculator::compute_reduced_shape(src_shape, axis, false); - return reference::reduction_operation(src, output_shape, axis, op); + return reference::reduction_operation(src, compute_output_shape(src_shape, axis), axis, op, output_type); } - TensorType _target{}; - SimpleTensor _reference{}; + TensorType _target{}; + SimpleTensor _reference{}; }; -template -class ArgMinMaxValidationQuantizedFixture : public ArgMinMaxValidationBaseFixture +template +class ArgMinMaxValidationQuantizedFixture : public ArgMinMaxValidationBaseFixture { public: - void setup(const TensorShape &shape, DataType data_type, int axis, ReductionOperation op, QuantizationInfo quantization_info) + void setup(const TensorShape &shape, DataType input_type, DataType output_type, int axis, ReductionOperation op, QuantizationInfo quantization_info) { - ArgMinMaxValidationBaseFixture::setup(shape, data_type, axis, op, quantization_info); + ArgMinMaxValidationBaseFixture::setup(shape, input_type, output_type, axis, op, quantization_info); } }; -template -class ArgMinMaxValidationFixture : public ArgMinMaxValidationBaseFixture +template +class ArgMinMaxValidationFixture : public ArgMinMaxValidationBaseFixture { public: - void setup(const TensorShape &shape, DataType data_type, int axis, ReductionOperation op) + void setup(const TensorShape &shape, DataType input_type, DataType output_type, int axis, ReductionOperation op) { - ArgMinMaxValidationBaseFixture::setup(shape, data_type, axis, op, QuantizationInfo()); + ArgMinMaxValidationBaseFixture::setup(shape, input_type, output_type, axis, op, QuantizationInfo()); } }; } // namespace validation diff --git a/tests/validation/fixtures/ReduceMeanFixture.h b/tests/validation/fixtures/ReduceMeanFixture.h index 39fdea59fc..5363d6ba53 100644 --- a/tests/validation/fixtures/ReduceMeanFixture.h +++ b/tests/validation/fixtures/ReduceMeanFixture.h @@ -127,9 +127,9 @@ protected: #ifdef ARM_COMPUTE_OPENCL_ENABLED is_opencl = std::is_same::value; // Round down to zero on opencl to match kernel -#endif /* ARM_COMPUTE_OPENCL_ENABLED */ - out = reference::reduction_operation(i == 0 ? src : out, output_shape, axis[i], ReductionOperation::MEAN_SUM, quantization_info_output, is_opencl ? RoundingPolicy::TO_ZERO : RoundingPolicy::TO_NEAREST_UP); - +#endif /* ARM_COMPUTE_OPENCL_ENABLED */ + out = reference::reduction_operation(i == 0 ? src : out, output_shape, axis[i], ReductionOperation::MEAN_SUM, data_type, quantization_info_output, + is_opencl ? RoundingPolicy::TO_ZERO : RoundingPolicy::TO_NEAREST_UP); } if(!keep_dims) diff --git a/tests/validation/fixtures/ReductionOperationFixture.h b/tests/validation/fixtures/ReductionOperationFixture.h index 36e6309d6f..b44f299486 100644 --- a/tests/validation/fixtures/ReductionOperationFixture.h +++ b/tests/validation/fixtures/ReductionOperationFixture.h @@ -134,7 +134,7 @@ protected: // Fill reference fill(src); - return reference::reduction_operation(src, dst_shape, axis, op, quantization_info); + return reference::reduction_operation(src, dst_shape, axis, op, data_type, quantization_info); } TensorType _target{}; -- cgit v1.2.1