aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/layers/GatherLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/layers/GatherLayer.cpp')
-rw-r--r--src/armnn/layers/GatherLayer.cpp38
1 files changed, 24 insertions, 14 deletions
diff --git a/src/armnn/layers/GatherLayer.cpp b/src/armnn/layers/GatherLayer.cpp
index 9a4f9bf8f0..cdbdaabcdc 100644
--- a/src/armnn/layers/GatherLayer.cpp
+++ b/src/armnn/layers/GatherLayer.cpp
@@ -31,16 +31,11 @@ GatherLayer* GatherLayer::Clone(Graph& graph) const
return CloneBase<GatherLayer>(graph, m_Param, GetName());
}
-void GatherLayer::ValidateTensorShapesFromInputs()
+std::vector<TensorShape> GatherLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
{
- VerifyLayerConnections(2, CHECK_LOCATION());
-
- const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape();
-
- VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod);
-
- const TensorInfo& params = GetInputSlot(0).GetConnection()->GetTensorInfo();
- const TensorInfo& indices = GetInputSlot(1).GetConnection()->GetTensorInfo();
+ ARMNN_ASSERT(inputShapes.size() == 2);
+ const TensorShape& params = inputShapes[0];
+ const TensorShape& indices = inputShapes[1];
const unsigned int paramsDim = params.GetNumDimensions();
const unsigned int indicesDim = indices.GetNumDimensions();
@@ -57,20 +52,35 @@ void GatherLayer::ValidateTensorShapesFromInputs()
for (unsigned int i = 0; i < axis; ++i)
{
- dimSizes.push_back(params.GetShape()[i]);
+ dimSizes.push_back(params[i]);
}
for (unsigned int i = axis; i < indicesDim + axis; ++i)
{
- dimSizes.push_back(indices.GetShape()[i - axis]);
+ dimSizes.push_back(indices[i - axis]);
}
for (unsigned int i = 1 + axis; i < paramsDim; ++i)
{
- dimSizes.push_back(params.GetShape()[i]);
+ dimSizes.push_back(params[i]);
}
- const TensorShape& inferredShape = TensorShape(outputDim, dimSizes.data());
+ return std::vector<TensorShape>({ TensorShape({outputDim, dimSizes.data()})});
+}
+
+void GatherLayer::ValidateTensorShapesFromInputs()
+{
+ VerifyLayerConnections(2, CHECK_LOCATION());
+
+ const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape();
+
+ VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod);
+
+ std::vector<TensorShape> inferredShapes = InferOutputShapes(
+ {GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(),
+ GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape()});
+ ARMNN_ASSERT(inferredShapes.size() == 1);
+ ARMNN_ASSERT(inferredShapes[0].GetDimensionality() == Dimensionality::Specified);
- ValidateAndCopyShape(outputShape, inferredShape, m_ShapeInferenceMethod, "GatherLayer");
+ ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "GatherLayer");
}
void GatherLayer::Accept(ILayerVisitor& visitor) const