diff options
-rw-r--r-- | tests/validation/reference/ReductionOperation.cpp | 69 |
1 files changed, 41 insertions, 28 deletions
diff --git a/tests/validation/reference/ReductionOperation.cpp b/tests/validation/reference/ReductionOperation.cpp index 571b991b92..fe128cc6ac 100644 --- a/tests/validation/reference/ReductionOperation.cpp +++ b/tests/validation/reference/ReductionOperation.cpp @@ -71,18 +71,6 @@ OT reduce_operation(const T *ptr, int reduce_elements, ReductionOperation op, in switch(op) { - case ReductionOperation::ARG_IDX_MIN: - if(*(ptr + stride * static_cast<uint32_t>(int_res)) > elem) - { - int_res = static_cast<uint32_t>(i); - } - break; - case ReductionOperation::ARG_IDX_MAX: - if(*(ptr + stride * static_cast<uint32_t>(int_res)) < elem) - { - int_res = static_cast<uint32_t>(i); - } - break; case ReductionOperation::MIN: if(static_cast<T>(int_res) > elem) { @@ -122,18 +110,6 @@ OT reduce_operation(const T *ptr, int reduce_elements, ReductionOperation op, in auto elem = *(ptr + stride * i); switch(op) { - case ReductionOperation::ARG_IDX_MIN: - if(*(ptr + stride * static_cast<uint32_t>(res)) > elem) - { - res = static_cast<uint32_t>(i); - } - break; - case ReductionOperation::ARG_IDX_MAX: - if(*(ptr + stride * static_cast<uint32_t>(res)) < elem) - { - res = static_cast<uint32_t>(i); - } - break; case ReductionOperation::MIN: if(res > elem) { @@ -167,6 +143,35 @@ OT reduce_operation(const T *ptr, int reduce_elements, ReductionOperation op, in } return res; } + +template <typename T, typename OT> +OT reduce_operation_arg_min_max(const T *ptr, int reduce_elements, ReductionOperation op, int stride) +{ + uint32_t res = 0; + for(int i = 0; i < reduce_elements; ++i) + { + auto elem = *(ptr + stride * i); + switch(op) + { + case ReductionOperation::ARG_IDX_MIN: + if(*(ptr + stride * res) > elem) + { + res = static_cast<uint32_t>(i); + } + break; + case ReductionOperation::ARG_IDX_MAX: + if(*(ptr + stride * res) < elem) + { + res = static_cast<uint32_t>(i); + } + break; + default: + ARM_COMPUTE_ERROR("Operation not supported"); + } + } + return static_cast<OT>(res); +} + } // namespace template <typename T, typename OT> @@ -190,7 +195,9 @@ SimpleTensor<OT> compute_reduction_operation(const SimpleTensor<T> &src, const T for(unsigned int du = 0; du < upper_dims; ++du) { const T *src_row_ptr = src.data() + du * reduce_elems; - dst[du] = reduce_operation<T, OT>(src_row_ptr, reduce_elems, op, 1); + dst[du] = is_arg_min_max ? + reduce_operation_arg_min_max<T, OT>(src_row_ptr, reduce_elems, op, 1) : + reduce_operation<T, OT>(src_row_ptr, reduce_elems, op, 1); } } break; @@ -204,7 +211,9 @@ SimpleTensor<OT> compute_reduction_operation(const SimpleTensor<T> &src, const T const int in_offset = du * src_height * src_width + x; const int out_offset = du * src_width + x; const T *src_row_ptr = src.data() + in_offset; - dst[out_offset] = reduce_operation<T, OT>(src_row_ptr, reduce_elems, op, src_width); + dst[out_offset] = is_arg_min_max ? + reduce_operation_arg_min_max<T, OT>(src_row_ptr, reduce_elems, op, src_width) : + reduce_operation<T, OT>(src_row_ptr, reduce_elems, op, src_width); } } } @@ -221,7 +230,9 @@ SimpleTensor<OT> compute_reduction_operation(const SimpleTensor<T> &src, const T const int in_offset = du * src_depth * src_height * src_width + y * src_width + x; const int out_offset = du * src_width * src_height + y * src_width + x; const T *src_row_ptr = src.data() + in_offset; - dst[out_offset] = reduce_operation<T, OT>(src_row_ptr, reduce_elems, op, src_height * src_width); + dst[out_offset] = is_arg_min_max ? + reduce_operation_arg_min_max<T, OT>(src_row_ptr, reduce_elems, op, src_width * src_height) : + reduce_operation<T, OT>(src_row_ptr, reduce_elems, op, src_width * src_height); } } } @@ -241,7 +252,9 @@ SimpleTensor<OT> compute_reduction_operation(const SimpleTensor<T> &src, const T const int in_offset = du * src_batch * src_depth * src_height * src_width + z * src_width * src_height + y * src_width + x; const int out_offset = du * src_depth * src_height * src_width + z * src_width * src_height + y * src_width + x; const T *src_row_ptr = src.data() + in_offset; - dst[out_offset] = reduce_operation<T, OT>(src_row_ptr, reduce_elems, op, src_width * src_height * src_depth); + dst[out_offset] = is_arg_min_max ? + reduce_operation_arg_min_max<T, OT>(src_row_ptr, reduce_elems, op, src_width * src_height * src_depth) : + reduce_operation<T, OT>(src_row_ptr, reduce_elems, op, src_width * src_height * src_depth); } } } |