diff options
author | Teresa Charlin <teresa.charlinreyes@arm.com> | 2021-02-11 23:05:40 +0000 |
---|---|---|
committer | Teresa Charlin <teresa.charlinreyes@arm.com> | 2021-02-11 23:07:30 +0000 |
commit | 2226ca98e01a8f7de37357ebdb2a0ed14fd3b0d2 (patch) | |
tree | 4dfe5e1768f023c71f23b9c689a5e37bb8f1ec8e | |
parent | 49bdb794170c3d25e3e51fc7b4c267c3d8dbcebf (diff) | |
download | armnn-2226ca98e01a8f7de37357ebdb2a0ed14fd3b0d2.tar.gz |
MLCE-347 Bug fixes in Reduce: QueueDescriptor.validate and init REDUCE_MIN
* Allow input tensors of any rank in ReduceQueueDescriptor::validate
* Fix VTS tests failing for REDUCE_MIN due to initialization
Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com>
Change-Id: Id8fba1662ade4e0a967093fe5a53b275847f2393
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.cpp | 2 | ||||
-rw-r--r-- | src/backends/reference/workloads/Reduce.cpp | 24 |
2 files changed, 14 insertions, 12 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index b51099ff79..90db57f953 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -3643,8 +3643,6 @@ void ReduceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0]; const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0]; - ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input"); - std::vector<DataType> supportedTypes = { DataType::BFloat16, diff --git a/src/backends/reference/workloads/Reduce.cpp b/src/backends/reference/workloads/Reduce.cpp index 31c6262c9a..392ef8e5ba 100644 --- a/src/backends/reference/workloads/Reduce.cpp +++ b/src/backends/reference/workloads/Reduce.cpp @@ -81,17 +81,21 @@ void Reduce(const TensorInfo& inputInfo, // Initialise temp output std::vector<float> tempOut(numOutputs); - if (reduceOperation == ReduceOperation::Max || reduceOperation == ReduceOperation::Min) + switch(reduceOperation) { - for (unsigned int idx = 0; idx < numOutputs; ++idx) - { - input[idx]; - tempOut[idx] = input.Get(); - } - } - else - { - std::fill(tempOut.begin(), tempOut.end(), 0.0); + case ReduceOperation::Mean: + case ReduceOperation::Sum: + std::fill(tempOut.begin(), tempOut.end(), 0.0); + break; + case ReduceOperation::Max: + std::fill(tempOut.begin(), tempOut.end(), -1 * std::numeric_limits<float>::max()); + break; + case ReduceOperation::Min: + std::fill(tempOut.begin(), tempOut.end(), std::numeric_limits<float>::max()); + break; + default: + throw armnn::InvalidArgumentException("Unknown reduce method: " + + std::to_string(static_cast<int>(reduceOperation))); } // Initialise temp index |