diff options
Diffstat (limited to 'src/armnn/layers/GatherLayer.cpp')
-rw-r--r-- | src/armnn/layers/GatherLayer.cpp | 8 |
1 files changed, 7 insertions, 1 deletions
diff --git a/src/armnn/layers/GatherLayer.cpp b/src/armnn/layers/GatherLayer.cpp index e8b67b8348..a808c42384 100644 --- a/src/armnn/layers/GatherLayer.cpp +++ b/src/armnn/layers/GatherLayer.cpp @@ -37,6 +37,11 @@ std::vector<TensorShape> GatherLayer::InferOutputShapes(const std::vector<Tensor const TensorShape& params = inputShapes[0]; const TensorShape& indices = inputShapes[1]; + if (indices.GetDimensionality() == Dimensionality::Scalar && indices.GetNumDimensions() == 1) + { + return std::vector<TensorShape>({ TensorShape(Dimensionality::Scalar)}); + } + const unsigned int paramsDim = params.GetNumDimensions(); const unsigned int indicesDim = indices.GetNumDimensions(); const unsigned int outputDim = paramsDim - 1 + indicesDim; @@ -78,7 +83,8 @@ void GatherLayer::ValidateTensorShapesFromInputs() {GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(), GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape()}); ARMNN_ASSERT(inferredShapes.size() == 1); - ARMNN_ASSERT(inferredShapes[0].GetDimensionality() == Dimensionality::Specified); + ARMNN_ASSERT(inferredShapes[0].GetDimensionality() == Dimensionality::Specified || + inferredShapes[0].GetDimensionality() == Dimensionality::Scalar); ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "GatherLayer"); } |