aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads
diff options
context:
space:
mode:
authorTeresa Charlin <teresa.charlinreyes@arm.com>2021-08-05 12:34:37 +0100
committerTeresaARM <teresa.charlinreyes@arm.com>2021-09-03 08:41:21 +0000
commit4e3e831da1d6d85dffffacf57e9de8fc891b7e58 (patch)
tree9a3653729feba788dcfbbdc5255ad379cbbf597d /src/backends/reference/workloads
parent14bef9f83f7cd58e5038ae7432d75da2d50e7b68 (diff)
downloadarmnn-4e3e831da1d6d85dffffacf57e9de8fc891b7e58.tar.gz
IVGCVSW-6262 Add support for Reduce Prod
* Tflite parser * Tflite delegate * Serializer * Deserializer * Ref, CpuAcc and GpuAcc workloads Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com> Change-Id: I601a9ee1680b372c7955d9a628857d08c3cfd377
Diffstat (limited to 'src/backends/reference/workloads')
-rw-r--r--src/backends/reference/workloads/Reduce.cpp43
1 files changed, 26 insertions, 17 deletions
diff --git a/src/backends/reference/workloads/Reduce.cpp b/src/backends/reference/workloads/Reduce.cpp
index 8bf422aea3..3f929c43bc 100644
--- a/src/backends/reference/workloads/Reduce.cpp
+++ b/src/backends/reference/workloads/Reduce.cpp
@@ -9,7 +9,6 @@
#include <backendsCommon/WorkloadData.hpp>
-#include <cmath>
#include <cstddef>
#include <functional>
#include <limits>
@@ -87,6 +86,9 @@ void Reduce(const TensorInfo& inputInfo,
case ReduceOperation::Sum:
std::fill(tempOut.begin(), tempOut.end(), 0.0f);
break;
+ case ReduceOperation::Prod:
+ std::fill(tempOut.begin(), tempOut.end(), 1.0f);
+ break;
case ReduceOperation::Max:
std::fill(tempOut.begin(), tempOut.end(), -1 * std::numeric_limits<float>::max());
break;
@@ -119,23 +121,30 @@ void Reduce(const TensorInfo& inputInfo,
numResolvedAxis, resolvedAxis);
input[inputOffset];
auto inputValue = input.Get();
- if (reduceOperation == ReduceOperation::Max)
- {
- if (inputValue > tempOut[outputOffset])
- {
- tempOut[outputOffset] = inputValue;
- }
- }
- else if (reduceOperation == ReduceOperation::Min)
- {
- if (inputValue < tempOut[outputOffset])
- {
- tempOut[outputOffset] = inputValue;
- }
- }
- else
+ switch(reduceOperation)
{
- tempOut[outputOffset] += inputValue;
+ case ReduceOperation::Mean:
+ case ReduceOperation::Sum:
+ tempOut[outputOffset] += inputValue;
+ break;
+ case ReduceOperation::Prod:
+ tempOut[outputOffset] *= inputValue;
+ break;
+ case ReduceOperation::Max:
+ if (inputValue > tempOut[outputOffset])
+ {
+ tempOut[outputOffset] = inputValue;
+ }
+ break;
+ case ReduceOperation::Min:
+ if (inputValue < tempOut[outputOffset])
+ {
+ tempOut[outputOffset] = inputValue;
+ }
+ break;
+ default:
+ throw armnn::InvalidArgumentException("Unknown reduce method: " +
+ std::to_string(static_cast<int>(reduceOperation)));
}
}