From c58f0ad7ac6d91f2789a78049d3cec7355113f9a Mon Sep 17 00:00:00 2001 From: Manuel Bottini Date: Fri, 7 Aug 2020 16:49:15 +0100 Subject: COMPMID-3502: Add support of different quantization input/output for ReduceMean Change-Id: If9a5c6ee3902a7381f4117e473adbddf006f3347 Signed-off-by: Manuel Bottini Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3731 Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins Reviewed-by: Sang-Hoon Park --- tests/validation/reference/ReductionOperation.cpp | 36 ++++++++++++++--------- 1 file changed, 22 insertions(+), 14 deletions(-) (limited to 'tests/validation/reference/ReductionOperation.cpp') diff --git a/tests/validation/reference/ReductionOperation.cpp b/tests/validation/reference/ReductionOperation.cpp index 5bdd4f7e95..ffb79f86c5 100644 --- a/tests/validation/reference/ReductionOperation.cpp +++ b/tests/validation/reference/ReductionOperation.cpp @@ -269,18 +269,19 @@ SimpleTensor compute_reduction_operation(const SimpleTensor &src, const T } template -SimpleTensor reduction_operation(const SimpleTensor &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op) +SimpleTensor reduction_operation(const SimpleTensor &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op, QuantizationInfo quantization_info_output) { + ARM_COMPUTE_UNUSED(quantization_info_output); return compute_reduction_operation(src, dst_shape, axis, op); } template <> -SimpleTensor reduction_operation(const SimpleTensor &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op) +SimpleTensor reduction_operation(const SimpleTensor &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op, QuantizationInfo quantization_info_output) { if(src.data_type() == DataType::QASYMM8) { // If the operation is MEAN_SUM, we can directly use the uint8 implementation without taking into account scale and offset - if(op == ReductionOperation::MEAN_SUM) + if(op == ReductionOperation::MEAN_SUM && src.quantization_info() == quantization_info_output) { return compute_reduction_operation(src, dst_shape, axis, op); } @@ -288,7 +289,7 @@ SimpleTensor reduction_operation(const SimpleTensor &src, cons { SimpleTensor src_f = convert_from_asymmetric(src); SimpleTensor dst_f = reference::reduction_operation(src_f, dst_shape, axis, op); - return convert_to_asymmetric(dst_f, src.quantization_info()); + return convert_to_asymmetric(dst_f, quantization_info_output); } } else @@ -298,12 +299,12 @@ SimpleTensor reduction_operation(const SimpleTensor &src, cons } template <> -SimpleTensor reduction_operation(const SimpleTensor &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op) +SimpleTensor reduction_operation(const SimpleTensor &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op, QuantizationInfo quantization_info_output) { if(src.data_type() == DataType::QASYMM8_SIGNED) { // If the operation is MEAN_SUM, we can directly use the int8 implementation without taking into account scale and offset - if(op == ReductionOperation::MEAN_SUM) + if(op == ReductionOperation::MEAN_SUM && src.quantization_info() == quantization_info_output) { return compute_reduction_operation(src, dst_shape, axis, op); } @@ -311,7 +312,7 @@ SimpleTensor reduction_operation(const SimpleTensor &src, const { SimpleTensor src_f = convert_from_asymmetric(src); SimpleTensor dst_f = reference::reduction_operation(src_f, dst_shape, axis, op); - return convert_to_asymmetric(dst_f, src.quantization_info()); + return convert_to_asymmetric(dst_f, quantization_info_output); } } else @@ -320,14 +321,21 @@ SimpleTensor reduction_operation(const SimpleTensor &src, const } } -template SimpleTensor reduction_operation(const SimpleTensor &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op); -template SimpleTensor reduction_operation(const SimpleTensor &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op); +template SimpleTensor reduction_operation(const SimpleTensor &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op, + QuantizationInfo quantization_info_output = QuantizationInfo()); +template SimpleTensor reduction_operation(const SimpleTensor &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op, + QuantizationInfo quantization_info_output = QuantizationInfo()); -template SimpleTensor reduction_operation(const SimpleTensor &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op); -template SimpleTensor reduction_operation(const SimpleTensor &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op); -template SimpleTensor reduction_operation(const SimpleTensor &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op); -template SimpleTensor reduction_operation(const SimpleTensor &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op); -template SimpleTensor reduction_operation(const SimpleTensor &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op); +template SimpleTensor reduction_operation(const SimpleTensor &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op, + QuantizationInfo quantization_info_output = QuantizationInfo()); +template SimpleTensor reduction_operation(const SimpleTensor &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op, + QuantizationInfo quantization_info_output = QuantizationInfo()); +template SimpleTensor reduction_operation(const SimpleTensor &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op, + QuantizationInfo quantization_info_output = QuantizationInfo()); +template SimpleTensor reduction_operation(const SimpleTensor &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op, + QuantizationInfo quantization_info_output = QuantizationInfo()); +template SimpleTensor reduction_operation(const SimpleTensor &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op, + QuantizationInfo quantization_info_output = QuantizationInfo()); } // namespace reference } // namespace validation -- cgit v1.2.1