aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/WorkloadData.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.hpp')
-rw-r--r--src/backends/backendsCommon/WorkloadData.hpp52
1 files changed, 52 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.hpp b/src/backends/backendsCommon/WorkloadData.hpp
index 36653bdc0d..78da00be5d 100644
--- a/src/backends/backendsCommon/WorkloadData.hpp
+++ b/src/backends/backendsCommon/WorkloadData.hpp
@@ -695,4 +695,56 @@ struct ShapeQueueDescriptor : QueueDescriptor
void Validate(const WorkloadInfo& workloadInfo) const;
};
+struct UnidirectionalSequenceLstmQueueDescriptor : QueueDescriptorWithParameters<LstmDescriptor>
+{
+ UnidirectionalSequenceLstmQueueDescriptor()
+ : m_InputToInputWeights(nullptr)
+ , m_InputToForgetWeights(nullptr)
+ , m_InputToCellWeights(nullptr)
+ , m_InputToOutputWeights(nullptr)
+ , m_RecurrentToInputWeights(nullptr)
+ , m_RecurrentToForgetWeights(nullptr)
+ , m_RecurrentToCellWeights(nullptr)
+ , m_RecurrentToOutputWeights(nullptr)
+ , m_CellToInputWeights(nullptr)
+ , m_CellToForgetWeights(nullptr)
+ , m_CellToOutputWeights(nullptr)
+ , m_InputGateBias(nullptr)
+ , m_ForgetGateBias(nullptr)
+ , m_CellBias(nullptr)
+ , m_OutputGateBias(nullptr)
+ , m_ProjectionWeights(nullptr)
+ , m_ProjectionBias(nullptr)
+ , m_InputLayerNormWeights(nullptr)
+ , m_ForgetLayerNormWeights(nullptr)
+ , m_CellLayerNormWeights(nullptr)
+ , m_OutputLayerNormWeights(nullptr)
+ {
+ }
+
+ const ConstTensorHandle* m_InputToInputWeights;
+ const ConstTensorHandle* m_InputToForgetWeights;
+ const ConstTensorHandle* m_InputToCellWeights;
+ const ConstTensorHandle* m_InputToOutputWeights;
+ const ConstTensorHandle* m_RecurrentToInputWeights;
+ const ConstTensorHandle* m_RecurrentToForgetWeights;
+ const ConstTensorHandle* m_RecurrentToCellWeights;
+ const ConstTensorHandle* m_RecurrentToOutputWeights;
+ const ConstTensorHandle* m_CellToInputWeights;
+ const ConstTensorHandle* m_CellToForgetWeights;
+ const ConstTensorHandle* m_CellToOutputWeights;
+ const ConstTensorHandle* m_InputGateBias;
+ const ConstTensorHandle* m_ForgetGateBias;
+ const ConstTensorHandle* m_CellBias;
+ const ConstTensorHandle* m_OutputGateBias;
+ const ConstTensorHandle* m_ProjectionWeights;
+ const ConstTensorHandle* m_ProjectionBias;
+ const ConstTensorHandle* m_InputLayerNormWeights;
+ const ConstTensorHandle* m_ForgetLayerNormWeights;
+ const ConstTensorHandle* m_CellLayerNormWeights;
+ const ConstTensorHandle* m_OutputLayerNormWeights;
+
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
} // namespace armnn