aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference
diff options
context:
space:
mode:
authorTeresa Charlin <teresa.charlinreyes@arm.com>2021-02-11 23:05:40 +0000
committerTeresa Charlin <teresa.charlinreyes@arm.com>2021-02-11 23:07:30 +0000
commit2226ca98e01a8f7de37357ebdb2a0ed14fd3b0d2 (patch)
tree4dfe5e1768f023c71f23b9c689a5e37bb8f1ec8e /src/backends/reference
parent49bdb794170c3d25e3e51fc7b4c267c3d8dbcebf (diff)
downloadarmnn-2226ca98e01a8f7de37357ebdb2a0ed14fd3b0d2.tar.gz
MLCE-347 Bug fixes in Reduce: QueueDescriptor.validate and init REDUCE_MIN
* Allow input tensors of any rank in ReduceQueueDescriptor::validate * Fix VTS tests failing for REDUCE_MIN due to initialization Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com> Change-Id: Id8fba1662ade4e0a967093fe5a53b275847f2393
Diffstat (limited to 'src/backends/reference')
-rw-r--r--src/backends/reference/workloads/Reduce.cpp24
1 files changed, 14 insertions, 10 deletions
diff --git a/src/backends/reference/workloads/Reduce.cpp b/src/backends/reference/workloads/Reduce.cpp
index 31c6262c9a..392ef8e5ba 100644
--- a/src/backends/reference/workloads/Reduce.cpp
+++ b/src/backends/reference/workloads/Reduce.cpp
@@ -81,17 +81,21 @@ void Reduce(const TensorInfo& inputInfo,
// Initialise temp output
std::vector<float> tempOut(numOutputs);
- if (reduceOperation == ReduceOperation::Max || reduceOperation == ReduceOperation::Min)
+ switch(reduceOperation)
{
- for (unsigned int idx = 0; idx < numOutputs; ++idx)
- {
- input[idx];
- tempOut[idx] = input.Get();
- }
- }
- else
- {
- std::fill(tempOut.begin(), tempOut.end(), 0.0);
+ case ReduceOperation::Mean:
+ case ReduceOperation::Sum:
+ std::fill(tempOut.begin(), tempOut.end(), 0.0);
+ break;
+ case ReduceOperation::Max:
+ std::fill(tempOut.begin(), tempOut.end(), -1 * std::numeric_limits<float>::max());
+ break;
+ case ReduceOperation::Min:
+ std::fill(tempOut.begin(), tempOut.end(), std::numeric_limits<float>::max());
+ break;
+ default:
+ throw armnn::InvalidArgumentException("Unknown reduce method: " +
+ std::to_string(static_cast<int>(reduceOperation)));
}
// Initialise temp index