diff options
author | Jan Eilers <jan.eilers@arm.com> | 2019-07-17 11:07:49 +0100 |
---|---|---|
committer | Jan Eilers <jan.eilers@arm.com> | 2019-07-23 11:01:29 +0100 |
commit | f8c629701e760e2476582c91e6b7f5e1313dc02a (patch) | |
tree | 3c9dc1a673fe99367a7f25d98199d8ab386e8bc2 /src/armnnSerializer/Serializer.cpp | |
parent | 07f2121feeeeae36a7e67eeb8a6965df63b848f3 (diff) | |
download | armnn-f8c629701e760e2476582c91e6b7f5e1313dc02a.tar.gz |
IVGCVSW-3526 Add layer norm support for lstm serialization
* Adds layer norm support for serialization/deserialization
* Adds related unit tests
Change-Id: If80b668accc8b0754a93d18ab3a243284cb383d1
Signed-off-by: Jan Eilers <jan.eilers@arm.com>
Diffstat (limited to 'src/armnnSerializer/Serializer.cpp')
-rw-r--r-- | src/armnnSerializer/Serializer.cpp | 24 |
1 files changed, 22 insertions, 2 deletions
diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp index b59bac6041..05df2c942a 100644 --- a/src/armnnSerializer/Serializer.cpp +++ b/src/armnnSerializer/Serializer.cpp @@ -402,7 +402,8 @@ void SerializerVisitor::VisitLstmLayer(const armnn::IConnectableLayer* layer, co descriptor.m_ClippingThresProj, descriptor.m_CifgEnabled, descriptor.m_PeepholeEnabled, - descriptor.m_ProjectionEnabled); + descriptor.m_ProjectionEnabled, + descriptor.m_LayerNormEnabled); // Get mandatory input parameters auto inputToForgetWeights = CreateConstTensorInfo(*params.m_InputToForgetWeights); @@ -424,6 +425,10 @@ void SerializerVisitor::VisitLstmLayer(const armnn::IConnectableLayer* layer, co flatbuffers::Offset<serializer::ConstTensor> projectionBias; flatbuffers::Offset<serializer::ConstTensor> cellToForgetWeights; flatbuffers::Offset<serializer::ConstTensor> cellToOutputWeights; + flatbuffers::Offset<serializer::ConstTensor> inputLayerNormWeights; + flatbuffers::Offset<serializer::ConstTensor> forgetLayerNormWeights; + flatbuffers::Offset<serializer::ConstTensor> cellLayerNormWeights; + flatbuffers::Offset<serializer::ConstTensor> outputLayerNormWeights; if (!descriptor.m_CifgEnabled) { @@ -445,6 +450,17 @@ void SerializerVisitor::VisitLstmLayer(const armnn::IConnectableLayer* layer, co cellToOutputWeights = CreateConstTensorInfo(*params.m_CellToOutputWeights); } + if (descriptor.m_LayerNormEnabled) + { + if (!descriptor.m_CifgEnabled) + { + inputLayerNormWeights = CreateConstTensorInfo((*params.m_InputLayerNormWeights)); + } + forgetLayerNormWeights = CreateConstTensorInfo(*params.m_ForgetLayerNormWeights); + cellLayerNormWeights = CreateConstTensorInfo(*params.m_CellLayerNormWeights); + outputLayerNormWeights = CreateConstTensorInfo(*params.m_OutputLayerNormWeights); + } + auto fbLstmParams = serializer::CreateLstmInputParams( m_flatBufferBuilder, inputToForgetWeights, @@ -463,7 +479,11 @@ void SerializerVisitor::VisitLstmLayer(const armnn::IConnectableLayer* layer, co projectionWeights, projectionBias, cellToForgetWeights, - cellToOutputWeights); + cellToOutputWeights, + inputLayerNormWeights, + forgetLayerNormWeights, + cellLayerNormWeights, + outputLayerNormWeights); auto fbLstmLayer = serializer::CreateLstmLayer( m_flatBufferBuilder, |