diff options
Diffstat (limited to 'tests/validation/reference')
-rw-r--r-- | tests/validation/reference/ReductionOperation.cpp | 30 |
1 files changed, 26 insertions, 4 deletions
diff --git a/tests/validation/reference/ReductionOperation.cpp b/tests/validation/reference/ReductionOperation.cpp index 8e79c3bfb0..fb7a6d6997 100644 --- a/tests/validation/reference/ReductionOperation.cpp +++ b/tests/validation/reference/ReductionOperation.cpp @@ -128,7 +128,7 @@ OT reduce_operation(const T *ptr, int reduce_elements, ReductionOperation op, in } // namespace template <typename T, typename OT> -SimpleTensor<OT> reduction_operation(const SimpleTensor<T> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op) +SimpleTensor<OT> compute_reduction_operation(const SimpleTensor<T> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op) { // Create reference const bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX); @@ -213,12 +213,34 @@ SimpleTensor<OT> reduction_operation(const SimpleTensor<T> &src, const TensorSha return dst; } +template <typename T, typename OT> +SimpleTensor<OT> reduction_operation(const SimpleTensor<T> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op) +{ + return compute_reduction_operation<T, OT>(src, dst_shape, axis, op); +} + +template <> +SimpleTensor<uint8_t> reduction_operation(const SimpleTensor<uint8_t> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op) +{ + if(src.data_type() == DataType::QASYMM8 && op != ReductionOperation::MEAN_SUM) + { + SimpleTensor<float> src_f = convert_from_asymmetric(src); + SimpleTensor<float> dst_f = reference::reduction_operation<float, float>(src_f, dst_shape, axis, op); + return convert_to_asymmetric(dst_f, src.quantization_info()); + } + else + { + return compute_reduction_operation<uint8_t, uint8_t>(src, dst_shape, axis, op); + } +} + +template SimpleTensor<float> reduction_operation(const SimpleTensor<float> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op); +template SimpleTensor<half> reduction_operation(const SimpleTensor<half> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op); + template SimpleTensor<uint32_t> reduction_operation(const SimpleTensor<float> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op); template SimpleTensor<uint32_t> reduction_operation(const SimpleTensor<half> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op); template SimpleTensor<uint32_t> reduction_operation(const SimpleTensor<uint8_t> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op); -template SimpleTensor<float> reduction_operation(const SimpleTensor<float> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op); -template SimpleTensor<half> reduction_operation(const SimpleTensor<half> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op); -template SimpleTensor<uint8_t> reduction_operation(const SimpleTensor<uint8_t> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op); + } // namespace reference } // namespace validation } // namespace test |