aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/Reduce.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/workloads/Reduce.cpp')
-rw-r--r--src/backends/reference/workloads/Reduce.cpp24
1 files changed, 14 insertions, 10 deletions
diff --git a/src/backends/reference/workloads/Reduce.cpp b/src/backends/reference/workloads/Reduce.cpp
index 31c6262c9a..392ef8e5ba 100644
--- a/src/backends/reference/workloads/Reduce.cpp
+++ b/src/backends/reference/workloads/Reduce.cpp
@@ -81,17 +81,21 @@ void Reduce(const TensorInfo& inputInfo,
// Initialise temp output
std::vector<float> tempOut(numOutputs);
- if (reduceOperation == ReduceOperation::Max || reduceOperation == ReduceOperation::Min)
+ switch(reduceOperation)
{
- for (unsigned int idx = 0; idx < numOutputs; ++idx)
- {
- input[idx];
- tempOut[idx] = input.Get();
- }
- }
- else
- {
- std::fill(tempOut.begin(), tempOut.end(), 0.0);
+ case ReduceOperation::Mean:
+ case ReduceOperation::Sum:
+ std::fill(tempOut.begin(), tempOut.end(), 0.0);
+ break;
+ case ReduceOperation::Max:
+ std::fill(tempOut.begin(), tempOut.end(), -1 * std::numeric_limits<float>::max());
+ break;
+ case ReduceOperation::Min:
+ std::fill(tempOut.begin(), tempOut.end(), std::numeric_limits<float>::max());
+ break;
+ default:
+ throw armnn::InvalidArgumentException("Unknown reduce method: " +
+ std::to_string(static_cast<int>(reduceOperation)));
}
// Initialise temp index