aboutsummaryrefslogtreecommitdiff
path: root/include
diff options
context:
space:
mode:
authorJan Eilers <jan.eilers@arm.com>2019-07-03 18:20:40 +0100
committerNikhil Raj Arm <nikhil.raj@arm.com>2019-07-09 11:22:28 +0000
commitd01a83c8de77c44a938a618918d17385da3baa88 (patch)
treefca6f5422adfbdcce059049b36d32e0168edcef4 /include
parente6eaf661c5b84f4ca051daaf08281d9b8de3fcb9 (diff)
downloadarmnn-d01a83c8de77c44a938a618918d17385da3baa88.tar.gz
IVGCVSW-3397 Join lstm parameter infos in a struct for isLstmSupported
!android-nn-driver:1461 Change-Id: I9d8fe7adf13832ed0cbcfe98b2353c2f37011d22 Signed-off-by: Jan Eilers <jan.eilers@arm.com>
Diffstat (limited to 'include')
-rw-r--r--include/armnn/ILayerSupport.hpp25
-rw-r--r--include/armnn/LayerSupport.hpp11
-rw-r--r--include/armnn/LstmParams.hpp145
3 files changed, 150 insertions, 31 deletions
diff --git a/include/armnn/ILayerSupport.hpp b/include/armnn/ILayerSupport.hpp
index 58722fe1a0..53dd29d87e 100644
--- a/include/armnn/ILayerSupport.hpp
+++ b/include/armnn/ILayerSupport.hpp
@@ -7,6 +7,7 @@
#include <armnn/Deprecated.hpp>
#include <armnn/DescriptorsFwd.hpp>
#include <armnn/Optional.hpp>
+#include <armnn/LstmParams.hpp>
#include <cctype>
#include <functional>
@@ -153,28 +154,8 @@ public:
const TensorInfo& cellStateOut,
const TensorInfo& output,
const LstmDescriptor& descriptor,
- const TensorInfo& inputToForgetWeights,
- const TensorInfo& inputToCellWeights,
- const TensorInfo& inputToOutputWeights,
- const TensorInfo& recurrentToForgetWeights,
- const TensorInfo& recurrentToCellWeights,
- const TensorInfo& recurrentToOutputWeights,
- const TensorInfo& forgetGateBias,
- const TensorInfo& cellBias,
- const TensorInfo& outputGateBias,
- const TensorInfo* inputToInputWeights,
- const TensorInfo* recurrentToInputWeights,
- const TensorInfo* cellToInputWeights,
- const TensorInfo* inputGateBias,
- const TensorInfo* projectionWeights,
- const TensorInfo* projectionBias,
- const TensorInfo* cellToForgetWeights,
- const TensorInfo* cellToOutputWeights,
- Optional<std::string&> reasonIfUnsupported = EmptyOptional(),
- const TensorInfo* inputLayerNormWeights = nullptr,
- const TensorInfo* forgetLayerNormWeights = nullptr,
- const TensorInfo* cellLayerNormWeights = nullptr,
- const TensorInfo* outputLayerNormWeights = nullptr) const = 0;
+ const LstmInputParamsInfo& paramsInfo,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
virtual bool IsMaximumSupported(const TensorInfo& input0,
const TensorInfo& input1,
diff --git a/include/armnn/LayerSupport.hpp b/include/armnn/LayerSupport.hpp
index 35336ed7da..65f9d089ba 100644
--- a/include/armnn/LayerSupport.hpp
+++ b/include/armnn/LayerSupport.hpp
@@ -9,6 +9,7 @@
#include <armnn/Optional.hpp>
#include <armnn/Tensor.hpp>
#include <armnn/Types.hpp>
+#include "LstmParams.hpp"
namespace armnn
{
@@ -178,15 +179,7 @@ bool IsLstmSupported(const BackendId& backend, const TensorInfo& input, const Te
const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer,
const TensorInfo& outputStateOut, const TensorInfo& cellStateOut,
const TensorInfo& output, const LstmDescriptor& descriptor,
- const TensorInfo& inputToForgetWeights, const TensorInfo& inputToCellWeights,
- const TensorInfo& inputToOutputWeights, const TensorInfo& recurrentToForgetWeights,
- const TensorInfo& recurrentToCellWeights, const TensorInfo& recurrentToOutputWeights,
- const TensorInfo& forgetGateBias, const TensorInfo& cellBias,
- const TensorInfo& outputGateBias, const TensorInfo* inputToInputWeights,
- const TensorInfo* recurrentToInputWeights, const TensorInfo* cellToInputWeights,
- const TensorInfo* inputGateBias, const TensorInfo* projectionWeights,
- const TensorInfo* projectionBias, const TensorInfo* cellToForgetWeights,
- const TensorInfo* cellToOutputWeights, char* reasonIfUnsupported = nullptr,
+ const LstmInputParamsInfo& paramsInfo, char* reasonIfUnsupported = nullptr,
size_t reasonIfUnsupportedMaxLength = 1024);
/// Deprecated in favor of IBackend and ILayerSupport interfaces
diff --git a/include/armnn/LstmParams.hpp b/include/armnn/LstmParams.hpp
index a7c57c78b2..0c8e66dfde 100644
--- a/include/armnn/LstmParams.hpp
+++ b/include/armnn/LstmParams.hpp
@@ -5,6 +5,7 @@
#pragma once
#include "TensorFwd.hpp"
+#include "Exceptions.hpp"
namespace armnn
{
@@ -59,5 +60,149 @@ struct LstmInputParams
const ConstTensor* m_OutputLayerNormWeights;
};
+struct LstmInputParamsInfo
+{
+ LstmInputParamsInfo()
+ : m_InputToInputWeights(nullptr)
+ , m_InputToForgetWeights(nullptr)
+ , m_InputToCellWeights(nullptr)
+ , m_InputToOutputWeights(nullptr)
+ , m_RecurrentToInputWeights(nullptr)
+ , m_RecurrentToForgetWeights(nullptr)
+ , m_RecurrentToCellWeights(nullptr)
+ , m_RecurrentToOutputWeights(nullptr)
+ , m_CellToInputWeights(nullptr)
+ , m_CellToForgetWeights(nullptr)
+ , m_CellToOutputWeights(nullptr)
+ , m_InputGateBias(nullptr)
+ , m_ForgetGateBias(nullptr)
+ , m_CellBias(nullptr)
+ , m_OutputGateBias(nullptr)
+ , m_ProjectionWeights(nullptr)
+ , m_ProjectionBias(nullptr)
+ , m_InputLayerNormWeights(nullptr)
+ , m_ForgetLayerNormWeights(nullptr)
+ , m_CellLayerNormWeights(nullptr)
+ , m_OutputLayerNormWeights(nullptr)
+ {
+ }
+ const TensorInfo* m_InputToInputWeights;
+ const TensorInfo* m_InputToForgetWeights;
+ const TensorInfo* m_InputToCellWeights;
+ const TensorInfo* m_InputToOutputWeights;
+ const TensorInfo* m_RecurrentToInputWeights;
+ const TensorInfo* m_RecurrentToForgetWeights;
+ const TensorInfo* m_RecurrentToCellWeights;
+ const TensorInfo* m_RecurrentToOutputWeights;
+ const TensorInfo* m_CellToInputWeights;
+ const TensorInfo* m_CellToForgetWeights;
+ const TensorInfo* m_CellToOutputWeights;
+ const TensorInfo* m_InputGateBias;
+ const TensorInfo* m_ForgetGateBias;
+ const TensorInfo* m_CellBias;
+ const TensorInfo* m_OutputGateBias;
+ const TensorInfo* m_ProjectionWeights;
+ const TensorInfo* m_ProjectionBias;
+ const TensorInfo* m_InputLayerNormWeights;
+ const TensorInfo* m_ForgetLayerNormWeights;
+ const TensorInfo* m_CellLayerNormWeights;
+ const TensorInfo* m_OutputLayerNormWeights;
+
+ const TensorInfo& deref(const TensorInfo* tensorInfo) const
+ {
+ if (tensorInfo != nullptr)
+ {
+ const TensorInfo &temp = *tensorInfo;
+ return temp;
+ }
+ throw InvalidArgumentException("Can't dereference a null pointer");
+ }
+
+ const TensorInfo& get_InputToInputWeights() const
+ {
+ return deref(m_InputToInputWeights);
+ }
+ const TensorInfo& get_InputToForgetWeights() const
+ {
+ return deref(m_InputToForgetWeights);
+ }
+ const TensorInfo& get_InputToCellWeights() const
+ {
+ return deref(m_InputToCellWeights);
+ }
+ const TensorInfo& get_InputToOutputWeights() const
+ {
+ return deref(m_InputToOutputWeights);
+ }
+ const TensorInfo& get_RecurrentToInputWeights() const
+ {
+ return deref(m_RecurrentToInputWeights);
+ }
+ const TensorInfo& get_RecurrentToForgetWeights() const
+ {
+ return deref(m_RecurrentToForgetWeights);
+ }
+ const TensorInfo& get_RecurrentToCellWeights() const
+ {
+ return deref(m_RecurrentToCellWeights);
+ }
+ const TensorInfo& get_RecurrentToOutputWeights() const
+ {
+ return deref(m_RecurrentToOutputWeights);
+ }
+ const TensorInfo& get_CellToInputWeights() const
+ {
+ return deref(m_CellToInputWeights);
+ }
+ const TensorInfo& get_CellToForgetWeights() const
+ {
+ return deref(m_CellToForgetWeights);
+ }
+ const TensorInfo& get_CellToOutputWeights() const
+ {
+ return deref(m_CellToOutputWeights);
+ }
+ const TensorInfo& get_InputGateBias() const
+ {
+ return deref(m_InputGateBias);
+ }
+ const TensorInfo& get_ForgetGateBias() const
+ {
+ return deref(m_ForgetGateBias);
+ }
+ const TensorInfo& get_CellBias() const
+ {
+ return deref(m_CellBias);
+ }
+ const TensorInfo& get_OutputGateBias() const
+ {
+ return deref(m_OutputGateBias);
+ }
+ const TensorInfo& get_ProjectionWeights() const
+ {
+ return deref(m_ProjectionWeights);
+ }
+ const TensorInfo& get_ProjectionBias() const
+ {
+ return deref(m_ProjectionBias);
+ }
+ const TensorInfo& get_InputLayerNormWeights() const
+ {
+ return deref(m_InputLayerNormWeights);
+ }
+ const TensorInfo& get_ForgetLayerNormWeights() const
+ {
+ return deref(m_ForgetLayerNormWeights);
+ }
+ const TensorInfo& get_CellLayerNormWeights() const
+ {
+ return deref(m_CellLayerNormWeights);
+ }
+ const TensorInfo& get_OutputLayerNormWeights() const
+ {
+ return deref(m_OutputLayerNormWeights);
+ }
+};
+
} // namespace armnn