diff options
Diffstat (limited to 'src/backends/reference/workloads/Reduce.cpp')
-rw-r--r-- | src/backends/reference/workloads/Reduce.cpp | 43 |
1 files changed, 26 insertions, 17 deletions
diff --git a/src/backends/reference/workloads/Reduce.cpp b/src/backends/reference/workloads/Reduce.cpp index 8bf422aea3..3f929c43bc 100644 --- a/src/backends/reference/workloads/Reduce.cpp +++ b/src/backends/reference/workloads/Reduce.cpp @@ -9,7 +9,6 @@ #include <backendsCommon/WorkloadData.hpp> -#include <cmath> #include <cstddef> #include <functional> #include <limits> @@ -87,6 +86,9 @@ void Reduce(const TensorInfo& inputInfo, case ReduceOperation::Sum: std::fill(tempOut.begin(), tempOut.end(), 0.0f); break; + case ReduceOperation::Prod: + std::fill(tempOut.begin(), tempOut.end(), 1.0f); + break; case ReduceOperation::Max: std::fill(tempOut.begin(), tempOut.end(), -1 * std::numeric_limits<float>::max()); break; @@ -119,23 +121,30 @@ void Reduce(const TensorInfo& inputInfo, numResolvedAxis, resolvedAxis); input[inputOffset]; auto inputValue = input.Get(); - if (reduceOperation == ReduceOperation::Max) - { - if (inputValue > tempOut[outputOffset]) - { - tempOut[outputOffset] = inputValue; - } - } - else if (reduceOperation == ReduceOperation::Min) - { - if (inputValue < tempOut[outputOffset]) - { - tempOut[outputOffset] = inputValue; - } - } - else + switch(reduceOperation) { - tempOut[outputOffset] += inputValue; + case ReduceOperation::Mean: + case ReduceOperation::Sum: + tempOut[outputOffset] += inputValue; + break; + case ReduceOperation::Prod: + tempOut[outputOffset] *= inputValue; + break; + case ReduceOperation::Max: + if (inputValue > tempOut[outputOffset]) + { + tempOut[outputOffset] = inputValue; + } + break; + case ReduceOperation::Min: + if (inputValue < tempOut[outputOffset]) + { + tempOut[outputOffset] = inputValue; + } + break; + default: + throw armnn::InvalidArgumentException("Unknown reduce method: " + + std::to_string(static_cast<int>(reduceOperation))); } } |