diff options
author | James Conroy <james.conroy@arm.com> | 2020-03-20 08:49:33 +0000 |
---|---|---|
committer | James Conroy <james.conroy@arm.com> | 2020-03-20 14:53:44 +0000 |
commit | 586a9aac99312eb9cb304cbbd18cec46b9158e23 (patch) | |
tree | 6d620eae6dcfb920ac04eae43424548dc602a1eb /src/backends/backendsCommon/WorkloadData.hpp | |
parent | c94d3f7107b84b586791aa096f8641e6efa18c90 (diff) | |
download | armnn-586a9aac99312eb9cb304cbbd18cec46b9158e23.tar.gz |
IVGCVSW-4549 Add front end for new QLSTM layer
* Added new layer QLstm (Android R HAL 1.3)
* Made necessary updates to APIs
* Added unit tests
* This layer is functionally equivalent to the
original unquantized LSTM layer with some
additonal quantization features added. Due
to this, original LstmParams are used for
this layer.
Signed-off-by: James Conroy <james.conroy@arm.com>
Change-Id: I5b7f2d2fb6e17e81573b41a31bc55f49ae79608f
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.hpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.hpp | 52 |
1 files changed, 52 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.hpp b/src/backends/backendsCommon/WorkloadData.hpp index 85bda5469a..448de6a1ee 100644 --- a/src/backends/backendsCommon/WorkloadData.hpp +++ b/src/backends/backendsCommon/WorkloadData.hpp @@ -519,6 +519,58 @@ struct TransposeQueueDescriptor : QueueDescriptorWithParameters<TransposeDescrip void Validate(const WorkloadInfo& workloadInfo) const; }; +struct QLstmQueueDescriptor : QueueDescriptorWithParameters<QLstmDescriptor> +{ + QLstmQueueDescriptor() + : m_InputToInputWeights(nullptr) + , m_InputToForgetWeights(nullptr) + , m_InputToCellWeights(nullptr) + , m_InputToOutputWeights(nullptr) + , m_RecurrentToInputWeights(nullptr) + , m_RecurrentToForgetWeights(nullptr) + , m_RecurrentToCellWeights(nullptr) + , m_RecurrentToOutputWeights(nullptr) + , m_CellToInputWeights(nullptr) + , m_CellToForgetWeights(nullptr) + , m_CellToOutputWeights(nullptr) + , m_InputGateBias(nullptr) + , m_ForgetGateBias(nullptr) + , m_CellBias(nullptr) + , m_OutputGateBias(nullptr) + , m_ProjectionWeights(nullptr) + , m_ProjectionBias(nullptr) + , m_InputLayerNormWeights(nullptr) + , m_ForgetLayerNormWeights(nullptr) + , m_CellLayerNormWeights(nullptr) + , m_OutputLayerNormWeights(nullptr) + { + } + + const ConstCpuTensorHandle* m_InputToInputWeights; + const ConstCpuTensorHandle* m_InputToForgetWeights; + const ConstCpuTensorHandle* m_InputToCellWeights; + const ConstCpuTensorHandle* m_InputToOutputWeights; + const ConstCpuTensorHandle* m_RecurrentToInputWeights; + const ConstCpuTensorHandle* m_RecurrentToForgetWeights; + const ConstCpuTensorHandle* m_RecurrentToCellWeights; + const ConstCpuTensorHandle* m_RecurrentToOutputWeights; + const ConstCpuTensorHandle* m_CellToInputWeights; + const ConstCpuTensorHandle* m_CellToForgetWeights; + const ConstCpuTensorHandle* m_CellToOutputWeights; + const ConstCpuTensorHandle* m_InputGateBias; + const ConstCpuTensorHandle* m_ForgetGateBias; + const ConstCpuTensorHandle* m_CellBias; + const ConstCpuTensorHandle* m_OutputGateBias; + const ConstCpuTensorHandle* m_ProjectionWeights; + const ConstCpuTensorHandle* m_ProjectionBias; + const ConstCpuTensorHandle* m_InputLayerNormWeights; + const ConstCpuTensorHandle* m_ForgetLayerNormWeights; + const ConstCpuTensorHandle* m_CellLayerNormWeights; + const ConstCpuTensorHandle* m_OutputLayerNormWeights; + + void Validate(const WorkloadInfo& workloadInfo) const; +}; + struct QuantizedLstmQueueDescriptor : QueueDescriptor { QuantizedLstmQueueDescriptor() |