aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/layers/GatherLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/layers/GatherLayer.cpp')
-rw-r--r--src/armnn/layers/GatherLayer.cpp27
1 files changed, 19 insertions, 8 deletions
diff --git a/src/armnn/layers/GatherLayer.cpp b/src/armnn/layers/GatherLayer.cpp
index a99913073f..3e85d25dac 100644
--- a/src/armnn/layers/GatherLayer.cpp
+++ b/src/armnn/layers/GatherLayer.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -13,8 +13,8 @@
namespace armnn
{
-GatherLayer::GatherLayer(const char* name)
- : Layer(2, 1, LayerType::Gather, name)
+GatherLayer::GatherLayer(const GatherDescriptor& param, const char* name)
+ : LayerWithParameters(2, 1, LayerType::Gather, param, name)
{
}
@@ -26,7 +26,7 @@ std::unique_ptr<IWorkload> GatherLayer::CreateWorkload(const armnn::IWorkloadFac
GatherLayer* GatherLayer::Clone(Graph& graph) const
{
- return CloneBase<GatherLayer>(graph, GetName());
+ return CloneBase<GatherLayer>(graph, m_Param, GetName());
}
void GatherLayer::ValidateTensorShapesFromInputs(ShapeInferenceMethod shapeInferenceMethod)
@@ -44,11 +44,22 @@ void GatherLayer::ValidateTensorShapesFromInputs(ShapeInferenceMethod shapeInfer
std::vector<unsigned int> dimSizes;
- for (unsigned int i = 0; i < indicesDim; ++i)
+ unsigned int axis = static_cast<unsigned int>(m_Param.m_Axis);
+ if (m_Param.m_Axis < 0)
{
- dimSizes.push_back(indices.GetShape()[i]);
+ int32_t axis_aux = static_cast<int32_t>(paramsDim) + m_Param.m_Axis;
+ axis = static_cast<unsigned int> (axis_aux);
}
- for (unsigned int i = 1; i < paramsDim; ++i)
+
+ for (unsigned int i = 0; i < axis; ++i)
+ {
+ dimSizes.push_back(params.GetShape()[i]);
+ }
+ for (unsigned int i = axis; i < indicesDim + axis; ++i)
+ {
+ dimSizes.push_back(indices.GetShape()[i - axis]);
+ }
+ for (unsigned int i = 1 + axis; i < paramsDim; ++i)
{
dimSizes.push_back(params.GetShape()[i]);
}
@@ -63,7 +74,7 @@ void GatherLayer::ValidateTensorShapesFromInputs(ShapeInferenceMethod shapeInfer
void GatherLayer::Accept(ILayerVisitor& visitor) const
{
- visitor.VisitGatherLayer(this, GetName());
+ visitor.VisitGatherLayer(this, GetParameters(), GetName());
}
} // namespace armnn