aboutsummaryrefslogtreecommitdiff
path: root/include
diff options
context:
space:
mode:
authorSadik Armagan <sadik.armagan@arm.com>2021-03-25 07:46:55 +0000
committerSadik Armagan <sadik.armagan@arm.com>2021-03-25 07:46:55 +0000
commitf0a6dec75832604d5ab18242dc216852821a8279 (patch)
treeff25e64c62c63975a54abd16a8bff744be70d7c0 /include
parent16fb1a2d9c1d3d80c0f0b6ab549919fbabd2a0b9 (diff)
downloadarmnn-f0a6dec75832604d5ab18242dc216852821a8279.tar.gz
IVGCVSW-5736 and IVGCVSW-5743 'NonConstWeights: Update front-end and TfLiteDelegate support for FullyConnected Operator'
* Added front-end support for non-const weights for FULLY_CONNECTED operator * Added FULLY_CONNECTED end-to-end test * Updated FULLY_CONNECTED operator support in TfLite Arm NN Delegate for non-const weights * Updated the version numbers Signed-off-by: Sadik Armagan <sadik.armagan@arm.com> Change-Id: Iffa5b9aa9297aca4c02d923cce4636c88ac21faa
Diffstat (limited to 'include')
-rw-r--r--include/armnn/BackendHelper.hpp10
-rw-r--r--include/armnn/Descriptors.hpp10
-rw-r--r--include/armnn/INetwork.hpp11
-rw-r--r--include/armnn/Types.hpp10
-rw-r--r--include/armnn/Version.hpp2
-rw-r--r--include/armnn/backends/IBackendInternal.hpp3
-rw-r--r--include/armnnCaffeParser/Version.hpp2
-rw-r--r--include/armnnOnnxParser/Version.hpp2
-rw-r--r--include/armnnTfLiteParser/Version.hpp2
-rw-r--r--include/armnnTfParser/Version.hpp2
10 files changed, 47 insertions, 7 deletions
diff --git a/include/armnn/BackendHelper.hpp b/include/armnn/BackendHelper.hpp
index a562f60c23..41bb5f9c3a 100644
--- a/include/armnn/BackendHelper.hpp
+++ b/include/armnn/BackendHelper.hpp
@@ -7,6 +7,7 @@
#include <armnn/BackendId.hpp>
#include <armnn/ILayerSupport.hpp>
+#include <armnn/Types.hpp>
namespace armnn
{
@@ -19,7 +20,10 @@ class LayerSupportHandle
{
public:
explicit LayerSupportHandle(std::shared_ptr<ILayerSupport> layerSupport)
- : m_LayerSupport(std::move(layerSupport)) {};
+ : m_LayerSupport(std::move(layerSupport)), m_BackendId(Compute::Undefined) {};
+
+ explicit LayerSupportHandle(std::shared_ptr<ILayerSupport> layerSupport, const BackendId& backendId)
+ : m_LayerSupport(std::move(layerSupport)), m_BackendId(backendId) {};
bool IsBackendRegistered() const;
@@ -422,9 +426,13 @@ public:
private:
std::shared_ptr<ILayerSupport> m_LayerSupport;
+ const BackendId m_BackendId;
};
/// Convenience function to retrieve the ILayerSupportHandle for a backend
LayerSupportHandle GetILayerSupportByBackendId(const armnn::BackendId& backend);
+/// Convenience function to check a capability on a backend
+bool IsCapabilitySupported(const armnn::BackendId& backend, armnn::BackendCapability capability);
+
}
diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp
index 20511ab00f..278c61f7d4 100644
--- a/include/armnn/Descriptors.hpp
+++ b/include/armnn/Descriptors.hpp
@@ -391,17 +391,25 @@ struct FullyConnectedDescriptor : BaseDescriptor
FullyConnectedDescriptor()
: m_BiasEnabled(false)
, m_TransposeWeightMatrix(false)
+ , m_ConstantWeights(true)
{}
bool operator ==(const FullyConnectedDescriptor& rhs) const
{
- return m_BiasEnabled == rhs.m_BiasEnabled && m_TransposeWeightMatrix == rhs.m_TransposeWeightMatrix;
+ return m_BiasEnabled == rhs.m_BiasEnabled
+ && m_TransposeWeightMatrix == rhs.m_TransposeWeightMatrix
+ && m_ConstantWeights == rhs.m_ConstantWeights;
}
+ /// Get the number of views/inputs.
+ uint32_t GetNumViews() const;
+
/// Enable/disable bias.
bool m_BiasEnabled;
/// Enable/disable transpose weight matrix.
bool m_TransposeWeightMatrix;
+ /// Enable/disable constant weights and biases.
+ bool m_ConstantWeights;
};
/// A Convolution2dDescriptor for the Convolution2dLayer.
diff --git a/include/armnn/INetwork.hpp b/include/armnn/INetwork.hpp
index d1d4744a42..bceb07405a 100644
--- a/include/armnn/INetwork.hpp
+++ b/include/armnn/INetwork.hpp
@@ -299,6 +299,17 @@ public:
/// Adds a fully connected layer to the network.
/// @param fullyConnectedDescriptor - Description of the fully connected layer.
+ /// @param weights -Optional Tensor for the weights data.
+ /// @param biases - Optional tensor for the bias data.
+ /// @param name - Optional name for the layer.
+ /// @return - Interface for configuring the layer.
+ IConnectableLayer* AddFullyConnectedLayer(const FullyConnectedDescriptor& fullyConnectedDescriptor,
+ const Optional<ConstTensor>& weights,
+ const Optional<ConstTensor>& biases,
+ const char* name = nullptr);
+
+ /// Adds a fully connected layer to the network.
+ /// @param fullyConnectedDescriptor - Description of the fully connected layer.
/// @param weights - Tensor for the weights data.
/// @param biases - Optional tensor for the bias data.
/// @param name - Optional name for the layer.
diff --git a/include/armnn/Types.hpp b/include/armnn/Types.hpp
index e1ff46b023..576e67ea18 100644
--- a/include/armnn/Types.hpp
+++ b/include/armnn/Types.hpp
@@ -196,6 +196,16 @@ public:
using IBackendSharedPtr = std::shared_ptr<IBackend>;
using IBackendUniquePtr = std::unique_ptr<IBackend, void(*)(IBackend* backend)>;
+/// BackendCapability class
+enum class BackendCapability : uint32_t
+{
+ /// Constant weights can be accessed through the descriptors,
+ /// On the other hand, non-const weights can be accessed through inputs.
+ NonConstWeights,
+
+ // add new enum values here
+};
+
/// Device specific knowledge to be passed to the optimizer.
class IDeviceSpec
{
diff --git a/include/armnn/Version.hpp b/include/armnn/Version.hpp
index d8c14ab262..2139637b5b 100644
--- a/include/armnn/Version.hpp
+++ b/include/armnn/Version.hpp
@@ -10,7 +10,7 @@
#define STRINGIFY_MACRO(s) #s
// ArmNN version components
-#define ARMNN_MAJOR_VERSION 24
+#define ARMNN_MAJOR_VERSION 25
#define ARMNN_MINOR_VERSION 0
#define ARMNN_PATCH_VERSION 0
diff --git a/include/armnn/backends/IBackendInternal.hpp b/include/armnn/backends/IBackendInternal.hpp
index c7ed8efa78..8035cff456 100644
--- a/include/armnn/backends/IBackendInternal.hpp
+++ b/include/armnn/backends/IBackendInternal.hpp
@@ -164,6 +164,9 @@ public:
/// Returns the version of the Backend API
static constexpr BackendVersion GetApiVersion() { return BackendVersion(1, 0); }
+
+ /// Returns true if backend support the capability false otherwise
+ virtual bool HasCapability(BackendCapability /*capabilityClass*/) const { return false; }
};
using IBackendInternalUniquePtr = std::unique_ptr<IBackendInternal>;
diff --git a/include/armnnCaffeParser/Version.hpp b/include/armnnCaffeParser/Version.hpp
index d7135bf158..6e7ce5a539 100644
--- a/include/armnnCaffeParser/Version.hpp
+++ b/include/armnnCaffeParser/Version.hpp
@@ -14,7 +14,7 @@ namespace armnnCaffeParser
// CaffeParser version components
#define CAFFE_PARSER_MAJOR_VERSION 24
-#define CAFFE_PARSER_MINOR_VERSION 0
+#define CAFFE_PARSER_MINOR_VERSION 1
#define CAFFE_PARSER_PATCH_VERSION 0
/// CAFFE_PARSER_VERSION: "X.Y.Z"
diff --git a/include/armnnOnnxParser/Version.hpp b/include/armnnOnnxParser/Version.hpp
index e42adf711d..d6308b376a 100644
--- a/include/armnnOnnxParser/Version.hpp
+++ b/include/armnnOnnxParser/Version.hpp
@@ -14,7 +14,7 @@ namespace armnnOnnxParser
// OnnxParser version components
#define ONNX_PARSER_MAJOR_VERSION 24
-#define ONNX_PARSER_MINOR_VERSION 0
+#define ONNX_PARSER_MINOR_VERSION 1
#define ONNX_PARSER_PATCH_VERSION 0
/// ONNX_PARSER_VERSION: "X.Y.Z"
diff --git a/include/armnnTfLiteParser/Version.hpp b/include/armnnTfLiteParser/Version.hpp
index 7d239bba38..99237f325d 100644
--- a/include/armnnTfLiteParser/Version.hpp
+++ b/include/armnnTfLiteParser/Version.hpp
@@ -14,7 +14,7 @@ namespace armnnTfLiteParser
// TfLiteParser version components
#define TFLITE_PARSER_MAJOR_VERSION 24
-#define TFLITE_PARSER_MINOR_VERSION 0
+#define TFLITE_PARSER_MINOR_VERSION 1
#define TFLITE_PARSER_PATCH_VERSION 0
/// TFLITE_PARSER_VERSION: "X.Y.Z"
diff --git a/include/armnnTfParser/Version.hpp b/include/armnnTfParser/Version.hpp
index 6f6aac9b38..25449f3180 100644
--- a/include/armnnTfParser/Version.hpp
+++ b/include/armnnTfParser/Version.hpp
@@ -14,7 +14,7 @@ namespace armnnTfParser
// tfParser version components
#define TF_PARSER_MAJOR_VERSION 24
-#define TF_PARSER_MINOR_VERSION 0
+#define TF_PARSER_MINOR_VERSION 1
#define TF_PARSER_PATCH_VERSION 0
/// TF_PARSER_VERSION: "X.Y.Z"