From eb06191bdaeab6640691a00349eb93141fb49fe8 Mon Sep 17 00:00:00 2001 From: narpra01 Date: Mon, 10 Sep 2018 17:35:27 +0100 Subject: IVGCVSW-1831 - Add dimension check to MeanQueueDescriptor::Validate to check if the output dimension is correct from a given input and options. Change-Id: Ibc15d9ea3151a7ba1935feafeb1843ee035e7f2e --- src/armnn/backends/WorkloadData.cpp | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) (limited to 'src/armnn/backends/WorkloadData.cpp') diff --git a/src/armnn/backends/WorkloadData.cpp b/src/armnn/backends/WorkloadData.cpp index c934a53a5e..3ed77dacdb 100644 --- a/src/armnn/backends/WorkloadData.cpp +++ b/src/armnn/backends/WorkloadData.cpp @@ -129,6 +129,18 @@ void ValidateTensorNumDimensions(const TensorInfo& tensor, } } +void ValidateTensorMaxNumElements(const TensorInfo& tensor, + std::string const& descName, + unsigned int maxNumElements, + std::string const& tensorName) +{ + if (tensor.GetNumElements() > maxNumElements) + { + throw InvalidArgumentException(descName + ": Expected maximum of " + to_string(maxNumElements) + " but got " + + to_string(tensor.GetNumElements()) + " elements for " + tensorName + " tensor."); + } +} + //--------------------------------------------------------------- void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType, const std::string& descName, std::string const& tensorName) @@ -828,6 +840,29 @@ void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { ValidateSingleInput(workloadInfo, "MeanQueueDescriptor"); ValidateSingleOutput(workloadInfo, "MeanQueueDescriptor"); + + const TensorInfo& input = workloadInfo.m_InputTensorInfos[0]; + const TensorInfo& output = workloadInfo.m_OutputTensorInfos[0]; + + if (m_Keepdims) + { + ValidateTensorNumDimensions(output, "MeanQueueDescriptor", input.GetNumDimensions(), "output"); + } + else if (m_Axis == nullptr) + { + ValidateTensorNumDimensions(output, "MeanQueueDescriptor", 1, "output"); + } + else + { + const TensorInfo& axis = m_Axis->GetTensorInfo(); + ValidateTensorNumDimensions(axis, "MeanQueueDescriptor", 1, "axis"); + ValidateTensorMaxNumElements(axis, "MeanQueueDescriptor", input.GetNumDimensions(), "axis"); + unsigned int outputDim = input.GetNumDimensions() - axis.GetNumElements(); + ValidateTensorNumDimensions(output, + "MeanQueueDescriptor", + outputDim > 0 ? outputDim : 1, + "output"); + } } } //namespace armnn -- cgit v1.2.1