aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/backends/WorkloadData.cpp
diff options
context:
space:
mode:
authornarpra01 <narumol.prangnawarat@arm.com>2018-09-10 17:35:27 +0100
committerMatthew Bentham <matthew.bentham@arm.com>2018-09-25 14:54:29 +0100
commiteb06191bdaeab6640691a00349eb93141fb49fe8 (patch)
treeb850e2eddee1e61f79a1b0efc6b6670a0920d88c /src/armnn/backends/WorkloadData.cpp
parent4a8692cf18ebd3c4de125274d5c840d7be64e3cd (diff)
downloadarmnn-eb06191bdaeab6640691a00349eb93141fb49fe8.tar.gz
IVGCVSW-1831 - Add dimension check to MeanQueueDescriptor::Validate to check if the output dimension is correct from a given input and options.
Change-Id: Ibc15d9ea3151a7ba1935feafeb1843ee035e7f2e
Diffstat (limited to 'src/armnn/backends/WorkloadData.cpp')
-rw-r--r--src/armnn/backends/WorkloadData.cpp35
1 files changed, 35 insertions, 0 deletions
diff --git a/src/armnn/backends/WorkloadData.cpp b/src/armnn/backends/WorkloadData.cpp
index c934a53a5e..3ed77dacdb 100644
--- a/src/armnn/backends/WorkloadData.cpp
+++ b/src/armnn/backends/WorkloadData.cpp
@@ -129,6 +129,18 @@ void ValidateTensorNumDimensions(const TensorInfo& tensor,
}
}
+void ValidateTensorMaxNumElements(const TensorInfo& tensor,
+ std::string const& descName,
+ unsigned int maxNumElements,
+ std::string const& tensorName)
+{
+ if (tensor.GetNumElements() > maxNumElements)
+ {
+ throw InvalidArgumentException(descName + ": Expected maximum of " + to_string(maxNumElements) + " but got " +
+ to_string(tensor.GetNumElements()) + " elements for " + tensorName + " tensor.");
+ }
+}
+
//---------------------------------------------------------------
void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType,
const std::string& descName, std::string const& tensorName)
@@ -828,6 +840,29 @@ void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
ValidateSingleInput(workloadInfo, "MeanQueueDescriptor");
ValidateSingleOutput(workloadInfo, "MeanQueueDescriptor");
+
+ const TensorInfo& input = workloadInfo.m_InputTensorInfos[0];
+ const TensorInfo& output = workloadInfo.m_OutputTensorInfos[0];
+
+ if (m_Keepdims)
+ {
+ ValidateTensorNumDimensions(output, "MeanQueueDescriptor", input.GetNumDimensions(), "output");
+ }
+ else if (m_Axis == nullptr)
+ {
+ ValidateTensorNumDimensions(output, "MeanQueueDescriptor", 1, "output");
+ }
+ else
+ {
+ const TensorInfo& axis = m_Axis->GetTensorInfo();
+ ValidateTensorNumDimensions(axis, "MeanQueueDescriptor", 1, "axis");
+ ValidateTensorMaxNumElements(axis, "MeanQueueDescriptor", input.GetNumDimensions(), "axis");
+ unsigned int outputDim = input.GetNumDimensions() - axis.GetNumElements();
+ ValidateTensorNumDimensions(output,
+ "MeanQueueDescriptor",
+ outputDim > 0 ? outputDim : 1,
+ "output");
+ }
}
} //namespace armnn