diff options
author | James Conroy <james.conroy@arm.com> | 2019-10-08 15:41:34 +0100 |
---|---|---|
committer | Jim Flynn Arm <jim.flynn@arm.com> | 2019-10-10 08:48:39 +0000 |
commit | c8724c7b9ff663538bd32ad789dbcc3e1aa88637 (patch) | |
tree | 629a55e0f949fd9bede7e3b34cdc7ff337f9762c /src/armnn/layers/ArgMinMaxLayer.cpp | |
parent | 92bbcaed655d1dfe696f12f264599589b8ada602 (diff) | |
download | armnn-c8724c7b9ff663538bd32ad789dbcc3e1aa88637.tar.gz |
IVGCVSW-3944 Add ArgMinMax output shape validation
Signed-off-by: James Conroy <james.conroy@arm.com>
Change-Id: I469895da158b062cd19248832525fa21527f7d41
Diffstat (limited to 'src/armnn/layers/ArgMinMaxLayer.cpp')
-rw-r--r-- | src/armnn/layers/ArgMinMaxLayer.cpp | 39 |
1 files changed, 39 insertions, 0 deletions
diff --git a/src/armnn/layers/ArgMinMaxLayer.cpp b/src/armnn/layers/ArgMinMaxLayer.cpp index aad95eb0cf..bfd71d519b 100644 --- a/src/armnn/layers/ArgMinMaxLayer.cpp +++ b/src/armnn/layers/ArgMinMaxLayer.cpp @@ -6,6 +6,8 @@ #include "LayerCloneBase.hpp" +#include <TensorUtils.hpp> + #include <armnn/TypesUtils.hpp> #include <backendsCommon/WorkloadData.hpp> #include <backendsCommon/WorkloadFactory.hpp> @@ -30,6 +32,43 @@ ArgMinMaxLayer* ArgMinMaxLayer::Clone(Graph& graph) const return CloneBase<ArgMinMaxLayer>(graph, m_Param, GetName()); } +std::vector<TensorShape> ArgMinMaxLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const +{ + BOOST_ASSERT(inputShapes.size() == 1); + + TensorShape inputShape = inputShapes[0]; + auto inputNumDimensions = inputShape.GetNumDimensions(); + + auto axis = m_Param.m_Axis; + auto unsignedAxis = armnnUtils::GetUnsignedAxis(inputNumDimensions, axis); + + BOOST_ASSERT(unsignedAxis <= inputNumDimensions); + + // 1D input shape results in scalar output + if (inputShape.GetNumDimensions() == 1) + { + std::vector<unsigned int> tensorDimensions(1, 1); + TensorShape outputShape(1, tensorDimensions.data()); + + return std::vector<TensorShape>({ outputShape }); + } + + std::vector<unsigned int> tensorDimensions(inputNumDimensions - 1, 0); + for (unsigned int i = 0; i < unsignedAxis; ++i) + { + tensorDimensions[i] = inputShape[i]; + } + + for (unsigned int i = unsignedAxis + 1; i < inputNumDimensions; ++i) + { + tensorDimensions[i - 1] = inputShape[i]; + } + + TensorShape outputShape = TensorShape(inputNumDimensions - 1, tensorDimensions.data()); + + return std::vector<TensorShape>({ outputShape }); +} + void ArgMinMaxLayer::ValidateTensorShapesFromInputs() { VerifyLayerConnections(1, CHECK_LOCATION()); |