aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon
diff options
context:
space:
mode:
authorJames Conroy <james.conroy@arm.com>2019-10-08 15:41:34 +0100
committerJim Flynn Arm <jim.flynn@arm.com>2019-10-10 08:48:39 +0000
commitc8724c7b9ff663538bd32ad789dbcc3e1aa88637 (patch)
tree629a55e0f949fd9bede7e3b34cdc7ff337f9762c /src/backends/backendsCommon
parent92bbcaed655d1dfe696f12f264599589b8ada602 (diff)
downloadarmnn-c8724c7b9ff663538bd32ad789dbcc3e1aa88637.tar.gz
IVGCVSW-3944 Add ArgMinMax output shape validation
Signed-off-by: James Conroy <james.conroy@arm.com> Change-Id: I469895da158b062cd19248832525fa21527f7d41
Diffstat (limited to 'src/backends/backendsCommon')
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp36
-rw-r--r--src/backends/backendsCommon/test/layerTests/ArgMinMaxTestImpl.cpp4
2 files changed, 38 insertions, 2 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index 89277d798f..ea0e5c82b8 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -15,6 +15,7 @@
#include <boost/format.hpp>
#include <boost/numeric/conversion/cast.hpp>
+#include <TensorUtils.hpp>
using namespace armnnUtils;
@@ -485,6 +486,41 @@ void ArgMinMaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
};
ValidateDataTypes(inputTensorInfo, supportedInputTypes, descriptorName);
+
+ auto inputShape = inputTensorInfo.GetShape();
+ auto outputShape = outputTensorInfo.GetShape();
+
+ auto inputNumDimensions = inputShape.GetNumDimensions();
+ auto unsignedAxis = armnnUtils::GetUnsignedAxis(inputNumDimensions, m_Parameters.m_Axis);
+
+ const std::string outputShapeError{": Output tensor shape does not match shape inferred from input tensor."};
+
+ // 1D input shape results in scalar output shape
+ if (inputShape.GetNumDimensions() == 1)
+ {
+ if (outputShape.GetNumDimensions() != 1 && outputShape[0] != 1)
+ {
+ throw InvalidArgumentException(descriptorName + outputShapeError);
+ }
+ }
+ else
+ {
+ for (unsigned int i = 0; i < unsignedAxis; ++i)
+ {
+ if (outputShape[i] != inputShape[i])
+ {
+ throw InvalidArgumentException(descriptorName + outputShapeError);
+ }
+ }
+
+ for (auto i = unsignedAxis + 1; i < inputNumDimensions; ++i)
+ {
+ if (outputShape[i - 1] != inputShape[i])
+ {
+ throw InvalidArgumentException(descriptorName + outputShapeError);
+ }
+ }
+ }
}
void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
diff --git a/src/backends/backendsCommon/test/layerTests/ArgMinMaxTestImpl.cpp b/src/backends/backendsCommon/test/layerTests/ArgMinMaxTestImpl.cpp
index e023d60bf0..be7ef4e32e 100644
--- a/src/backends/backendsCommon/test/layerTests/ArgMinMaxTestImpl.cpp
+++ b/src/backends/backendsCommon/test/layerTests/ArgMinMaxTestImpl.cpp
@@ -190,7 +190,7 @@ LayerTestResult<int32_t, 3> ArgMaxHeightTest(
const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
{
const armnn::TensorShape inputShape{ 1, 3, 2, 4};
- const armnn::TensorShape outputShape{ 3, 1, 4 };
+ const armnn::TensorShape outputShape{ 1, 3, 4 };
armnn::TensorInfo inputTensorInfo(inputShape, ArmnnType);
armnn::TensorInfo outputTensorInfo(outputShape, armnn::DataType::Signed32);
@@ -219,7 +219,7 @@ LayerTestResult<int32_t, 3> ArgMinWidthTest(
const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
{
const armnn::TensorShape inputShape{ 1, 3, 2, 4};
- const armnn::TensorShape outputShape{ 3, 2, 1 };
+ const armnn::TensorShape outputShape{ 1, 3, 2 };
armnn::TensorInfo inputTensorInfo(inputShape, ArmnnType);
armnn::TensorInfo outputTensorInfo(outputShape, armnn::DataType::Signed32);