diff options
Diffstat (limited to 'src/armnnDeserializer/Deserializer.cpp')
-rw-r--r-- | src/armnnDeserializer/Deserializer.cpp | 21 |
1 files changed, 21 insertions, 0 deletions
diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp index 7f1831c989..47ed3a65ed 100644 --- a/src/armnnDeserializer/Deserializer.cpp +++ b/src/armnnDeserializer/Deserializer.cpp @@ -2089,6 +2089,7 @@ armnn::LstmDescriptor Deserializer::GetLstmDescriptor(Deserializer::LstmDescript desc.m_CifgEnabled = lstmDescriptor->cifgEnabled(); desc.m_PeepholeEnabled = lstmDescriptor->peepholeEnabled(); desc.m_ProjectionEnabled = lstmDescriptor->projectionEnabled(); + desc.m_LayerNormEnabled = lstmDescriptor->layerNormEnabled(); return desc; } @@ -2171,6 +2172,26 @@ void Deserializer::ParseLstm(GraphPtr graph, unsigned int layerIndex) lstmInputParams.m_CellToOutputWeights = &cellToOutputWeights; } + armnn::ConstTensor inputLayerNormWeights; + armnn::ConstTensor forgetLayerNormWeights; + armnn::ConstTensor cellLayerNormWeights; + armnn::ConstTensor outputLayerNormWeights; + if (lstmDescriptor.m_LayerNormEnabled) + { + if (!lstmDescriptor.m_CifgEnabled) + { + inputLayerNormWeights = ToConstTensor(flatBufferInputParams->inputLayerNormWeights()); + lstmInputParams.m_InputLayerNormWeights = &inputLayerNormWeights; + } + forgetLayerNormWeights = ToConstTensor(flatBufferInputParams->forgetLayerNormWeights()); + cellLayerNormWeights = ToConstTensor(flatBufferInputParams->cellLayerNormWeights()); + outputLayerNormWeights = ToConstTensor(flatBufferInputParams->outputLayerNormWeights()); + + lstmInputParams.m_ForgetLayerNormWeights = &forgetLayerNormWeights; + lstmInputParams.m_CellLayerNormWeights = &cellLayerNormWeights; + lstmInputParams.m_OutputLayerNormWeights = &outputLayerNormWeights; + } + IConnectableLayer* layer = m_Network->AddLstmLayer(lstmDescriptor, lstmInputParams, layerName.c_str()); armnn::TensorInfo outputTensorInfo1 = ToTensorInfo(outputs[0]); |