diff options
Diffstat (limited to 'include/armnn/LstmParams.hpp')
-rw-r--r-- | include/armnn/LstmParams.hpp | 145 |
1 files changed, 145 insertions, 0 deletions
diff --git a/include/armnn/LstmParams.hpp b/include/armnn/LstmParams.hpp index a7c57c78b2..0c8e66dfde 100644 --- a/include/armnn/LstmParams.hpp +++ b/include/armnn/LstmParams.hpp @@ -5,6 +5,7 @@ #pragma once #include "TensorFwd.hpp" +#include "Exceptions.hpp" namespace armnn { @@ -59,5 +60,149 @@ struct LstmInputParams const ConstTensor* m_OutputLayerNormWeights; }; +struct LstmInputParamsInfo +{ + LstmInputParamsInfo() + : m_InputToInputWeights(nullptr) + , m_InputToForgetWeights(nullptr) + , m_InputToCellWeights(nullptr) + , m_InputToOutputWeights(nullptr) + , m_RecurrentToInputWeights(nullptr) + , m_RecurrentToForgetWeights(nullptr) + , m_RecurrentToCellWeights(nullptr) + , m_RecurrentToOutputWeights(nullptr) + , m_CellToInputWeights(nullptr) + , m_CellToForgetWeights(nullptr) + , m_CellToOutputWeights(nullptr) + , m_InputGateBias(nullptr) + , m_ForgetGateBias(nullptr) + , m_CellBias(nullptr) + , m_OutputGateBias(nullptr) + , m_ProjectionWeights(nullptr) + , m_ProjectionBias(nullptr) + , m_InputLayerNormWeights(nullptr) + , m_ForgetLayerNormWeights(nullptr) + , m_CellLayerNormWeights(nullptr) + , m_OutputLayerNormWeights(nullptr) + { + } + const TensorInfo* m_InputToInputWeights; + const TensorInfo* m_InputToForgetWeights; + const TensorInfo* m_InputToCellWeights; + const TensorInfo* m_InputToOutputWeights; + const TensorInfo* m_RecurrentToInputWeights; + const TensorInfo* m_RecurrentToForgetWeights; + const TensorInfo* m_RecurrentToCellWeights; + const TensorInfo* m_RecurrentToOutputWeights; + const TensorInfo* m_CellToInputWeights; + const TensorInfo* m_CellToForgetWeights; + const TensorInfo* m_CellToOutputWeights; + const TensorInfo* m_InputGateBias; + const TensorInfo* m_ForgetGateBias; + const TensorInfo* m_CellBias; + const TensorInfo* m_OutputGateBias; + const TensorInfo* m_ProjectionWeights; + const TensorInfo* m_ProjectionBias; + const TensorInfo* m_InputLayerNormWeights; + const TensorInfo* m_ForgetLayerNormWeights; + const TensorInfo* m_CellLayerNormWeights; + const TensorInfo* m_OutputLayerNormWeights; + + const TensorInfo& deref(const TensorInfo* tensorInfo) const + { + if (tensorInfo != nullptr) + { + const TensorInfo &temp = *tensorInfo; + return temp; + } + throw InvalidArgumentException("Can't dereference a null pointer"); + } + + const TensorInfo& get_InputToInputWeights() const + { + return deref(m_InputToInputWeights); + } + const TensorInfo& get_InputToForgetWeights() const + { + return deref(m_InputToForgetWeights); + } + const TensorInfo& get_InputToCellWeights() const + { + return deref(m_InputToCellWeights); + } + const TensorInfo& get_InputToOutputWeights() const + { + return deref(m_InputToOutputWeights); + } + const TensorInfo& get_RecurrentToInputWeights() const + { + return deref(m_RecurrentToInputWeights); + } + const TensorInfo& get_RecurrentToForgetWeights() const + { + return deref(m_RecurrentToForgetWeights); + } + const TensorInfo& get_RecurrentToCellWeights() const + { + return deref(m_RecurrentToCellWeights); + } + const TensorInfo& get_RecurrentToOutputWeights() const + { + return deref(m_RecurrentToOutputWeights); + } + const TensorInfo& get_CellToInputWeights() const + { + return deref(m_CellToInputWeights); + } + const TensorInfo& get_CellToForgetWeights() const + { + return deref(m_CellToForgetWeights); + } + const TensorInfo& get_CellToOutputWeights() const + { + return deref(m_CellToOutputWeights); + } + const TensorInfo& get_InputGateBias() const + { + return deref(m_InputGateBias); + } + const TensorInfo& get_ForgetGateBias() const + { + return deref(m_ForgetGateBias); + } + const TensorInfo& get_CellBias() const + { + return deref(m_CellBias); + } + const TensorInfo& get_OutputGateBias() const + { + return deref(m_OutputGateBias); + } + const TensorInfo& get_ProjectionWeights() const + { + return deref(m_ProjectionWeights); + } + const TensorInfo& get_ProjectionBias() const + { + return deref(m_ProjectionBias); + } + const TensorInfo& get_InputLayerNormWeights() const + { + return deref(m_InputLayerNormWeights); + } + const TensorInfo& get_ForgetLayerNormWeights() const + { + return deref(m_ForgetLayerNormWeights); + } + const TensorInfo& get_CellLayerNormWeights() const + { + return deref(m_CellLayerNormWeights); + } + const TensorInfo& get_OutputLayerNormWeights() const + { + return deref(m_OutputLayerNormWeights); + } +}; + } // namespace armnn |