aboutsummaryrefslogtreecommitdiff
path: root/src/armnn
diff options
context:
space:
mode:
authorNarumol Prangnawarat <narumol.prangnawarat@arm.com>2021-09-23 16:12:19 +0100
committerNarumol Prangnawarat <narumol.prangnawarat@arm.com>2021-10-07 14:43:09 +0000
commit452274c86245082ce20563ede12b92af81dba38a (patch)
tree79718c6cf86acbb21138068c17aae15c4b172306 /src/armnn
parent4d217c02fe2c0a32ff9da69d8fe375a75173c0f3 (diff)
downloadarmnn-452274c86245082ce20563ede12b92af81dba38a.tar.gz
IVGCVSW-6459 Add support of scalar and flexible output datatypes to ONNX parser
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com> Change-Id: Id1e933f6ae55ddc1a57c80c9f6a5757ccb61f018
Diffstat (limited to 'src/armnn')
-rw-r--r--src/armnn/layers/GatherLayer.cpp8
1 files changed, 7 insertions, 1 deletions
diff --git a/src/armnn/layers/GatherLayer.cpp b/src/armnn/layers/GatherLayer.cpp
index e8b67b8348..a808c42384 100644
--- a/src/armnn/layers/GatherLayer.cpp
+++ b/src/armnn/layers/GatherLayer.cpp
@@ -37,6 +37,11 @@ std::vector<TensorShape> GatherLayer::InferOutputShapes(const std::vector<Tensor
const TensorShape& params = inputShapes[0];
const TensorShape& indices = inputShapes[1];
+ if (indices.GetDimensionality() == Dimensionality::Scalar && indices.GetNumDimensions() == 1)
+ {
+ return std::vector<TensorShape>({ TensorShape(Dimensionality::Scalar)});
+ }
+
const unsigned int paramsDim = params.GetNumDimensions();
const unsigned int indicesDim = indices.GetNumDimensions();
const unsigned int outputDim = paramsDim - 1 + indicesDim;
@@ -78,7 +83,8 @@ void GatherLayer::ValidateTensorShapesFromInputs()
{GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(),
GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape()});
ARMNN_ASSERT(inferredShapes.size() == 1);
- ARMNN_ASSERT(inferredShapes[0].GetDimensionality() == Dimensionality::Specified);
+ ARMNN_ASSERT(inferredShapes[0].GetDimensionality() == Dimensionality::Specified ||
+ inferredShapes[0].GetDimensionality() == Dimensionality::Scalar);
ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "GatherLayer");
}