diff options
author | James Conroy <james.conroy@arm.com> | 2019-10-08 15:41:34 +0100 |
---|---|---|
committer | Jim Flynn Arm <jim.flynn@arm.com> | 2019-10-10 08:48:39 +0000 |
commit | c8724c7b9ff663538bd32ad789dbcc3e1aa88637 (patch) | |
tree | 629a55e0f949fd9bede7e3b34cdc7ff337f9762c /src/backends/backendsCommon/WorkloadData.cpp | |
parent | 92bbcaed655d1dfe696f12f264599589b8ada602 (diff) | |
download | armnn-c8724c7b9ff663538bd32ad789dbcc3e1aa88637.tar.gz |
IVGCVSW-3944 Add ArgMinMax output shape validation
Signed-off-by: James Conroy <james.conroy@arm.com>
Change-Id: I469895da158b062cd19248832525fa21527f7d41
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.cpp | 36 |
1 files changed, 36 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index 89277d798f..ea0e5c82b8 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -15,6 +15,7 @@ #include <boost/format.hpp> #include <boost/numeric/conversion/cast.hpp> +#include <TensorUtils.hpp> using namespace armnnUtils; @@ -485,6 +486,41 @@ void ArgMinMaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const }; ValidateDataTypes(inputTensorInfo, supportedInputTypes, descriptorName); + + auto inputShape = inputTensorInfo.GetShape(); + auto outputShape = outputTensorInfo.GetShape(); + + auto inputNumDimensions = inputShape.GetNumDimensions(); + auto unsignedAxis = armnnUtils::GetUnsignedAxis(inputNumDimensions, m_Parameters.m_Axis); + + const std::string outputShapeError{": Output tensor shape does not match shape inferred from input tensor."}; + + // 1D input shape results in scalar output shape + if (inputShape.GetNumDimensions() == 1) + { + if (outputShape.GetNumDimensions() != 1 && outputShape[0] != 1) + { + throw InvalidArgumentException(descriptorName + outputShapeError); + } + } + else + { + for (unsigned int i = 0; i < unsignedAxis; ++i) + { + if (outputShape[i] != inputShape[i]) + { + throw InvalidArgumentException(descriptorName + outputShapeError); + } + } + + for (auto i = unsignedAxis + 1; i < inputNumDimensions; ++i) + { + if (outputShape[i - 1] != inputShape[i]) + { + throw InvalidArgumentException(descriptorName + outputShapeError); + } + } + } } void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const |