From aea14c63e2efeda9d5f7492099389d439c65204f Mon Sep 17 00:00:00 2001 From: Michalis Spyrou Date: Thu, 3 Jan 2019 11:10:25 +0000 Subject: COMPMID-1764 NEON: Implement ArgMax/ArgMin Change-Id: Ibe23aa90b36ffd8553d1d1c35fada5d300fab829 Reviewed-on: https://review.mlplatform.org/475 Reviewed-by: Isabella Gottardi Tested-by: Arm Jenkins Reviewed-by: Giuseppe Rossini --- tests/validation/reference/ReductionOperation.cpp | 59 ++++++----------------- 1 file changed, 16 insertions(+), 43 deletions(-) (limited to 'tests/validation/reference/ReductionOperation.cpp') diff --git a/tests/validation/reference/ReductionOperation.cpp b/tests/validation/reference/ReductionOperation.cpp index 37a9be86c0..fc12e31d75 100644 --- a/tests/validation/reference/ReductionOperation.cpp +++ b/tests/validation/reference/ReductionOperation.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -49,20 +49,20 @@ OT reduce_operation(const T *ptr, int reduce_elements, ReductionOperation op, in uint32_t int_res = 0; for(int i = 0; i < reduce_elements; ++i) { - auto elem = static_cast(*(ptr + stride * i)); + auto elem = *(ptr + stride * i); switch(op) { case ReductionOperation::ARG_IDX_MIN: - if(static_cast(*(ptr + stride * static_cast(res))) > elem) + if(*(ptr + stride * static_cast(int_res)) > elem) { - res = static_cast(i); + int_res = static_cast(i); } break; case ReductionOperation::ARG_IDX_MAX: - if(static_cast(*(ptr + stride * static_cast(res))) < elem) + if(*(ptr + stride * static_cast(int_res)) < elem) { - res = static_cast(i); + int_res = static_cast(i); } break; case ReductionOperation::SUM_SQUARE: @@ -122,13 +122,13 @@ OT reduce_operation(const T *ptr, int reduce_elements, ReductionOperation op, in } } // namespace -template -SimpleTensor reduction_operation(const SimpleTensor &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op) +template +SimpleTensor reduction_operation(const SimpleTensor &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op) { // Create reference const bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX); DataType output_data_type = is_arg_min_max ? DataType::U32 : src.data_type(); - SimpleTensor dst{ dst_shape, output_data_type, 1, src.quantization_info() }; + SimpleTensor dst{ dst_shape, output_data_type, 1, src.quantization_info() }; const unsigned int src_width = src.shape().x(); const unsigned int src_height = src.shape().y(); const unsigned int src_depth = src.shape().z(); @@ -143,14 +143,7 @@ SimpleTensor reduction_operation(const SimpleTensor &src, const TensorShap for(unsigned int du = 0; du < upper_dims; ++du) { const T *src_row_ptr = src.data() + du * reduce_elems; - if(is_arg_min_max) - { - dst[du] = reduce_operation(src_row_ptr, reduce_elems, op, 1); - } - else - { - dst[du] = reduce_operation(src_row_ptr, reduce_elems, op, 1); - } + dst[du] = reduce_operation(src_row_ptr, reduce_elems, op, 1); } } break; @@ -164,15 +157,7 @@ SimpleTensor reduction_operation(const SimpleTensor &src, const TensorShap const int in_offset = du * src_height * src_width + x; const int out_offset = du * src_width + x; const T *src_row_ptr = src.data() + in_offset; - - if(is_arg_min_max) - { - dst[out_offset] = reduce_operation(src_row_ptr, reduce_elems, op, src_width); - } - else - { - dst[out_offset] = reduce_operation(src_row_ptr, reduce_elems, op, src_width); - } + dst[out_offset] = reduce_operation(src_row_ptr, reduce_elems, op, src_width); } } } @@ -189,15 +174,7 @@ SimpleTensor reduction_operation(const SimpleTensor &src, const TensorShap const int in_offset = du * src_depth * src_height * src_width + y * src_width + x; const int out_offset = du * src_width * src_height + y * src_width + x; const T *src_row_ptr = src.data() + in_offset; - - if(is_arg_min_max) - { - dst[out_offset] = reduce_operation(src_row_ptr, reduce_elems, op, src_height * src_width); - } - else - { - dst[out_offset] = reduce_operation(src_row_ptr, reduce_elems, op, src_height * src_width); - } + dst[out_offset] = reduce_operation(src_row_ptr, reduce_elems, op, src_height * src_width); } } } @@ -217,14 +194,7 @@ SimpleTensor reduction_operation(const SimpleTensor &src, const TensorShap const int in_offset = du * src_batch * src_depth * src_height * src_width + z * src_width * src_height + y * src_width + x; const int out_offset = du * src_depth * src_height * src_width + z * src_width * src_height + y * src_width + x; const T *src_row_ptr = src.data() + in_offset; - if(is_arg_min_max) - { - dst[out_offset] = reduce_operation(src_row_ptr, reduce_elems, op, src_width * src_height * src_depth); - } - else - { - dst[out_offset] = reduce_operation(src_row_ptr, reduce_elems, op, src_width * src_height * src_depth); - } + dst[out_offset] = reduce_operation(src_row_ptr, reduce_elems, op, src_width * src_height * src_depth); } } } @@ -238,6 +208,9 @@ SimpleTensor reduction_operation(const SimpleTensor &src, const TensorShap return dst; } +template SimpleTensor reduction_operation(const SimpleTensor &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op); +template SimpleTensor reduction_operation(const SimpleTensor &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op); +template SimpleTensor reduction_operation(const SimpleTensor &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op); template SimpleTensor reduction_operation(const SimpleTensor &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op); template SimpleTensor reduction_operation(const SimpleTensor &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op); template SimpleTensor reduction_operation(const SimpleTensor &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op); -- cgit v1.2.1