aboutsummaryrefslogtreecommitdiff
path: root/include/armnn/LstmParams.hpp
diff options
context:
space:
mode:
authorJan Eilers <jan.eilers@arm.com>2019-07-03 18:20:40 +0100
committerNikhil Raj Arm <nikhil.raj@arm.com>2019-07-09 11:22:28 +0000
commitd01a83c8de77c44a938a618918d17385da3baa88 (patch)
treefca6f5422adfbdcce059049b36d32e0168edcef4 /include/armnn/LstmParams.hpp
parente6eaf661c5b84f4ca051daaf08281d9b8de3fcb9 (diff)
downloadarmnn-d01a83c8de77c44a938a618918d17385da3baa88.tar.gz
IVGCVSW-3397 Join lstm parameter infos in a struct for isLstmSupported
!android-nn-driver:1461 Change-Id: I9d8fe7adf13832ed0cbcfe98b2353c2f37011d22 Signed-off-by: Jan Eilers <jan.eilers@arm.com>
Diffstat (limited to 'include/armnn/LstmParams.hpp')
-rw-r--r--include/armnn/LstmParams.hpp145
1 files changed, 145 insertions, 0 deletions
diff --git a/include/armnn/LstmParams.hpp b/include/armnn/LstmParams.hpp
index a7c57c78b2..0c8e66dfde 100644
--- a/include/armnn/LstmParams.hpp
+++ b/include/armnn/LstmParams.hpp
@@ -5,6 +5,7 @@
#pragma once
#include "TensorFwd.hpp"
+#include "Exceptions.hpp"
namespace armnn
{
@@ -59,5 +60,149 @@ struct LstmInputParams
const ConstTensor* m_OutputLayerNormWeights;
};
+struct LstmInputParamsInfo
+{
+ LstmInputParamsInfo()
+ : m_InputToInputWeights(nullptr)
+ , m_InputToForgetWeights(nullptr)
+ , m_InputToCellWeights(nullptr)
+ , m_InputToOutputWeights(nullptr)
+ , m_RecurrentToInputWeights(nullptr)
+ , m_RecurrentToForgetWeights(nullptr)
+ , m_RecurrentToCellWeights(nullptr)
+ , m_RecurrentToOutputWeights(nullptr)
+ , m_CellToInputWeights(nullptr)
+ , m_CellToForgetWeights(nullptr)
+ , m_CellToOutputWeights(nullptr)
+ , m_InputGateBias(nullptr)
+ , m_ForgetGateBias(nullptr)
+ , m_CellBias(nullptr)
+ , m_OutputGateBias(nullptr)
+ , m_ProjectionWeights(nullptr)
+ , m_ProjectionBias(nullptr)
+ , m_InputLayerNormWeights(nullptr)
+ , m_ForgetLayerNormWeights(nullptr)
+ , m_CellLayerNormWeights(nullptr)
+ , m_OutputLayerNormWeights(nullptr)
+ {
+ }
+ const TensorInfo* m_InputToInputWeights;
+ const TensorInfo* m_InputToForgetWeights;
+ const TensorInfo* m_InputToCellWeights;
+ const TensorInfo* m_InputToOutputWeights;
+ const TensorInfo* m_RecurrentToInputWeights;
+ const TensorInfo* m_RecurrentToForgetWeights;
+ const TensorInfo* m_RecurrentToCellWeights;
+ const TensorInfo* m_RecurrentToOutputWeights;
+ const TensorInfo* m_CellToInputWeights;
+ const TensorInfo* m_CellToForgetWeights;
+ const TensorInfo* m_CellToOutputWeights;
+ const TensorInfo* m_InputGateBias;
+ const TensorInfo* m_ForgetGateBias;
+ const TensorInfo* m_CellBias;
+ const TensorInfo* m_OutputGateBias;
+ const TensorInfo* m_ProjectionWeights;
+ const TensorInfo* m_ProjectionBias;
+ const TensorInfo* m_InputLayerNormWeights;
+ const TensorInfo* m_ForgetLayerNormWeights;
+ const TensorInfo* m_CellLayerNormWeights;
+ const TensorInfo* m_OutputLayerNormWeights;
+
+ const TensorInfo& deref(const TensorInfo* tensorInfo) const
+ {
+ if (tensorInfo != nullptr)
+ {
+ const TensorInfo &temp = *tensorInfo;
+ return temp;
+ }
+ throw InvalidArgumentException("Can't dereference a null pointer");
+ }
+
+ const TensorInfo& get_InputToInputWeights() const
+ {
+ return deref(m_InputToInputWeights);
+ }
+ const TensorInfo& get_InputToForgetWeights() const
+ {
+ return deref(m_InputToForgetWeights);
+ }
+ const TensorInfo& get_InputToCellWeights() const
+ {
+ return deref(m_InputToCellWeights);
+ }
+ const TensorInfo& get_InputToOutputWeights() const
+ {
+ return deref(m_InputToOutputWeights);
+ }
+ const TensorInfo& get_RecurrentToInputWeights() const
+ {
+ return deref(m_RecurrentToInputWeights);
+ }
+ const TensorInfo& get_RecurrentToForgetWeights() const
+ {
+ return deref(m_RecurrentToForgetWeights);
+ }
+ const TensorInfo& get_RecurrentToCellWeights() const
+ {
+ return deref(m_RecurrentToCellWeights);
+ }
+ const TensorInfo& get_RecurrentToOutputWeights() const
+ {
+ return deref(m_RecurrentToOutputWeights);
+ }
+ const TensorInfo& get_CellToInputWeights() const
+ {
+ return deref(m_CellToInputWeights);
+ }
+ const TensorInfo& get_CellToForgetWeights() const
+ {
+ return deref(m_CellToForgetWeights);
+ }
+ const TensorInfo& get_CellToOutputWeights() const
+ {
+ return deref(m_CellToOutputWeights);
+ }
+ const TensorInfo& get_InputGateBias() const
+ {
+ return deref(m_InputGateBias);
+ }
+ const TensorInfo& get_ForgetGateBias() const
+ {
+ return deref(m_ForgetGateBias);
+ }
+ const TensorInfo& get_CellBias() const
+ {
+ return deref(m_CellBias);
+ }
+ const TensorInfo& get_OutputGateBias() const
+ {
+ return deref(m_OutputGateBias);
+ }
+ const TensorInfo& get_ProjectionWeights() const
+ {
+ return deref(m_ProjectionWeights);
+ }
+ const TensorInfo& get_ProjectionBias() const
+ {
+ return deref(m_ProjectionBias);
+ }
+ const TensorInfo& get_InputLayerNormWeights() const
+ {
+ return deref(m_InputLayerNormWeights);
+ }
+ const TensorInfo& get_ForgetLayerNormWeights() const
+ {
+ return deref(m_ForgetLayerNormWeights);
+ }
+ const TensorInfo& get_CellLayerNormWeights() const
+ {
+ return deref(m_CellLayerNormWeights);
+ }
+ const TensorInfo& get_OutputLayerNormWeights() const
+ {
+ return deref(m_OutputLayerNormWeights);
+ }
+};
+
} // namespace armnn