aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/layers
diff options
context:
space:
mode:
authorTeresa Charlin <teresa.charlinreyes@arm.com>2020-06-29 16:27:03 +0100
committerTeresaARM <teresa.charlinreyes@arm.com>2020-07-01 08:26:47 +0000
commit526647333571169076f5e72c9fb18c71025bf7c0 (patch)
tree6dc559a7b0fae3705172b09a88fa552926652040 /src/armnn/layers
parentcbd2c230b7ce5f26e2ccccf36b7ad450f6e1ad09 (diff)
downloadarmnn-526647333571169076f5e72c9fb18c71025bf7c0.tar.gz
IVGCVSW-4903 Connect axis parameter in Gather from android to ACL.
!android-nn-driver:3302 Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com> Change-Id: Ifbc49acb5272f8a36719bb68676e44817190537d
Diffstat (limited to 'src/armnn/layers')
-rw-r--r--src/armnn/layers/GatherLayer.cpp27
-rw-r--r--src/armnn/layers/GatherLayer.hpp11
2 files changed, 25 insertions, 13 deletions
diff --git a/src/armnn/layers/GatherLayer.cpp b/src/armnn/layers/GatherLayer.cpp
index a99913073f..3e85d25dac 100644
--- a/src/armnn/layers/GatherLayer.cpp
+++ b/src/armnn/layers/GatherLayer.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -13,8 +13,8 @@
namespace armnn
{
-GatherLayer::GatherLayer(const char* name)
- : Layer(2, 1, LayerType::Gather, name)
+GatherLayer::GatherLayer(const GatherDescriptor& param, const char* name)
+ : LayerWithParameters(2, 1, LayerType::Gather, param, name)
{
}
@@ -26,7 +26,7 @@ std::unique_ptr<IWorkload> GatherLayer::CreateWorkload(const armnn::IWorkloadFac
GatherLayer* GatherLayer::Clone(Graph& graph) const
{
- return CloneBase<GatherLayer>(graph, GetName());
+ return CloneBase<GatherLayer>(graph, m_Param, GetName());
}
void GatherLayer::ValidateTensorShapesFromInputs(ShapeInferenceMethod shapeInferenceMethod)
@@ -44,11 +44,22 @@ void GatherLayer::ValidateTensorShapesFromInputs(ShapeInferenceMethod shapeInfer
std::vector<unsigned int> dimSizes;
- for (unsigned int i = 0; i < indicesDim; ++i)
+ unsigned int axis = static_cast<unsigned int>(m_Param.m_Axis);
+ if (m_Param.m_Axis < 0)
{
- dimSizes.push_back(indices.GetShape()[i]);
+ int32_t axis_aux = static_cast<int32_t>(paramsDim) + m_Param.m_Axis;
+ axis = static_cast<unsigned int> (axis_aux);
}
- for (unsigned int i = 1; i < paramsDim; ++i)
+
+ for (unsigned int i = 0; i < axis; ++i)
+ {
+ dimSizes.push_back(params.GetShape()[i]);
+ }
+ for (unsigned int i = axis; i < indicesDim + axis; ++i)
+ {
+ dimSizes.push_back(indices.GetShape()[i - axis]);
+ }
+ for (unsigned int i = 1 + axis; i < paramsDim; ++i)
{
dimSizes.push_back(params.GetShape()[i]);
}
@@ -63,7 +74,7 @@ void GatherLayer::ValidateTensorShapesFromInputs(ShapeInferenceMethod shapeInfer
void GatherLayer::Accept(ILayerVisitor& visitor) const
{
- visitor.VisitGatherLayer(this, GetName());
+ visitor.VisitGatherLayer(this, GetParameters(), GetName());
}
} // namespace armnn
diff --git a/src/armnn/layers/GatherLayer.hpp b/src/armnn/layers/GatherLayer.hpp
index 598ca44dc4..d8737adbee 100644
--- a/src/armnn/layers/GatherLayer.hpp
+++ b/src/armnn/layers/GatherLayer.hpp
@@ -1,17 +1,17 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
-#include "Layer.hpp"
+#include "LayerWithParameters.hpp"
namespace armnn
{
/// This layer represents a Gather operator.
-class GatherLayer : public Layer
+class GatherLayer : public LayerWithParameters<GatherDescriptor>
{
public:
/// Makes a workload for the Gather type.
@@ -24,7 +24,7 @@ public:
/// @param [in] graph The graph into which this layer is being cloned.
GatherLayer* Clone(Graph& graph) const override;
- /// Check if the input tensor shape(s)
+ /// Check if the input tensor shape(s).
/// will lead to a valid configuration of @ref GatherLayer.
/// @param [in] shapeInferenceMethod Indicates if output shape shall be overwritten or just validate.
void ValidateTensorShapesFromInputs(
@@ -34,8 +34,9 @@ public:
protected:
/// Constructor to create a GatherLayer.
+ /// @param [in] param GatherDescriptor to configure the stack operation.
/// @param [in] name Optional name for the layer.
- GatherLayer(const char* name);
+ GatherLayer(const GatherDescriptor& param, const char* name);
/// Default destructor
~GatherLayer() = default;