From 8d33318a7ac33d90ed79701ff717de8d9940cc67 Mon Sep 17 00:00:00 2001 From: James Conroy Date: Wed, 13 May 2020 10:27:58 +0100 Subject: IVGCVSW-4777 Add QLstm serialization support * Adds serialization/deserilization for QLstm. * 3 unit tests: basic, layer norm and advanced. Signed-off-by: James Conroy Change-Id: I97d825e06b0d4a1257713cdd71ff06afa10d4380 --- src/armnnDeserializer/Deserializer.cpp | 152 +++++++++++++++++++++++++++++++++ 1 file changed, 152 insertions(+) (limited to 'src/armnnDeserializer/Deserializer.cpp') diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp index 42b0052b03..36beebc1cd 100644 --- a/src/armnnDeserializer/Deserializer.cpp +++ b/src/armnnDeserializer/Deserializer.cpp @@ -222,6 +222,7 @@ m_ParserFunctions(Layer_MAX+1, &Deserializer::ParseUnsupportedLayer) m_ParserFunctions[Layer_PermuteLayer] = &Deserializer::ParsePermute; m_ParserFunctions[Layer_Pooling2dLayer] = &Deserializer::ParsePooling2d; m_ParserFunctions[Layer_PreluLayer] = &Deserializer::ParsePrelu; + m_ParserFunctions[Layer_QLstmLayer] = &Deserializer::ParseQLstm; m_ParserFunctions[Layer_QuantizeLayer] = &Deserializer::ParseQuantize; m_ParserFunctions[Layer_QuantizedLstmLayer] = &Deserializer::ParseQuantizedLstm; m_ParserFunctions[Layer_ReshapeLayer] = &Deserializer::ParseReshape; @@ -322,6 +323,8 @@ Deserializer::LayerBaseRawPtr Deserializer::GetBaseLayer(const GraphPtr& graphPt return graphPtr->layers()->Get(layerIndex)->layer_as_Pooling2dLayer()->base(); case Layer::Layer_PreluLayer: return graphPtr->layers()->Get(layerIndex)->layer_as_PreluLayer()->base(); + case Layer::Layer_QLstmLayer: + return graphPtr->layers()->Get(layerIndex)->layer_as_QLstmLayer()->base(); case Layer::Layer_QuantizeLayer: return graphPtr->layers()->Get(layerIndex)->layer_as_QuantizeLayer()->base(); case Layer::Layer_QuantizedLstmLayer: @@ -2610,6 +2613,155 @@ void Deserializer::ParseLstm(GraphPtr graph, unsigned int layerIndex) RegisterOutputSlots(graph, layerIndex, layer); } +armnn::QLstmDescriptor Deserializer::GetQLstmDescriptor(Deserializer::QLstmDescriptorPtr qLstmDescriptor) +{ + armnn::QLstmDescriptor desc; + + desc.m_CifgEnabled = qLstmDescriptor->cifgEnabled(); + desc.m_PeepholeEnabled = qLstmDescriptor->peepholeEnabled(); + desc.m_ProjectionEnabled = qLstmDescriptor->projectionEnabled(); + desc.m_LayerNormEnabled = qLstmDescriptor->layerNormEnabled(); + + desc.m_CellClip = qLstmDescriptor->cellClip(); + desc.m_ProjectionClip = qLstmDescriptor->projectionClip(); + + desc.m_InputIntermediateScale = qLstmDescriptor->inputIntermediateScale(); + desc.m_ForgetIntermediateScale = qLstmDescriptor->forgetIntermediateScale(); + desc.m_CellIntermediateScale = qLstmDescriptor->cellIntermediateScale(); + desc.m_OutputIntermediateScale = qLstmDescriptor->outputIntermediateScale(); + + desc.m_HiddenStateScale = qLstmDescriptor->hiddenStateScale(); + desc.m_HiddenStateZeroPoint = qLstmDescriptor->hiddenStateZeroPoint(); + + return desc; +} + +void Deserializer::ParseQLstm(GraphPtr graph, unsigned int layerIndex) +{ + CHECK_LAYERS(graph, 0, layerIndex); + + auto inputs = GetInputs(graph, layerIndex); + CHECK_VALID_SIZE(inputs.size(), 3); + + auto outputs = GetOutputs(graph, layerIndex); + CHECK_VALID_SIZE(outputs.size(), 3); + + auto flatBufferLayer = graph->layers()->Get(layerIndex)->layer_as_QLstmLayer(); + auto layerName = GetLayerName(graph, layerIndex); + auto flatBufferDescriptor = flatBufferLayer->descriptor(); + auto flatBufferInputParams = flatBufferLayer->inputParams(); + + auto qLstmDescriptor = GetQLstmDescriptor(flatBufferDescriptor); + armnn::LstmInputParams qLstmInputParams; + + // Mandatory params + armnn::ConstTensor inputToForgetWeights = ToConstTensor(flatBufferInputParams->inputToForgetWeights()); + armnn::ConstTensor inputToCellWeights = ToConstTensor(flatBufferInputParams->inputToCellWeights()); + armnn::ConstTensor inputToOutputWeights = ToConstTensor(flatBufferInputParams->inputToOutputWeights()); + armnn::ConstTensor recurrentToForgetWeights = ToConstTensor(flatBufferInputParams->recurrentToForgetWeights()); + armnn::ConstTensor recurrentToCellWeights = ToConstTensor(flatBufferInputParams->recurrentToCellWeights()); + armnn::ConstTensor recurrentToOutputWeights = ToConstTensor(flatBufferInputParams->recurrentToOutputWeights()); + armnn::ConstTensor forgetGateBias = ToConstTensor(flatBufferInputParams->forgetGateBias()); + armnn::ConstTensor cellBias = ToConstTensor(flatBufferInputParams->cellBias()); + armnn::ConstTensor outputGateBias = ToConstTensor(flatBufferInputParams->outputGateBias()); + + qLstmInputParams.m_InputToForgetWeights = &inputToForgetWeights; + qLstmInputParams.m_InputToCellWeights = &inputToCellWeights; + qLstmInputParams.m_InputToOutputWeights = &inputToOutputWeights; + qLstmInputParams.m_RecurrentToForgetWeights = &recurrentToForgetWeights; + qLstmInputParams.m_RecurrentToCellWeights = &recurrentToCellWeights; + qLstmInputParams.m_RecurrentToOutputWeights = &recurrentToOutputWeights; + qLstmInputParams.m_ForgetGateBias = &forgetGateBias; + qLstmInputParams.m_CellBias = &cellBias; + qLstmInputParams.m_OutputGateBias = &outputGateBias; + + // Optional CIFG params + armnn::ConstTensor inputToInputWeights; + armnn::ConstTensor recurrentToInputWeights; + armnn::ConstTensor inputGateBias; + + if (!qLstmDescriptor.m_CifgEnabled) + { + inputToInputWeights = ToConstTensor(flatBufferInputParams->inputToInputWeights()); + recurrentToInputWeights = ToConstTensor(flatBufferInputParams->recurrentToInputWeights()); + inputGateBias = ToConstTensor(flatBufferInputParams->inputGateBias()); + + qLstmInputParams.m_InputToInputWeights = &inputToInputWeights; + qLstmInputParams.m_RecurrentToInputWeights = &recurrentToInputWeights; + qLstmInputParams.m_InputGateBias = &inputGateBias; + } + + // Optional projection params + armnn::ConstTensor projectionWeights; + armnn::ConstTensor projectionBias; + + if (qLstmDescriptor.m_ProjectionEnabled) + { + projectionWeights = ToConstTensor(flatBufferInputParams->projectionWeights()); + projectionBias = ToConstTensor(flatBufferInputParams->projectionBias()); + + qLstmInputParams.m_ProjectionWeights = &projectionWeights; + qLstmInputParams.m_ProjectionBias = &projectionBias; + } + + // Optional peephole params + armnn::ConstTensor cellToInputWeights; + armnn::ConstTensor cellToForgetWeights; + armnn::ConstTensor cellToOutputWeights; + + if (qLstmDescriptor.m_PeepholeEnabled) + { + if (!qLstmDescriptor.m_CifgEnabled) + { + cellToInputWeights = ToConstTensor(flatBufferInputParams->cellToInputWeights()); + qLstmInputParams.m_CellToInputWeights = &cellToInputWeights; + } + + cellToForgetWeights = ToConstTensor(flatBufferInputParams->cellToForgetWeights()); + cellToOutputWeights = ToConstTensor(flatBufferInputParams->cellToOutputWeights()); + + qLstmInputParams.m_CellToForgetWeights = &cellToForgetWeights; + qLstmInputParams.m_CellToOutputWeights = &cellToOutputWeights; + } + + // Optional layer norm params + armnn::ConstTensor inputLayerNormWeights; + armnn::ConstTensor forgetLayerNormWeights; + armnn::ConstTensor cellLayerNormWeights; + armnn::ConstTensor outputLayerNormWeights; + + if (qLstmDescriptor.m_LayerNormEnabled) + { + if (!qLstmDescriptor.m_CifgEnabled) + { + inputLayerNormWeights = ToConstTensor(flatBufferInputParams->inputLayerNormWeights()); + qLstmInputParams.m_InputLayerNormWeights = &inputLayerNormWeights; + } + + forgetLayerNormWeights = ToConstTensor(flatBufferInputParams->forgetLayerNormWeights()); + cellLayerNormWeights = ToConstTensor(flatBufferInputParams->cellLayerNormWeights()); + outputLayerNormWeights = ToConstTensor(flatBufferInputParams->outputLayerNormWeights()); + + qLstmInputParams.m_ForgetLayerNormWeights = &forgetLayerNormWeights; + qLstmInputParams.m_CellLayerNormWeights = &cellLayerNormWeights; + qLstmInputParams.m_OutputLayerNormWeights = &outputLayerNormWeights; + } + + IConnectableLayer* layer = m_Network->AddQLstmLayer(qLstmDescriptor, qLstmInputParams, layerName.c_str()); + + armnn::TensorInfo outputStateOutInfo = ToTensorInfo(outputs[0]); + layer->GetOutputSlot(0).SetTensorInfo(outputStateOutInfo); + + armnn::TensorInfo cellStateOutInfo = ToTensorInfo(outputs[1]); + layer->GetOutputSlot(1).SetTensorInfo(cellStateOutInfo); + + armnn::TensorInfo outputInfo = ToTensorInfo(outputs[2]); + layer->GetOutputSlot(2).SetTensorInfo(outputInfo); + + RegisterInputSlots(graph, layerIndex, layer); + RegisterOutputSlots(graph, layerIndex, layer); +} + void Deserializer::ParseQuantizedLstm(GraphPtr graph, unsigned int layerIndex) { CHECK_LAYERS(graph, 0, layerIndex); -- cgit v1.2.1