diff options
Diffstat (limited to 'tests/validation/reference')
-rw-r--r-- | tests/validation/reference/ReductionOperation.cpp | 36 | ||||
-rw-r--r-- | tests/validation/reference/ReductionOperation.h | 5 |
2 files changed, 25 insertions, 16 deletions
diff --git a/tests/validation/reference/ReductionOperation.cpp b/tests/validation/reference/ReductionOperation.cpp index 5bdd4f7e95..ffb79f86c5 100644 --- a/tests/validation/reference/ReductionOperation.cpp +++ b/tests/validation/reference/ReductionOperation.cpp @@ -269,18 +269,19 @@ SimpleTensor<OT> compute_reduction_operation(const SimpleTensor<T> &src, const T } template <typename T, typename OT> -SimpleTensor<OT> reduction_operation(const SimpleTensor<T> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op) +SimpleTensor<OT> reduction_operation(const SimpleTensor<T> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op, QuantizationInfo quantization_info_output) { + ARM_COMPUTE_UNUSED(quantization_info_output); return compute_reduction_operation<T, OT>(src, dst_shape, axis, op); } template <> -SimpleTensor<uint8_t> reduction_operation(const SimpleTensor<uint8_t> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op) +SimpleTensor<uint8_t> reduction_operation(const SimpleTensor<uint8_t> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op, QuantizationInfo quantization_info_output) { if(src.data_type() == DataType::QASYMM8) { // If the operation is MEAN_SUM, we can directly use the uint8 implementation without taking into account scale and offset - if(op == ReductionOperation::MEAN_SUM) + if(op == ReductionOperation::MEAN_SUM && src.quantization_info() == quantization_info_output) { return compute_reduction_operation<uint8_t, uint8_t>(src, dst_shape, axis, op); } @@ -288,7 +289,7 @@ SimpleTensor<uint8_t> reduction_operation(const SimpleTensor<uint8_t> &src, cons { SimpleTensor<float> src_f = convert_from_asymmetric(src); SimpleTensor<float> dst_f = reference::reduction_operation<float, float>(src_f, dst_shape, axis, op); - return convert_to_asymmetric<uint8_t>(dst_f, src.quantization_info()); + return convert_to_asymmetric<uint8_t>(dst_f, quantization_info_output); } } else @@ -298,12 +299,12 @@ SimpleTensor<uint8_t> reduction_operation(const SimpleTensor<uint8_t> &src, cons } template <> -SimpleTensor<int8_t> reduction_operation(const SimpleTensor<int8_t> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op) +SimpleTensor<int8_t> reduction_operation(const SimpleTensor<int8_t> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op, QuantizationInfo quantization_info_output) { if(src.data_type() == DataType::QASYMM8_SIGNED) { // If the operation is MEAN_SUM, we can directly use the int8 implementation without taking into account scale and offset - if(op == ReductionOperation::MEAN_SUM) + if(op == ReductionOperation::MEAN_SUM && src.quantization_info() == quantization_info_output) { return compute_reduction_operation<int8_t, int8_t>(src, dst_shape, axis, op); } @@ -311,7 +312,7 @@ SimpleTensor<int8_t> reduction_operation(const SimpleTensor<int8_t> &src, const { SimpleTensor<float> src_f = convert_from_asymmetric(src); SimpleTensor<float> dst_f = reference::reduction_operation<float, float>(src_f, dst_shape, axis, op); - return convert_to_asymmetric<int8_t>(dst_f, src.quantization_info()); + return convert_to_asymmetric<int8_t>(dst_f, quantization_info_output); } } else @@ -320,14 +321,21 @@ SimpleTensor<int8_t> reduction_operation(const SimpleTensor<int8_t> &src, const } } -template SimpleTensor<float> reduction_operation(const SimpleTensor<float> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op); -template SimpleTensor<half> reduction_operation(const SimpleTensor<half> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op); +template SimpleTensor<float> reduction_operation(const SimpleTensor<float> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op, + QuantizationInfo quantization_info_output = QuantizationInfo()); +template SimpleTensor<half> reduction_operation(const SimpleTensor<half> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op, + QuantizationInfo quantization_info_output = QuantizationInfo()); -template SimpleTensor<int32_t> reduction_operation(const SimpleTensor<float> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op); -template SimpleTensor<int32_t> reduction_operation(const SimpleTensor<int32_t> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op); -template SimpleTensor<int32_t> reduction_operation(const SimpleTensor<half> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op); -template SimpleTensor<int32_t> reduction_operation(const SimpleTensor<uint8_t> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op); -template SimpleTensor<int32_t> reduction_operation(const SimpleTensor<int8_t> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op); +template SimpleTensor<int32_t> reduction_operation(const SimpleTensor<float> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op, + QuantizationInfo quantization_info_output = QuantizationInfo()); +template SimpleTensor<int32_t> reduction_operation(const SimpleTensor<int32_t> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op, + QuantizationInfo quantization_info_output = QuantizationInfo()); +template SimpleTensor<int32_t> reduction_operation(const SimpleTensor<half> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op, + QuantizationInfo quantization_info_output = QuantizationInfo()); +template SimpleTensor<int32_t> reduction_operation(const SimpleTensor<uint8_t> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op, + QuantizationInfo quantization_info_output = QuantizationInfo()); +template SimpleTensor<int32_t> reduction_operation(const SimpleTensor<int8_t> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op, + QuantizationInfo quantization_info_output = QuantizationInfo()); } // namespace reference } // namespace validation diff --git a/tests/validation/reference/ReductionOperation.h b/tests/validation/reference/ReductionOperation.h index 56d37e4f4d..9c9e721b29 100644 --- a/tests/validation/reference/ReductionOperation.h +++ b/tests/validation/reference/ReductionOperation.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 Arm Limited. + * Copyright (c) 2017-2020 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -36,7 +36,8 @@ namespace validation namespace reference { template <typename T, typename OT> -SimpleTensor<OT> reduction_operation(const SimpleTensor<T> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op); +SimpleTensor<OT> reduction_operation(const SimpleTensor<T> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op, + QuantizationInfo quantization_info_output = QuantizationInfo()); } // namespace reference } // namespace validation } // namespace test |