diff options
author | Jim Flynn <jim.flynn@arm.com> | 2019-03-19 17:22:29 +0000 |
---|---|---|
committer | Jim Flynn <jim.flynn@arm.com> | 2019-03-21 16:09:19 +0000 |
commit | 11af375a5a6bf88b4f3b933a86d53000b0d91ed0 (patch) | |
tree | f4f4db5192b275be44d96d96c7f3c8c10f15b3f1 /src/armnnDeserializer/Deserializer.cpp | |
parent | db059fd50f9afb398b8b12cd4592323fc8f60d7f (diff) | |
download | armnn-11af375a5a6bf88b4f3b933a86d53000b0d91ed0.tar.gz |
IVGCVSW-2694: serialize/deserialize LSTM
* added serialize/deserialize methods for LSTM and tests
Change-Id: Ic59557f03001c496008c4bef92c2e0406e1fbc6c
Signed-off-by: Nina Drozd <nina.drozd@arm.com>
Signed-off-by: Jim Flynn <jim.flynn@arm.com>
Diffstat (limited to 'src/armnnDeserializer/Deserializer.cpp')
-rw-r--r-- | src/armnnDeserializer/Deserializer.cpp | 113 |
1 files changed, 113 insertions, 0 deletions
diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp index 152a5b4c93..d64bed7409 100644 --- a/src/armnnDeserializer/Deserializer.cpp +++ b/src/armnnDeserializer/Deserializer.cpp @@ -201,6 +201,7 @@ m_ParserFunctions(Layer_MAX+1, &Deserializer::ParseUnsupportedLayer) m_ParserFunctions[Layer_GatherLayer] = &Deserializer::ParseGather; m_ParserFunctions[Layer_GreaterLayer] = &Deserializer::ParseGreater; m_ParserFunctions[Layer_L2NormalizationLayer] = &Deserializer::ParseL2Normalization; + m_ParserFunctions[Layer_LstmLayer] = &Deserializer::ParseLstm; m_ParserFunctions[Layer_MaximumLayer] = &Deserializer::ParseMaximum; m_ParserFunctions[Layer_MeanLayer] = &Deserializer::ParseMean; m_ParserFunctions[Layer_MinimumLayer] = &Deserializer::ParseMinimum; @@ -258,6 +259,8 @@ Deserializer::LayerBaseRawPtr Deserializer::GetBaseLayer(const GraphPtr& graphPt return graphPtr->layers()->Get(layerIndex)->layer_as_InputLayer()->base()->base(); case Layer::Layer_L2NormalizationLayer: return graphPtr->layers()->Get(layerIndex)->layer_as_L2NormalizationLayer()->base(); + case Layer::Layer_LstmLayer: + return graphPtr->layers()->Get(layerIndex)->layer_as_LstmLayer()->base(); case Layer::Layer_MeanLayer: return graphPtr->layers()->Get(layerIndex)->layer_as_MeanLayer()->base(); case Layer::Layer_MinimumLayer: @@ -1927,4 +1930,114 @@ void Deserializer::ParseSplitter(GraphPtr graph, unsigned int layerIndex) RegisterOutputSlots(graph, layerIndex, layer); } +armnn::LstmDescriptor Deserializer::GetLstmDescriptor(Deserializer::LstmDescriptorPtr lstmDescriptor) +{ + armnn::LstmDescriptor desc; + + desc.m_ActivationFunc = lstmDescriptor->activationFunc(); + desc.m_ClippingThresCell = lstmDescriptor->clippingThresCell(); + desc.m_ClippingThresProj = lstmDescriptor->clippingThresProj(); + desc.m_CifgEnabled = lstmDescriptor->cifgEnabled(); + desc.m_PeepholeEnabled = lstmDescriptor->peepholeEnabled(); + desc.m_ProjectionEnabled = lstmDescriptor->projectionEnabled(); + + return desc; +} + +void Deserializer::ParseLstm(GraphPtr graph, unsigned int layerIndex) +{ + CHECK_LAYERS(graph, 0, layerIndex); + + auto inputs = GetInputs(graph, layerIndex); + CHECK_VALID_SIZE(inputs.size(), 3); + + auto outputs = GetOutputs(graph, layerIndex); + CHECK_VALID_SIZE(outputs.size(), 4); + + auto flatBufferLayer = graph->layers()->Get(layerIndex)->layer_as_LstmLayer(); + auto layerName = GetLayerName(graph, layerIndex); + auto flatBufferDescriptor = flatBufferLayer->descriptor(); + auto flatBufferInputParams = flatBufferLayer->inputParams(); + + auto lstmDescriptor = GetLstmDescriptor(flatBufferDescriptor); + + armnn::LstmInputParams lstmInputParams; + + armnn::ConstTensor inputToForgetWeights = ToConstTensor(flatBufferInputParams->inputToForgetWeights()); + armnn::ConstTensor inputToCellWeights = ToConstTensor(flatBufferInputParams->inputToCellWeights()); + armnn::ConstTensor inputToOutputWeights = ToConstTensor(flatBufferInputParams->inputToOutputWeights()); + armnn::ConstTensor recurrentToForgetWeights = ToConstTensor(flatBufferInputParams->recurrentToForgetWeights()); + armnn::ConstTensor recurrentToCellWeights = ToConstTensor(flatBufferInputParams->recurrentToCellWeights()); + armnn::ConstTensor recurrentToOutputWeights = ToConstTensor(flatBufferInputParams->recurrentToOutputWeights()); + armnn::ConstTensor forgetGateBias = ToConstTensor(flatBufferInputParams->forgetGateBias()); + armnn::ConstTensor cellBias = ToConstTensor(flatBufferInputParams->cellBias()); + armnn::ConstTensor outputGateBias = ToConstTensor(flatBufferInputParams->outputGateBias()); + + lstmInputParams.m_InputToForgetWeights = &inputToForgetWeights; + lstmInputParams.m_InputToCellWeights = &inputToCellWeights; + lstmInputParams.m_InputToOutputWeights = &inputToOutputWeights; + lstmInputParams.m_RecurrentToForgetWeights = &recurrentToForgetWeights; + lstmInputParams.m_RecurrentToCellWeights = &recurrentToCellWeights; + lstmInputParams.m_RecurrentToOutputWeights = &recurrentToOutputWeights; + lstmInputParams.m_ForgetGateBias = &forgetGateBias; + lstmInputParams.m_CellBias = &cellBias; + lstmInputParams.m_OutputGateBias = &outputGateBias; + + armnn::ConstTensor inputToInputWeights; + armnn::ConstTensor recurrentToInputWeights; + armnn::ConstTensor cellToInputWeights; + armnn::ConstTensor inputGateBias; + if (!lstmDescriptor.m_CifgEnabled) + { + inputToInputWeights = ToConstTensor(flatBufferInputParams->inputToInputWeights()); + recurrentToInputWeights = ToConstTensor(flatBufferInputParams->recurrentToInputWeights()); + cellToInputWeights = ToConstTensor(flatBufferInputParams->cellToInputWeights()); + inputGateBias = ToConstTensor(flatBufferInputParams->inputGateBias()); + + lstmInputParams.m_InputToInputWeights = &inputToInputWeights; + lstmInputParams.m_RecurrentToInputWeights = &recurrentToInputWeights; + lstmInputParams.m_CellToInputWeights = &cellToInputWeights; + lstmInputParams.m_InputGateBias = &inputGateBias; + } + + armnn::ConstTensor projectionWeights; + armnn::ConstTensor projectionBias; + if (lstmDescriptor.m_ProjectionEnabled) + { + projectionWeights = ToConstTensor(flatBufferInputParams->projectionWeights()); + projectionBias = ToConstTensor(flatBufferInputParams->projectionBias()); + + lstmInputParams.m_ProjectionWeights = &projectionWeights; + lstmInputParams.m_ProjectionBias = &projectionBias; + } + + armnn::ConstTensor cellToForgetWeights; + armnn::ConstTensor cellToOutputWeights; + if (lstmDescriptor.m_PeepholeEnabled) + { + cellToForgetWeights = ToConstTensor(flatBufferInputParams->cellToForgetWeights()); + cellToOutputWeights = ToConstTensor(flatBufferInputParams->cellToOutputWeights()); + + lstmInputParams.m_CellToForgetWeights = &cellToForgetWeights; + lstmInputParams.m_CellToOutputWeights = &cellToOutputWeights; + } + + IConnectableLayer* layer = m_Network->AddLstmLayer(lstmDescriptor, lstmInputParams, layerName.c_str()); + + armnn::TensorInfo outputTensorInfo1 = ToTensorInfo(outputs[0]); + layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo1); + + armnn::TensorInfo outputTensorInfo2 = ToTensorInfo(outputs[1]); + layer->GetOutputSlot(1).SetTensorInfo(outputTensorInfo2); + + armnn::TensorInfo outputTensorInfo3 = ToTensorInfo(outputs[2]); + layer->GetOutputSlot(2).SetTensorInfo(outputTensorInfo3); + + armnn::TensorInfo outputTensorInfo4 = ToTensorInfo(outputs[3]); + layer->GetOutputSlot(3).SetTensorInfo(outputTensorInfo4); + + RegisterInputSlots(graph, layerIndex, layer); + RegisterOutputSlots(graph, layerIndex, layer); +} + } // namespace armnnDeserializer |