aboutsummaryrefslogtreecommitdiff
path: root/include/armnn/Descriptors.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'include/armnn/Descriptors.hpp')
-rw-r--r--include/armnn/Descriptors.hpp8
1 files changed, 7 insertions, 1 deletions
diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp
index 683ef7ac98..bcee902d75 100644
--- a/include/armnn/Descriptors.hpp
+++ b/include/armnn/Descriptors.hpp
@@ -926,6 +926,7 @@ struct LstmDescriptor : BaseDescriptor
, m_PeepholeEnabled(false)
, m_ProjectionEnabled(false)
, m_LayerNormEnabled(false)
+ , m_TimeMajor(true)
{}
bool operator ==(const LstmDescriptor& rhs) const
@@ -935,7 +936,8 @@ struct LstmDescriptor : BaseDescriptor
m_ClippingThresProj == rhs.m_ClippingThresProj &&
m_CifgEnabled == rhs.m_CifgEnabled &&
m_PeepholeEnabled == rhs.m_PeepholeEnabled &&
- m_LayerNormEnabled == rhs.m_LayerNormEnabled;
+ m_LayerNormEnabled == rhs.m_LayerNormEnabled &&
+ m_TimeMajor == rhs.m_TimeMajor;
}
/// @brief The activation function to use.
@@ -953,8 +955,12 @@ struct LstmDescriptor : BaseDescriptor
bool m_ProjectionEnabled;
/// Enable/disable layer normalization
bool m_LayerNormEnabled;
+ /// Enable/disable time major
+ bool m_TimeMajor;
};
+using UnidirectionalSequenceLstmDescriptor = LstmDescriptor;
+
/// A MeanDescriptor for the MeanLayer.
struct MeanDescriptor : BaseDescriptor
{