aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/layers/GatherLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/layers/GatherLayer.cpp')
-rw-r--r--src/armnn/layers/GatherLayer.cpp8
1 files changed, 7 insertions, 1 deletions
diff --git a/src/armnn/layers/GatherLayer.cpp b/src/armnn/layers/GatherLayer.cpp
index e8b67b8348..a808c42384 100644
--- a/src/armnn/layers/GatherLayer.cpp
+++ b/src/armnn/layers/GatherLayer.cpp
@@ -37,6 +37,11 @@ std::vector<TensorShape> GatherLayer::InferOutputShapes(const std::vector<Tensor
const TensorShape& params = inputShapes[0];
const TensorShape& indices = inputShapes[1];
+ if (indices.GetDimensionality() == Dimensionality::Scalar && indices.GetNumDimensions() == 1)
+ {
+ return std::vector<TensorShape>({ TensorShape(Dimensionality::Scalar)});
+ }
+
const unsigned int paramsDim = params.GetNumDimensions();
const unsigned int indicesDim = indices.GetNumDimensions();
const unsigned int outputDim = paramsDim - 1 + indicesDim;
@@ -78,7 +83,8 @@ void GatherLayer::ValidateTensorShapesFromInputs()
{GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(),
GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape()});
ARMNN_ASSERT(inferredShapes.size() == 1);
- ARMNN_ASSERT(inferredShapes[0].GetDimensionality() == Dimensionality::Specified);
+ ARMNN_ASSERT(inferredShapes[0].GetDimensionality() == Dimensionality::Specified ||
+ inferredShapes[0].GetDimensionality() == Dimensionality::Scalar);
ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "GatherLayer");
}