diff options
Diffstat (limited to 'include/armnn/Descriptors.hpp')
-rw-r--r-- | include/armnn/Descriptors.hpp | 8 |
1 files changed, 7 insertions, 1 deletions
diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp index 683ef7ac98..bcee902d75 100644 --- a/include/armnn/Descriptors.hpp +++ b/include/armnn/Descriptors.hpp @@ -926,6 +926,7 @@ struct LstmDescriptor : BaseDescriptor , m_PeepholeEnabled(false) , m_ProjectionEnabled(false) , m_LayerNormEnabled(false) + , m_TimeMajor(true) {} bool operator ==(const LstmDescriptor& rhs) const @@ -935,7 +936,8 @@ struct LstmDescriptor : BaseDescriptor m_ClippingThresProj == rhs.m_ClippingThresProj && m_CifgEnabled == rhs.m_CifgEnabled && m_PeepholeEnabled == rhs.m_PeepholeEnabled && - m_LayerNormEnabled == rhs.m_LayerNormEnabled; + m_LayerNormEnabled == rhs.m_LayerNormEnabled && + m_TimeMajor == rhs.m_TimeMajor; } /// @brief The activation function to use. @@ -953,8 +955,12 @@ struct LstmDescriptor : BaseDescriptor bool m_ProjectionEnabled; /// Enable/disable layer normalization bool m_LayerNormEnabled; + /// Enable/disable time major + bool m_TimeMajor; }; +using UnidirectionalSequenceLstmDescriptor = LstmDescriptor; + /// A MeanDescriptor for the MeanLayer. struct MeanDescriptor : BaseDescriptor { |