diff options
Diffstat (limited to 'tests/validation/reference/ReductionOperation.cpp')
-rw-r--r-- | tests/validation/reference/ReductionOperation.cpp | 59 |
1 files changed, 16 insertions, 43 deletions
diff --git a/tests/validation/reference/ReductionOperation.cpp b/tests/validation/reference/ReductionOperation.cpp index 37a9be86c0..fc12e31d75 100644 --- a/tests/validation/reference/ReductionOperation.cpp +++ b/tests/validation/reference/ReductionOperation.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -49,20 +49,20 @@ OT reduce_operation(const T *ptr, int reduce_elements, ReductionOperation op, in uint32_t int_res = 0; for(int i = 0; i < reduce_elements; ++i) { - auto elem = static_cast<uint32_t>(*(ptr + stride * i)); + auto elem = *(ptr + stride * i); switch(op) { case ReductionOperation::ARG_IDX_MIN: - if(static_cast<uint32_t>(*(ptr + stride * static_cast<uint32_t>(res))) > elem) + if(*(ptr + stride * static_cast<uint32_t>(int_res)) > elem) { - res = static_cast<uint32_t>(i); + int_res = static_cast<uint32_t>(i); } break; case ReductionOperation::ARG_IDX_MAX: - if(static_cast<uint32_t>(*(ptr + stride * static_cast<uint32_t>(res))) < elem) + if(*(ptr + stride * static_cast<uint32_t>(int_res)) < elem) { - res = static_cast<uint32_t>(i); + int_res = static_cast<uint32_t>(i); } break; case ReductionOperation::SUM_SQUARE: @@ -122,13 +122,13 @@ OT reduce_operation(const T *ptr, int reduce_elements, ReductionOperation op, in } } // namespace -template <typename T> -SimpleTensor<T> reduction_operation(const SimpleTensor<T> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op) +template <typename T, typename OT> +SimpleTensor<OT> reduction_operation(const SimpleTensor<T> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op) { // Create reference const bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX); DataType output_data_type = is_arg_min_max ? DataType::U32 : src.data_type(); - SimpleTensor<T> dst{ dst_shape, output_data_type, 1, src.quantization_info() }; + SimpleTensor<OT> dst{ dst_shape, output_data_type, 1, src.quantization_info() }; const unsigned int src_width = src.shape().x(); const unsigned int src_height = src.shape().y(); const unsigned int src_depth = src.shape().z(); @@ -143,14 +143,7 @@ SimpleTensor<T> reduction_operation(const SimpleTensor<T> &src, const TensorShap for(unsigned int du = 0; du < upper_dims; ++du) { const T *src_row_ptr = src.data() + du * reduce_elems; - if(is_arg_min_max) - { - dst[du] = reduce_operation<T, uint32_t>(src_row_ptr, reduce_elems, op, 1); - } - else - { - dst[du] = reduce_operation<T, T>(src_row_ptr, reduce_elems, op, 1); - } + dst[du] = reduce_operation<T, OT>(src_row_ptr, reduce_elems, op, 1); } } break; @@ -164,15 +157,7 @@ SimpleTensor<T> reduction_operation(const SimpleTensor<T> &src, const TensorShap const int in_offset = du * src_height * src_width + x; const int out_offset = du * src_width + x; const T *src_row_ptr = src.data() + in_offset; - - if(is_arg_min_max) - { - dst[out_offset] = reduce_operation<T, uint32_t>(src_row_ptr, reduce_elems, op, src_width); - } - else - { - dst[out_offset] = reduce_operation<T, T>(src_row_ptr, reduce_elems, op, src_width); - } + dst[out_offset] = reduce_operation<T, OT>(src_row_ptr, reduce_elems, op, src_width); } } } @@ -189,15 +174,7 @@ SimpleTensor<T> reduction_operation(const SimpleTensor<T> &src, const TensorShap const int in_offset = du * src_depth * src_height * src_width + y * src_width + x; const int out_offset = du * src_width * src_height + y * src_width + x; const T *src_row_ptr = src.data() + in_offset; - - if(is_arg_min_max) - { - dst[out_offset] = reduce_operation<T, uint32_t>(src_row_ptr, reduce_elems, op, src_height * src_width); - } - else - { - dst[out_offset] = reduce_operation<T, T>(src_row_ptr, reduce_elems, op, src_height * src_width); - } + dst[out_offset] = reduce_operation<T, OT>(src_row_ptr, reduce_elems, op, src_height * src_width); } } } @@ -217,14 +194,7 @@ SimpleTensor<T> reduction_operation(const SimpleTensor<T> &src, const TensorShap const int in_offset = du * src_batch * src_depth * src_height * src_width + z * src_width * src_height + y * src_width + x; const int out_offset = du * src_depth * src_height * src_width + z * src_width * src_height + y * src_width + x; const T *src_row_ptr = src.data() + in_offset; - if(is_arg_min_max) - { - dst[out_offset] = reduce_operation<T, uint32_t>(src_row_ptr, reduce_elems, op, src_width * src_height * src_depth); - } - else - { - dst[out_offset] = reduce_operation<T, T>(src_row_ptr, reduce_elems, op, src_width * src_height * src_depth); - } + dst[out_offset] = reduce_operation<T, OT>(src_row_ptr, reduce_elems, op, src_width * src_height * src_depth); } } } @@ -238,6 +208,9 @@ SimpleTensor<T> reduction_operation(const SimpleTensor<T> &src, const TensorShap return dst; } +template SimpleTensor<uint32_t> reduction_operation(const SimpleTensor<float> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op); +template SimpleTensor<uint32_t> reduction_operation(const SimpleTensor<half> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op); +template SimpleTensor<uint32_t> reduction_operation(const SimpleTensor<uint8_t> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op); template SimpleTensor<float> reduction_operation(const SimpleTensor<float> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op); template SimpleTensor<half> reduction_operation(const SimpleTensor<half> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op); template SimpleTensor<uint8_t> reduction_operation(const SimpleTensor<uint8_t> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op); |