aboutsummaryrefslogtreecommitdiff
path: root/src/armnnDeserializer/Deserializer.cpp
diff options
context:
space:
mode:
authorJan Eilers <jan.eilers@arm.com>2019-07-17 11:07:49 +0100
committerJan Eilers <jan.eilers@arm.com>2019-07-23 11:01:29 +0100
commitf8c629701e760e2476582c91e6b7f5e1313dc02a (patch)
tree3c9dc1a673fe99367a7f25d98199d8ab386e8bc2 /src/armnnDeserializer/Deserializer.cpp
parent07f2121feeeeae36a7e67eeb8a6965df63b848f3 (diff)
downloadarmnn-f8c629701e760e2476582c91e6b7f5e1313dc02a.tar.gz
IVGCVSW-3526 Add layer norm support for lstm serialization
* Adds layer norm support for serialization/deserialization * Adds related unit tests Change-Id: If80b668accc8b0754a93d18ab3a243284cb383d1 Signed-off-by: Jan Eilers <jan.eilers@arm.com>
Diffstat (limited to 'src/armnnDeserializer/Deserializer.cpp')
-rw-r--r--src/armnnDeserializer/Deserializer.cpp21
1 files changed, 21 insertions, 0 deletions
diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp
index 7f1831c989..47ed3a65ed 100644
--- a/src/armnnDeserializer/Deserializer.cpp
+++ b/src/armnnDeserializer/Deserializer.cpp
@@ -2089,6 +2089,7 @@ armnn::LstmDescriptor Deserializer::GetLstmDescriptor(Deserializer::LstmDescript
desc.m_CifgEnabled = lstmDescriptor->cifgEnabled();
desc.m_PeepholeEnabled = lstmDescriptor->peepholeEnabled();
desc.m_ProjectionEnabled = lstmDescriptor->projectionEnabled();
+ desc.m_LayerNormEnabled = lstmDescriptor->layerNormEnabled();
return desc;
}
@@ -2171,6 +2172,26 @@ void Deserializer::ParseLstm(GraphPtr graph, unsigned int layerIndex)
lstmInputParams.m_CellToOutputWeights = &cellToOutputWeights;
}
+ armnn::ConstTensor inputLayerNormWeights;
+ armnn::ConstTensor forgetLayerNormWeights;
+ armnn::ConstTensor cellLayerNormWeights;
+ armnn::ConstTensor outputLayerNormWeights;
+ if (lstmDescriptor.m_LayerNormEnabled)
+ {
+ if (!lstmDescriptor.m_CifgEnabled)
+ {
+ inputLayerNormWeights = ToConstTensor(flatBufferInputParams->inputLayerNormWeights());
+ lstmInputParams.m_InputLayerNormWeights = &inputLayerNormWeights;
+ }
+ forgetLayerNormWeights = ToConstTensor(flatBufferInputParams->forgetLayerNormWeights());
+ cellLayerNormWeights = ToConstTensor(flatBufferInputParams->cellLayerNormWeights());
+ outputLayerNormWeights = ToConstTensor(flatBufferInputParams->outputLayerNormWeights());
+
+ lstmInputParams.m_ForgetLayerNormWeights = &forgetLayerNormWeights;
+ lstmInputParams.m_CellLayerNormWeights = &cellLayerNormWeights;
+ lstmInputParams.m_OutputLayerNormWeights = &outputLayerNormWeights;
+ }
+
IConnectableLayer* layer = m_Network->AddLstmLayer(lstmDescriptor, lstmInputParams, layerName.c_str());
armnn::TensorInfo outputTensorInfo1 = ToTensorInfo(outputs[0]);