diff options
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.hpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.hpp | 52 |
1 files changed, 52 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.hpp b/src/backends/backendsCommon/WorkloadData.hpp index 85bda5469a..448de6a1ee 100644 --- a/src/backends/backendsCommon/WorkloadData.hpp +++ b/src/backends/backendsCommon/WorkloadData.hpp @@ -519,6 +519,58 @@ struct TransposeQueueDescriptor : QueueDescriptorWithParameters<TransposeDescrip void Validate(const WorkloadInfo& workloadInfo) const; }; +struct QLstmQueueDescriptor : QueueDescriptorWithParameters<QLstmDescriptor> +{ + QLstmQueueDescriptor() + : m_InputToInputWeights(nullptr) + , m_InputToForgetWeights(nullptr) + , m_InputToCellWeights(nullptr) + , m_InputToOutputWeights(nullptr) + , m_RecurrentToInputWeights(nullptr) + , m_RecurrentToForgetWeights(nullptr) + , m_RecurrentToCellWeights(nullptr) + , m_RecurrentToOutputWeights(nullptr) + , m_CellToInputWeights(nullptr) + , m_CellToForgetWeights(nullptr) + , m_CellToOutputWeights(nullptr) + , m_InputGateBias(nullptr) + , m_ForgetGateBias(nullptr) + , m_CellBias(nullptr) + , m_OutputGateBias(nullptr) + , m_ProjectionWeights(nullptr) + , m_ProjectionBias(nullptr) + , m_InputLayerNormWeights(nullptr) + , m_ForgetLayerNormWeights(nullptr) + , m_CellLayerNormWeights(nullptr) + , m_OutputLayerNormWeights(nullptr) + { + } + + const ConstCpuTensorHandle* m_InputToInputWeights; + const ConstCpuTensorHandle* m_InputToForgetWeights; + const ConstCpuTensorHandle* m_InputToCellWeights; + const ConstCpuTensorHandle* m_InputToOutputWeights; + const ConstCpuTensorHandle* m_RecurrentToInputWeights; + const ConstCpuTensorHandle* m_RecurrentToForgetWeights; + const ConstCpuTensorHandle* m_RecurrentToCellWeights; + const ConstCpuTensorHandle* m_RecurrentToOutputWeights; + const ConstCpuTensorHandle* m_CellToInputWeights; + const ConstCpuTensorHandle* m_CellToForgetWeights; + const ConstCpuTensorHandle* m_CellToOutputWeights; + const ConstCpuTensorHandle* m_InputGateBias; + const ConstCpuTensorHandle* m_ForgetGateBias; + const ConstCpuTensorHandle* m_CellBias; + const ConstCpuTensorHandle* m_OutputGateBias; + const ConstCpuTensorHandle* m_ProjectionWeights; + const ConstCpuTensorHandle* m_ProjectionBias; + const ConstCpuTensorHandle* m_InputLayerNormWeights; + const ConstCpuTensorHandle* m_ForgetLayerNormWeights; + const ConstCpuTensorHandle* m_CellLayerNormWeights; + const ConstCpuTensorHandle* m_OutputLayerNormWeights; + + void Validate(const WorkloadInfo& workloadInfo) const; +}; + struct QuantizedLstmQueueDescriptor : QueueDescriptor { QuantizedLstmQueueDescriptor() |