From a0162e17c56538ee6d72ecce4c3e0836cbb34c56 Mon Sep 17 00:00:00 2001 From: Narumol Prangnawarat Date: Fri, 23 Jul 2021 14:47:49 +0100 Subject: MLCE-530 Add Serializer and Deserializer for UnidirectionalSequenceLstm Signed-off-by: Narumol Prangnawarat Change-Id: Ic1c56a57941ebede19ab8b9032e7f9df1221be7a --- src/armnnDeserializer/Deserializer.cpp | 133 +++++++++++++++++++++++++++++++++ 1 file changed, 133 insertions(+) (limited to 'src/armnnDeserializer/Deserializer.cpp') diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp index af6ff842a7..2d9194a350 100644 --- a/src/armnnDeserializer/Deserializer.cpp +++ b/src/armnnDeserializer/Deserializer.cpp @@ -270,6 +270,7 @@ m_ParserFunctions(Layer_MAX+1, &IDeserializer::DeserializerImpl::ParseUnsupporte m_ParserFunctions[Layer_SwitchLayer] = &DeserializerImpl::ParseSwitch; m_ParserFunctions[Layer_TransposeConvolution2dLayer] = &DeserializerImpl::ParseTransposeConvolution2d; m_ParserFunctions[Layer_TransposeLayer] = &DeserializerImpl::ParseTranspose; + m_ParserFunctions[Layer_UnidirectionalSequenceLstmLayer] = &DeserializerImpl::ParseUnidirectionalSequenceLstm; } LayerBaseRawPtr IDeserializer::DeserializerImpl::GetBaseLayer(const GraphPtr& graphPtr, unsigned int layerIndex) @@ -404,6 +405,8 @@ LayerBaseRawPtr IDeserializer::DeserializerImpl::GetBaseLayer(const GraphPtr& gr return graphPtr->layers()->Get(layerIndex)->layer_as_TransposeConvolution2dLayer()->base(); case Layer::Layer_TransposeLayer: return graphPtr->layers()->Get(layerIndex)->layer_as_TransposeLayer()->base(); + case Layer::Layer_UnidirectionalSequenceLstmLayer: + return graphPtr->layers()->Get(layerIndex)->layer_as_UnidirectionalSequenceLstmLayer()->base(); case Layer::Layer_NONE: default: throw ParseException(fmt::format("Layer type {} not recognized", layerType)); @@ -3325,4 +3328,134 @@ void IDeserializer::DeserializerImpl::ParseStandIn(GraphPtr graph, unsigned int RegisterOutputSlots(graph, layerIndex, layer); } +armnn::UnidirectionalSequenceLstmDescriptor IDeserializer::DeserializerImpl::GetUnidirectionalSequenceLstmDescriptor( + UnidirectionalSequenceLstmDescriptorPtr descriptor) +{ + armnn::UnidirectionalSequenceLstmDescriptor desc; + + desc.m_ActivationFunc = descriptor->activationFunc(); + desc.m_ClippingThresCell = descriptor->clippingThresCell(); + desc.m_ClippingThresProj = descriptor->clippingThresProj(); + desc.m_CifgEnabled = descriptor->cifgEnabled(); + desc.m_PeepholeEnabled = descriptor->peepholeEnabled(); + desc.m_ProjectionEnabled = descriptor->projectionEnabled(); + desc.m_LayerNormEnabled = descriptor->layerNormEnabled(); + desc.m_TimeMajor = descriptor->timeMajor(); + + return desc; +} + +void IDeserializer::DeserializerImpl::ParseUnidirectionalSequenceLstm(GraphPtr graph, unsigned int layerIndex) +{ + CHECK_LAYERS(graph, 0, layerIndex); + + auto inputs = GetInputs(graph, layerIndex); + CHECK_VALID_SIZE(inputs.size(), 3); + + auto outputs = GetOutputs(graph, layerIndex); + CHECK_VALID_SIZE(outputs.size(), 1); + + auto flatBufferLayer = graph->layers()->Get(layerIndex)->layer_as_UnidirectionalSequenceLstmLayer(); + auto layerName = GetLayerName(graph, layerIndex); + auto flatBufferDescriptor = flatBufferLayer->descriptor(); + auto flatBufferInputParams = flatBufferLayer->inputParams(); + + auto descriptor = GetUnidirectionalSequenceLstmDescriptor(flatBufferDescriptor); + + armnn::LstmInputParams lstmInputParams; + + armnn::ConstTensor inputToForgetWeights = ToConstTensor(flatBufferInputParams->inputToForgetWeights()); + armnn::ConstTensor inputToCellWeights = ToConstTensor(flatBufferInputParams->inputToCellWeights()); + armnn::ConstTensor inputToOutputWeights = ToConstTensor(flatBufferInputParams->inputToOutputWeights()); + armnn::ConstTensor recurrentToForgetWeights = ToConstTensor(flatBufferInputParams->recurrentToForgetWeights()); + armnn::ConstTensor recurrentToCellWeights = ToConstTensor(flatBufferInputParams->recurrentToCellWeights()); + armnn::ConstTensor recurrentToOutputWeights = ToConstTensor(flatBufferInputParams->recurrentToOutputWeights()); + armnn::ConstTensor forgetGateBias = ToConstTensor(flatBufferInputParams->forgetGateBias()); + armnn::ConstTensor cellBias = ToConstTensor(flatBufferInputParams->cellBias()); + armnn::ConstTensor outputGateBias = ToConstTensor(flatBufferInputParams->outputGateBias()); + + lstmInputParams.m_InputToForgetWeights = &inputToForgetWeights; + lstmInputParams.m_InputToCellWeights = &inputToCellWeights; + lstmInputParams.m_InputToOutputWeights = &inputToOutputWeights; + lstmInputParams.m_RecurrentToForgetWeights = &recurrentToForgetWeights; + lstmInputParams.m_RecurrentToCellWeights = &recurrentToCellWeights; + lstmInputParams.m_RecurrentToOutputWeights = &recurrentToOutputWeights; + lstmInputParams.m_ForgetGateBias = &forgetGateBias; + lstmInputParams.m_CellBias = &cellBias; + lstmInputParams.m_OutputGateBias = &outputGateBias; + + armnn::ConstTensor inputToInputWeights; + armnn::ConstTensor recurrentToInputWeights; + armnn::ConstTensor cellToInputWeights; + armnn::ConstTensor inputGateBias; + if (!descriptor.m_CifgEnabled) + { + inputToInputWeights = ToConstTensor(flatBufferInputParams->inputToInputWeights()); + recurrentToInputWeights = ToConstTensor(flatBufferInputParams->recurrentToInputWeights()); + inputGateBias = ToConstTensor(flatBufferInputParams->inputGateBias()); + + lstmInputParams.m_InputToInputWeights = &inputToInputWeights; + lstmInputParams.m_RecurrentToInputWeights = &recurrentToInputWeights; + lstmInputParams.m_InputGateBias = &inputGateBias; + + if (descriptor.m_PeepholeEnabled) + { + cellToInputWeights = ToConstTensor(flatBufferInputParams->cellToInputWeights()); + lstmInputParams.m_CellToInputWeights = &cellToInputWeights; + } + } + + armnn::ConstTensor projectionWeights; + armnn::ConstTensor projectionBias; + if (descriptor.m_ProjectionEnabled) + { + projectionWeights = ToConstTensor(flatBufferInputParams->projectionWeights()); + projectionBias = ToConstTensor(flatBufferInputParams->projectionBias()); + + lstmInputParams.m_ProjectionWeights = &projectionWeights; + lstmInputParams.m_ProjectionBias = &projectionBias; + } + + armnn::ConstTensor cellToForgetWeights; + armnn::ConstTensor cellToOutputWeights; + if (descriptor.m_PeepholeEnabled) + { + cellToForgetWeights = ToConstTensor(flatBufferInputParams->cellToForgetWeights()); + cellToOutputWeights = ToConstTensor(flatBufferInputParams->cellToOutputWeights()); + + lstmInputParams.m_CellToForgetWeights = &cellToForgetWeights; + lstmInputParams.m_CellToOutputWeights = &cellToOutputWeights; + } + + armnn::ConstTensor inputLayerNormWeights; + armnn::ConstTensor forgetLayerNormWeights; + armnn::ConstTensor cellLayerNormWeights; + armnn::ConstTensor outputLayerNormWeights; + if (descriptor.m_LayerNormEnabled) + { + if (!descriptor.m_CifgEnabled) + { + inputLayerNormWeights = ToConstTensor(flatBufferInputParams->inputLayerNormWeights()); + lstmInputParams.m_InputLayerNormWeights = &inputLayerNormWeights; + } + forgetLayerNormWeights = ToConstTensor(flatBufferInputParams->forgetLayerNormWeights()); + cellLayerNormWeights = ToConstTensor(flatBufferInputParams->cellLayerNormWeights()); + outputLayerNormWeights = ToConstTensor(flatBufferInputParams->outputLayerNormWeights()); + + lstmInputParams.m_ForgetLayerNormWeights = &forgetLayerNormWeights; + lstmInputParams.m_CellLayerNormWeights = &cellLayerNormWeights; + lstmInputParams.m_OutputLayerNormWeights = &outputLayerNormWeights; + } + + IConnectableLayer* layer = m_Network->AddUnidirectionalSequenceLstmLayer(descriptor, + lstmInputParams, + layerName.c_str()); + + armnn::TensorInfo outputTensorInfo1 = ToTensorInfo(outputs[0]); + layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo1); + + RegisterInputSlots(graph, layerIndex, layer); + RegisterOutputSlots(graph, layerIndex, layer); +} + } // namespace armnnDeserializer -- cgit v1.2.1