diff options
Diffstat (limited to 'src/armnn/layers/ArgMinMaxLayer.cpp')
-rw-r--r-- | src/armnn/layers/ArgMinMaxLayer.cpp | 39 |
1 files changed, 39 insertions, 0 deletions
diff --git a/src/armnn/layers/ArgMinMaxLayer.cpp b/src/armnn/layers/ArgMinMaxLayer.cpp index aad95eb0cf..bfd71d519b 100644 --- a/src/armnn/layers/ArgMinMaxLayer.cpp +++ b/src/armnn/layers/ArgMinMaxLayer.cpp @@ -6,6 +6,8 @@ #include "LayerCloneBase.hpp" +#include <TensorUtils.hpp> + #include <armnn/TypesUtils.hpp> #include <backendsCommon/WorkloadData.hpp> #include <backendsCommon/WorkloadFactory.hpp> @@ -30,6 +32,43 @@ ArgMinMaxLayer* ArgMinMaxLayer::Clone(Graph& graph) const return CloneBase<ArgMinMaxLayer>(graph, m_Param, GetName()); } +std::vector<TensorShape> ArgMinMaxLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const +{ + BOOST_ASSERT(inputShapes.size() == 1); + + TensorShape inputShape = inputShapes[0]; + auto inputNumDimensions = inputShape.GetNumDimensions(); + + auto axis = m_Param.m_Axis; + auto unsignedAxis = armnnUtils::GetUnsignedAxis(inputNumDimensions, axis); + + BOOST_ASSERT(unsignedAxis <= inputNumDimensions); + + // 1D input shape results in scalar output + if (inputShape.GetNumDimensions() == 1) + { + std::vector<unsigned int> tensorDimensions(1, 1); + TensorShape outputShape(1, tensorDimensions.data()); + + return std::vector<TensorShape>({ outputShape }); + } + + std::vector<unsigned int> tensorDimensions(inputNumDimensions - 1, 0); + for (unsigned int i = 0; i < unsignedAxis; ++i) + { + tensorDimensions[i] = inputShape[i]; + } + + for (unsigned int i = unsignedAxis + 1; i < inputNumDimensions; ++i) + { + tensorDimensions[i - 1] = inputShape[i]; + } + + TensorShape outputShape = TensorShape(inputNumDimensions - 1, tensorDimensions.data()); + + return std::vector<TensorShape>({ outputShape }); +} + void ArgMinMaxLayer::ValidateTensorShapesFromInputs() { VerifyLayerConnections(1, CHECK_LOCATION()); |