From 452274c86245082ce20563ede12b92af81dba38a Mon Sep 17 00:00:00 2001 From: Narumol Prangnawarat Date: Thu, 23 Sep 2021 16:12:19 +0100 Subject: IVGCVSW-6459 Add support of scalar and flexible output datatypes to ONNX parser Signed-off-by: Narumol Prangnawarat Change-Id: Id1e933f6ae55ddc1a57c80c9f6a5757ccb61f018 --- src/armnn/layers/GatherLayer.cpp | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) (limited to 'src/armnn/layers') diff --git a/src/armnn/layers/GatherLayer.cpp b/src/armnn/layers/GatherLayer.cpp index e8b67b8348..a808c42384 100644 --- a/src/armnn/layers/GatherLayer.cpp +++ b/src/armnn/layers/GatherLayer.cpp @@ -37,6 +37,11 @@ std::vector GatherLayer::InferOutputShapes(const std::vector({ TensorShape(Dimensionality::Scalar)}); + } + const unsigned int paramsDim = params.GetNumDimensions(); const unsigned int indicesDim = indices.GetNumDimensions(); const unsigned int outputDim = paramsDim - 1 + indicesDim; @@ -78,7 +83,8 @@ void GatherLayer::ValidateTensorShapesFromInputs() {GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(), GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape()}); ARMNN_ASSERT(inferredShapes.size() == 1); - ARMNN_ASSERT(inferredShapes[0].GetDimensionality() == Dimensionality::Specified); + ARMNN_ASSERT(inferredShapes[0].GetDimensionality() == Dimensionality::Specified || + inferredShapes[0].GetDimensionality() == Dimensionality::Scalar); ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "GatherLayer"); } -- cgit v1.2.1