diff options
Diffstat (limited to 'src/backends/reference')
-rw-r--r-- | src/backends/reference/workloads/Reduce.cpp | 24 |
1 files changed, 14 insertions, 10 deletions
diff --git a/src/backends/reference/workloads/Reduce.cpp b/src/backends/reference/workloads/Reduce.cpp index 31c6262c9a..392ef8e5ba 100644 --- a/src/backends/reference/workloads/Reduce.cpp +++ b/src/backends/reference/workloads/Reduce.cpp @@ -81,17 +81,21 @@ void Reduce(const TensorInfo& inputInfo, // Initialise temp output std::vector<float> tempOut(numOutputs); - if (reduceOperation == ReduceOperation::Max || reduceOperation == ReduceOperation::Min) + switch(reduceOperation) { - for (unsigned int idx = 0; idx < numOutputs; ++idx) - { - input[idx]; - tempOut[idx] = input.Get(); - } - } - else - { - std::fill(tempOut.begin(), tempOut.end(), 0.0); + case ReduceOperation::Mean: + case ReduceOperation::Sum: + std::fill(tempOut.begin(), tempOut.end(), 0.0); + break; + case ReduceOperation::Max: + std::fill(tempOut.begin(), tempOut.end(), -1 * std::numeric_limits<float>::max()); + break; + case ReduceOperation::Min: + std::fill(tempOut.begin(), tempOut.end(), std::numeric_limits<float>::max()); + break; + default: + throw armnn::InvalidArgumentException("Unknown reduce method: " + + std::to_string(static_cast<int>(reduceOperation))); } // Initialise temp index |