diff options
author | Narumol Prangnawarat <narumol.prangnawarat@arm.com> | 2021-09-23 16:12:19 +0100 |
---|---|---|
committer | Narumol Prangnawarat <narumol.prangnawarat@arm.com> | 2021-10-07 14:43:09 +0000 |
commit | 452274c86245082ce20563ede12b92af81dba38a (patch) | |
tree | 79718c6cf86acbb21138068c17aae15c4b172306 /src/armnn/layers/GatherLayer.cpp | |
parent | 4d217c02fe2c0a32ff9da69d8fe375a75173c0f3 (diff) | |
download | armnn-452274c86245082ce20563ede12b92af81dba38a.tar.gz |
IVGCVSW-6459 Add support of scalar and flexible output datatypes to ONNX parser
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com>
Change-Id: Id1e933f6ae55ddc1a57c80c9f6a5757ccb61f018
Diffstat (limited to 'src/armnn/layers/GatherLayer.cpp')
-rw-r--r-- | src/armnn/layers/GatherLayer.cpp | 8 |
1 files changed, 7 insertions, 1 deletions
diff --git a/src/armnn/layers/GatherLayer.cpp b/src/armnn/layers/GatherLayer.cpp index e8b67b8348..a808c42384 100644 --- a/src/armnn/layers/GatherLayer.cpp +++ b/src/armnn/layers/GatherLayer.cpp @@ -37,6 +37,11 @@ std::vector<TensorShape> GatherLayer::InferOutputShapes(const std::vector<Tensor const TensorShape& params = inputShapes[0]; const TensorShape& indices = inputShapes[1]; + if (indices.GetDimensionality() == Dimensionality::Scalar && indices.GetNumDimensions() == 1) + { + return std::vector<TensorShape>({ TensorShape(Dimensionality::Scalar)}); + } + const unsigned int paramsDim = params.GetNumDimensions(); const unsigned int indicesDim = indices.GetNumDimensions(); const unsigned int outputDim = paramsDim - 1 + indicesDim; @@ -78,7 +83,8 @@ void GatherLayer::ValidateTensorShapesFromInputs() {GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(), GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape()}); ARMNN_ASSERT(inferredShapes.size() == 1); - ARMNN_ASSERT(inferredShapes[0].GetDimensionality() == Dimensionality::Specified); + ARMNN_ASSERT(inferredShapes[0].GetDimensionality() == Dimensionality::Specified || + inferredShapes[0].GetDimensionality() == Dimensionality::Scalar); ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "GatherLayer"); } |