aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/reference/ReductionOperation.cpp
diff options
context:
space:
mode:
authorMichalis Spyrou <michalis.spyrou@arm.com>2018-11-22 17:36:28 +0000
committerMichalis Spyrou <michalis.spyrou@arm.com>2018-11-30 15:46:49 +0000
commit7930db48e12dd3a14c1971f41f5b83527efea281 (patch)
treed17899ba82203423320bfa8d2dea1e07b045c898 /tests/validation/reference/ReductionOperation.cpp
parent95abfddfa08ab85d4f88c6f4d2e077969178f2d5 (diff)
downloadComputeLibrary-7930db48e12dd3a14c1971f41f5b83527efea281.tar.gz
COMPMID-1728 CL: Implement ArgMax/ArgMin
Change-Id: I7eae2e55cc0b0b7bbebb7617299daaca6f75f40c Reviewed-on: https://review.mlplatform.org/292 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Diffstat (limited to 'tests/validation/reference/ReductionOperation.cpp')
-rw-r--r--tests/validation/reference/ReductionOperation.cpp103
1 files changed, 89 insertions, 14 deletions
diff --git a/tests/validation/reference/ReductionOperation.cpp b/tests/validation/reference/ReductionOperation.cpp
index 2f103a6f65..37a9be86c0 100644
--- a/tests/validation/reference/ReductionOperation.cpp
+++ b/tests/validation/reference/ReductionOperation.cpp
@@ -38,10 +38,10 @@ namespace reference
{
namespace
{
-template <typename T>
-T reduce_operation(T *ptr, int reduce_elements, ReductionOperation op, int stride)
+template <typename T, typename OT>
+OT reduce_operation(const T *ptr, int reduce_elements, ReductionOperation op, int stride)
{
- using type = typename std::remove_cv<T>::type;
+ using type = typename std::remove_cv<OT>::type;
auto res = type(0);
if(std::is_integral<type>::value)
@@ -50,7 +50,31 @@ T reduce_operation(T *ptr, int reduce_elements, ReductionOperation op, int strid
for(int i = 0; i < reduce_elements; ++i)
{
auto elem = static_cast<uint32_t>(*(ptr + stride * i));
- int_res += (op == ReductionOperation::SUM_SQUARE) ? elem * elem : elem;
+
+ switch(op)
+ {
+ case ReductionOperation::ARG_IDX_MIN:
+ if(static_cast<uint32_t>(*(ptr + stride * static_cast<uint32_t>(res))) > elem)
+ {
+ res = static_cast<uint32_t>(i);
+ }
+ break;
+ case ReductionOperation::ARG_IDX_MAX:
+ if(static_cast<uint32_t>(*(ptr + stride * static_cast<uint32_t>(res))) < elem)
+ {
+ res = static_cast<uint32_t>(i);
+ }
+ break;
+ case ReductionOperation::SUM_SQUARE:
+ int_res += elem * elem;
+ break;
+ case ReductionOperation::MEAN_SUM:
+ case ReductionOperation::SUM:
+ int_res += elem;
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Operation not supported");
+ }
}
if(op == ReductionOperation::MEAN_SUM && reduce_elements > 0)
{
@@ -63,7 +87,30 @@ T reduce_operation(T *ptr, int reduce_elements, ReductionOperation op, int strid
for(int i = 0; i < reduce_elements; ++i)
{
auto elem = *(ptr + stride * i);
- res += (op == ReductionOperation::SUM_SQUARE) ? elem * elem : elem;
+ switch(op)
+ {
+ case ReductionOperation::ARG_IDX_MIN:
+ if(*(ptr + stride * static_cast<uint32_t>(res)) > elem)
+ {
+ res = static_cast<uint32_t>(i);
+ }
+ break;
+ case ReductionOperation::ARG_IDX_MAX:
+ if(*(ptr + stride * static_cast<uint32_t>(res)) < elem)
+ {
+ res = static_cast<uint32_t>(i);
+ }
+ break;
+ case ReductionOperation::SUM_SQUARE:
+ res += elem * elem;
+ break;
+ case ReductionOperation::MEAN_SUM:
+ case ReductionOperation::SUM:
+ res += elem;
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Operation not supported");
+ }
}
if(op == ReductionOperation::MEAN_SUM && reduce_elements > 0)
{
@@ -79,7 +126,9 @@ template <typename T>
SimpleTensor<T> reduction_operation(const SimpleTensor<T> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op)
{
// Create reference
- SimpleTensor<T> dst{ dst_shape, src.data_type(), 1, src.quantization_info() };
+ const bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX);
+ DataType output_data_type = is_arg_min_max ? DataType::U32 : src.data_type();
+ SimpleTensor<T> dst{ dst_shape, output_data_type, 1, src.quantization_info() };
const unsigned int src_width = src.shape().x();
const unsigned int src_height = src.shape().y();
const unsigned int src_depth = src.shape().z();
@@ -94,8 +143,14 @@ SimpleTensor<T> reduction_operation(const SimpleTensor<T> &src, const TensorShap
for(unsigned int du = 0; du < upper_dims; ++du)
{
const T *src_row_ptr = src.data() + du * reduce_elems;
- auto res = reduce_operation(src_row_ptr, reduce_elems, op, 1);
- dst[du] = res;
+ if(is_arg_min_max)
+ {
+ dst[du] = reduce_operation<T, uint32_t>(src_row_ptr, reduce_elems, op, 1);
+ }
+ else
+ {
+ dst[du] = reduce_operation<T, T>(src_row_ptr, reduce_elems, op, 1);
+ }
}
}
break;
@@ -109,8 +164,15 @@ SimpleTensor<T> reduction_operation(const SimpleTensor<T> &src, const TensorShap
const int in_offset = du * src_height * src_width + x;
const int out_offset = du * src_width + x;
const T *src_row_ptr = src.data() + in_offset;
- auto res = reduce_operation(src_row_ptr, reduce_elems, op, src_width);
- dst[out_offset] = res;
+
+ if(is_arg_min_max)
+ {
+ dst[out_offset] = reduce_operation<T, uint32_t>(src_row_ptr, reduce_elems, op, src_width);
+ }
+ else
+ {
+ dst[out_offset] = reduce_operation<T, T>(src_row_ptr, reduce_elems, op, src_width);
+ }
}
}
}
@@ -127,8 +189,15 @@ SimpleTensor<T> reduction_operation(const SimpleTensor<T> &src, const TensorShap
const int in_offset = du * src_depth * src_height * src_width + y * src_width + x;
const int out_offset = du * src_width * src_height + y * src_width + x;
const T *src_row_ptr = src.data() + in_offset;
- auto res = reduce_operation(src_row_ptr, reduce_elems, op, src_height * src_width);
- dst[out_offset] = res;
+
+ if(is_arg_min_max)
+ {
+ dst[out_offset] = reduce_operation<T, uint32_t>(src_row_ptr, reduce_elems, op, src_height * src_width);
+ }
+ else
+ {
+ dst[out_offset] = reduce_operation<T, T>(src_row_ptr, reduce_elems, op, src_height * src_width);
+ }
}
}
}
@@ -148,8 +217,14 @@ SimpleTensor<T> reduction_operation(const SimpleTensor<T> &src, const TensorShap
const int in_offset = du * src_batch * src_depth * src_height * src_width + z * src_width * src_height + y * src_width + x;
const int out_offset = du * src_depth * src_height * src_width + z * src_width * src_height + y * src_width + x;
const T *src_row_ptr = src.data() + in_offset;
- auto res = reduce_operation(src_row_ptr, reduce_elems, op, src_width * src_height * src_depth);
- dst[out_offset] = res;
+ if(is_arg_min_max)
+ {
+ dst[out_offset] = reduce_operation<T, uint32_t>(src_row_ptr, reduce_elems, op, src_width * src_height * src_depth);
+ }
+ else
+ {
+ dst[out_offset] = reduce_operation<T, T>(src_row_ptr, reduce_elems, op, src_width * src_height * src_depth);
+ }
}
}
}