diff options
author | James Conroy <james.conroy@arm.com> | 2019-10-08 15:41:34 +0100 |
---|---|---|
committer | Jim Flynn Arm <jim.flynn@arm.com> | 2019-10-10 08:48:39 +0000 |
commit | c8724c7b9ff663538bd32ad789dbcc3e1aa88637 (patch) | |
tree | 629a55e0f949fd9bede7e3b34cdc7ff337f9762c /src/armnn/layers | |
parent | 92bbcaed655d1dfe696f12f264599589b8ada602 (diff) | |
download | armnn-c8724c7b9ff663538bd32ad789dbcc3e1aa88637.tar.gz |
IVGCVSW-3944 Add ArgMinMax output shape validation
Signed-off-by: James Conroy <james.conroy@arm.com>
Change-Id: I469895da158b062cd19248832525fa21527f7d41
Diffstat (limited to 'src/armnn/layers')
-rw-r--r-- | src/armnn/layers/ArgMinMaxLayer.cpp | 39 | ||||
-rw-r--r-- | src/armnn/layers/ArgMinMaxLayer.hpp | 5 |
2 files changed, 44 insertions, 0 deletions
diff --git a/src/armnn/layers/ArgMinMaxLayer.cpp b/src/armnn/layers/ArgMinMaxLayer.cpp index aad95eb0cf..bfd71d519b 100644 --- a/src/armnn/layers/ArgMinMaxLayer.cpp +++ b/src/armnn/layers/ArgMinMaxLayer.cpp @@ -6,6 +6,8 @@ #include "LayerCloneBase.hpp" +#include <TensorUtils.hpp> + #include <armnn/TypesUtils.hpp> #include <backendsCommon/WorkloadData.hpp> #include <backendsCommon/WorkloadFactory.hpp> @@ -30,6 +32,43 @@ ArgMinMaxLayer* ArgMinMaxLayer::Clone(Graph& graph) const return CloneBase<ArgMinMaxLayer>(graph, m_Param, GetName()); } +std::vector<TensorShape> ArgMinMaxLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const +{ + BOOST_ASSERT(inputShapes.size() == 1); + + TensorShape inputShape = inputShapes[0]; + auto inputNumDimensions = inputShape.GetNumDimensions(); + + auto axis = m_Param.m_Axis; + auto unsignedAxis = armnnUtils::GetUnsignedAxis(inputNumDimensions, axis); + + BOOST_ASSERT(unsignedAxis <= inputNumDimensions); + + // 1D input shape results in scalar output + if (inputShape.GetNumDimensions() == 1) + { + std::vector<unsigned int> tensorDimensions(1, 1); + TensorShape outputShape(1, tensorDimensions.data()); + + return std::vector<TensorShape>({ outputShape }); + } + + std::vector<unsigned int> tensorDimensions(inputNumDimensions - 1, 0); + for (unsigned int i = 0; i < unsignedAxis; ++i) + { + tensorDimensions[i] = inputShape[i]; + } + + for (unsigned int i = unsignedAxis + 1; i < inputNumDimensions; ++i) + { + tensorDimensions[i - 1] = inputShape[i]; + } + + TensorShape outputShape = TensorShape(inputNumDimensions - 1, tensorDimensions.data()); + + return std::vector<TensorShape>({ outputShape }); +} + void ArgMinMaxLayer::ValidateTensorShapesFromInputs() { VerifyLayerConnections(1, CHECK_LOCATION()); diff --git a/src/armnn/layers/ArgMinMaxLayer.hpp b/src/armnn/layers/ArgMinMaxLayer.hpp index ca1337f065..43ea056c9e 100644 --- a/src/armnn/layers/ArgMinMaxLayer.hpp +++ b/src/armnn/layers/ArgMinMaxLayer.hpp @@ -25,6 +25,11 @@ public: /// @param [in] graph The graph into which this layer is being cloned. ArgMinMaxLayer* Clone(Graph& graph) const override; + /// Infers the output shape from a given input shape and axis parameter. + /// @param [in] inputShapes The vector of input shapes for ArgMinMax. + /// @return A vector of inferred output shapes. + std::vector<TensorShape> InferOutputShapes(const std::vector<TensorShape>& inputShapes) const override; + /// Check if the input tensor shape(s) /// will lead to a valid configuration of @ref ArgMinMaxLayer. void ValidateTensorShapesFromInputs() override; |