aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp2
-rw-r--r--src/backends/reference/workloads/Reduce.cpp24
2 files changed, 14 insertions, 12 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index b51099ff79..90db57f953 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -3643,8 +3643,6 @@ void ReduceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
- ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
-
std::vector<DataType> supportedTypes =
{
DataType::BFloat16,
diff --git a/src/backends/reference/workloads/Reduce.cpp b/src/backends/reference/workloads/Reduce.cpp
index 31c6262c9a..392ef8e5ba 100644
--- a/src/backends/reference/workloads/Reduce.cpp
+++ b/src/backends/reference/workloads/Reduce.cpp
@@ -81,17 +81,21 @@ void Reduce(const TensorInfo& inputInfo,
// Initialise temp output
std::vector<float> tempOut(numOutputs);
- if (reduceOperation == ReduceOperation::Max || reduceOperation == ReduceOperation::Min)
+ switch(reduceOperation)
{
- for (unsigned int idx = 0; idx < numOutputs; ++idx)
- {
- input[idx];
- tempOut[idx] = input.Get();
- }
- }
- else
- {
- std::fill(tempOut.begin(), tempOut.end(), 0.0);
+ case ReduceOperation::Mean:
+ case ReduceOperation::Sum:
+ std::fill(tempOut.begin(), tempOut.end(), 0.0);
+ break;
+ case ReduceOperation::Max:
+ std::fill(tempOut.begin(), tempOut.end(), -1 * std::numeric_limits<float>::max());
+ break;
+ case ReduceOperation::Min:
+ std::fill(tempOut.begin(), tempOut.end(), std::numeric_limits<float>::max());
+ break;
+ default:
+ throw armnn::InvalidArgumentException("Unknown reduce method: " +
+ std::to_string(static_cast<int>(reduceOperation)));
}
// Initialise temp index