diff options
author | Sang-Hoon Park <sang-hoon.park@arm.com> | 2019-10-15 16:49:24 +0100 |
---|---|---|
committer | Sang-Hoon Park <sang-hoon.park@arm.com> | 2019-10-30 14:49:34 +0000 |
commit | 2697fd8fa42425f7bfdd60dd486d4c2132b06523 (patch) | |
tree | 098450f7f60211c7e5bfbd41eb1a7a10c1c0437f /tests/validation/fixtures | |
parent | df4cf57c7394265b27d051cb1cf0152c53659126 (diff) | |
download | ComputeLibrary-2697fd8fa42425f7bfdd60dd486d4c2132b06523.tar.gz |
COMPMID-2707: add keep_dims parameter to Reduction Operation
The added parameter is used to decide whether or not to keep
the target dimension of reduction operation. ArgMinMax operations
will always remove the reduced dimension. Following things
are updated to support the parameter.
- [CL/NEON] functions and reference kernel
- [CL/NEON] ArgMinMax function to use ReductionOperation function
- [CL/NEON] validation test suite for Reduction and ArgMinMax operations
to validate the added parameter
- ReductionOperationFixture is modified NOT to pre-populate output
tensor and now relies on underlying kernel/function.
- Adjust CL validation test suite for Reduction operation to remove
excessive test cases with axis values beyond input tensor's
dimension.
Change-Id: I3e24d276ed469a4201f323001708f0c525f11c4f
Signed-off-by: Sang-Hoon Park <sang-hoon.park@arm.com>
Reviewed-on: https://review.mlplatform.org/c/2167
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Diffstat (limited to 'tests/validation/fixtures')
-rw-r--r-- | tests/validation/fixtures/ArgMinMaxFixture.h | 4 | ||||
-rw-r--r-- | tests/validation/fixtures/ReductionOperationFixture.h | 34 |
2 files changed, 19 insertions, 19 deletions
diff --git a/tests/validation/fixtures/ArgMinMaxFixture.h b/tests/validation/fixtures/ArgMinMaxFixture.h index ed6b51abe5..f8fe4ff1ee 100644 --- a/tests/validation/fixtures/ArgMinMaxFixture.h +++ b/tests/validation/fixtures/ArgMinMaxFixture.h @@ -26,6 +26,7 @@ #include "arm_compute/core/TensorShape.h" #include "arm_compute/core/Types.h" +#include "arm_compute/core/utils/misc/ShapeCalculator.h" #include "arm_compute/runtime/Tensor.h" #include "tests/AssetsLibrary.h" #include "tests/Globals.h" @@ -121,8 +122,7 @@ protected: // Fill reference fill(src); - TensorShape output_shape = src_shape; - output_shape.set(axis, 1); + TensorShape output_shape = arm_compute::misc::shape_calculator::compute_reduced_shape(src_shape, axis, false); return reference::reduction_operation<T, uint32_t>(src, output_shape, axis, op); } diff --git a/tests/validation/fixtures/ReductionOperationFixture.h b/tests/validation/fixtures/ReductionOperationFixture.h index d01f41abf0..867c08ec3a 100644 --- a/tests/validation/fixtures/ReductionOperationFixture.h +++ b/tests/validation/fixtures/ReductionOperationFixture.h @@ -26,6 +26,7 @@ #include "arm_compute/core/TensorShape.h" #include "arm_compute/core/Types.h" +#include "arm_compute/core/utils/misc/ShapeCalculator.h" #include "arm_compute/runtime/Tensor.h" #include "tests/AssetsLibrary.h" #include "tests/Globals.h" @@ -45,11 +46,15 @@ class ReductionOperationValidationFixture : public framework::Fixture { public: template <typename...> - void setup(TensorShape shape, DataType data_type, unsigned int axis, ReductionOperation op, QuantizationInfo quantization_info) + void setup(TensorShape shape, DataType data_type, unsigned int axis, ReductionOperation op, QuantizationInfo quantization_info, bool keep_dims = false) { - const TensorShape output_shape = get_output_shape(shape, axis); - _target = compute_target(shape, output_shape, data_type, axis, op, quantization_info); - _reference = compute_reference(shape, output_shape, data_type, axis, op, quantization_info); + const bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MAX) || (op == ReductionOperation::ARG_IDX_MIN); + _keep_dims = keep_dims && !is_arg_min_max; + + const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_reduced_shape(shape, axis, _keep_dims); + + _target = compute_target(shape, data_type, axis, op, quantization_info); + _reference = compute_reference(shape, output_shape, data_type, axis, op, quantization_info); } protected: @@ -70,15 +75,15 @@ protected: } } - TensorType compute_target(const TensorShape &src_shape, const TensorShape &dst_shape, DataType data_type, unsigned int axis, ReductionOperation op, QuantizationInfo quantization_info) + TensorType compute_target(const TensorShape &src_shape, DataType data_type, unsigned int axis, ReductionOperation op, QuantizationInfo quantization_info) { // Create tensors TensorType src = create_tensor<TensorType>(src_shape, data_type, 1, quantization_info); - TensorType dst = create_tensor<TensorType>(dst_shape, data_type, 1, quantization_info); + TensorType dst; // Create and configure function FunctionType reduction_func; - reduction_func.configure(&src, &dst, axis, op); + reduction_func.configure(&src, &dst, axis, op, _keep_dims); ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS); @@ -114,12 +119,7 @@ protected: SimpleTensor<T> _reference{}; private: - TensorShape get_output_shape(TensorShape shape, unsigned int axis) - { - TensorShape output_shape(shape); - output_shape.set(axis, 1); - return output_shape; - } + bool _keep_dims{ false }; }; template <typename TensorType, typename AccessorType, typename FunctionType, typename T> @@ -127,9 +127,9 @@ class ReductionOperationQuantizedFixture : public ReductionOperationValidationFi { public: template <typename...> - void setup(TensorShape shape, DataType data_type, unsigned int axis, ReductionOperation op, QuantizationInfo quantization_info = QuantizationInfo()) + void setup(TensorShape shape, DataType data_type, unsigned int axis, ReductionOperation op, QuantizationInfo quantization_info = QuantizationInfo(), bool keep_dims = false) { - ReductionOperationValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, axis, op, quantization_info); + ReductionOperationValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, axis, op, quantization_info, keep_dims); } }; @@ -138,9 +138,9 @@ class ReductionOperationFixture : public ReductionOperationValidationFixture<Ten { public: template <typename...> - void setup(TensorShape shape, DataType data_type, unsigned int axis, ReductionOperation op) + void setup(TensorShape shape, DataType data_type, unsigned int axis, ReductionOperation op, bool keep_dims = false) { - ReductionOperationValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, axis, op, QuantizationInfo()); + ReductionOperationValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, axis, op, QuantizationInfo(), keep_dims); } }; } // namespace validation |