diff options
Diffstat (limited to 'src/armnn/layers/GatherLayer.cpp')
-rw-r--r-- | src/armnn/layers/GatherLayer.cpp | 38 |
1 files changed, 24 insertions, 14 deletions
diff --git a/src/armnn/layers/GatherLayer.cpp b/src/armnn/layers/GatherLayer.cpp index 9a4f9bf8f0..cdbdaabcdc 100644 --- a/src/armnn/layers/GatherLayer.cpp +++ b/src/armnn/layers/GatherLayer.cpp @@ -31,16 +31,11 @@ GatherLayer* GatherLayer::Clone(Graph& graph) const return CloneBase<GatherLayer>(graph, m_Param, GetName()); } -void GatherLayer::ValidateTensorShapesFromInputs() +std::vector<TensorShape> GatherLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const { - VerifyLayerConnections(2, CHECK_LOCATION()); - - const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape(); - - VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod); - - const TensorInfo& params = GetInputSlot(0).GetConnection()->GetTensorInfo(); - const TensorInfo& indices = GetInputSlot(1).GetConnection()->GetTensorInfo(); + ARMNN_ASSERT(inputShapes.size() == 2); + const TensorShape& params = inputShapes[0]; + const TensorShape& indices = inputShapes[1]; const unsigned int paramsDim = params.GetNumDimensions(); const unsigned int indicesDim = indices.GetNumDimensions(); @@ -57,20 +52,35 @@ void GatherLayer::ValidateTensorShapesFromInputs() for (unsigned int i = 0; i < axis; ++i) { - dimSizes.push_back(params.GetShape()[i]); + dimSizes.push_back(params[i]); } for (unsigned int i = axis; i < indicesDim + axis; ++i) { - dimSizes.push_back(indices.GetShape()[i - axis]); + dimSizes.push_back(indices[i - axis]); } for (unsigned int i = 1 + axis; i < paramsDim; ++i) { - dimSizes.push_back(params.GetShape()[i]); + dimSizes.push_back(params[i]); } - const TensorShape& inferredShape = TensorShape(outputDim, dimSizes.data()); + return std::vector<TensorShape>({ TensorShape({outputDim, dimSizes.data()})}); +} + +void GatherLayer::ValidateTensorShapesFromInputs() +{ + VerifyLayerConnections(2, CHECK_LOCATION()); + + const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape(); + + VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod); + + std::vector<TensorShape> inferredShapes = InferOutputShapes( + {GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(), + GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape()}); + ARMNN_ASSERT(inferredShapes.size() == 1); + ARMNN_ASSERT(inferredShapes[0].GetDimensionality() == Dimensionality::Specified); - ValidateAndCopyShape(outputShape, inferredShape, m_ShapeInferenceMethod, "GatherLayer"); + ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "GatherLayer"); } void GatherLayer::Accept(ILayerVisitor& visitor) const |