diff options
Diffstat (limited to 'tests/validation/reference/ReductionOperation.cpp')
-rw-r--r-- | tests/validation/reference/ReductionOperation.cpp | 11 |
1 files changed, 8 insertions, 3 deletions
diff --git a/tests/validation/reference/ReductionOperation.cpp b/tests/validation/reference/ReductionOperation.cpp index fc12e31d75..8e79c3bfb0 100644 --- a/tests/validation/reference/ReductionOperation.cpp +++ b/tests/validation/reference/ReductionOperation.cpp @@ -42,11 +42,11 @@ template <typename T, typename OT> OT reduce_operation(const T *ptr, int reduce_elements, ReductionOperation op, int stride) { using type = typename std::remove_cv<OT>::type; - auto res = type(0); + auto res = (op == ReductionOperation::PROD) ? type(1) : type(0); if(std::is_integral<type>::value) { - uint32_t int_res = 0; + auto int_res = static_cast<uint32_t>(res); for(int i = 0; i < reduce_elements; ++i) { auto elem = *(ptr + stride * i); @@ -72,6 +72,9 @@ OT reduce_operation(const T *ptr, int reduce_elements, ReductionOperation op, in case ReductionOperation::SUM: int_res += elem; break; + case ReductionOperation::PROD: + int_res *= elem; + break; default: ARM_COMPUTE_ERROR("Operation not supported"); } @@ -108,6 +111,9 @@ OT reduce_operation(const T *ptr, int reduce_elements, ReductionOperation op, in case ReductionOperation::SUM: res += elem; break; + case ReductionOperation::PROD: + res *= elem; + break; default: ARM_COMPUTE_ERROR("Operation not supported"); } @@ -117,7 +123,6 @@ OT reduce_operation(const T *ptr, int reduce_elements, ReductionOperation op, in res /= reduce_elements; } } - return res; } } // namespace |