diff options
author | Manuel Bottini <manuel.bottini@arm.com> | 2020-08-07 16:49:15 +0100 |
---|---|---|
committer | Manuel Bottini <manuel.bottini@arm.com> | 2020-08-19 08:53:41 +0000 |
commit | c58f0ad7ac6d91f2789a78049d3cec7355113f9a (patch) | |
tree | 09124c0b141892e35c9293c3ebde06f3766812dd /tests/validation/reference/ReductionOperation.cpp | |
parent | 97c1a6751c4f9bf52f0a4421b94da80a3028ca78 (diff) | |
download | ComputeLibrary-c58f0ad7ac6d91f2789a78049d3cec7355113f9a.tar.gz |
COMPMID-3502: Add support of different quantization input/output for ReduceMean
Change-Id: If9a5c6ee3902a7381f4117e473adbddf006f3347
Signed-off-by: Manuel Bottini <manuel.bottini@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3731
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Sang-Hoon Park <sang-hoon.park@arm.com>
Diffstat (limited to 'tests/validation/reference/ReductionOperation.cpp')
-rw-r--r-- | tests/validation/reference/ReductionOperation.cpp | 36 |
1 files changed, 22 insertions, 14 deletions
diff --git a/tests/validation/reference/ReductionOperation.cpp b/tests/validation/reference/ReductionOperation.cpp index 5bdd4f7e95..ffb79f86c5 100644 --- a/tests/validation/reference/ReductionOperation.cpp +++ b/tests/validation/reference/ReductionOperation.cpp @@ -269,18 +269,19 @@ SimpleTensor<OT> compute_reduction_operation(const SimpleTensor<T> &src, const T } template <typename T, typename OT> -SimpleTensor<OT> reduction_operation(const SimpleTensor<T> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op) +SimpleTensor<OT> reduction_operation(const SimpleTensor<T> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op, QuantizationInfo quantization_info_output) { + ARM_COMPUTE_UNUSED(quantization_info_output); return compute_reduction_operation<T, OT>(src, dst_shape, axis, op); } template <> -SimpleTensor<uint8_t> reduction_operation(const SimpleTensor<uint8_t> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op) +SimpleTensor<uint8_t> reduction_operation(const SimpleTensor<uint8_t> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op, QuantizationInfo quantization_info_output) { if(src.data_type() == DataType::QASYMM8) { // If the operation is MEAN_SUM, we can directly use the uint8 implementation without taking into account scale and offset - if(op == ReductionOperation::MEAN_SUM) + if(op == ReductionOperation::MEAN_SUM && src.quantization_info() == quantization_info_output) { return compute_reduction_operation<uint8_t, uint8_t>(src, dst_shape, axis, op); } @@ -288,7 +289,7 @@ SimpleTensor<uint8_t> reduction_operation(const SimpleTensor<uint8_t> &src, cons { SimpleTensor<float> src_f = convert_from_asymmetric(src); SimpleTensor<float> dst_f = reference::reduction_operation<float, float>(src_f, dst_shape, axis, op); - return convert_to_asymmetric<uint8_t>(dst_f, src.quantization_info()); + return convert_to_asymmetric<uint8_t>(dst_f, quantization_info_output); } } else @@ -298,12 +299,12 @@ SimpleTensor<uint8_t> reduction_operation(const SimpleTensor<uint8_t> &src, cons } template <> -SimpleTensor<int8_t> reduction_operation(const SimpleTensor<int8_t> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op) +SimpleTensor<int8_t> reduction_operation(const SimpleTensor<int8_t> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op, QuantizationInfo quantization_info_output) { if(src.data_type() == DataType::QASYMM8_SIGNED) { // If the operation is MEAN_SUM, we can directly use the int8 implementation without taking into account scale and offset - if(op == ReductionOperation::MEAN_SUM) + if(op == ReductionOperation::MEAN_SUM && src.quantization_info() == quantization_info_output) { return compute_reduction_operation<int8_t, int8_t>(src, dst_shape, axis, op); } @@ -311,7 +312,7 @@ SimpleTensor<int8_t> reduction_operation(const SimpleTensor<int8_t> &src, const { SimpleTensor<float> src_f = convert_from_asymmetric(src); SimpleTensor<float> dst_f = reference::reduction_operation<float, float>(src_f, dst_shape, axis, op); - return convert_to_asymmetric<int8_t>(dst_f, src.quantization_info()); + return convert_to_asymmetric<int8_t>(dst_f, quantization_info_output); } } else @@ -320,14 +321,21 @@ SimpleTensor<int8_t> reduction_operation(const SimpleTensor<int8_t> &src, const } } -template SimpleTensor<float> reduction_operation(const SimpleTensor<float> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op); -template SimpleTensor<half> reduction_operation(const SimpleTensor<half> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op); +template SimpleTensor<float> reduction_operation(const SimpleTensor<float> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op, + QuantizationInfo quantization_info_output = QuantizationInfo()); +template SimpleTensor<half> reduction_operation(const SimpleTensor<half> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op, + QuantizationInfo quantization_info_output = QuantizationInfo()); -template SimpleTensor<int32_t> reduction_operation(const SimpleTensor<float> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op); -template SimpleTensor<int32_t> reduction_operation(const SimpleTensor<int32_t> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op); -template SimpleTensor<int32_t> reduction_operation(const SimpleTensor<half> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op); -template SimpleTensor<int32_t> reduction_operation(const SimpleTensor<uint8_t> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op); -template SimpleTensor<int32_t> reduction_operation(const SimpleTensor<int8_t> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op); +template SimpleTensor<int32_t> reduction_operation(const SimpleTensor<float> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op, + QuantizationInfo quantization_info_output = QuantizationInfo()); +template SimpleTensor<int32_t> reduction_operation(const SimpleTensor<int32_t> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op, + QuantizationInfo quantization_info_output = QuantizationInfo()); +template SimpleTensor<int32_t> reduction_operation(const SimpleTensor<half> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op, + QuantizationInfo quantization_info_output = QuantizationInfo()); +template SimpleTensor<int32_t> reduction_operation(const SimpleTensor<uint8_t> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op, + QuantizationInfo quantization_info_output = QuantizationInfo()); +template SimpleTensor<int32_t> reduction_operation(const SimpleTensor<int8_t> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op, + QuantizationInfo quantization_info_output = QuantizationInfo()); } // namespace reference } // namespace validation |