aboutsummaryrefslogtreecommitdiff
path: root/src/armnn
diff options
context:
space:
mode:
authorNarumol Prangnawarat <narumol.prangnawarat@arm.com>2021-09-17 21:08:57 +0100
committerJim Flynn <jim.flynn@arm.com>2021-09-24 14:17:31 +0000
commitf10b15a8946f39bdf3f60cebc59d2963069eedca (patch)
tree9cba39db69acad2bd5728cefbad578161e6ba63c /src/armnn
parent4fcc8632aaa64e683d98199659093d1aa99ffb08 (diff)
downloadarmnn-f10b15a8946f39bdf3f60cebc59d2963069eedca.tar.gz
IVGCVSW-6382 Add Gather operator support to ONNX parser
* Add ParseGather to support Gather operator on ONNX * Add Support of int64 converted to int32 for constant * Add OnnxParserTestUtils * Refactor ValidateTensorShapesFromInputs of GatherLayer * Unit tests Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com> Change-Id: Ie9dff640240e14a062fef38f7faf0ccc212de5f7
Diffstat (limited to 'src/armnn')
-rw-r--r--src/armnn/layers/GatherLayer.cpp38
-rw-r--r--src/armnn/layers/GatherLayer.hpp5
2 files changed, 29 insertions, 14 deletions
diff --git a/src/armnn/layers/GatherLayer.cpp b/src/armnn/layers/GatherLayer.cpp
index 9a4f9bf8f0..cdbdaabcdc 100644
--- a/src/armnn/layers/GatherLayer.cpp
+++ b/src/armnn/layers/GatherLayer.cpp
@@ -31,16 +31,11 @@ GatherLayer* GatherLayer::Clone(Graph& graph) const
return CloneBase<GatherLayer>(graph, m_Param, GetName());
}
-void GatherLayer::ValidateTensorShapesFromInputs()
+std::vector<TensorShape> GatherLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
{
- VerifyLayerConnections(2, CHECK_LOCATION());
-
- const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape();
-
- VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod);
-
- const TensorInfo& params = GetInputSlot(0).GetConnection()->GetTensorInfo();
- const TensorInfo& indices = GetInputSlot(1).GetConnection()->GetTensorInfo();
+ ARMNN_ASSERT(inputShapes.size() == 2);
+ const TensorShape& params = inputShapes[0];
+ const TensorShape& indices = inputShapes[1];
const unsigned int paramsDim = params.GetNumDimensions();
const unsigned int indicesDim = indices.GetNumDimensions();
@@ -57,20 +52,35 @@ void GatherLayer::ValidateTensorShapesFromInputs()
for (unsigned int i = 0; i < axis; ++i)
{
- dimSizes.push_back(params.GetShape()[i]);
+ dimSizes.push_back(params[i]);
}
for (unsigned int i = axis; i < indicesDim + axis; ++i)
{
- dimSizes.push_back(indices.GetShape()[i - axis]);
+ dimSizes.push_back(indices[i - axis]);
}
for (unsigned int i = 1 + axis; i < paramsDim; ++i)
{
- dimSizes.push_back(params.GetShape()[i]);
+ dimSizes.push_back(params[i]);
}
- const TensorShape& inferredShape = TensorShape(outputDim, dimSizes.data());
+ return std::vector<TensorShape>({ TensorShape({outputDim, dimSizes.data()})});
+}
+
+void GatherLayer::ValidateTensorShapesFromInputs()
+{
+ VerifyLayerConnections(2, CHECK_LOCATION());
+
+ const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape();
+
+ VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod);
+
+ std::vector<TensorShape> inferredShapes = InferOutputShapes(
+ {GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(),
+ GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape()});
+ ARMNN_ASSERT(inferredShapes.size() == 1);
+ ARMNN_ASSERT(inferredShapes[0].GetDimensionality() == Dimensionality::Specified);
- ValidateAndCopyShape(outputShape, inferredShape, m_ShapeInferenceMethod, "GatherLayer");
+ ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "GatherLayer");
}
void GatherLayer::Accept(ILayerVisitor& visitor) const
diff --git a/src/armnn/layers/GatherLayer.hpp b/src/armnn/layers/GatherLayer.hpp
index 010af37b49..3bc8c69bc4 100644
--- a/src/armnn/layers/GatherLayer.hpp
+++ b/src/armnn/layers/GatherLayer.hpp
@@ -24,6 +24,11 @@ public:
/// @param [in] graph The graph into which this layer is being cloned.
GatherLayer* Clone(Graph& graph) const override;
+ /// Infers the output shapes from given input shapes and layer properties.
+ /// @param [in] inputShapes The input shapes layer has.
+ /// @return A vector to the inferred output shape.
+ std::vector<TensorShape> InferOutputShapes(const std::vector<TensorShape>& inputShapes) const override;
+
/// Check if the input tensor shape(s).
/// will lead to a valid configuration of @ref GatherLayer.
/// @param [in] shapeInferenceMethod Indicates if output shape shall be overwritten or just validate.