aboutsummaryrefslogtreecommitdiff
path: root/src/armnnSerializer/Serializer.cpp
diff options
context:
space:
mode:
authorJan Eilers <jan.eilers@arm.com>2019-07-17 11:07:49 +0100
committerJan Eilers <jan.eilers@arm.com>2019-07-23 11:01:29 +0100
commitf8c629701e760e2476582c91e6b7f5e1313dc02a (patch)
tree3c9dc1a673fe99367a7f25d98199d8ab386e8bc2 /src/armnnSerializer/Serializer.cpp
parent07f2121feeeeae36a7e67eeb8a6965df63b848f3 (diff)
downloadarmnn-f8c629701e760e2476582c91e6b7f5e1313dc02a.tar.gz
IVGCVSW-3526 Add layer norm support for lstm serialization
* Adds layer norm support for serialization/deserialization * Adds related unit tests Change-Id: If80b668accc8b0754a93d18ab3a243284cb383d1 Signed-off-by: Jan Eilers <jan.eilers@arm.com>
Diffstat (limited to 'src/armnnSerializer/Serializer.cpp')
-rw-r--r--src/armnnSerializer/Serializer.cpp24
1 files changed, 22 insertions, 2 deletions
diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp
index b59bac6041..05df2c942a 100644
--- a/src/armnnSerializer/Serializer.cpp
+++ b/src/armnnSerializer/Serializer.cpp
@@ -402,7 +402,8 @@ void SerializerVisitor::VisitLstmLayer(const armnn::IConnectableLayer* layer, co
descriptor.m_ClippingThresProj,
descriptor.m_CifgEnabled,
descriptor.m_PeepholeEnabled,
- descriptor.m_ProjectionEnabled);
+ descriptor.m_ProjectionEnabled,
+ descriptor.m_LayerNormEnabled);
// Get mandatory input parameters
auto inputToForgetWeights = CreateConstTensorInfo(*params.m_InputToForgetWeights);
@@ -424,6 +425,10 @@ void SerializerVisitor::VisitLstmLayer(const armnn::IConnectableLayer* layer, co
flatbuffers::Offset<serializer::ConstTensor> projectionBias;
flatbuffers::Offset<serializer::ConstTensor> cellToForgetWeights;
flatbuffers::Offset<serializer::ConstTensor> cellToOutputWeights;
+ flatbuffers::Offset<serializer::ConstTensor> inputLayerNormWeights;
+ flatbuffers::Offset<serializer::ConstTensor> forgetLayerNormWeights;
+ flatbuffers::Offset<serializer::ConstTensor> cellLayerNormWeights;
+ flatbuffers::Offset<serializer::ConstTensor> outputLayerNormWeights;
if (!descriptor.m_CifgEnabled)
{
@@ -445,6 +450,17 @@ void SerializerVisitor::VisitLstmLayer(const armnn::IConnectableLayer* layer, co
cellToOutputWeights = CreateConstTensorInfo(*params.m_CellToOutputWeights);
}
+ if (descriptor.m_LayerNormEnabled)
+ {
+ if (!descriptor.m_CifgEnabled)
+ {
+ inputLayerNormWeights = CreateConstTensorInfo((*params.m_InputLayerNormWeights));
+ }
+ forgetLayerNormWeights = CreateConstTensorInfo(*params.m_ForgetLayerNormWeights);
+ cellLayerNormWeights = CreateConstTensorInfo(*params.m_CellLayerNormWeights);
+ outputLayerNormWeights = CreateConstTensorInfo(*params.m_OutputLayerNormWeights);
+ }
+
auto fbLstmParams = serializer::CreateLstmInputParams(
m_flatBufferBuilder,
inputToForgetWeights,
@@ -463,7 +479,11 @@ void SerializerVisitor::VisitLstmLayer(const armnn::IConnectableLayer* layer, co
projectionWeights,
projectionBias,
cellToForgetWeights,
- cellToOutputWeights);
+ cellToOutputWeights,
+ inputLayerNormWeights,
+ forgetLayerNormWeights,
+ cellLayerNormWeights,
+ outputLayerNormWeights);
auto fbLstmLayer = serializer::CreateLstmLayer(
m_flatBufferBuilder,