diff options
Diffstat (limited to 'src/armnn/layers')
-rw-r--r-- | src/armnn/layers/ArgMinMaxLayer.cpp | 39 | ||||
-rw-r--r-- | src/armnn/layers/ArgMinMaxLayer.hpp | 5 |
2 files changed, 44 insertions, 0 deletions
diff --git a/src/armnn/layers/ArgMinMaxLayer.cpp b/src/armnn/layers/ArgMinMaxLayer.cpp index aad95eb0cf..bfd71d519b 100644 --- a/src/armnn/layers/ArgMinMaxLayer.cpp +++ b/src/armnn/layers/ArgMinMaxLayer.cpp @@ -6,6 +6,8 @@ #include "LayerCloneBase.hpp" +#include <TensorUtils.hpp> + #include <armnn/TypesUtils.hpp> #include <backendsCommon/WorkloadData.hpp> #include <backendsCommon/WorkloadFactory.hpp> @@ -30,6 +32,43 @@ ArgMinMaxLayer* ArgMinMaxLayer::Clone(Graph& graph) const return CloneBase<ArgMinMaxLayer>(graph, m_Param, GetName()); } +std::vector<TensorShape> ArgMinMaxLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const +{ + BOOST_ASSERT(inputShapes.size() == 1); + + TensorShape inputShape = inputShapes[0]; + auto inputNumDimensions = inputShape.GetNumDimensions(); + + auto axis = m_Param.m_Axis; + auto unsignedAxis = armnnUtils::GetUnsignedAxis(inputNumDimensions, axis); + + BOOST_ASSERT(unsignedAxis <= inputNumDimensions); + + // 1D input shape results in scalar output + if (inputShape.GetNumDimensions() == 1) + { + std::vector<unsigned int> tensorDimensions(1, 1); + TensorShape outputShape(1, tensorDimensions.data()); + + return std::vector<TensorShape>({ outputShape }); + } + + std::vector<unsigned int> tensorDimensions(inputNumDimensions - 1, 0); + for (unsigned int i = 0; i < unsignedAxis; ++i) + { + tensorDimensions[i] = inputShape[i]; + } + + for (unsigned int i = unsignedAxis + 1; i < inputNumDimensions; ++i) + { + tensorDimensions[i - 1] = inputShape[i]; + } + + TensorShape outputShape = TensorShape(inputNumDimensions - 1, tensorDimensions.data()); + + return std::vector<TensorShape>({ outputShape }); +} + void ArgMinMaxLayer::ValidateTensorShapesFromInputs() { VerifyLayerConnections(1, CHECK_LOCATION()); diff --git a/src/armnn/layers/ArgMinMaxLayer.hpp b/src/armnn/layers/ArgMinMaxLayer.hpp index ca1337f065..43ea056c9e 100644 --- a/src/armnn/layers/ArgMinMaxLayer.hpp +++ b/src/armnn/layers/ArgMinMaxLayer.hpp @@ -25,6 +25,11 @@ public: /// @param [in] graph The graph into which this layer is being cloned. ArgMinMaxLayer* Clone(Graph& graph) const override; + /// Infers the output shape from a given input shape and axis parameter. + /// @param [in] inputShapes The vector of input shapes for ArgMinMax. + /// @return A vector of inferred output shapes. + std::vector<TensorShape> InferOutputShapes(const std::vector<TensorShape>& inputShapes) const override; + /// Check if the input tensor shape(s) /// will lead to a valid configuration of @ref ArgMinMaxLayer. void ValidateTensorShapesFromInputs() override; |