From 2226ca98e01a8f7de37357ebdb2a0ed14fd3b0d2 Mon Sep 17 00:00:00 2001 From: Teresa Charlin Date: Thu, 11 Feb 2021 23:05:40 +0000 Subject: MLCE-347 Bug fixes in Reduce: QueueDescriptor.validate and init REDUCE_MIN * Allow input tensors of any rank in ReduceQueueDescriptor::validate * Fix VTS tests failing for REDUCE_MIN due to initialization Signed-off-by: Teresa Charlin Change-Id: Id8fba1662ade4e0a967093fe5a53b275847f2393 --- src/backends/backendsCommon/WorkloadData.cpp | 2 -- src/backends/reference/workloads/Reduce.cpp | 24 ++++++++++++++---------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index b51099ff79..90db57f953 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -3643,8 +3643,6 @@ void ReduceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0]; const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0]; - ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input"); - std::vector supportedTypes = { DataType::BFloat16, diff --git a/src/backends/reference/workloads/Reduce.cpp b/src/backends/reference/workloads/Reduce.cpp index 31c6262c9a..392ef8e5ba 100644 --- a/src/backends/reference/workloads/Reduce.cpp +++ b/src/backends/reference/workloads/Reduce.cpp @@ -81,17 +81,21 @@ void Reduce(const TensorInfo& inputInfo, // Initialise temp output std::vector tempOut(numOutputs); - if (reduceOperation == ReduceOperation::Max || reduceOperation == ReduceOperation::Min) + switch(reduceOperation) { - for (unsigned int idx = 0; idx < numOutputs; ++idx) - { - input[idx]; - tempOut[idx] = input.Get(); - } - } - else - { - std::fill(tempOut.begin(), tempOut.end(), 0.0); + case ReduceOperation::Mean: + case ReduceOperation::Sum: + std::fill(tempOut.begin(), tempOut.end(), 0.0); + break; + case ReduceOperation::Max: + std::fill(tempOut.begin(), tempOut.end(), -1 * std::numeric_limits::max()); + break; + case ReduceOperation::Min: + std::fill(tempOut.begin(), tempOut.end(), std::numeric_limits::max()); + break; + default: + throw armnn::InvalidArgumentException("Unknown reduce method: " + + std::to_string(static_cast(reduceOperation))); } // Initialise temp index -- cgit v1.2.1