diff options
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.cpp | 2 | ||||
-rw-r--r-- | src/backends/reference/workloads/Reduce.cpp | 24 |
2 files changed, 14 insertions, 12 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index b51099ff79..90db57f953 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -3643,8 +3643,6 @@ void ReduceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0]; const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0]; - ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input"); - std::vector<DataType> supportedTypes = { DataType::BFloat16, diff --git a/src/backends/reference/workloads/Reduce.cpp b/src/backends/reference/workloads/Reduce.cpp index 31c6262c9a..392ef8e5ba 100644 --- a/src/backends/reference/workloads/Reduce.cpp +++ b/src/backends/reference/workloads/Reduce.cpp @@ -81,17 +81,21 @@ void Reduce(const TensorInfo& inputInfo, // Initialise temp output std::vector<float> tempOut(numOutputs); - if (reduceOperation == ReduceOperation::Max || reduceOperation == ReduceOperation::Min) + switch(reduceOperation) { - for (unsigned int idx = 0; idx < numOutputs; ++idx) - { - input[idx]; - tempOut[idx] = input.Get(); - } - } - else - { - std::fill(tempOut.begin(), tempOut.end(), 0.0); + case ReduceOperation::Mean: + case ReduceOperation::Sum: + std::fill(tempOut.begin(), tempOut.end(), 0.0); + break; + case ReduceOperation::Max: + std::fill(tempOut.begin(), tempOut.end(), -1 * std::numeric_limits<float>::max()); + break; + case ReduceOperation::Min: + std::fill(tempOut.begin(), tempOut.end(), std::numeric_limits<float>::max()); + break; + default: + throw armnn::InvalidArgumentException("Unknown reduce method: " + + std::to_string(static_cast<int>(reduceOperation))); } // Initialise temp index |