From 5880b911bf4b7fd8308c93e299d77ac78f282c19 Mon Sep 17 00:00:00 2001 From: Mike Kelly Date: Fri, 28 Jan 2022 16:18:54 +0000 Subject: MLCE-604 Add Unidirectional Sequence Lstm support to TFLite * Added Unidirectional Sequence Lstm support to TFLite Parser * Added support for float operations with int8 weights to TFLite Parser * Added to Conv2d, Conv3D, DepthwiseConv2D, FullyConnected, TransposeConv and UnidirectionalSequenceLstm * Renamed subgraphIndex to subgraph to fix name-shadowing warning. Signed-off-by: Mike Kelly Change-Id: I818976ab88abc05dcb4bad246fb4108e6e879283 --- include/armnn/Descriptors.hpp | 38 +++++++++++++++++++++++++++++++------- 1 file changed, 31 insertions(+), 7 deletions(-) (limited to 'include/armnn') diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp index 280c18e78c..4c2242e1ad 100644 --- a/include/armnn/Descriptors.hpp +++ b/include/armnn/Descriptors.hpp @@ -1086,17 +1086,29 @@ struct LstmDescriptor : BaseDescriptor , m_ProjectionEnabled(false) , m_LayerNormEnabled(false) , m_TimeMajor(false) + , m_InputIntermediateScale(0.0) + , m_ForgetIntermediateScale(0.0) + , m_CellIntermediateScale(0.0) + , m_OutputIntermediateScale(0.0) + , m_HiddenStateZeroPoint(0) + , m_HiddenStateScale(0.0) {} bool operator ==(const LstmDescriptor& rhs) const { - return m_ActivationFunc == rhs.m_ActivationFunc && - m_ClippingThresCell == rhs.m_ClippingThresCell && - m_ClippingThresProj == rhs.m_ClippingThresProj && - m_CifgEnabled == rhs.m_CifgEnabled && - m_PeepholeEnabled == rhs.m_PeepholeEnabled && - m_LayerNormEnabled == rhs.m_LayerNormEnabled && - m_TimeMajor == rhs.m_TimeMajor; + return m_ActivationFunc == rhs.m_ActivationFunc && + m_ClippingThresCell == rhs.m_ClippingThresCell && + m_ClippingThresProj == rhs.m_ClippingThresProj && + m_CifgEnabled == rhs.m_CifgEnabled && + m_PeepholeEnabled == rhs.m_PeepholeEnabled && + m_LayerNormEnabled == rhs.m_LayerNormEnabled && + m_TimeMajor == rhs.m_TimeMajor && + m_InputIntermediateScale == rhs.m_InputIntermediateScale && + m_ForgetIntermediateScale == rhs.m_ForgetIntermediateScale && + m_CellIntermediateScale == rhs.m_CellIntermediateScale && + m_OutputIntermediateScale == rhs.m_OutputIntermediateScale && + m_HiddenStateZeroPoint == rhs.m_HiddenStateZeroPoint && + m_HiddenStateScale == rhs.m_HiddenStateScale; } /// @brief The activation function to use. @@ -1116,6 +1128,18 @@ struct LstmDescriptor : BaseDescriptor bool m_LayerNormEnabled; /// Enable/disable time major bool m_TimeMajor; + /// Input intermediate quantization scale + float m_InputIntermediateScale; + /// Forget intermediate quantization scale + float m_ForgetIntermediateScale; + /// Cell intermediate quantization scale + float m_CellIntermediateScale; + /// Output intermediate quantization scale + float m_OutputIntermediateScale; + /// Hidden State zero point + int32_t m_HiddenStateZeroPoint; + /// Hidden State quantization scale + float m_HiddenStateScale; }; using UnidirectionalSequenceLstmDescriptor = LstmDescriptor; -- cgit v1.2.1