From 526647333571169076f5e72c9fb18c71025bf7c0 Mon Sep 17 00:00:00 2001 From: Teresa Charlin Date: Mon, 29 Jun 2020 16:27:03 +0100 Subject: IVGCVSW-4903 Connect axis parameter in Gather from android to ACL. !android-nn-driver:3302 Signed-off-by: Teresa Charlin Change-Id: Ifbc49acb5272f8a36719bb68676e44817190537d --- src/armnn/layers/GatherLayer.cpp | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) (limited to 'src/armnn/layers/GatherLayer.cpp') diff --git a/src/armnn/layers/GatherLayer.cpp b/src/armnn/layers/GatherLayer.cpp index a99913073f..3e85d25dac 100644 --- a/src/armnn/layers/GatherLayer.cpp +++ b/src/armnn/layers/GatherLayer.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -13,8 +13,8 @@ namespace armnn { -GatherLayer::GatherLayer(const char* name) - : Layer(2, 1, LayerType::Gather, name) +GatherLayer::GatherLayer(const GatherDescriptor& param, const char* name) + : LayerWithParameters(2, 1, LayerType::Gather, param, name) { } @@ -26,7 +26,7 @@ std::unique_ptr GatherLayer::CreateWorkload(const armnn::IWorkloadFac GatherLayer* GatherLayer::Clone(Graph& graph) const { - return CloneBase(graph, GetName()); + return CloneBase(graph, m_Param, GetName()); } void GatherLayer::ValidateTensorShapesFromInputs(ShapeInferenceMethod shapeInferenceMethod) @@ -44,11 +44,22 @@ void GatherLayer::ValidateTensorShapesFromInputs(ShapeInferenceMethod shapeInfer std::vector dimSizes; - for (unsigned int i = 0; i < indicesDim; ++i) + unsigned int axis = static_cast(m_Param.m_Axis); + if (m_Param.m_Axis < 0) { - dimSizes.push_back(indices.GetShape()[i]); + int32_t axis_aux = static_cast(paramsDim) + m_Param.m_Axis; + axis = static_cast (axis_aux); } - for (unsigned int i = 1; i < paramsDim; ++i) + + for (unsigned int i = 0; i < axis; ++i) + { + dimSizes.push_back(params.GetShape()[i]); + } + for (unsigned int i = axis; i < indicesDim + axis; ++i) + { + dimSizes.push_back(indices.GetShape()[i - axis]); + } + for (unsigned int i = 1 + axis; i < paramsDim; ++i) { dimSizes.push_back(params.GetShape()[i]); } @@ -63,7 +74,7 @@ void GatherLayer::ValidateTensorShapesFromInputs(ShapeInferenceMethod shapeInfer void GatherLayer::Accept(ILayerVisitor& visitor) const { - visitor.VisitGatherLayer(this, GetName()); + visitor.VisitGatherLayer(this, GetParameters(), GetName()); } } // namespace armnn -- cgit v1.2.1