diff options
Diffstat (limited to 'src/armnn')
-rw-r--r-- | src/armnn/layers/GatherLayer.cpp | 38 | ||||
-rw-r--r-- | src/armnn/layers/GatherLayer.hpp | 5 |
2 files changed, 29 insertions, 14 deletions
diff --git a/src/armnn/layers/GatherLayer.cpp b/src/armnn/layers/GatherLayer.cpp index 9a4f9bf8f0..cdbdaabcdc 100644 --- a/src/armnn/layers/GatherLayer.cpp +++ b/src/armnn/layers/GatherLayer.cpp @@ -31,16 +31,11 @@ GatherLayer* GatherLayer::Clone(Graph& graph) const return CloneBase<GatherLayer>(graph, m_Param, GetName()); } -void GatherLayer::ValidateTensorShapesFromInputs() +std::vector<TensorShape> GatherLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const { - VerifyLayerConnections(2, CHECK_LOCATION()); - - const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape(); - - VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod); - - const TensorInfo& params = GetInputSlot(0).GetConnection()->GetTensorInfo(); - const TensorInfo& indices = GetInputSlot(1).GetConnection()->GetTensorInfo(); + ARMNN_ASSERT(inputShapes.size() == 2); + const TensorShape& params = inputShapes[0]; + const TensorShape& indices = inputShapes[1]; const unsigned int paramsDim = params.GetNumDimensions(); const unsigned int indicesDim = indices.GetNumDimensions(); @@ -57,20 +52,35 @@ void GatherLayer::ValidateTensorShapesFromInputs() for (unsigned int i = 0; i < axis; ++i) { - dimSizes.push_back(params.GetShape()[i]); + dimSizes.push_back(params[i]); } for (unsigned int i = axis; i < indicesDim + axis; ++i) { - dimSizes.push_back(indices.GetShape()[i - axis]); + dimSizes.push_back(indices[i - axis]); } for (unsigned int i = 1 + axis; i < paramsDim; ++i) { - dimSizes.push_back(params.GetShape()[i]); + dimSizes.push_back(params[i]); } - const TensorShape& inferredShape = TensorShape(outputDim, dimSizes.data()); + return std::vector<TensorShape>({ TensorShape({outputDim, dimSizes.data()})}); +} + +void GatherLayer::ValidateTensorShapesFromInputs() +{ + VerifyLayerConnections(2, CHECK_LOCATION()); + + const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape(); + + VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod); + + std::vector<TensorShape> inferredShapes = InferOutputShapes( + {GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(), + GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape()}); + ARMNN_ASSERT(inferredShapes.size() == 1); + ARMNN_ASSERT(inferredShapes[0].GetDimensionality() == Dimensionality::Specified); - ValidateAndCopyShape(outputShape, inferredShape, m_ShapeInferenceMethod, "GatherLayer"); + ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "GatherLayer"); } void GatherLayer::Accept(ILayerVisitor& visitor) const diff --git a/src/armnn/layers/GatherLayer.hpp b/src/armnn/layers/GatherLayer.hpp index 010af37b49..3bc8c69bc4 100644 --- a/src/armnn/layers/GatherLayer.hpp +++ b/src/armnn/layers/GatherLayer.hpp @@ -24,6 +24,11 @@ public: /// @param [in] graph The graph into which this layer is being cloned. GatherLayer* Clone(Graph& graph) const override; + /// Infers the output shapes from given input shapes and layer properties. + /// @param [in] inputShapes The input shapes layer has. + /// @return A vector to the inferred output shape. + std::vector<TensorShape> InferOutputShapes(const std::vector<TensorShape>& inputShapes) const override; + /// Check if the input tensor shape(s). /// will lead to a valid configuration of @ref GatherLayer. /// @param [in] shapeInferenceMethod Indicates if output shape shall be overwritten or just validate. |