aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--tests/validation/reference/ReductionOperation.cpp69
1 files changed, 41 insertions, 28 deletions
diff --git a/tests/validation/reference/ReductionOperation.cpp b/tests/validation/reference/ReductionOperation.cpp
index 571b991b92..fe128cc6ac 100644
--- a/tests/validation/reference/ReductionOperation.cpp
+++ b/tests/validation/reference/ReductionOperation.cpp
@@ -71,18 +71,6 @@ OT reduce_operation(const T *ptr, int reduce_elements, ReductionOperation op, in
switch(op)
{
- case ReductionOperation::ARG_IDX_MIN:
- if(*(ptr + stride * static_cast<uint32_t>(int_res)) > elem)
- {
- int_res = static_cast<uint32_t>(i);
- }
- break;
- case ReductionOperation::ARG_IDX_MAX:
- if(*(ptr + stride * static_cast<uint32_t>(int_res)) < elem)
- {
- int_res = static_cast<uint32_t>(i);
- }
- break;
case ReductionOperation::MIN:
if(static_cast<T>(int_res) > elem)
{
@@ -122,18 +110,6 @@ OT reduce_operation(const T *ptr, int reduce_elements, ReductionOperation op, in
auto elem = *(ptr + stride * i);
switch(op)
{
- case ReductionOperation::ARG_IDX_MIN:
- if(*(ptr + stride * static_cast<uint32_t>(res)) > elem)
- {
- res = static_cast<uint32_t>(i);
- }
- break;
- case ReductionOperation::ARG_IDX_MAX:
- if(*(ptr + stride * static_cast<uint32_t>(res)) < elem)
- {
- res = static_cast<uint32_t>(i);
- }
- break;
case ReductionOperation::MIN:
if(res > elem)
{
@@ -167,6 +143,35 @@ OT reduce_operation(const T *ptr, int reduce_elements, ReductionOperation op, in
}
return res;
}
+
+template <typename T, typename OT>
+OT reduce_operation_arg_min_max(const T *ptr, int reduce_elements, ReductionOperation op, int stride)
+{
+ uint32_t res = 0;
+ for(int i = 0; i < reduce_elements; ++i)
+ {
+ auto elem = *(ptr + stride * i);
+ switch(op)
+ {
+ case ReductionOperation::ARG_IDX_MIN:
+ if(*(ptr + stride * res) > elem)
+ {
+ res = static_cast<uint32_t>(i);
+ }
+ break;
+ case ReductionOperation::ARG_IDX_MAX:
+ if(*(ptr + stride * res) < elem)
+ {
+ res = static_cast<uint32_t>(i);
+ }
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Operation not supported");
+ }
+ }
+ return static_cast<OT>(res);
+}
+
} // namespace
template <typename T, typename OT>
@@ -190,7 +195,9 @@ SimpleTensor<OT> compute_reduction_operation(const SimpleTensor<T> &src, const T
for(unsigned int du = 0; du < upper_dims; ++du)
{
const T *src_row_ptr = src.data() + du * reduce_elems;
- dst[du] = reduce_operation<T, OT>(src_row_ptr, reduce_elems, op, 1);
+ dst[du] = is_arg_min_max ?
+ reduce_operation_arg_min_max<T, OT>(src_row_ptr, reduce_elems, op, 1) :
+ reduce_operation<T, OT>(src_row_ptr, reduce_elems, op, 1);
}
}
break;
@@ -204,7 +211,9 @@ SimpleTensor<OT> compute_reduction_operation(const SimpleTensor<T> &src, const T
const int in_offset = du * src_height * src_width + x;
const int out_offset = du * src_width + x;
const T *src_row_ptr = src.data() + in_offset;
- dst[out_offset] = reduce_operation<T, OT>(src_row_ptr, reduce_elems, op, src_width);
+ dst[out_offset] = is_arg_min_max ?
+ reduce_operation_arg_min_max<T, OT>(src_row_ptr, reduce_elems, op, src_width) :
+ reduce_operation<T, OT>(src_row_ptr, reduce_elems, op, src_width);
}
}
}
@@ -221,7 +230,9 @@ SimpleTensor<OT> compute_reduction_operation(const SimpleTensor<T> &src, const T
const int in_offset = du * src_depth * src_height * src_width + y * src_width + x;
const int out_offset = du * src_width * src_height + y * src_width + x;
const T *src_row_ptr = src.data() + in_offset;
- dst[out_offset] = reduce_operation<T, OT>(src_row_ptr, reduce_elems, op, src_height * src_width);
+ dst[out_offset] = is_arg_min_max ?
+ reduce_operation_arg_min_max<T, OT>(src_row_ptr, reduce_elems, op, src_width * src_height) :
+ reduce_operation<T, OT>(src_row_ptr, reduce_elems, op, src_width * src_height);
}
}
}
@@ -241,7 +252,9 @@ SimpleTensor<OT> compute_reduction_operation(const SimpleTensor<T> &src, const T
const int in_offset = du * src_batch * src_depth * src_height * src_width + z * src_width * src_height + y * src_width + x;
const int out_offset = du * src_depth * src_height * src_width + z * src_width * src_height + y * src_width + x;
const T *src_row_ptr = src.data() + in_offset;
- dst[out_offset] = reduce_operation<T, OT>(src_row_ptr, reduce_elems, op, src_width * src_height * src_depth);
+ dst[out_offset] = is_arg_min_max ?
+ reduce_operation_arg_min_max<T, OT>(src_row_ptr, reduce_elems, op, src_width * src_height * src_depth) :
+ reduce_operation<T, OT>(src_row_ptr, reduce_elems, op, src_width * src_height * src_depth);
}
}
}