aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTeresa Charlin <teresa.charlinreyes@arm.com>2021-02-11 23:05:40 +0000
committerTeresa Charlin <teresa.charlinreyes@arm.com>2021-02-11 23:07:30 +0000
commit2226ca98e01a8f7de37357ebdb2a0ed14fd3b0d2 (patch)
tree4dfe5e1768f023c71f23b9c689a5e37bb8f1ec8e
parent49bdb794170c3d25e3e51fc7b4c267c3d8dbcebf (diff)
downloadarmnn-2226ca98e01a8f7de37357ebdb2a0ed14fd3b0d2.tar.gz
MLCE-347 Bug fixes in Reduce: QueueDescriptor.validate and init REDUCE_MIN
* Allow input tensors of any rank in ReduceQueueDescriptor::validate * Fix VTS tests failing for REDUCE_MIN due to initialization Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com> Change-Id: Id8fba1662ade4e0a967093fe5a53b275847f2393
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp2
-rw-r--r--src/backends/reference/workloads/Reduce.cpp24
2 files changed, 14 insertions, 12 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index b51099ff79..90db57f953 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -3643,8 +3643,6 @@ void ReduceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
- ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
-
std::vector<DataType> supportedTypes =
{
DataType::BFloat16,
diff --git a/src/backends/reference/workloads/Reduce.cpp b/src/backends/reference/workloads/Reduce.cpp
index 31c6262c9a..392ef8e5ba 100644
--- a/src/backends/reference/workloads/Reduce.cpp
+++ b/src/backends/reference/workloads/Reduce.cpp
@@ -81,17 +81,21 @@ void Reduce(const TensorInfo& inputInfo,
// Initialise temp output
std::vector<float> tempOut(numOutputs);
- if (reduceOperation == ReduceOperation::Max || reduceOperation == ReduceOperation::Min)
+ switch(reduceOperation)
{
- for (unsigned int idx = 0; idx < numOutputs; ++idx)
- {
- input[idx];
- tempOut[idx] = input.Get();
- }
- }
- else
- {
- std::fill(tempOut.begin(), tempOut.end(), 0.0);
+ case ReduceOperation::Mean:
+ case ReduceOperation::Sum:
+ std::fill(tempOut.begin(), tempOut.end(), 0.0);
+ break;
+ case ReduceOperation::Max:
+ std::fill(tempOut.begin(), tempOut.end(), -1 * std::numeric_limits<float>::max());
+ break;
+ case ReduceOperation::Min:
+ std::fill(tempOut.begin(), tempOut.end(), std::numeric_limits<float>::max());
+ break;
+ default:
+ throw armnn::InvalidArgumentException("Unknown reduce method: " +
+ std::to_string(static_cast<int>(reduceOperation)));
}
// Initialise temp index