aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTeresa Charlin <teresa.charlinreyes@arm.com>2020-06-29 16:27:03 +0100
committerTeresaARM <teresa.charlinreyes@arm.com>2020-07-01 08:26:47 +0000
commit526647333571169076f5e72c9fb18c71025bf7c0 (patch)
tree6dc559a7b0fae3705172b09a88fa552926652040
parentcbd2c230b7ce5f26e2ccccf36b7ad450f6e1ad09 (diff)
downloadarmnn-526647333571169076f5e72c9fb18c71025bf7c0.tar.gz
IVGCVSW-4903 Connect axis parameter in Gather from android to ACL.
!android-nn-driver:3302 Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com> Change-Id: Ifbc49acb5272f8a36719bb68676e44817190537d
-rw-r--r--include/armnn/Descriptors.hpp22
-rw-r--r--include/armnn/DescriptorsFwd.hpp3
-rw-r--r--include/armnn/ILayerSupport.hpp9
-rw-r--r--include/armnn/ILayerVisitor.hpp11
-rw-r--r--include/armnn/INetwork.hpp28
-rw-r--r--include/armnn/LayerVisitorBase.hpp7
-rw-r--r--src/armnn/LayerSupport.cpp17
-rw-r--r--src/armnn/Network.cpp11
-rw-r--r--src/armnn/Network.hpp6
-rw-r--r--src/armnn/layers/GatherLayer.cpp27
-rw-r--r--src/armnn/layers/GatherLayer.hpp11
-rw-r--r--src/armnn/test/OptimizerTests.cpp5
-rw-r--r--src/armnn/test/TestNameAndDescriptorLayerVisitor.cpp9
-rw-r--r--src/armnn/test/TestNameAndDescriptorLayerVisitor.hpp3
-rw-r--r--src/armnn/test/TestNameOnlyLayerVisitor.cpp3
-rw-r--r--src/armnn/test/TestNameOnlyLayerVisitor.hpp3
-rw-r--r--src/armnnDeserializer/Deserializer.cpp5
-rw-r--r--src/armnnDeserializer/test/DeserializeGather.cpp18
-rw-r--r--src/armnnSerializer/ArmnnSchema.fbs7
-rw-r--r--src/armnnSerializer/Serializer.cpp17
-rw-r--r--src/armnnSerializer/Serializer.hpp7
-rw-r--r--src/armnnSerializer/test/SerializerTests.cpp21
-rwxr-xr-xsrc/armnnTfParser/TfParser.cpp6
-rw-r--r--src/armnnTfParser/test/Gather.cpp26
-rw-r--r--src/backends/backendsCommon/LayerSupportBase.cpp445
-rw-r--r--src/backends/backendsCommon/LayerSupportBase.hpp9
-rw-r--r--src/backends/backendsCommon/WorkloadData.hpp2
-rw-r--r--src/backends/backendsCommon/WorkloadFactory.cpp5
-rw-r--r--src/backends/backendsCommon/test/GatherEndToEndTestImpl.hpp5
-rw-r--r--src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp4
-rw-r--r--src/backends/cl/ClLayerSupport.cpp4
-rw-r--r--src/backends/cl/ClLayerSupport.hpp3
-rw-r--r--src/backends/cl/workloads/ClGatherWorkload.cpp9
-rw-r--r--src/backends/cl/workloads/ClGatherWorkload.hpp3
-rw-r--r--src/backends/neon/NeonLayerSupport.cpp6
-rw-r--r--src/backends/neon/NeonLayerSupport.hpp1
-rw-r--r--src/backends/neon/workloads/NeonGatherWorkload.cpp9
-rw-r--r--src/backends/neon/workloads/NeonGatherWorkload.hpp5
-rw-r--r--src/backends/reference/RefLayerSupport.cpp8
-rw-r--r--src/backends/reference/RefLayerSupport.hpp4
-rw-r--r--src/backends/reference/workloads/Gather.cpp7
-rw-r--r--src/backends/reference/workloads/Gather.hpp5
-rw-r--r--src/backends/reference/workloads/RefGatherWorkload.cpp4
43 files changed, 503 insertions, 317 deletions
diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp
index 653e64701a..60aa219638 100644
--- a/include/armnn/Descriptors.hpp
+++ b/include/armnn/Descriptors.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
@@ -730,6 +730,26 @@ struct FillDescriptor
float m_Value;
};
+/// A GatherDescriptor for the GatherLayer.
+struct GatherDescriptor
+{
+ GatherDescriptor()
+ : m_Axis(0)
+ {}
+
+ GatherDescriptor(int32_t axis)
+ : m_Axis(axis)
+ {}
+
+ bool operator ==(const GatherDescriptor& rhs) const
+ {
+ return m_Axis == rhs.m_Axis;
+ }
+
+ /// The axis in params to gather indices from
+ int32_t m_Axis;
+};
+
/// A ResizeBilinearDescriptor for the ResizeBilinearLayer.
struct ResizeBilinearDescriptor
{
diff --git a/include/armnn/DescriptorsFwd.hpp b/include/armnn/DescriptorsFwd.hpp
index e31fb96aec..fba976c788 100644
--- a/include/armnn/DescriptorsFwd.hpp
+++ b/include/armnn/DescriptorsFwd.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -20,6 +20,7 @@ struct ElementwiseUnaryDescriptor;
struct FakeQuantizationDescriptor;
struct FillDescriptor;
struct FullyConnectedDescriptor;
+struct GatherDescriptor;
struct InstanceNormalizationDescriptor;
struct L2NormalizationDescriptor;
struct LstmDescriptor;
diff --git a/include/armnn/ILayerSupport.hpp b/include/armnn/ILayerSupport.hpp
index 33389eb25f..889b811903 100644
--- a/include/armnn/ILayerSupport.hpp
+++ b/include/armnn/ILayerSupport.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
@@ -173,11 +173,18 @@ public:
const FullyConnectedDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
+ ARMNN_DEPRECATED_MSG("Use IsGatherSupported with descriptor instead")
virtual bool IsGatherSupported(const TensorInfo& input0,
const TensorInfo& input1,
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
+ virtual bool IsGatherSupported(const TensorInfo& input0,
+ const TensorInfo& input1,
+ const TensorInfo& output,
+ const GatherDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
+
ARMNN_DEPRECATED_MSG("Use IsComparisonSupported instead")
virtual bool IsGreaterSupported(const TensorInfo& input0,
const TensorInfo& input1,
diff --git a/include/armnn/ILayerVisitor.hpp b/include/armnn/ILayerVisitor.hpp
index aa5bdba33c..9b3998db9a 100644
--- a/include/armnn/ILayerVisitor.hpp
+++ b/include/armnn/ILayerVisitor.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
@@ -214,9 +214,18 @@ public:
/// Function a Gather layer should call back to when its Accept(ILayerVisitor&) function is invoked.
/// @param layer - pointer to the layer which is calling back to this visit function.
/// @param name - Optional name for the layer.
+ ARMNN_DEPRECATED_MSG("Use VisitGatherLayer with descriptor instead")
virtual void VisitGatherLayer(const IConnectableLayer* layer,
const char* name = nullptr) = 0;
+ /// Function a Gather layer should call back to when its Accept(ILayerVisitor&) function is invoked.
+ /// @param layer - pointer to the layer which is calling back to this visit function.
+ /// @param gatherDescriptor - Parameters for the gather operation.
+ /// @param name - Optional name for the layer.
+ virtual void VisitGatherLayer(const IConnectableLayer* layer,
+ const GatherDescriptor& gatherDescriptor,
+ const char* name = nullptr) = 0;
+
/// Function a Greater layer should call back to when its Accept(ILayerVisitor&) function is invoked.
/// @param layer - pointer to the layer which is calling back to this visit function.
/// @param name - Optional name for the layer.
diff --git a/include/armnn/INetwork.hpp b/include/armnn/INetwork.hpp
index 49cd582e67..8e7a4437c8 100644
--- a/include/armnn/INetwork.hpp
+++ b/include/armnn/INetwork.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
@@ -128,7 +128,7 @@ public:
/// Add a Comparison layer to the network.
/// @param name - Optional name for the layer.
/// @param desc - Descriptor for the comparison operation.
- /// @ return - Interface for configuring the layer.
+ /// @return - Interface for configuring the layer.
virtual IConnectableLayer* AddComparisonLayer(const ComparisonDescriptor& comparisonDescriptor,
const char* name = nullptr) = 0;
@@ -324,7 +324,7 @@ public:
/// Add absolute layer to the network.
/// @param name - Optional name for the layer.
- /// @ return - Interface for configuring the layer.
+ /// @return - Interface for configuring the layer.
ARMNN_DEPRECATED_MSG("Use AddElementwiseUnaryLayer instead")
virtual IConnectableLayer* AddAbsLayer(const char* name = nullptr) = 0;
@@ -453,13 +453,13 @@ public:
/// Add a Maximum layer to the network.
/// @param name - Optional name for the layer.
- /// @ return - Interface for configuring the layer.
+ /// @return - Interface for configuring the layer.
virtual IConnectableLayer* AddMaximumLayer(const char* name = nullptr) = 0;
/// Add a Mean layer to the network.
/// @param meanDescriptor - Parameters for the mean operation.
/// @param name - Optional name for the layer.
- /// @ return - Interface for configuring the layer.
+ /// @return - Interface for configuring the layer.
virtual IConnectableLayer* AddMeanLayer(const MeanDescriptor& meanDescriptor, const char* name = nullptr) = 0;
/// Adds a fully pad layer to the network.
@@ -485,32 +485,40 @@ public:
/// Add a Minimum layer to the network.
/// @param name - Optional name for the layer.
- /// @ return - Interface for configuring the layer.
+ /// @return - Interface for configuring the layer.
virtual IConnectableLayer* AddMinimumLayer(const char* name = nullptr) = 0;
/// Add a Greater layer to the network.
/// @param name - Optional name for the layer.
- /// @ return - Interface for configuring the layer.
+ /// @return - Interface for configuring the layer.
ARMNN_DEPRECATED_MSG("Use AddComparisonLayer instead")
virtual IConnectableLayer* AddGreaterLayer(const char* name = nullptr) = 0;
/// Add a Equal layer to the network.
/// @param name - Optional name for the layer.
- /// @ return - Interface for configuring the layer.
+ /// @return - Interface for configuring the layer.
ARMNN_DEPRECATED_MSG("Use AddComparisonLayer instead")
virtual IConnectableLayer* AddEqualLayer(const char* name = nullptr) = 0;
/// Add Reciprocal of square root layer to the network.
/// @param name - Optional name for the layer.
- /// @ return - Interface for configuring the layer.
+ /// @return - Interface for configuring the layer.
ARMNN_DEPRECATED_MSG("Use AddElementwiseUnaryLayer instead")
virtual IConnectableLayer* AddRsqrtLayer(const char* name = nullptr) = 0;
/// Add Gather layer to the network.
/// @param name - Optional name for the layer.
- /// @ return - Interface for configuring the layer.
+ /// @return - Interface for configuring the layer.
+ ARMNN_DEPRECATED_MSG("Use AddGatherLayer with descriptor instead")
virtual IConnectableLayer* AddGatherLayer(const char* name = nullptr) = 0;
+ /// Add Gather layer to the network.
+ /// @param descriptor - Description of the gather layer.
+ /// @param name - Optional name for the layer.
+ /// @return - Interface for configuring the layer.
+ virtual IConnectableLayer* AddGatherLayer(const GatherDescriptor& descriptor,
+ const char* name = nullptr) = 0;
+
/// Adds a switch layer to the network.
/// @param name - Optional name for the layer.
/// @return - Interface for configuring the layer.
diff --git a/include/armnn/LayerVisitorBase.hpp b/include/armnn/LayerVisitorBase.hpp
index 0dc5e545e3..93ba7fe287 100644
--- a/include/armnn/LayerVisitorBase.hpp
+++ b/include/armnn/LayerVisitorBase.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -114,9 +114,14 @@ public:
const Optional<ConstTensor>&,
const char*) override { DefaultPolicy::Apply(__func__); }
+ ARMNN_DEPRECATED_MSG("Use VisitGatherLayer with descriptor instead")
void VisitGatherLayer(const IConnectableLayer*,
const char*) override { DefaultPolicy::Apply(__func__); }
+ void VisitGatherLayer(const IConnectableLayer*,
+ const GatherDescriptor&,
+ const char*) override { DefaultPolicy::Apply(__func__); }
+
void VisitGreaterLayer(const IConnectableLayer*,
const char*) override { DefaultPolicy::Apply(__func__); }
diff --git a/src/armnn/LayerSupport.cpp b/src/armnn/LayerSupport.cpp
index fe5b542867..197e1afe18 100644
--- a/src/armnn/LayerSupport.cpp
+++ b/src/armnn/LayerSupport.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -305,6 +305,7 @@ bool IsFullyConnectedSupported(const BackendId& backend,
FORWARD_LAYER_SUPPORT_FUNC(backend, IsFullyConnectedSupported, input, output, weights, biases, descriptor);
}
+ARMNN_DEPRECATED_MSG("Use IsGatherSupported with descriptor instead")
bool IsGatherSupported(const BackendId& backend,
const TensorInfo& input0,
const TensorInfo& input1,
@@ -312,7 +313,19 @@ bool IsGatherSupported(const BackendId& backend,
char* reasonIfUnsupported,
size_t reasonIfUnsupportedMaxLength)
{
- FORWARD_LAYER_SUPPORT_FUNC(backend, IsGatherSupported, input0, input1, output);
+ const GatherDescriptor descriptor{};
+ FORWARD_LAYER_SUPPORT_FUNC(backend, IsGatherSupported, input0, input1, output, descriptor);
+}
+
+bool IsGatherSupported(const BackendId& backend,
+ const TensorInfo& input0,
+ const TensorInfo& input1,
+ const TensorInfo& output,
+ const GatherDescriptor& descriptor,
+ char* reasonIfUnsupported,
+ size_t reasonIfUnsupportedMaxLength)
+{
+ FORWARD_LAYER_SUPPORT_FUNC(backend, IsGatherSupported, input0, input1, output, descriptor);
}
bool IsGreaterSupported(const BackendId& backend,
diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp
index 180c00b0c3..6c7314feb2 100644
--- a/src/armnn/Network.cpp
+++ b/src/armnn/Network.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -1724,7 +1724,14 @@ IConnectableLayer* Network::AddRsqrtLayer(const char * name)
IConnectableLayer* Network::AddGatherLayer(const char* name)
{
- return m_Graph->AddLayer<GatherLayer>(name);
+ GatherDescriptor gatherDescriptor{};
+ return AddGatherLayer(gatherDescriptor, name);
+}
+
+IConnectableLayer* Network::AddGatherLayer(const GatherDescriptor& gatherDescriptor,
+ const char* name)
+{
+ return m_Graph->AddLayer<GatherLayer>(gatherDescriptor, name);
}
IConnectableLayer* Network::AddMergeLayer(const char* name)
diff --git a/src/armnn/Network.hpp b/src/armnn/Network.hpp
index cac2b3a0e6..53bf3115f1 100644
--- a/src/armnn/Network.hpp
+++ b/src/armnn/Network.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
@@ -116,8 +116,12 @@ public:
const ConstTensor& biases,
const char* name = nullptr) override;
+ ARMNN_DEPRECATED_MSG("This AddGatherLayer overload is deprecated")
IConnectableLayer* AddGatherLayer(const char* name = nullptr) override;
+ IConnectableLayer* AddGatherLayer(const GatherDescriptor& gatherDescriptor,
+ const char* name = nullptr) override;
+
IConnectableLayer* AddPermuteLayer(const PermuteDescriptor& permuteDescriptor,
const char* name = nullptr) override;
diff --git a/src/armnn/layers/GatherLayer.cpp b/src/armnn/layers/GatherLayer.cpp
index a99913073f..3e85d25dac 100644
--- a/src/armnn/layers/GatherLayer.cpp
+++ b/src/armnn/layers/GatherLayer.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -13,8 +13,8 @@
namespace armnn
{
-GatherLayer::GatherLayer(const char* name)
- : Layer(2, 1, LayerType::Gather, name)
+GatherLayer::GatherLayer(const GatherDescriptor& param, const char* name)
+ : LayerWithParameters(2, 1, LayerType::Gather, param, name)
{
}
@@ -26,7 +26,7 @@ std::unique_ptr<IWorkload> GatherLayer::CreateWorkload(const armnn::IWorkloadFac
GatherLayer* GatherLayer::Clone(Graph& graph) const
{
- return CloneBase<GatherLayer>(graph, GetName());
+ return CloneBase<GatherLayer>(graph, m_Param, GetName());
}
void GatherLayer::ValidateTensorShapesFromInputs(ShapeInferenceMethod shapeInferenceMethod)
@@ -44,11 +44,22 @@ void GatherLayer::ValidateTensorShapesFromInputs(ShapeInferenceMethod shapeInfer
std::vector<unsigned int> dimSizes;
- for (unsigned int i = 0; i < indicesDim; ++i)
+ unsigned int axis = static_cast<unsigned int>(m_Param.m_Axis);
+ if (m_Param.m_Axis < 0)
{
- dimSizes.push_back(indices.GetShape()[i]);
+ int32_t axis_aux = static_cast<int32_t>(paramsDim) + m_Param.m_Axis;
+ axis = static_cast<unsigned int> (axis_aux);
}
- for (unsigned int i = 1; i < paramsDim; ++i)
+
+ for (unsigned int i = 0; i < axis; ++i)
+ {
+ dimSizes.push_back(params.GetShape()[i]);
+ }
+ for (unsigned int i = axis; i < indicesDim + axis; ++i)
+ {
+ dimSizes.push_back(indices.GetShape()[i - axis]);
+ }
+ for (unsigned int i = 1 + axis; i < paramsDim; ++i)
{
dimSizes.push_back(params.GetShape()[i]);
}
@@ -63,7 +74,7 @@ void GatherLayer::ValidateTensorShapesFromInputs(ShapeInferenceMethod shapeInfer
void GatherLayer::Accept(ILayerVisitor& visitor) const
{
- visitor.VisitGatherLayer(this, GetName());
+ visitor.VisitGatherLayer(this, GetParameters(), GetName());
}
} // namespace armnn
diff --git a/src/armnn/layers/GatherLayer.hpp b/src/armnn/layers/GatherLayer.hpp
index 598ca44dc4..d8737adbee 100644
--- a/src/armnn/layers/GatherLayer.hpp
+++ b/src/armnn/layers/GatherLayer.hpp
@@ -1,17 +1,17 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
-#include "Layer.hpp"
+#include "LayerWithParameters.hpp"
namespace armnn
{
/// This layer represents a Gather operator.
-class GatherLayer : public Layer
+class GatherLayer : public LayerWithParameters<GatherDescriptor>
{
public:
/// Makes a workload for the Gather type.
@@ -24,7 +24,7 @@ public:
/// @param [in] graph The graph into which this layer is being cloned.
GatherLayer* Clone(Graph& graph) const override;
- /// Check if the input tensor shape(s)
+ /// Check if the input tensor shape(s).
/// will lead to a valid configuration of @ref GatherLayer.
/// @param [in] shapeInferenceMethod Indicates if output shape shall be overwritten or just validate.
void ValidateTensorShapesFromInputs(
@@ -34,8 +34,9 @@ public:
protected:
/// Constructor to create a GatherLayer.
+ /// @param [in] param GatherDescriptor to configure the stack operation.
/// @param [in] name Optional name for the layer.
- GatherLayer(const char* name);
+ GatherLayer(const GatherDescriptor& param, const char* name);
/// Default destructor
~GatherLayer() = default;
diff --git a/src/armnn/test/OptimizerTests.cpp b/src/armnn/test/OptimizerTests.cpp
index 65ea91d402..3af50ecf3a 100644
--- a/src/armnn/test/OptimizerTests.cpp
+++ b/src/armnn/test/OptimizerTests.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -458,7 +458,8 @@ void CreateGatherGraph(Graph& graph, const armnn::TensorInfo& paramsInfo, const
Layer* input1 = graph.AddLayer<InputLayer>(1, "indices");
input1->GetOutputSlot().SetTensorInfo(indicesInfo);
- GatherLayer* layer = graph.AddLayer<GatherLayer>("gather");
+ GatherDescriptor descriptor;
+ GatherLayer* layer = graph.AddLayer<GatherLayer>(descriptor, "gather");
layer->GetOutputSlot().SetTensorInfo(outputInfo);
Layer* output = graph.AddLayer<OutputLayer>(0, "output");
diff --git a/src/armnn/test/TestNameAndDescriptorLayerVisitor.cpp b/src/armnn/test/TestNameAndDescriptorLayerVisitor.cpp
index e07e497ab8..6ab9f9e1e4 100644
--- a/src/armnn/test/TestNameAndDescriptorLayerVisitor.cpp
+++ b/src/armnn/test/TestNameAndDescriptorLayerVisitor.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#include "TestNameAndDescriptorLayerVisitor.hpp"
@@ -98,6 +98,12 @@ armnn::FillDescriptor GetDescriptor<armnn::FillDescriptor>()
}
template<>
+armnn::GatherDescriptor GetDescriptor<armnn::GatherDescriptor>()
+{
+ return armnn::GatherDescriptor();
+}
+
+template<>
armnn::InstanceNormalizationDescriptor GetDescriptor<armnn::InstanceNormalizationDescriptor>()
{
armnn::InstanceNormalizationDescriptor descriptor;
@@ -271,6 +277,7 @@ TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(Comparison)
TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(Concat)
TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(ElementwiseUnary)
TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(Fill)
+TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(Gather)
TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(InstanceNormalization)
TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(L2Normalization)
TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(LogSoftmax)
diff --git a/src/armnn/test/TestNameAndDescriptorLayerVisitor.hpp b/src/armnn/test/TestNameAndDescriptorLayerVisitor.hpp
index c8df505db0..df0e055157 100644
--- a/src/armnn/test/TestNameAndDescriptorLayerVisitor.hpp
+++ b/src/armnn/test/TestNameAndDescriptorLayerVisitor.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
@@ -50,6 +50,7 @@ DECLARE_TEST_NAME_AND_DESCRIPTOR_LAYER_VISITOR_CLASS(Concat)
DECLARE_TEST_NAME_AND_DESCRIPTOR_LAYER_VISITOR_CLASS(DepthToSpace)
DECLARE_TEST_NAME_AND_DESCRIPTOR_LAYER_VISITOR_CLASS(ElementwiseUnary)
DECLARE_TEST_NAME_AND_DESCRIPTOR_LAYER_VISITOR_CLASS(Fill)
+DECLARE_TEST_NAME_AND_DESCRIPTOR_LAYER_VISITOR_CLASS(Gather)
DECLARE_TEST_NAME_AND_DESCRIPTOR_LAYER_VISITOR_CLASS(InstanceNormalization)
DECLARE_TEST_NAME_AND_DESCRIPTOR_LAYER_VISITOR_CLASS(L2Normalization)
DECLARE_TEST_NAME_AND_DESCRIPTOR_LAYER_VISITOR_CLASS(LogSoftmax)
diff --git a/src/armnn/test/TestNameOnlyLayerVisitor.cpp b/src/armnn/test/TestNameOnlyLayerVisitor.cpp
index 0653b39e58..945afa8ff5 100644
--- a/src/armnn/test/TestNameOnlyLayerVisitor.cpp
+++ b/src/armnn/test/TestNameOnlyLayerVisitor.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -42,7 +42,6 @@ TEST_SUITE_NAME_ONLY_LAYER_VISITOR(Addition)
TEST_SUITE_NAME_ONLY_LAYER_VISITOR(Dequantize)
TEST_SUITE_NAME_ONLY_LAYER_VISITOR(Division)
TEST_SUITE_NAME_ONLY_LAYER_VISITOR(Floor)
-TEST_SUITE_NAME_ONLY_LAYER_VISITOR(Gather)
TEST_SUITE_NAME_ONLY_LAYER_VISITOR(Maximum)
TEST_SUITE_NAME_ONLY_LAYER_VISITOR(Merge)
TEST_SUITE_NAME_ONLY_LAYER_VISITOR(Minimum)
diff --git a/src/armnn/test/TestNameOnlyLayerVisitor.hpp b/src/armnn/test/TestNameOnlyLayerVisitor.hpp
index 84dfdd6539..0e1ea8eac7 100644
--- a/src/armnn/test/TestNameOnlyLayerVisitor.hpp
+++ b/src/armnn/test/TestNameOnlyLayerVisitor.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
@@ -29,7 +29,6 @@ DECLARE_TEST_NAME_ONLY_LAYER_VISITOR_CLASS(Addition)
DECLARE_TEST_NAME_ONLY_LAYER_VISITOR_CLASS(Dequantize)
DECLARE_TEST_NAME_ONLY_LAYER_VISITOR_CLASS(Division)
DECLARE_TEST_NAME_ONLY_LAYER_VISITOR_CLASS(Floor)
-DECLARE_TEST_NAME_ONLY_LAYER_VISITOR_CLASS(Gather)
DECLARE_TEST_NAME_ONLY_LAYER_VISITOR_CLASS(Maximum)
DECLARE_TEST_NAME_ONLY_LAYER_VISITOR_CLASS(Merge)
DECLARE_TEST_NAME_ONLY_LAYER_VISITOR_CLASS(Minimum)
diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp
index f59c757f4b..31fae2af86 100644
--- a/src/armnnDeserializer/Deserializer.cpp
+++ b/src/armnnDeserializer/Deserializer.cpp
@@ -2427,8 +2427,11 @@ void Deserializer::ParseGather(GraphPtr graph, unsigned int layerIndex)
Deserializer::TensorRawPtrVector outputs = GetOutputs(graph, layerIndex);
CHECK_VALID_SIZE(outputs.size(), 1);
+ armnn::GatherDescriptor descriptor;
+ descriptor.m_Axis = graph->layers()->Get(layerIndex)->layer_as_GatherLayer()->descriptor()->axis();
+
auto layerName = GetLayerName(graph, layerIndex);
- IConnectableLayer* layer = m_Network->AddGatherLayer(layerName.c_str());
+ IConnectableLayer* layer = m_Network->AddGatherLayer(descriptor, layerName.c_str());
armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
diff --git a/src/armnnDeserializer/test/DeserializeGather.cpp b/src/armnnDeserializer/test/DeserializeGather.cpp
index 3fdcf51aed..0f75db4abb 100644
--- a/src/armnnDeserializer/test/DeserializeGather.cpp
+++ b/src/armnnDeserializer/test/DeserializeGather.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -14,10 +14,11 @@ BOOST_AUTO_TEST_SUITE(Deserializer)
struct GatherFixture : public ParserFlatbuffersSerializeFixture
{
- explicit GatherFixture(const std::string &inputShape,
- const std::string &indicesShape,
- const std::string &input1Content,
- const std::string &outputShape,
+ explicit GatherFixture(const std::string& inputShape,
+ const std::string& indicesShape,
+ const std::string& input1Content,
+ const std::string& outputShape,
+ const std::string& axis,
const std::string dataType,
const std::string constDataType)
{
@@ -94,7 +95,10 @@ struct GatherFixture : public ParserFlatbuffersSerializeFixture
dimensions: )" + outputShape + R"(,
dataType: )" + dataType + R"(
- }}]}
+ }}]},
+ descriptor: {
+ axis: )" + axis + R"(
+ }
}},
{
layer_type: "OutputLayer",
@@ -127,7 +131,7 @@ struct GatherFixture : public ParserFlatbuffersSerializeFixture
struct SimpleGatherFixtureFloat32 : GatherFixture
{
SimpleGatherFixtureFloat32() : GatherFixture("[ 3, 2, 3 ]", "[ 2, 3 ]", "[1, 2, 1, 2, 1, 0]",
- "[ 2, 3, 2, 3 ]", "Float32", "IntData") {}
+ "[ 2, 3, 2, 3 ]", "0", "Float32", "IntData") {}
};
BOOST_FIXTURE_TEST_CASE(GatherFloat32, SimpleGatherFixtureFloat32)
diff --git a/src/armnnSerializer/ArmnnSchema.fbs b/src/armnnSerializer/ArmnnSchema.fbs
index 18415ce785..6a388db699 100644
--- a/src/armnnSerializer/ArmnnSchema.fbs
+++ b/src/armnnSerializer/ArmnnSchema.fbs
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -312,6 +312,11 @@ table FullyConnectedDescriptor {
table GatherLayer {
base:LayerBase;
+ descriptor:GatherDescriptor;
+}
+
+table GatherDescriptor {
+ axis:int = 0;
}
/// @deprecated Use ComparisonLayer instead
diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp
index 17076c62ab..6555a34be7 100644
--- a/src/armnnSerializer/Serializer.cpp
+++ b/src/armnnSerializer/Serializer.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -493,12 +493,23 @@ void SerializerVisitor::VisitFloorLayer(const armnn::IConnectableLayer *layer, c
CreateAnyLayer(flatBufferFloorLayer.o, serializer::Layer::Layer_FloorLayer);
}
-void SerializerVisitor::VisitGatherLayer(const armnn::IConnectableLayer* layer, const char* name)
+void SerializerVisitor::VisitGatherLayer(const armnn::IConnectableLayer* layer,
+ const char* name)
+{
+ armnn::GatherDescriptor gatherDescriptor{};
+ VisitGatherLayer(layer, gatherDescriptor, name);
+}
+
+void SerializerVisitor::VisitGatherLayer(const armnn::IConnectableLayer* layer,
+ const armnn::GatherDescriptor& gatherDescriptor,
+ const char* name)
{
IgnoreUnused(name);
+ auto fbGatherDescriptor = CreateGatherDescriptor(m_flatBufferBuilder,
+ gatherDescriptor.m_Axis);
auto fbGatherBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Gather);
- auto flatBufferLayer = serializer::CreateGatherLayer(m_flatBufferBuilder, fbGatherBaseLayer);
+ auto flatBufferLayer = serializer::CreateGatherLayer(m_flatBufferBuilder, fbGatherBaseLayer, fbGatherDescriptor);
CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_GatherLayer);
}
diff --git a/src/armnnSerializer/Serializer.hpp b/src/armnnSerializer/Serializer.hpp
index 65d87b7cf7..e4104dda8e 100644
--- a/src/armnnSerializer/Serializer.hpp
+++ b/src/armnnSerializer/Serializer.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
@@ -134,9 +134,14 @@ public:
const armnn::Optional<armnn::ConstTensor>& biases,
const char* name = nullptr) override;
+ ARMNN_DEPRECATED_MSG("Use VisitGatherLayer with descriptor instead")
void VisitGatherLayer(const armnn::IConnectableLayer* layer,
const char* name = nullptr) override;
+ void VisitGatherLayer(const armnn::IConnectableLayer* layer,
+ const armnn::GatherDescriptor& gatherDescriptor,
+ const char* name = nullptr) override;
+
ARMNN_DEPRECATED_MSG("Use VisitComparisonLayer instead")
void VisitGreaterLayer(const armnn::IConnectableLayer* layer,
const char* name = nullptr) override;
diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp
index fa43e09647..088282a18a 100644
--- a/src/armnnSerializer/test/SerializerTests.cpp
+++ b/src/armnnSerializer/test/SerializerTests.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -1333,17 +1333,22 @@ BOOST_AUTO_TEST_CASE(SerializeFullyConnected)
BOOST_AUTO_TEST_CASE(SerializeGather)
{
- class GatherLayerVerifier : public LayerVerifierBase
+ using GatherDescriptor = armnn::GatherDescriptor;
+ class GatherLayerVerifier : public LayerVerifierBaseWithDescriptor<GatherDescriptor>
{
public:
GatherLayerVerifier(const std::string& layerName,
const std::vector<armnn::TensorInfo>& inputInfos,
- const std::vector<armnn::TensorInfo>& outputInfos)
- : LayerVerifierBase(layerName, inputInfos, outputInfos) {}
+ const std::vector<armnn::TensorInfo>& outputInfos,
+ const GatherDescriptor& descriptor)
+ : LayerVerifierBaseWithDescriptor<GatherDescriptor>(layerName, inputInfos, outputInfos, descriptor) {}
- void VisitGatherLayer(const armnn::IConnectableLayer* layer, const char *name) override
+ void VisitGatherLayer(const armnn::IConnectableLayer* layer,
+ const GatherDescriptor& descriptor,
+ const char *name) override
{
VerifyNameAndConnections(layer, name);
+ BOOST_CHECK(descriptor.m_Axis == m_Descriptor.m_Axis);
}
void VisitConstantLayer(const armnn::IConnectableLayer*,
@@ -1355,6 +1360,8 @@ BOOST_AUTO_TEST_CASE(SerializeGather)
armnn::TensorInfo paramsInfo({ 8 }, armnn::DataType::QAsymmU8);
armnn::TensorInfo outputInfo({ 3 }, armnn::DataType::QAsymmU8);
const armnn::TensorInfo indicesInfo({ 3 }, armnn::DataType::Signed32);
+ GatherDescriptor descriptor;
+ descriptor.m_Axis = 1;
paramsInfo.SetQuantizationScale(1.0f);
paramsInfo.SetQuantizationOffset(0);
@@ -1367,7 +1374,7 @@ BOOST_AUTO_TEST_CASE(SerializeGather)
armnn::IConnectableLayer *const inputLayer = network->AddInputLayer(0);
armnn::IConnectableLayer *const constantLayer =
network->AddConstantLayer(armnn::ConstTensor(indicesInfo, indicesData));
- armnn::IConnectableLayer *const gatherLayer = network->AddGatherLayer(layerName.c_str());
+ armnn::IConnectableLayer *const gatherLayer = network->AddGatherLayer(descriptor, layerName.c_str());
armnn::IConnectableLayer *const outputLayer = network->AddOutputLayer(0);
inputLayer->GetOutputSlot(0).Connect(gatherLayer->GetInputSlot(0));
@@ -1381,7 +1388,7 @@ BOOST_AUTO_TEST_CASE(SerializeGather)
armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
BOOST_CHECK(deserializedNetwork);
- GatherLayerVerifier verifier(layerName, {paramsInfo, indicesInfo}, {outputInfo});
+ GatherLayerVerifier verifier(layerName, {paramsInfo, indicesInfo}, {outputInfo}, descriptor);
deserializedNetwork->Accept(verifier);
}
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp
index 7a7c5a4375..38202fcf94 100755
--- a/src/armnnTfParser/TfParser.cpp
+++ b/src/armnnTfParser/TfParser.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -1853,6 +1853,8 @@ ParsedTfOperationPtr TfParser::ParseGather(const tensorflow::NodeDef& nodeDef,
std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, 2);
IOutputSlot& params = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index);
IOutputSlot& indices = inputs[1].m_IndexedValue->ResolveArmnnOutputSlot(inputs[1].m_Index);
+ GatherDescriptor descriptor;
+ descriptor.m_Axis = ReadMandatoryNodeInt32Attribute(nodeDef, "axis");
// Infer shape of output tensor
unsigned int paramsDim = params.GetTensorInfo().GetNumDimensions();
@@ -1874,7 +1876,7 @@ ParsedTfOperationPtr TfParser::ParseGather(const tensorflow::NodeDef& nodeDef,
const TensorInfo inferredOutputInfo(inferredShape, params.GetTensorInfo().GetDataType());
- IConnectableLayer* const layer = m_Network->AddGatherLayer(nodeDef.name().c_str());
+ IConnectableLayer* const layer = m_Network->AddGatherLayer(descriptor, nodeDef.name().c_str());
layer->GetOutputSlot(0).SetTensorInfo(inferredOutputInfo);
params.Connect(layer->GetInputSlot(0));
diff --git a/src/armnnTfParser/test/Gather.cpp b/src/armnnTfParser/test/Gather.cpp
index 8c4b891141..ab5fb7104d 100644
--- a/src/armnnTfParser/test/Gather.cpp
+++ b/src/armnnTfParser/test/Gather.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -38,7 +38,8 @@ struct GatherFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::I
const armnn::TensorShape& inputShape1,
const std::vector<int>& input1Content,
const std::vector<int>& input0Dims,
- const std::vector<int>& input1Dims)
+ const std::vector<int>& input1Dims,
+ int axis = 0)
{
m_Prototext = R"(
node {
@@ -56,6 +57,7 @@ node {
shape {
)";
dimsHelper(input0Dims, m_Prototext);
+
m_Prototext.append(R"(
}
}
@@ -78,6 +80,7 @@ node {
tensor_shape {
)");
dimsHelper(input1Dims, m_Prototext);
+
m_Prototext.append(R"(
}
tensor_content: ")");
@@ -104,8 +107,18 @@ node {
type: DT_FLOAT
}
}
+ attr {
+ key: "axis"
+ value {
+ i: )");
+ m_Prototext += std::to_string(axis);
+
+ m_Prototext.append(R"(
+ }
+ }
}
)");
+
Setup({ { "input0", inputShape0 },
{ "input1", inputShape1 } },
{ "output" });
@@ -121,7 +134,8 @@ struct GatherFixture1DParams1DIndices : public GatherFixture
{ 4, 0, 0, 0 },
{ 0, 2, 1, 3 },
{ 4 },
- { 4 }) {}
+ { 4 },
+ 0) {}
};
struct GatherFixture1DParamsMultiDimIndices : public GatherFixture
@@ -131,7 +145,8 @@ struct GatherFixture1DParamsMultiDimIndices : public GatherFixture
{ 2, 2, 1, 1 },
{ 0, 1, 1, 3 },
{ 4 },
- { 2, 2 }) {}
+ { 2, 2 },
+ 0) {}
};
struct GatherFixtureMultiDimParamMultiDimIndices : public GatherFixture
@@ -141,7 +156,8 @@ struct GatherFixtureMultiDimParamMultiDimIndices : public GatherFixture
{ 2, 1, 4 },
{ 1, 3, 0, 2 },
{ 5, 2 },
- { 2, 2 }) {}
+ { 2, 2 },
+ 0) {}
};
BOOST_FIXTURE_TEST_CASE(ParseGather1DParams1DIndices, GatherFixture1DParams1DIndices)
diff --git a/src/backends/backendsCommon/LayerSupportBase.cpp b/src/backends/backendsCommon/LayerSupportBase.cpp
index e509a7b929..52e615a2d9 100644
--- a/src/backends/backendsCommon/LayerSupportBase.cpp
+++ b/src/backends/backendsCommon/LayerSupportBase.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -37,177 +37,177 @@ bool DefaultLayerSupport(const char* func,
namespace armnn
{
-bool LayerSupportBase::IsAbsSupported(const TensorInfo& /*input*/,
- const TensorInfo& /*output*/,
+bool LayerSupportBase::IsAbsSupported(const TensorInfo&, // input
+ const TensorInfo&, // output
Optional<std::string &> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsActivationSupported(const TensorInfo& /*input*/,
- const TensorInfo& /*output*/,
- const ActivationDescriptor& /*descriptor*/,
+bool LayerSupportBase::IsActivationSupported(const TensorInfo&, // input
+ const TensorInfo&, //output
+ const ActivationDescriptor&, // descriptor
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsAdditionSupported(const TensorInfo& /*input0*/,
- const TensorInfo& /*input1*/,
- const TensorInfo& /*output*/,
+bool LayerSupportBase::IsAdditionSupported(const TensorInfo&, // input0
+ const TensorInfo&, // input1
+ const TensorInfo&, // output
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsArgMinMaxSupported(const armnn::TensorInfo &/*input*/,
- const armnn::TensorInfo &/*output*/,
- const armnn::ArgMinMaxDescriptor& /*descriptor*/,
+bool LayerSupportBase::IsArgMinMaxSupported(const armnn::TensorInfo&, // input
+ const armnn::TensorInfo&, // output
+ const armnn::ArgMinMaxDescriptor&, // descriptor
armnn::Optional<std::string &> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsBatchNormalizationSupported(const TensorInfo& /*input*/,
- const TensorInfo& /*output*/,
- const TensorInfo& /*mean*/,
- const TensorInfo& /*var*/,
- const TensorInfo& /*beta*/,
- const TensorInfo& /*gamma*/,
- const BatchNormalizationDescriptor& /*descriptor*/,
+bool LayerSupportBase::IsBatchNormalizationSupported(const TensorInfo&, //input
+ const TensorInfo&, // output
+ const TensorInfo&, //mean
+ const TensorInfo&, //var
+ const TensorInfo&, //beta
+ const TensorInfo&, //gamma
+ const BatchNormalizationDescriptor&, // descriptor
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsBatchToSpaceNdSupported(const TensorInfo& /*input*/,
- const TensorInfo& /*output*/,
- const BatchToSpaceNdDescriptor& /*descriptor*/,
+bool LayerSupportBase::IsBatchToSpaceNdSupported(const TensorInfo&, // input
+ const TensorInfo&, // output
+ const BatchToSpaceNdDescriptor&, //descriptor
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsComparisonSupported(const TensorInfo& /*input0*/,
- const TensorInfo& /*input1*/,
- const TensorInfo& /*output*/,
- const ComparisonDescriptor& /*descriptor*/,
+bool LayerSupportBase::IsComparisonSupported(const TensorInfo&, // input0
+ const TensorInfo&, // input1
+ const TensorInfo&, // output
+ const ComparisonDescriptor&, // descriptor
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsConcatSupported(const std::vector<const TensorInfo*> /*inputs*/,
- const TensorInfo& /*output*/,
- const OriginsDescriptor& /*descriptor*/,
+bool LayerSupportBase::IsConcatSupported(const std::vector<const TensorInfo*>, // inputs
+ const TensorInfo&, // output
+ const OriginsDescriptor&, // descriptor
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsConstantSupported(const TensorInfo& /*output*/,
+bool LayerSupportBase::IsConstantSupported(const TensorInfo&, // output
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsConvertBf16ToFp32Supported(const TensorInfo& /*input*/,
- const TensorInfo& /*output*/,
+bool LayerSupportBase::IsConvertBf16ToFp32Supported(const TensorInfo&, // input
+ const TensorInfo&, // output
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsConvertFp16ToFp32Supported(const TensorInfo& /*input*/,
- const TensorInfo& /*output*/,
+bool LayerSupportBase::IsConvertFp16ToFp32Supported(const TensorInfo&, // input
+ const TensorInfo&, // output
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsConvertFp32ToBf16Supported(const TensorInfo& /*input*/,
- const TensorInfo& /*output*/,
+bool LayerSupportBase::IsConvertFp32ToBf16Supported(const TensorInfo&, // input
+ const TensorInfo&, // output
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsConvertFp32ToFp16Supported(const TensorInfo& /*input*/,
- const TensorInfo& /*output*/,
+bool LayerSupportBase::IsConvertFp32ToFp16Supported(const TensorInfo&, // input
+ const TensorInfo&, // output
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsConvolution2dSupported(const TensorInfo& /*input*/,
- const TensorInfo& /*output*/,
- const Convolution2dDescriptor& /*descriptor*/,
- const TensorInfo& /*weights*/,
- const Optional<TensorInfo>& /*biases*/,
+bool LayerSupportBase::IsConvolution2dSupported(const TensorInfo&, // input
+ const TensorInfo&, // output
+ const Convolution2dDescriptor&, // descriptor
+ const TensorInfo&, // weights
+ const Optional<TensorInfo>&, // biases
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsDebugSupported(const TensorInfo& /*input*/,
- const TensorInfo& /*output*/,
+bool LayerSupportBase::IsDebugSupported(const TensorInfo&, // input
+ const TensorInfo&, // output
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsDepthToSpaceSupported(const TensorInfo& /*input*/,
- const TensorInfo& /*output*/,
- const DepthToSpaceDescriptor& /*descriptor*/,
+bool LayerSupportBase::IsDepthToSpaceSupported(const TensorInfo&, // input
+ const TensorInfo&, // output
+ const DepthToSpaceDescriptor&, // descriptor
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsDepthwiseConvolutionSupported(const TensorInfo& /*input*/,
- const TensorInfo& /*output*/,
- const DepthwiseConvolution2dDescriptor& /*descriptor*/,
- const TensorInfo& /*weights*/,
- const Optional<TensorInfo>& /*biases*/,
+bool LayerSupportBase::IsDepthwiseConvolutionSupported(const TensorInfo&, //input
+ const TensorInfo&, //output
+ const DepthwiseConvolution2dDescriptor&, // descriptor
+ const TensorInfo&, // weights
+ const Optional<TensorInfo>&, // biases
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsDequantizeSupported(const TensorInfo& /*input*/,
- const TensorInfo& /*output*/,
+bool LayerSupportBase::IsDequantizeSupported(const TensorInfo&, // input
+ const TensorInfo&, // output
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsDetectionPostProcessSupported(const TensorInfo& /*boxEncodings*/,
- const TensorInfo& /*scores*/,
- const TensorInfo& /*anchors*/,
- const TensorInfo& /*detectionBoxes*/,
- const TensorInfo& /*detectionClasses*/,
- const TensorInfo& /*detectionScores*/,
- const TensorInfo& /*numDetections*/,
- const DetectionPostProcessDescriptor& /*descriptor*/,
+bool LayerSupportBase::IsDetectionPostProcessSupported(const TensorInfo&, // boxEncodings
+ const TensorInfo&, // scores
+ const TensorInfo&, // anchors
+ const TensorInfo&, // detectionBoxes
+ const TensorInfo&, // detectionClasses
+ const TensorInfo&, // detectionScores
+ const TensorInfo&, // numDetections
+ const DetectionPostProcessDescriptor&, //descriptor
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& /*input*/,
- const TensorInfo& /*output*/,
- const DepthwiseConvolution2dDescriptor& /*descriptor*/,
- const TensorInfo& /*weights*/,
- const Optional<TensorInfo>& /*biases*/,
+bool LayerSupportBase::IsDilatedDepthwiseConvolutionSupported(const TensorInfo&, // input
+ const TensorInfo&, // output
+ const DepthwiseConvolution2dDescriptor&, // descriptor
+ const TensorInfo&,// weights
+ const Optional<TensorInfo>&, // biases
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsDivisionSupported(const TensorInfo& /*input0*/,
- const TensorInfo& /*input1*/,
- const TensorInfo& /*output*/,
+bool LayerSupportBase::IsDivisionSupported(const TensorInfo&, // input0
+ const TensorInfo&, // input1
+ const TensorInfo&, // output
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
@@ -233,139 +233,148 @@ bool LayerSupportBase::IsElementwiseUnarySupported(const TensorInfo& input,
return false;
}
-bool LayerSupportBase::IsEqualSupported(const armnn::TensorInfo& /*input0*/,
- const armnn::TensorInfo& /*input1*/,
- const armnn::TensorInfo& /*output*/,
+bool LayerSupportBase::IsEqualSupported(const armnn::TensorInfo&, // input0
+ const armnn::TensorInfo&, // input1
+ const armnn::TensorInfo&, // output
armnn::Optional<std::string &> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsFakeQuantizationSupported(const TensorInfo& /*input*/,
- const FakeQuantizationDescriptor& /*descriptor*/,
+bool LayerSupportBase::IsFakeQuantizationSupported(const TensorInfo&, // input
+ const FakeQuantizationDescriptor&, // descriptor
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsFillSupported(const TensorInfo& /*input*/,
- const TensorInfo& /*output*/,
- const FillDescriptor& /*descriptor*/,
+bool LayerSupportBase::IsFillSupported(const TensorInfo&, // input
+ const TensorInfo&, // output
+ const FillDescriptor&, // descriptor
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsFloorSupported(const TensorInfo& /*input*/,
- const TensorInfo& /*output*/,
+bool LayerSupportBase::IsFloorSupported(const TensorInfo&, // input
+ const TensorInfo&, // output
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsFullyConnectedSupported(const TensorInfo& /*input*/,
- const TensorInfo& /*output*/,
- const TensorInfo& /*weights*/,
- const TensorInfo& /*biases*/,
- const FullyConnectedDescriptor& /*descriptor*/,
+bool LayerSupportBase::IsFullyConnectedSupported(const TensorInfo&, // input
+ const TensorInfo&, // output
+ const TensorInfo&, // weights
+ const TensorInfo&, // biases
+ const FullyConnectedDescriptor&, // descriptor
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsGatherSupported(const armnn::TensorInfo& /*input0*/,
- const armnn::TensorInfo& /*input1*/,
- const armnn::TensorInfo& /*output*/,
+bool LayerSupportBase::IsGatherSupported(const armnn::TensorInfo&, // input0
+ const armnn::TensorInfo&, // input1
+ const armnn::TensorInfo&, // output
armnn::Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsGreaterSupported(const TensorInfo& /*input0*/,
- const TensorInfo& /*input1*/,
- const TensorInfo& /*output*/,
+bool LayerSupportBase::IsGatherSupported(const armnn::TensorInfo&, // input0
+ const armnn::TensorInfo&, // input1
+ const armnn::TensorInfo&, // output
+ const GatherDescriptor&, // descriptor
+ armnn::Optional<std::string&> reasonIfUnsupported) const
+{
+ return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
+}
+
+bool LayerSupportBase::IsGreaterSupported(const TensorInfo&, // input0
+ const TensorInfo&, // input1
+ const TensorInfo&, // output
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsInputSupported(const TensorInfo& /*input*/,
+bool LayerSupportBase::IsInputSupported(const TensorInfo&, // input
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsInstanceNormalizationSupported(const TensorInfo& /*input*/,
- const TensorInfo& /*output*/,
- const InstanceNormalizationDescriptor& /*descriptor*/,
+bool LayerSupportBase::IsInstanceNormalizationSupported(const TensorInfo&, // input
+ const TensorInfo&, // output
+ const InstanceNormalizationDescriptor&, // descriptor
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsL2NormalizationSupported(const TensorInfo& /*input*/,
- const TensorInfo& /*output*/,
- const L2NormalizationDescriptor& /*descriptor*/,
+bool LayerSupportBase::IsL2NormalizationSupported(const TensorInfo&, // input
+ const TensorInfo&, // output
+ const L2NormalizationDescriptor&, // descriptor
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsLogSoftmaxSupported(const TensorInfo& /*input*/,
- const TensorInfo& /*output*/,
- const LogSoftmaxDescriptor& /*descriptor*/,
+bool LayerSupportBase::IsLogSoftmaxSupported(const TensorInfo&, // input
+ const TensorInfo&, // output
+ const LogSoftmaxDescriptor&, // descriptor
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsLstmSupported(const TensorInfo& /*input*/,
- const TensorInfo& /*outputStateIn*/,
- const TensorInfo& /*cellStateIn*/,
- const TensorInfo& /*scratchBuffer*/,
- const TensorInfo& /*outputStateOut*/,
- const TensorInfo& /*cellStateOut*/,
- const TensorInfo& /*output*/,
- const LstmDescriptor& /*descriptor*/,
- const LstmInputParamsInfo& /*paramsInfo*/,
+bool LayerSupportBase::IsLstmSupported(const TensorInfo&, // input
+ const TensorInfo&, // outputStateIn
+ const TensorInfo&, // cellStateIn
+ const TensorInfo&, // scratchBuffer
+ const TensorInfo&, // outputStateOut
+ const TensorInfo&, // cellStateOut
+ const TensorInfo&, // output
+ const LstmDescriptor&, // descriptor
+ const LstmInputParamsInfo&, // paramsInfo
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsMaximumSupported(const TensorInfo& /*input0*/,
- const TensorInfo& /*input1*/,
- const TensorInfo& /*output*/,
+bool LayerSupportBase::IsMaximumSupported(const TensorInfo&, // input0
+ const TensorInfo&, // input1
+ const TensorInfo&, // output
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsMeanSupported(const TensorInfo& /*input*/,
- const TensorInfo& /*output*/,
- const MeanDescriptor& /*descriptor*/,
+bool LayerSupportBase::IsMeanSupported(const TensorInfo&, // input
+ const TensorInfo&, // output
+ const MeanDescriptor&, // descriptor
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsMemCopySupported(const armnn::TensorInfo& /*input*/,
- const armnn::TensorInfo& /*output*/,
- armnn::Optional<std::string &> /*reasonIfUnsupported*/) const
+bool LayerSupportBase::IsMemCopySupported(const armnn::TensorInfo&, // input
+ const armnn::TensorInfo&, // output
+ armnn::Optional<std::string &> ) const // reasonIfUnsupported
{
return true;
}
-bool LayerSupportBase::IsMemImportSupported(const armnn::TensorInfo& /*input*/,
- const armnn::TensorInfo& /*output*/,
- armnn::Optional<std::string &> /*reasonIfUnsupported*/) const
+bool LayerSupportBase::IsMemImportSupported(const armnn::TensorInfo&, // input
+ const armnn::TensorInfo&, // output
+ armnn::Optional<std::string &> ) const // reasonIfUnsupported
{
return true;
}
-bool LayerSupportBase::IsMergeSupported(const TensorInfo& /*input0*/,
- const TensorInfo& /*input1*/,
- const TensorInfo& /*output*/,
+bool LayerSupportBase::IsMergeSupported(const TensorInfo&, // input0
+ const TensorInfo&, // input1
+ const TensorInfo&, // output
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
@@ -379,194 +388,194 @@ bool LayerSupportBase::IsMergerSupported(const std::vector<const TensorInfo*> in
return IsConcatSupported(inputs, output, descriptor, reasonIfUnsupported);
}
-bool LayerSupportBase::IsMinimumSupported(const TensorInfo& /*input0*/,
- const TensorInfo& /*input1*/,
- const TensorInfo& /*output*/,
+bool LayerSupportBase::IsMinimumSupported(const TensorInfo&, // input0
+ const TensorInfo&, // input1
+ const TensorInfo&, // output
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsMultiplicationSupported(const TensorInfo& /*input0*/,
- const TensorInfo& /*input1*/,
- const TensorInfo& /*output*/,
+bool LayerSupportBase::IsMultiplicationSupported(const TensorInfo&, // input0
+ const TensorInfo&, // input1
+ const TensorInfo&, // output
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsNormalizationSupported(const TensorInfo& /*input*/,
- const TensorInfo& /*output*/,
- const NormalizationDescriptor& /*descriptor*/,
+bool LayerSupportBase::IsNormalizationSupported(const TensorInfo&, // input
+ const TensorInfo&, // output
+ const NormalizationDescriptor&, // descriptor
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsOutputSupported(const TensorInfo& /*output*/,
+bool LayerSupportBase::IsOutputSupported(const TensorInfo&, // output
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsPadSupported(const TensorInfo& /*input*/,
- const TensorInfo& /*output*/,
- const PadDescriptor& /*descriptor*/,
+bool LayerSupportBase::IsPadSupported(const TensorInfo&, // input
+ const TensorInfo&, // output
+ const PadDescriptor&, // descriptor
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsPermuteSupported(const TensorInfo& /*input*/,
- const TensorInfo& /*output*/,
- const PermuteDescriptor& /*descriptor*/,
+bool LayerSupportBase::IsPermuteSupported(const TensorInfo&, // input
+ const TensorInfo&, // output
+ const PermuteDescriptor&, // descriptor
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsPooling2dSupported(const TensorInfo& /*input*/,
- const TensorInfo& /*output*/,
- const Pooling2dDescriptor& /*descriptor*/,
+bool LayerSupportBase::IsPooling2dSupported(const TensorInfo&, // input
+ const TensorInfo&, // output
+ const Pooling2dDescriptor&, // descriptor
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsPreCompiledSupported(const TensorInfo& /*input*/,
- const PreCompiledDescriptor& /*descriptor*/,
+bool LayerSupportBase::IsPreCompiledSupported(const TensorInfo&, // input
+ const PreCompiledDescriptor&, // descriptor
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsPreluSupported(const TensorInfo& /*input*/,
- const TensorInfo& /*alpha*/,
- const TensorInfo& /*output*/,
+bool LayerSupportBase::IsPreluSupported(const TensorInfo&, // input
+ const TensorInfo&, // alpha
+ const TensorInfo&, // output
Optional<std::string &> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsQuantizeSupported(const armnn::TensorInfo& /*input*/,
- const armnn::TensorInfo& /*output*/,
+bool LayerSupportBase::IsQuantizeSupported(const armnn::TensorInfo&, // input
+ const armnn::TensorInfo&, // output
armnn::Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsQLstmSupported(const TensorInfo& /*input*/,
- const TensorInfo& /*previousOutputIn*/,
- const TensorInfo& /*previousCellStateIn*/,
- const TensorInfo& /*outputStateOut*/,
- const TensorInfo& /*cellStateOut*/,
- const TensorInfo& /*output*/,
- const QLstmDescriptor& /*descriptor*/,
- const LstmInputParamsInfo& /*paramsInfo*/,
+bool LayerSupportBase::IsQLstmSupported(const TensorInfo&, // input
+ const TensorInfo&, // previousOutputIn
+ const TensorInfo&, // previousCellStateIn
+ const TensorInfo&, // outputStateOut
+ const TensorInfo&, // cellStateOut
+ const TensorInfo&, // output
+ const QLstmDescriptor&, // descriptor
+ const LstmInputParamsInfo&, // paramsInfo
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsQuantizedLstmSupported(const TensorInfo& /*input*/,
- const TensorInfo& /*previousCellStateIn*/,
- const TensorInfo& /*previousOutputIn*/,
- const TensorInfo& /*cellStateOut*/,
- const TensorInfo& /*output*/,
- const QuantizedLstmInputParamsInfo& /*paramsInfo*/,
+bool LayerSupportBase::IsQuantizedLstmSupported(const TensorInfo&, // input
+ const TensorInfo&, // previousCellStateIn
+ const TensorInfo&, // previousOutputIn
+ const TensorInfo&, // cellStateOut
+ const TensorInfo&, // output
+ const QuantizedLstmInputParamsInfo&, // paramsInfo
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsReshapeSupported(const TensorInfo& /*input*/,
- const TensorInfo& /*output*/,
- const ReshapeDescriptor& /*descriptor*/,
+bool LayerSupportBase::IsReshapeSupported(const TensorInfo&, // input
+ const TensorInfo&, // output
+ const ReshapeDescriptor&, // descriptor
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsResizeBilinearSupported(const TensorInfo& /*input*/,
- const TensorInfo& /*output*/,
+bool LayerSupportBase::IsResizeBilinearSupported(const TensorInfo&, // input
+ const TensorInfo&, // output
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsResizeSupported(const TensorInfo& /*input*/,
- const TensorInfo& /*output*/,
- const ResizeDescriptor& /*descriptor*/,
+bool LayerSupportBase::IsResizeSupported(const TensorInfo&, // input
+ const TensorInfo&, // output
+ const ResizeDescriptor&, // descriptor
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsRsqrtSupported(const TensorInfo &/*input*/,
- const TensorInfo &/*output*/,
+bool LayerSupportBase::IsRsqrtSupported(const TensorInfo&, // input
+ const TensorInfo&, // output
Optional<std::string &> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsSliceSupported(const TensorInfo& /*input*/,
- const TensorInfo& /*output*/,
- const SliceDescriptor& /*descriptor*/,
+bool LayerSupportBase::IsSliceSupported(const TensorInfo&, // input
+ const TensorInfo&, // output
+ const SliceDescriptor&, // descriptor
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsSoftmaxSupported(const TensorInfo& /*input*/,
- const TensorInfo& /*output*/,
- const SoftmaxDescriptor& /*descriptor*/,
+bool LayerSupportBase::IsSoftmaxSupported(const TensorInfo&, // input
+ const TensorInfo&, // output
+ const SoftmaxDescriptor&, // descriptor
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
/**/
-bool LayerSupportBase::IsSpaceToBatchNdSupported(const TensorInfo& /*input*/,
- const TensorInfo& /*output*/,
- const SpaceToBatchNdDescriptor& /*descriptor*/,
+bool LayerSupportBase::IsSpaceToBatchNdSupported(const TensorInfo&, // input
+ const TensorInfo&, // output
+ const SpaceToBatchNdDescriptor&, // descriptor
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsSpaceToDepthSupported(const TensorInfo& /*input*/,
- const TensorInfo& /*output*/,
- const SpaceToDepthDescriptor& /*descriptor*/,
+bool LayerSupportBase::IsSpaceToDepthSupported(const TensorInfo&, // input
+ const TensorInfo&, // output
+ const SpaceToDepthDescriptor&, // descriptor
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsSplitterSupported(const TensorInfo& /*input*/,
- const ViewsDescriptor& /*descriptor*/,
+bool LayerSupportBase::IsSplitterSupported(const TensorInfo&, // input
+ const ViewsDescriptor&, // descriptor
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsSplitterSupported(const TensorInfo& /*input*/,
- const std::vector<std::reference_wrapper<TensorInfo>>& /*outputs*/,
- const ViewsDescriptor& /*descriptor*/,
+bool LayerSupportBase::IsSplitterSupported(const TensorInfo&, // input
+ const std::vector<std::reference_wrapper<TensorInfo>>&, // outputs
+ const ViewsDescriptor&, // descriptor
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsStackSupported(const std::vector<const TensorInfo*>& /*inputs*/,
- const TensorInfo& /*output*/,
- const StackDescriptor& /*descriptor*/,
+bool LayerSupportBase::IsStackSupported(const std::vector<const TensorInfo*>&, // inputs
+ const TensorInfo&, // output
+ const StackDescriptor&, // descriptor
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsStandInSupported(const std::vector<const TensorInfo*>& /*inputs*/,
- const std::vector<const TensorInfo*>& /*outputs*/,
- const StandInDescriptor& /*descriptor*/,
+bool LayerSupportBase::IsStandInSupported(const std::vector<const TensorInfo*>&, // inputs
+ const std::vector<const TensorInfo*>&, // outputs
+ const StandInDescriptor&, // descriptor
Optional<std::string&> reasonIfUnsupported) const
{
if (reasonIfUnsupported)
@@ -580,44 +589,44 @@ bool LayerSupportBase::IsStandInSupported(const std::vector<const TensorInfo*>&
return false;
}
-bool LayerSupportBase::IsStridedSliceSupported(const TensorInfo& /*input*/,
- const TensorInfo& /*output*/,
- const StridedSliceDescriptor& /*descriptor*/,
+bool LayerSupportBase::IsStridedSliceSupported(const TensorInfo&, // input
+ const TensorInfo&, // output
+ const StridedSliceDescriptor&, // descriptor
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsSubtractionSupported(const TensorInfo& /*input0*/,
- const TensorInfo& /*input1*/,
- const TensorInfo& /*output*/,
+bool LayerSupportBase::IsSubtractionSupported(const TensorInfo&, // input0
+ const TensorInfo&, // input1
+ const TensorInfo&, // output
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsSwitchSupported(const TensorInfo& /*input0*/,
- const TensorInfo& /*input1*/,
- const TensorInfo& /*output0*/,
- const TensorInfo& /*output1*/,
+bool LayerSupportBase::IsSwitchSupported(const TensorInfo&, // input0
+ const TensorInfo&, // input1
+ const TensorInfo&, // output0
+ const TensorInfo&, // output1
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
-bool LayerSupportBase::IsTransposeConvolution2dSupported(const TensorInfo& /*input*/,
- const TensorInfo& /*output*/,
- const TransposeConvolution2dDescriptor& /*descriptor*/,
- const TensorInfo& /*weights*/,
- const Optional<TensorInfo>& /*biases*/,
+bool LayerSupportBase::IsTransposeConvolution2dSupported(const TensorInfo&, // input
+ const TensorInfo&, // output
+ const TransposeConvolution2dDescriptor&, // descriptor
+ const TensorInfo&, // weights
+ const Optional<TensorInfo>&, // biases
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
-}
+}
-bool LayerSupportBase::IsTransposeSupported(const TensorInfo& /*input*/,
- const TensorInfo& /*output*/,
- const TransposeDescriptor& /*descriptor*/,
+bool LayerSupportBase::IsTransposeSupported(const TensorInfo&, // input
+ const TensorInfo&, // output
+ const TransposeDescriptor&, // descriptor
Optional<std::string&> reasonIfUnsupported) const
{
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
diff --git a/src/backends/backendsCommon/LayerSupportBase.hpp b/src/backends/backendsCommon/LayerSupportBase.hpp
index aff4529417..8d5535ab4e 100644
--- a/src/backends/backendsCommon/LayerSupportBase.hpp
+++ b/src/backends/backendsCommon/LayerSupportBase.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -159,11 +159,18 @@ public:
const FullyConnectedDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+ ARMNN_DEPRECATED_MSG("Use IsGatherSupported with descriptor instead")
bool IsGatherSupported(const TensorInfo& input0,
const TensorInfo& input1,
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+ bool IsGatherSupported(const TensorInfo& input0,
+ const TensorInfo& input1,
+ const TensorInfo& output,
+ const GatherDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+
ARMNN_DEPRECATED_MSG("Use IsComparisonSupported instead")
bool IsGreaterSupported(const TensorInfo& input0,
const TensorInfo& input1,
diff --git a/src/backends/backendsCommon/WorkloadData.hpp b/src/backends/backendsCommon/WorkloadData.hpp
index f2f7089040..6b2c00e298 100644
--- a/src/backends/backendsCommon/WorkloadData.hpp
+++ b/src/backends/backendsCommon/WorkloadData.hpp
@@ -471,7 +471,7 @@ struct RsqrtQueueDescriptor : QueueDescriptor
void Validate(const WorkloadInfo& workloadInfo) const;
};
-struct GatherQueueDescriptor : QueueDescriptor
+struct GatherQueueDescriptor : QueueDescriptorWithParameters<GatherDescriptor>
{
void Validate(const WorkloadInfo& workloadInfo) const;
};
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp
index d2565cf21d..788cb7e712 100644
--- a/src/backends/backendsCommon/WorkloadFactory.cpp
+++ b/src/backends/backendsCommon/WorkloadFactory.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -414,9 +414,12 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
+ auto cLayer = PolymorphicDowncast<const GatherLayer*>(&layer);
+ const GatherDescriptor& descriptor = cLayer->GetParameters();
result = layerSupportObject->IsGatherSupported(OverrideDataType(input0, dataType),
input1,
OverrideDataType(output, dataType),
+ descriptor,
reason);
break;
}
diff --git a/src/backends/backendsCommon/test/GatherEndToEndTestImpl.hpp b/src/backends/backendsCommon/test/GatherEndToEndTestImpl.hpp
index 1c97bef467..82f94512c3 100644
--- a/src/backends/backendsCommon/test/GatherEndToEndTestImpl.hpp
+++ b/src/backends/backendsCommon/test/GatherEndToEndTestImpl.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -19,9 +19,10 @@ armnn::INetworkPtr CreateGatherNetwork(const armnn::TensorInfo& paramsInfo,
{
armnn::INetworkPtr net(armnn::INetwork::Create());
+ armnn::GatherDescriptor descriptor;
armnn::IConnectableLayer* paramsLayer = net->AddInputLayer(0);
armnn::IConnectableLayer* indicesLayer = net->AddConstantLayer(armnn::ConstTensor(indicesInfo, indicesData));
- armnn::IConnectableLayer* gatherLayer = net->AddGatherLayer("gather");
+ armnn::IConnectableLayer* gatherLayer = net->AddGatherLayer(descriptor, "gather");
armnn::IConnectableLayer* outputLayer = net->AddOutputLayer(0, "output");
Connect(paramsLayer, gatherLayer, paramsInfo, 0, 0);
Connect(indicesLayer, gatherLayer, indicesInfo, 0, 1);
diff --git a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
index dcd073d279..e30cbb3d31 100644
--- a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
+++ b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
@@ -551,7 +551,7 @@ DECLARE_LAYER_POLICY_1_PARAM(Floor)
DECLARE_LAYER_POLICY_2_PARAM(FullyConnected)
-DECLARE_LAYER_POLICY_1_PARAM(Gather)
+DECLARE_LAYER_POLICY_2_PARAM(Gather)
DECLARE_LAYER_POLICY_CUSTOM_PARAM(Input, armnn::LayerBindingId)
diff --git a/src/backends/cl/ClLayerSupport.cpp b/src/backends/cl/ClLayerSupport.cpp
index 44da423f92..0bff96345a 100644
--- a/src/backends/cl/ClLayerSupport.cpp
+++ b/src/backends/cl/ClLayerSupport.cpp
@@ -466,13 +466,15 @@ bool ClLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
bool ClLayerSupport::IsGatherSupported(const TensorInfo& input0,
const TensorInfo& input1,
const TensorInfo& output,
+ const GatherDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported) const
{
FORWARD_WORKLOAD_VALIDATE_FUNC(ClGatherWorkloadValidate,
reasonIfUnsupported,
input0,
input1,
- output);
+ output,
+ descriptor);
}
bool ClLayerSupport::IsGreaterSupported(const TensorInfo& input0,
diff --git a/src/backends/cl/ClLayerSupport.hpp b/src/backends/cl/ClLayerSupport.hpp
index d3c3295d2f..49100bf6a2 100644
--- a/src/backends/cl/ClLayerSupport.hpp
+++ b/src/backends/cl/ClLayerSupport.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
@@ -127,6 +127,7 @@ public:
bool IsGatherSupported(const TensorInfo& input0,
const TensorInfo& input1,
const TensorInfo& output,
+ const GatherDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported) const override;
ARMNN_DEPRECATED_MSG("Use IsComparisonSupported instead")
diff --git a/src/backends/cl/workloads/ClGatherWorkload.cpp b/src/backends/cl/workloads/ClGatherWorkload.cpp
index 068487039b..c76b9c7a17 100644
--- a/src/backends/cl/workloads/ClGatherWorkload.cpp
+++ b/src/backends/cl/workloads/ClGatherWorkload.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2020 Arm Ltd. All rights reserved.
+// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -14,13 +14,14 @@ namespace armnn
{
arm_compute::Status ClGatherWorkloadValidate(const TensorInfo& input,
const TensorInfo& indices,
- const TensorInfo& output)
+ const TensorInfo& output,
+ const GatherDescriptor& descriptor)
{
const arm_compute::TensorInfo aclInput = BuildArmComputeTensorInfo(input);
const arm_compute::TensorInfo aclIndices = BuildArmComputeTensorInfo(indices);
const arm_compute::TensorInfo aclOutput = BuildArmComputeTensorInfo(output);
- int aclAxis = ComputeAclAxis(0, input);
+ int aclAxis = ComputeAclAxis(descriptor.m_Axis, input);
return arm_compute::CLGather::validate(&aclInput, &aclIndices, &aclOutput, aclAxis);
}
@@ -35,7 +36,7 @@ ClGatherWorkload::ClGatherWorkload(const GatherQueueDescriptor& descriptor,
arm_compute::ICLTensor& indices = static_cast<IClTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
- int aclAxis = ComputeAclAxis(0, info.m_InputTensorInfos[0]);
+ int aclAxis = ComputeAclAxis(descriptor.m_Parameters.m_Axis, info.m_InputTensorInfos[0]);
m_Layer.configure(&input, &indices, &output, aclAxis);
};
diff --git a/src/backends/cl/workloads/ClGatherWorkload.hpp b/src/backends/cl/workloads/ClGatherWorkload.hpp
index 5dbeaade59..df71a99fa0 100644
--- a/src/backends/cl/workloads/ClGatherWorkload.hpp
+++ b/src/backends/cl/workloads/ClGatherWorkload.hpp
@@ -13,7 +13,8 @@ namespace armnn
{
arm_compute::Status ClGatherWorkloadValidate(const TensorInfo& input,
const TensorInfo& indices,
- const TensorInfo& output);
+ const TensorInfo& output,
+ const GatherDescriptor& descriptor);
class ClGatherWorkload : public BaseWorkload<GatherQueueDescriptor>
{
diff --git a/src/backends/neon/NeonLayerSupport.cpp b/src/backends/neon/NeonLayerSupport.cpp
index b611bf45f9..f6b3b7627a 100644
--- a/src/backends/neon/NeonLayerSupport.cpp
+++ b/src/backends/neon/NeonLayerSupport.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -447,13 +447,15 @@ bool NeonLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
bool NeonLayerSupport::IsGatherSupported(const TensorInfo& input0,
const TensorInfo& input1,
const TensorInfo& output,
+ const GatherDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported) const
{
FORWARD_WORKLOAD_VALIDATE_FUNC(NeonGatherWorkloadValidate,
reasonIfUnsupported,
input0,
input1,
- output);
+ output,
+ descriptor);
}
bool NeonLayerSupport::IsGreaterSupported(const armnn::TensorInfo& input0,
diff --git a/src/backends/neon/NeonLayerSupport.hpp b/src/backends/neon/NeonLayerSupport.hpp
index 7217fc8971..aff62d18d7 100644
--- a/src/backends/neon/NeonLayerSupport.hpp
+++ b/src/backends/neon/NeonLayerSupport.hpp
@@ -131,6 +131,7 @@ public:
bool IsGatherSupported(const TensorInfo& input0,
const TensorInfo& input1,
const TensorInfo& output,
+ const GatherDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported) const override;
ARMNN_DEPRECATED_MSG("Use IsComparisonSupported instead")
diff --git a/src/backends/neon/workloads/NeonGatherWorkload.cpp b/src/backends/neon/workloads/NeonGatherWorkload.cpp
index 2e7c741781..2c94cb5184 100644
--- a/src/backends/neon/workloads/NeonGatherWorkload.cpp
+++ b/src/backends/neon/workloads/NeonGatherWorkload.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2020 Arm Ltd. All rights reserved.
+// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -12,13 +12,14 @@ namespace armnn
{
arm_compute::Status NeonGatherWorkloadValidate(const TensorInfo& input,
const TensorInfo& indices,
- const TensorInfo& output)
+ const TensorInfo& output,
+ const GatherDescriptor& descriptor)
{
const arm_compute::TensorInfo aclInput = BuildArmComputeTensorInfo(input);
const arm_compute::TensorInfo aclIndices = BuildArmComputeTensorInfo(indices);
const arm_compute::TensorInfo aclOutput = BuildArmComputeTensorInfo(output);
- int aclAxis = ComputeAclAxis(0, input);
+ int aclAxis = ComputeAclAxis(descriptor.m_Axis, input);
return arm_compute::NEGather::validate(&aclInput, &aclIndices, &aclOutput, aclAxis);
}
@@ -33,7 +34,7 @@ NeonGatherWorkload::NeonGatherWorkload(const GatherQueueDescriptor& descriptor,
arm_compute::ITensor& indices = PolymorphicDowncast<IAclTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
arm_compute::ITensor& output = PolymorphicDowncast<IAclTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
- int aclAxis = ComputeAclAxis(0, info.m_InputTensorInfos[0]);
+ int aclAxis = ComputeAclAxis(descriptor.m_Parameters.m_Axis, info.m_InputTensorInfos[0]);
m_Layer.configure(&input, &indices, &output, aclAxis);
}
diff --git a/src/backends/neon/workloads/NeonGatherWorkload.hpp b/src/backends/neon/workloads/NeonGatherWorkload.hpp
index b1b47a5069..e5b7b57629 100644
--- a/src/backends/neon/workloads/NeonGatherWorkload.hpp
+++ b/src/backends/neon/workloads/NeonGatherWorkload.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2020 Arm Ltd. All rights reserved.
+// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -13,7 +13,8 @@ namespace armnn
{
arm_compute::Status NeonGatherWorkloadValidate(const TensorInfo& input,
const TensorInfo& indices,
- const TensorInfo& output);
+ const TensorInfo& output,
+ const GatherDescriptor& descriptor);
class NeonGatherWorkload : public BaseWorkload<GatherQueueDescriptor>
{
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index 18b36a5fa8..696c6d9dac 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -987,6 +987,7 @@ bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
const armnn::TensorInfo& input1,
const armnn::TensorInfo& output,
+ const GatherDescriptor& descriptor,
armnn::Optional<std::string&> reasonIfUnsupported) const
{
bool supported = true;
@@ -1001,6 +1002,11 @@ bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
DataType::Signed32
};
+ if (descriptor.m_Axis != 0)
+ {
+ reasonIfUnsupported.value() += std::string("Reference Gather: axis not supported\n");
+ supported &= false;
+ }
supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
"Reference Gather: input type not supported");
diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp
index 96bff56a42..7d2bbf240e 100644
--- a/src/backends/reference/RefLayerSupport.hpp
+++ b/src/backends/reference/RefLayerSupport.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
@@ -160,6 +160,7 @@ public:
bool IsGatherSupported(const TensorInfo& input0,
const TensorInfo& input1,
const TensorInfo& output,
+ const GatherDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
ARMNN_DEPRECATED_MSG("Use IsComparisonSupported instead")
@@ -346,7 +347,6 @@ public:
const TensorInfo& output,
const TransposeDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
-
};
} // namespace armnn
diff --git a/src/backends/reference/workloads/Gather.cpp b/src/backends/reference/workloads/Gather.cpp
index c23edcd3bd..3e2190c81b 100644
--- a/src/backends/reference/workloads/Gather.cpp
+++ b/src/backends/reference/workloads/Gather.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -20,9 +20,12 @@ void Gather(const TensorInfo& paramsInfo,
const TensorInfo& outputInfo,
Decoder<float>& params,
const int32_t* indices,
- Encoder<float>& output)
+ Encoder<float>& output,
+ const int32_t axis)
{
IgnoreUnused(outputInfo);
+ IgnoreUnused(axis);
+
const TensorShape& paramsShape = paramsInfo.GetShape();
unsigned int paramsProduct = 1;
diff --git a/src/backends/reference/workloads/Gather.hpp b/src/backends/reference/workloads/Gather.hpp
index 16c983eec4..1550f4b97c 100644
--- a/src/backends/reference/workloads/Gather.hpp
+++ b/src/backends/reference/workloads/Gather.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -19,6 +19,7 @@ void Gather(const TensorInfo& paramsInfo,
const TensorInfo& outputInfo,
Decoder<float>& params,
const int32_t* indices,
- Encoder<float>& output);
+ Encoder<float>& output,
+ const int32_t = 0);
} //namespace armnn
diff --git a/src/backends/reference/workloads/RefGatherWorkload.cpp b/src/backends/reference/workloads/RefGatherWorkload.cpp
index 8edf14c8f8..eaeed61b0a 100644
--- a/src/backends/reference/workloads/RefGatherWorkload.cpp
+++ b/src/backends/reference/workloads/RefGatherWorkload.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -29,7 +29,7 @@ void RefGatherWorkload::Execute() const
std::unique_ptr<Encoder<float>> encoderPtr = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[0]->Map());
Encoder<float>& encoder = *encoderPtr;
- Gather(inputInfo0, inputInfo1, outputInfo, decoder, indicesData, encoder);
+ Gather(inputInfo0, inputInfo1, outputInfo, decoder, indicesData, encoder, m_Data.m_Parameters.m_Axis);
}
} //namespace armnn