aboutsummaryrefslogtreecommitdiff
path: root/include/armnn/Descriptors.hpp
diff options
context:
space:
mode:
authorNarumol Prangnawarat <narumol.prangnawarat@arm.com>2021-07-15 16:16:25 +0100
committerNarumol Prangnawarat <narumol.prangnawarat@arm.com>2021-07-22 18:29:55 +0100
commit8ed39ae450a077c7e4d672b5f05ff1d68ee67aab (patch)
tree31a1cf006e50db54f3e7a605825c8e9e3f9d689e /include/armnn/Descriptors.hpp
parent15fcc7ed3163c9d4b1856955271854198c3c2696 (diff)
downloadarmnn-8ed39ae450a077c7e4d672b5f05ff1d68ee67aab.tar.gz
MLCE-530 Add front end support for UnidirectionalSequenceLstm on ArmNN
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com> Change-Id: I57bcbdec3eb0155f41af0fe7d6abf9bac2ec86eb
Diffstat (limited to 'include/armnn/Descriptors.hpp')
-rw-r--r--include/armnn/Descriptors.hpp8
1 files changed, 7 insertions, 1 deletions
diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp
index 683ef7ac98..bcee902d75 100644
--- a/include/armnn/Descriptors.hpp
+++ b/include/armnn/Descriptors.hpp
@@ -926,6 +926,7 @@ struct LstmDescriptor : BaseDescriptor
, m_PeepholeEnabled(false)
, m_ProjectionEnabled(false)
, m_LayerNormEnabled(false)
+ , m_TimeMajor(true)
{}
bool operator ==(const LstmDescriptor& rhs) const
@@ -935,7 +936,8 @@ struct LstmDescriptor : BaseDescriptor
m_ClippingThresProj == rhs.m_ClippingThresProj &&
m_CifgEnabled == rhs.m_CifgEnabled &&
m_PeepholeEnabled == rhs.m_PeepholeEnabled &&
- m_LayerNormEnabled == rhs.m_LayerNormEnabled;
+ m_LayerNormEnabled == rhs.m_LayerNormEnabled &&
+ m_TimeMajor == rhs.m_TimeMajor;
}
/// @brief The activation function to use.
@@ -953,8 +955,12 @@ struct LstmDescriptor : BaseDescriptor
bool m_ProjectionEnabled;
/// Enable/disable layer normalization
bool m_LayerNormEnabled;
+ /// Enable/disable time major
+ bool m_TimeMajor;
};
+using UnidirectionalSequenceLstmDescriptor = LstmDescriptor;
+
/// A MeanDescriptor for the MeanLayer.
struct MeanDescriptor : BaseDescriptor
{