diff options
author | Mike Kelly <mike.kelly@arm.com> | 2022-01-28 16:18:54 +0000 |
---|---|---|
committer | mike.kelly <mike.kelly@arm.com> | 2022-04-22 15:46:50 +0000 |
commit | 5880b911bf4b7fd8308c93e299d77ac78f282c19 (patch) | |
tree | b256346d6fc78e78735cc50ec822286f809dd37f /include/armnn/Descriptors.hpp | |
parent | 4dae5794644b44be8c93bc6db553a205551bc077 (diff) | |
download | armnn-5880b911bf4b7fd8308c93e299d77ac78f282c19.tar.gz |
MLCE-604 Add Unidirectional Sequence Lstm support to TFLite
* Added Unidirectional Sequence Lstm support to TFLite Parser
* Added support for float operations with int8 weights to TFLite Parser
* Added to Conv2d, Conv3D, DepthwiseConv2D, FullyConnected,
TransposeConv and UnidirectionalSequenceLstm
* Renamed subgraphIndex to subgraph to fix name-shadowing warning.
Signed-off-by: Mike Kelly <mike.kelly@arm.com>
Change-Id: I818976ab88abc05dcb4bad246fb4108e6e879283
Diffstat (limited to 'include/armnn/Descriptors.hpp')
-rw-r--r-- | include/armnn/Descriptors.hpp | 38 |
1 files changed, 31 insertions, 7 deletions
diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp index 280c18e78c..4c2242e1ad 100644 --- a/include/armnn/Descriptors.hpp +++ b/include/armnn/Descriptors.hpp @@ -1086,17 +1086,29 @@ struct LstmDescriptor : BaseDescriptor , m_ProjectionEnabled(false) , m_LayerNormEnabled(false) , m_TimeMajor(false) + , m_InputIntermediateScale(0.0) + , m_ForgetIntermediateScale(0.0) + , m_CellIntermediateScale(0.0) + , m_OutputIntermediateScale(0.0) + , m_HiddenStateZeroPoint(0) + , m_HiddenStateScale(0.0) {} bool operator ==(const LstmDescriptor& rhs) const { - return m_ActivationFunc == rhs.m_ActivationFunc && - m_ClippingThresCell == rhs.m_ClippingThresCell && - m_ClippingThresProj == rhs.m_ClippingThresProj && - m_CifgEnabled == rhs.m_CifgEnabled && - m_PeepholeEnabled == rhs.m_PeepholeEnabled && - m_LayerNormEnabled == rhs.m_LayerNormEnabled && - m_TimeMajor == rhs.m_TimeMajor; + return m_ActivationFunc == rhs.m_ActivationFunc && + m_ClippingThresCell == rhs.m_ClippingThresCell && + m_ClippingThresProj == rhs.m_ClippingThresProj && + m_CifgEnabled == rhs.m_CifgEnabled && + m_PeepholeEnabled == rhs.m_PeepholeEnabled && + m_LayerNormEnabled == rhs.m_LayerNormEnabled && + m_TimeMajor == rhs.m_TimeMajor && + m_InputIntermediateScale == rhs.m_InputIntermediateScale && + m_ForgetIntermediateScale == rhs.m_ForgetIntermediateScale && + m_CellIntermediateScale == rhs.m_CellIntermediateScale && + m_OutputIntermediateScale == rhs.m_OutputIntermediateScale && + m_HiddenStateZeroPoint == rhs.m_HiddenStateZeroPoint && + m_HiddenStateScale == rhs.m_HiddenStateScale; } /// @brief The activation function to use. @@ -1116,6 +1128,18 @@ struct LstmDescriptor : BaseDescriptor bool m_LayerNormEnabled; /// Enable/disable time major bool m_TimeMajor; + /// Input intermediate quantization scale + float m_InputIntermediateScale; + /// Forget intermediate quantization scale + float m_ForgetIntermediateScale; + /// Cell intermediate quantization scale + float m_CellIntermediateScale; + /// Output intermediate quantization scale + float m_OutputIntermediateScale; + /// Hidden State zero point + int32_t m_HiddenStateZeroPoint; + /// Hidden State quantization scale + float m_HiddenStateScale; }; using UnidirectionalSequenceLstmDescriptor = LstmDescriptor; |