From f10b15a8946f39bdf3f60cebc59d2963069eedca Mon Sep 17 00:00:00 2001 From: Narumol Prangnawarat Date: Fri, 17 Sep 2021 21:08:57 +0100 Subject: IVGCVSW-6382 Add Gather operator support to ONNX parser * Add ParseGather to support Gather operator on ONNX * Add Support of int64 converted to int32 for constant * Add OnnxParserTestUtils * Refactor ValidateTensorShapesFromInputs of GatherLayer * Unit tests Signed-off-by: Narumol Prangnawarat Change-Id: Ie9dff640240e14a062fef38f7faf0ccc212de5f7 --- src/armnn/layers/GatherLayer.cpp | 38 ++++++++++++++++++++++++-------------- 1 file changed, 24 insertions(+), 14 deletions(-) (limited to 'src/armnn/layers/GatherLayer.cpp') diff --git a/src/armnn/layers/GatherLayer.cpp b/src/armnn/layers/GatherLayer.cpp index 9a4f9bf8f0..cdbdaabcdc 100644 --- a/src/armnn/layers/GatherLayer.cpp +++ b/src/armnn/layers/GatherLayer.cpp @@ -31,16 +31,11 @@ GatherLayer* GatherLayer::Clone(Graph& graph) const return CloneBase(graph, m_Param, GetName()); } -void GatherLayer::ValidateTensorShapesFromInputs() +std::vector GatherLayer::InferOutputShapes(const std::vector& inputShapes) const { - VerifyLayerConnections(2, CHECK_LOCATION()); - - const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape(); - - VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod); - - const TensorInfo& params = GetInputSlot(0).GetConnection()->GetTensorInfo(); - const TensorInfo& indices = GetInputSlot(1).GetConnection()->GetTensorInfo(); + ARMNN_ASSERT(inputShapes.size() == 2); + const TensorShape& params = inputShapes[0]; + const TensorShape& indices = inputShapes[1]; const unsigned int paramsDim = params.GetNumDimensions(); const unsigned int indicesDim = indices.GetNumDimensions(); @@ -57,20 +52,35 @@ void GatherLayer::ValidateTensorShapesFromInputs() for (unsigned int i = 0; i < axis; ++i) { - dimSizes.push_back(params.GetShape()[i]); + dimSizes.push_back(params[i]); } for (unsigned int i = axis; i < indicesDim + axis; ++i) { - dimSizes.push_back(indices.GetShape()[i - axis]); + dimSizes.push_back(indices[i - axis]); } for (unsigned int i = 1 + axis; i < paramsDim; ++i) { - dimSizes.push_back(params.GetShape()[i]); + dimSizes.push_back(params[i]); } - const TensorShape& inferredShape = TensorShape(outputDim, dimSizes.data()); + return std::vector({ TensorShape({outputDim, dimSizes.data()})}); +} + +void GatherLayer::ValidateTensorShapesFromInputs() +{ + VerifyLayerConnections(2, CHECK_LOCATION()); + + const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape(); + + VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod); + + std::vector inferredShapes = InferOutputShapes( + {GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(), + GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape()}); + ARMNN_ASSERT(inferredShapes.size() == 1); + ARMNN_ASSERT(inferredShapes[0].GetDimensionality() == Dimensionality::Specified); - ValidateAndCopyShape(outputShape, inferredShape, m_ShapeInferenceMethod, "GatherLayer"); + ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "GatherLayer"); } void GatherLayer::Accept(ILayerVisitor& visitor) const -- cgit v1.2.1