aboutsummaryrefslogtreecommitdiff
path: root/include
diff options
context:
space:
mode:
authorTeresa Charlin <teresa.charlinreyes@arm.com>2020-06-29 16:27:03 +0100
committerTeresaARM <teresa.charlinreyes@arm.com>2020-07-01 08:26:47 +0000
commit526647333571169076f5e72c9fb18c71025bf7c0 (patch)
tree6dc559a7b0fae3705172b09a88fa552926652040 /include
parentcbd2c230b7ce5f26e2ccccf36b7ad450f6e1ad09 (diff)
downloadarmnn-526647333571169076f5e72c9fb18c71025bf7c0.tar.gz
IVGCVSW-4903 Connect axis parameter in Gather from android to ACL.
!android-nn-driver:3302 Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com> Change-Id: Ifbc49acb5272f8a36719bb68676e44817190537d
Diffstat (limited to 'include')
-rw-r--r--include/armnn/Descriptors.hpp22
-rw-r--r--include/armnn/DescriptorsFwd.hpp3
-rw-r--r--include/armnn/ILayerSupport.hpp9
-rw-r--r--include/armnn/ILayerVisitor.hpp11
-rw-r--r--include/armnn/INetwork.hpp28
-rw-r--r--include/armnn/LayerVisitorBase.hpp7
6 files changed, 65 insertions, 15 deletions
diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp
index 653e64701a..60aa219638 100644
--- a/include/armnn/Descriptors.hpp
+++ b/include/armnn/Descriptors.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
@@ -730,6 +730,26 @@ struct FillDescriptor
float m_Value;
};
+/// A GatherDescriptor for the GatherLayer.
+struct GatherDescriptor
+{
+ GatherDescriptor()
+ : m_Axis(0)
+ {}
+
+ GatherDescriptor(int32_t axis)
+ : m_Axis(axis)
+ {}
+
+ bool operator ==(const GatherDescriptor& rhs) const
+ {
+ return m_Axis == rhs.m_Axis;
+ }
+
+ /// The axis in params to gather indices from
+ int32_t m_Axis;
+};
+
/// A ResizeBilinearDescriptor for the ResizeBilinearLayer.
struct ResizeBilinearDescriptor
{
diff --git a/include/armnn/DescriptorsFwd.hpp b/include/armnn/DescriptorsFwd.hpp
index e31fb96aec..fba976c788 100644
--- a/include/armnn/DescriptorsFwd.hpp
+++ b/include/armnn/DescriptorsFwd.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -20,6 +20,7 @@ struct ElementwiseUnaryDescriptor;
struct FakeQuantizationDescriptor;
struct FillDescriptor;
struct FullyConnectedDescriptor;
+struct GatherDescriptor;
struct InstanceNormalizationDescriptor;
struct L2NormalizationDescriptor;
struct LstmDescriptor;
diff --git a/include/armnn/ILayerSupport.hpp b/include/armnn/ILayerSupport.hpp
index 33389eb25f..889b811903 100644
--- a/include/armnn/ILayerSupport.hpp
+++ b/include/armnn/ILayerSupport.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
@@ -173,11 +173,18 @@ public:
const FullyConnectedDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
+ ARMNN_DEPRECATED_MSG("Use IsGatherSupported with descriptor instead")
virtual bool IsGatherSupported(const TensorInfo& input0,
const TensorInfo& input1,
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
+ virtual bool IsGatherSupported(const TensorInfo& input0,
+ const TensorInfo& input1,
+ const TensorInfo& output,
+ const GatherDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
+
ARMNN_DEPRECATED_MSG("Use IsComparisonSupported instead")
virtual bool IsGreaterSupported(const TensorInfo& input0,
const TensorInfo& input1,
diff --git a/include/armnn/ILayerVisitor.hpp b/include/armnn/ILayerVisitor.hpp
index aa5bdba33c..9b3998db9a 100644
--- a/include/armnn/ILayerVisitor.hpp
+++ b/include/armnn/ILayerVisitor.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
@@ -214,9 +214,18 @@ public:
/// Function a Gather layer should call back to when its Accept(ILayerVisitor&) function is invoked.
/// @param layer - pointer to the layer which is calling back to this visit function.
/// @param name - Optional name for the layer.
+ ARMNN_DEPRECATED_MSG("Use VisitGatherLayer with descriptor instead")
virtual void VisitGatherLayer(const IConnectableLayer* layer,
const char* name = nullptr) = 0;
+ /// Function a Gather layer should call back to when its Accept(ILayerVisitor&) function is invoked.
+ /// @param layer - pointer to the layer which is calling back to this visit function.
+ /// @param gatherDescriptor - Parameters for the gather operation.
+ /// @param name - Optional name for the layer.
+ virtual void VisitGatherLayer(const IConnectableLayer* layer,
+ const GatherDescriptor& gatherDescriptor,
+ const char* name = nullptr) = 0;
+
/// Function a Greater layer should call back to when its Accept(ILayerVisitor&) function is invoked.
/// @param layer - pointer to the layer which is calling back to this visit function.
/// @param name - Optional name for the layer.
diff --git a/include/armnn/INetwork.hpp b/include/armnn/INetwork.hpp
index 49cd582e67..8e7a4437c8 100644
--- a/include/armnn/INetwork.hpp
+++ b/include/armnn/INetwork.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
@@ -128,7 +128,7 @@ public:
/// Add a Comparison layer to the network.
/// @param name - Optional name for the layer.
/// @param desc - Descriptor for the comparison operation.
- /// @ return - Interface for configuring the layer.
+ /// @return - Interface for configuring the layer.
virtual IConnectableLayer* AddComparisonLayer(const ComparisonDescriptor& comparisonDescriptor,
const char* name = nullptr) = 0;
@@ -324,7 +324,7 @@ public:
/// Add absolute layer to the network.
/// @param name - Optional name for the layer.
- /// @ return - Interface for configuring the layer.
+ /// @return - Interface for configuring the layer.
ARMNN_DEPRECATED_MSG("Use AddElementwiseUnaryLayer instead")
virtual IConnectableLayer* AddAbsLayer(const char* name = nullptr) = 0;
@@ -453,13 +453,13 @@ public:
/// Add a Maximum layer to the network.
/// @param name - Optional name for the layer.
- /// @ return - Interface for configuring the layer.
+ /// @return - Interface for configuring the layer.
virtual IConnectableLayer* AddMaximumLayer(const char* name = nullptr) = 0;
/// Add a Mean layer to the network.
/// @param meanDescriptor - Parameters for the mean operation.
/// @param name - Optional name for the layer.
- /// @ return - Interface for configuring the layer.
+ /// @return - Interface for configuring the layer.
virtual IConnectableLayer* AddMeanLayer(const MeanDescriptor& meanDescriptor, const char* name = nullptr) = 0;
/// Adds a fully pad layer to the network.
@@ -485,32 +485,40 @@ public:
/// Add a Minimum layer to the network.
/// @param name - Optional name for the layer.
- /// @ return - Interface for configuring the layer.
+ /// @return - Interface for configuring the layer.
virtual IConnectableLayer* AddMinimumLayer(const char* name = nullptr) = 0;
/// Add a Greater layer to the network.
/// @param name - Optional name for the layer.
- /// @ return - Interface for configuring the layer.
+ /// @return - Interface for configuring the layer.
ARMNN_DEPRECATED_MSG("Use AddComparisonLayer instead")
virtual IConnectableLayer* AddGreaterLayer(const char* name = nullptr) = 0;
/// Add a Equal layer to the network.
/// @param name - Optional name for the layer.
- /// @ return - Interface for configuring the layer.
+ /// @return - Interface for configuring the layer.
ARMNN_DEPRECATED_MSG("Use AddComparisonLayer instead")
virtual IConnectableLayer* AddEqualLayer(const char* name = nullptr) = 0;
/// Add Reciprocal of square root layer to the network.
/// @param name - Optional name for the layer.
- /// @ return - Interface for configuring the layer.
+ /// @return - Interface for configuring the layer.
ARMNN_DEPRECATED_MSG("Use AddElementwiseUnaryLayer instead")
virtual IConnectableLayer* AddRsqrtLayer(const char* name = nullptr) = 0;
/// Add Gather layer to the network.
/// @param name - Optional name for the layer.
- /// @ return - Interface for configuring the layer.
+ /// @return - Interface for configuring the layer.
+ ARMNN_DEPRECATED_MSG("Use AddGatherLayer with descriptor instead")
virtual IConnectableLayer* AddGatherLayer(const char* name = nullptr) = 0;
+ /// Add Gather layer to the network.
+ /// @param descriptor - Description of the gather layer.
+ /// @param name - Optional name for the layer.
+ /// @return - Interface for configuring the layer.
+ virtual IConnectableLayer* AddGatherLayer(const GatherDescriptor& descriptor,
+ const char* name = nullptr) = 0;
+
/// Adds a switch layer to the network.
/// @param name - Optional name for the layer.
/// @return - Interface for configuring the layer.
diff --git a/include/armnn/LayerVisitorBase.hpp b/include/armnn/LayerVisitorBase.hpp
index 0dc5e545e3..93ba7fe287 100644
--- a/include/armnn/LayerVisitorBase.hpp
+++ b/include/armnn/LayerVisitorBase.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -114,9 +114,14 @@ public:
const Optional<ConstTensor>&,
const char*) override { DefaultPolicy::Apply(__func__); }
+ ARMNN_DEPRECATED_MSG("Use VisitGatherLayer with descriptor instead")
void VisitGatherLayer(const IConnectableLayer*,
const char*) override { DefaultPolicy::Apply(__func__); }
+ void VisitGatherLayer(const IConnectableLayer*,
+ const GatherDescriptor&,
+ const char*) override { DefaultPolicy::Apply(__func__); }
+
void VisitGreaterLayer(const IConnectableLayer*,
const char*) override { DefaultPolicy::Apply(__func__); }