aboutsummaryrefslogtreecommitdiff
path: root/src/armnnSerializer/Serializer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnSerializer/Serializer.cpp')
-rw-r--r--src/armnnSerializer/Serializer.cpp24
1 files changed, 22 insertions, 2 deletions
diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp
index b59bac6041..05df2c942a 100644
--- a/src/armnnSerializer/Serializer.cpp
+++ b/src/armnnSerializer/Serializer.cpp
@@ -402,7 +402,8 @@ void SerializerVisitor::VisitLstmLayer(const armnn::IConnectableLayer* layer, co
descriptor.m_ClippingThresProj,
descriptor.m_CifgEnabled,
descriptor.m_PeepholeEnabled,
- descriptor.m_ProjectionEnabled);
+ descriptor.m_ProjectionEnabled,
+ descriptor.m_LayerNormEnabled);
// Get mandatory input parameters
auto inputToForgetWeights = CreateConstTensorInfo(*params.m_InputToForgetWeights);
@@ -424,6 +425,10 @@ void SerializerVisitor::VisitLstmLayer(const armnn::IConnectableLayer* layer, co
flatbuffers::Offset<serializer::ConstTensor> projectionBias;
flatbuffers::Offset<serializer::ConstTensor> cellToForgetWeights;
flatbuffers::Offset<serializer::ConstTensor> cellToOutputWeights;
+ flatbuffers::Offset<serializer::ConstTensor> inputLayerNormWeights;
+ flatbuffers::Offset<serializer::ConstTensor> forgetLayerNormWeights;
+ flatbuffers::Offset<serializer::ConstTensor> cellLayerNormWeights;
+ flatbuffers::Offset<serializer::ConstTensor> outputLayerNormWeights;
if (!descriptor.m_CifgEnabled)
{
@@ -445,6 +450,17 @@ void SerializerVisitor::VisitLstmLayer(const armnn::IConnectableLayer* layer, co
cellToOutputWeights = CreateConstTensorInfo(*params.m_CellToOutputWeights);
}
+ if (descriptor.m_LayerNormEnabled)
+ {
+ if (!descriptor.m_CifgEnabled)
+ {
+ inputLayerNormWeights = CreateConstTensorInfo((*params.m_InputLayerNormWeights));
+ }
+ forgetLayerNormWeights = CreateConstTensorInfo(*params.m_ForgetLayerNormWeights);
+ cellLayerNormWeights = CreateConstTensorInfo(*params.m_CellLayerNormWeights);
+ outputLayerNormWeights = CreateConstTensorInfo(*params.m_OutputLayerNormWeights);
+ }
+
auto fbLstmParams = serializer::CreateLstmInputParams(
m_flatBufferBuilder,
inputToForgetWeights,
@@ -463,7 +479,11 @@ void SerializerVisitor::VisitLstmLayer(const armnn::IConnectableLayer* layer, co
projectionWeights,
projectionBias,
cellToForgetWeights,
- cellToOutputWeights);
+ cellToOutputWeights,
+ inputLayerNormWeights,
+ forgetLayerNormWeights,
+ cellLayerNormWeights,
+ outputLayerNormWeights);
auto fbLstmLayer = serializer::CreateLstmLayer(
m_flatBufferBuilder,