diff options
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.hpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.hpp | 52 |
1 files changed, 52 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.hpp b/src/backends/backendsCommon/WorkloadData.hpp index 36653bdc0d..78da00be5d 100644 --- a/src/backends/backendsCommon/WorkloadData.hpp +++ b/src/backends/backendsCommon/WorkloadData.hpp @@ -695,4 +695,56 @@ struct ShapeQueueDescriptor : QueueDescriptor void Validate(const WorkloadInfo& workloadInfo) const; }; +struct UnidirectionalSequenceLstmQueueDescriptor : QueueDescriptorWithParameters<LstmDescriptor> +{ + UnidirectionalSequenceLstmQueueDescriptor() + : m_InputToInputWeights(nullptr) + , m_InputToForgetWeights(nullptr) + , m_InputToCellWeights(nullptr) + , m_InputToOutputWeights(nullptr) + , m_RecurrentToInputWeights(nullptr) + , m_RecurrentToForgetWeights(nullptr) + , m_RecurrentToCellWeights(nullptr) + , m_RecurrentToOutputWeights(nullptr) + , m_CellToInputWeights(nullptr) + , m_CellToForgetWeights(nullptr) + , m_CellToOutputWeights(nullptr) + , m_InputGateBias(nullptr) + , m_ForgetGateBias(nullptr) + , m_CellBias(nullptr) + , m_OutputGateBias(nullptr) + , m_ProjectionWeights(nullptr) + , m_ProjectionBias(nullptr) + , m_InputLayerNormWeights(nullptr) + , m_ForgetLayerNormWeights(nullptr) + , m_CellLayerNormWeights(nullptr) + , m_OutputLayerNormWeights(nullptr) + { + } + + const ConstTensorHandle* m_InputToInputWeights; + const ConstTensorHandle* m_InputToForgetWeights; + const ConstTensorHandle* m_InputToCellWeights; + const ConstTensorHandle* m_InputToOutputWeights; + const ConstTensorHandle* m_RecurrentToInputWeights; + const ConstTensorHandle* m_RecurrentToForgetWeights; + const ConstTensorHandle* m_RecurrentToCellWeights; + const ConstTensorHandle* m_RecurrentToOutputWeights; + const ConstTensorHandle* m_CellToInputWeights; + const ConstTensorHandle* m_CellToForgetWeights; + const ConstTensorHandle* m_CellToOutputWeights; + const ConstTensorHandle* m_InputGateBias; + const ConstTensorHandle* m_ForgetGateBias; + const ConstTensorHandle* m_CellBias; + const ConstTensorHandle* m_OutputGateBias; + const ConstTensorHandle* m_ProjectionWeights; + const ConstTensorHandle* m_ProjectionBias; + const ConstTensorHandle* m_InputLayerNormWeights; + const ConstTensorHandle* m_ForgetLayerNormWeights; + const ConstTensorHandle* m_CellLayerNormWeights; + const ConstTensorHandle* m_OutputLayerNormWeights; + + void Validate(const WorkloadInfo& workloadInfo) const; +}; + } // namespace armnn |