From 11af375a5a6bf88b4f3b933a86d53000b0d91ed0 Mon Sep 17 00:00:00 2001 From: Jim Flynn Date: Tue, 19 Mar 2019 17:22:29 +0000 Subject: IVGCVSW-2694: serialize/deserialize LSTM * added serialize/deserialize methods for LSTM and tests Change-Id: Ic59557f03001c496008c4bef92c2e0406e1fbc6c Signed-off-by: Nina Drozd Signed-off-by: Jim Flynn --- src/armnn/layers/LstmLayer.cpp | 92 ++++--- src/armnnDeserializer/Deserializer.cpp | 113 ++++++++ src/armnnDeserializer/Deserializer.hpp | 6 + src/armnnDeserializer/DeserializerSupport.md | 1 + src/armnnSerializer/ArmnnSchema.fbs | 44 +++- src/armnnSerializer/Serializer.cpp | 84 ++++++ src/armnnSerializer/Serializer.hpp | 5 + src/armnnSerializer/SerializerSupport.md | 1 + src/armnnSerializer/test/SerializerTests.cpp | 375 +++++++++++++++++++++++++++ 9 files changed, 690 insertions(+), 31 deletions(-) (limited to 'src') diff --git a/src/armnn/layers/LstmLayer.cpp b/src/armnn/layers/LstmLayer.cpp index fa836d0317..2b99f284e8 100644 --- a/src/armnn/layers/LstmLayer.cpp +++ b/src/armnn/layers/LstmLayer.cpp @@ -252,110 +252,144 @@ Layer::ConstantTensors LstmLayer::GetConstantTensorsByRef() void LstmLayer::Accept(ILayerVisitor& visitor) const { LstmInputParams inputParams; + ConstTensor inputToInputWeightsTensor; if (m_CifgParameters.m_InputToInputWeights != nullptr) { - ConstTensor inputToInputWeightsTensor(m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), - m_CifgParameters.m_InputToInputWeights->Map(true)); + ConstTensor inputToInputWeightsTensorCopy(m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), + m_CifgParameters.m_InputToInputWeights->Map(true)); + inputToInputWeightsTensor = inputToInputWeightsTensorCopy; inputParams.m_InputToInputWeights = &inputToInputWeightsTensor; } + ConstTensor inputToForgetWeightsTensor; if (m_BasicParameters.m_InputToForgetWeights != nullptr) { - ConstTensor inputToForgetWeightsTensor(m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), - m_BasicParameters.m_InputToForgetWeights->Map(true)); + ConstTensor inputToForgetWeightsTensorCopy(m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), + m_BasicParameters.m_InputToForgetWeights->Map(true)); + inputToForgetWeightsTensor = inputToForgetWeightsTensorCopy; inputParams.m_InputToForgetWeights = &inputToForgetWeightsTensor; } + ConstTensor inputToCellWeightsTensor; if (m_BasicParameters.m_InputToCellWeights != nullptr) { - ConstTensor inputToCellWeightsTensor(m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), - m_BasicParameters.m_InputToCellWeights->Map(true)); + ConstTensor inputToCellWeightsTensorCopy(m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), + m_BasicParameters.m_InputToCellWeights->Map(true)); + inputToCellWeightsTensor = inputToCellWeightsTensorCopy; inputParams.m_InputToCellWeights = &inputToCellWeightsTensor; } + ConstTensor inputToOutputWeightsTensor; if (m_BasicParameters.m_InputToOutputWeights != nullptr) { - ConstTensor inputToOutputWeightsTensor(m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), - m_BasicParameters.m_InputToOutputWeights->Map(true)); + ConstTensor inputToOutputWeightsTensorCopy(m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), + m_BasicParameters.m_InputToOutputWeights->Map(true)); + inputToOutputWeightsTensor = inputToOutputWeightsTensorCopy; inputParams.m_InputToOutputWeights = &inputToOutputWeightsTensor; } + ConstTensor recurrentToInputWeightsTensor; if (m_CifgParameters.m_RecurrentToInputWeights != nullptr) { - ConstTensor recurrentToInputWeightsTensor( + ConstTensor recurrentToInputWeightsTensorCopy( m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), m_CifgParameters.m_RecurrentToInputWeights->Map(true)); + recurrentToInputWeightsTensor = recurrentToInputWeightsTensorCopy; inputParams.m_RecurrentToInputWeights = &recurrentToInputWeightsTensor; } + ConstTensor recurrentToForgetWeightsTensor; if (m_BasicParameters.m_RecurrentToForgetWeights != nullptr) { - ConstTensor recurrentToForgetWeightsTensor( + ConstTensor recurrentToForgetWeightsTensorCopy( m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), m_BasicParameters.m_RecurrentToForgetWeights->Map(true)); + recurrentToForgetWeightsTensor = recurrentToForgetWeightsTensorCopy; inputParams.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor; } + ConstTensor recurrentToCellWeightsTensor; if (m_BasicParameters.m_RecurrentToCellWeights != nullptr) { - ConstTensor recurrentToCellWeightsTensor( + ConstTensor recurrentToCellWeightsTensorCopy( m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), m_BasicParameters.m_RecurrentToCellWeights->Map(true)); + recurrentToCellWeightsTensor = recurrentToCellWeightsTensorCopy; inputParams.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor; } + ConstTensor recurrentToOutputWeightsTensor; if (m_BasicParameters.m_RecurrentToOutputWeights != nullptr) { - ConstTensor recurrentToOutputWeightsTensor( + ConstTensor recurrentToOutputWeightsTensorCopy( m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), m_BasicParameters.m_RecurrentToOutputWeights->Map(true)); + recurrentToOutputWeightsTensor = recurrentToOutputWeightsTensorCopy; inputParams.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor; } + ConstTensor cellToInputWeightsTensor; if (m_CifgParameters.m_CellToInputWeights != nullptr) { - ConstTensor cellToInputWeightsTensor(m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), - m_CifgParameters.m_CellToInputWeights->Map(true)); + ConstTensor cellToInputWeightsTensorCopy(m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), + m_CifgParameters.m_CellToInputWeights->Map(true)); + cellToInputWeightsTensor = cellToInputWeightsTensorCopy; inputParams.m_CellToInputWeights = &cellToInputWeightsTensor; } + ConstTensor cellToForgetWeightsTensor; if (m_PeepholeParameters.m_CellToForgetWeights != nullptr) { - ConstTensor cellToForgetWeightsTensor(m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), - m_PeepholeParameters.m_CellToForgetWeights->Map(true)); + ConstTensor cellToForgetWeightsTensorCopy(m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), + m_PeepholeParameters.m_CellToForgetWeights->Map(true)); + cellToForgetWeightsTensor = cellToForgetWeightsTensorCopy; inputParams.m_CellToForgetWeights = &cellToForgetWeightsTensor; } + ConstTensor cellToOutputWeightsTensor; if (m_PeepholeParameters.m_CellToOutputWeights != nullptr) { - ConstTensor cellToOutputWeightsTensor(m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), - m_PeepholeParameters.m_CellToOutputWeights->Map(true)); + ConstTensor cellToOutputWeightsTensorCopy(m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), + m_PeepholeParameters.m_CellToOutputWeights->Map(true)); + cellToOutputWeightsTensor = cellToOutputWeightsTensorCopy; inputParams.m_CellToOutputWeights = &cellToOutputWeightsTensor; } + ConstTensor inputGateBiasTensor; if (m_CifgParameters.m_InputGateBias != nullptr) { - ConstTensor inputGateBiasTensor(m_CifgParameters.m_InputGateBias->GetTensorInfo(), + ConstTensor inputGateBiasTensorCopy(m_CifgParameters.m_InputGateBias->GetTensorInfo(), m_CifgParameters.m_InputGateBias->Map(true)); + inputGateBiasTensor = inputGateBiasTensorCopy; inputParams.m_InputGateBias = &inputGateBiasTensor; } + ConstTensor forgetGateBiasTensor; if (m_BasicParameters.m_ForgetGateBias != nullptr) { - ConstTensor forgetGateBiasTensor(m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), - m_BasicParameters.m_ForgetGateBias->Map(true)); + ConstTensor forgetGateBiasTensorCopy(m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), + m_BasicParameters.m_ForgetGateBias->Map(true)); + forgetGateBiasTensor = forgetGateBiasTensorCopy; inputParams.m_ForgetGateBias = &forgetGateBiasTensor; } + ConstTensor cellBiasTensor; if (m_BasicParameters.m_CellBias != nullptr) { - ConstTensor cellBiasTensor(m_BasicParameters.m_CellBias->GetTensorInfo(), - m_BasicParameters.m_CellBias->Map(true)); + ConstTensor cellBiasTensorCopy(m_BasicParameters.m_CellBias->GetTensorInfo(), + m_BasicParameters.m_CellBias->Map(true)); + cellBiasTensor = cellBiasTensorCopy; inputParams.m_CellBias = &cellBiasTensor; } + ConstTensor outputGateBias; if (m_BasicParameters.m_OutputGateBias != nullptr) { - ConstTensor outputGateBias(m_BasicParameters.m_OutputGateBias->GetTensorInfo(), - m_BasicParameters.m_OutputGateBias->Map(true)); + ConstTensor outputGateBiasCopy(m_BasicParameters.m_OutputGateBias->GetTensorInfo(), + m_BasicParameters.m_OutputGateBias->Map(true)); + outputGateBias = outputGateBiasCopy; inputParams.m_OutputGateBias = &outputGateBias; } + ConstTensor projectionWeightsTensor; if (m_ProjectionParameters.m_ProjectionWeights != nullptr) { - ConstTensor projectionWeightsTensor(m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), - m_ProjectionParameters.m_ProjectionWeights->Map(true)); + ConstTensor projectionWeightsTensorCopy(m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), + m_ProjectionParameters.m_ProjectionWeights->Map(true)); + projectionWeightsTensor = projectionWeightsTensorCopy; inputParams.m_ProjectionWeights = &projectionWeightsTensor; } + ConstTensor projectionBiasTensor; if (m_ProjectionParameters.m_ProjectionBias != nullptr) { - ConstTensor projectionBiasTensor(m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), - m_ProjectionParameters.m_ProjectionBias->Map(true)); + ConstTensor projectionBiasTensorCopy(m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), + m_ProjectionParameters.m_ProjectionBias->Map(true)); + projectionBiasTensor = projectionBiasTensorCopy; inputParams.m_ProjectionBias = &projectionBiasTensor; } diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp index 152a5b4c93..d64bed7409 100644 --- a/src/armnnDeserializer/Deserializer.cpp +++ b/src/armnnDeserializer/Deserializer.cpp @@ -201,6 +201,7 @@ m_ParserFunctions(Layer_MAX+1, &Deserializer::ParseUnsupportedLayer) m_ParserFunctions[Layer_GatherLayer] = &Deserializer::ParseGather; m_ParserFunctions[Layer_GreaterLayer] = &Deserializer::ParseGreater; m_ParserFunctions[Layer_L2NormalizationLayer] = &Deserializer::ParseL2Normalization; + m_ParserFunctions[Layer_LstmLayer] = &Deserializer::ParseLstm; m_ParserFunctions[Layer_MaximumLayer] = &Deserializer::ParseMaximum; m_ParserFunctions[Layer_MeanLayer] = &Deserializer::ParseMean; m_ParserFunctions[Layer_MinimumLayer] = &Deserializer::ParseMinimum; @@ -258,6 +259,8 @@ Deserializer::LayerBaseRawPtr Deserializer::GetBaseLayer(const GraphPtr& graphPt return graphPtr->layers()->Get(layerIndex)->layer_as_InputLayer()->base()->base(); case Layer::Layer_L2NormalizationLayer: return graphPtr->layers()->Get(layerIndex)->layer_as_L2NormalizationLayer()->base(); + case Layer::Layer_LstmLayer: + return graphPtr->layers()->Get(layerIndex)->layer_as_LstmLayer()->base(); case Layer::Layer_MeanLayer: return graphPtr->layers()->Get(layerIndex)->layer_as_MeanLayer()->base(); case Layer::Layer_MinimumLayer: @@ -1927,4 +1930,114 @@ void Deserializer::ParseSplitter(GraphPtr graph, unsigned int layerIndex) RegisterOutputSlots(graph, layerIndex, layer); } +armnn::LstmDescriptor Deserializer::GetLstmDescriptor(Deserializer::LstmDescriptorPtr lstmDescriptor) +{ + armnn::LstmDescriptor desc; + + desc.m_ActivationFunc = lstmDescriptor->activationFunc(); + desc.m_ClippingThresCell = lstmDescriptor->clippingThresCell(); + desc.m_ClippingThresProj = lstmDescriptor->clippingThresProj(); + desc.m_CifgEnabled = lstmDescriptor->cifgEnabled(); + desc.m_PeepholeEnabled = lstmDescriptor->peepholeEnabled(); + desc.m_ProjectionEnabled = lstmDescriptor->projectionEnabled(); + + return desc; +} + +void Deserializer::ParseLstm(GraphPtr graph, unsigned int layerIndex) +{ + CHECK_LAYERS(graph, 0, layerIndex); + + auto inputs = GetInputs(graph, layerIndex); + CHECK_VALID_SIZE(inputs.size(), 3); + + auto outputs = GetOutputs(graph, layerIndex); + CHECK_VALID_SIZE(outputs.size(), 4); + + auto flatBufferLayer = graph->layers()->Get(layerIndex)->layer_as_LstmLayer(); + auto layerName = GetLayerName(graph, layerIndex); + auto flatBufferDescriptor = flatBufferLayer->descriptor(); + auto flatBufferInputParams = flatBufferLayer->inputParams(); + + auto lstmDescriptor = GetLstmDescriptor(flatBufferDescriptor); + + armnn::LstmInputParams lstmInputParams; + + armnn::ConstTensor inputToForgetWeights = ToConstTensor(flatBufferInputParams->inputToForgetWeights()); + armnn::ConstTensor inputToCellWeights = ToConstTensor(flatBufferInputParams->inputToCellWeights()); + armnn::ConstTensor inputToOutputWeights = ToConstTensor(flatBufferInputParams->inputToOutputWeights()); + armnn::ConstTensor recurrentToForgetWeights = ToConstTensor(flatBufferInputParams->recurrentToForgetWeights()); + armnn::ConstTensor recurrentToCellWeights = ToConstTensor(flatBufferInputParams->recurrentToCellWeights()); + armnn::ConstTensor recurrentToOutputWeights = ToConstTensor(flatBufferInputParams->recurrentToOutputWeights()); + armnn::ConstTensor forgetGateBias = ToConstTensor(flatBufferInputParams->forgetGateBias()); + armnn::ConstTensor cellBias = ToConstTensor(flatBufferInputParams->cellBias()); + armnn::ConstTensor outputGateBias = ToConstTensor(flatBufferInputParams->outputGateBias()); + + lstmInputParams.m_InputToForgetWeights = &inputToForgetWeights; + lstmInputParams.m_InputToCellWeights = &inputToCellWeights; + lstmInputParams.m_InputToOutputWeights = &inputToOutputWeights; + lstmInputParams.m_RecurrentToForgetWeights = &recurrentToForgetWeights; + lstmInputParams.m_RecurrentToCellWeights = &recurrentToCellWeights; + lstmInputParams.m_RecurrentToOutputWeights = &recurrentToOutputWeights; + lstmInputParams.m_ForgetGateBias = &forgetGateBias; + lstmInputParams.m_CellBias = &cellBias; + lstmInputParams.m_OutputGateBias = &outputGateBias; + + armnn::ConstTensor inputToInputWeights; + armnn::ConstTensor recurrentToInputWeights; + armnn::ConstTensor cellToInputWeights; + armnn::ConstTensor inputGateBias; + if (!lstmDescriptor.m_CifgEnabled) + { + inputToInputWeights = ToConstTensor(flatBufferInputParams->inputToInputWeights()); + recurrentToInputWeights = ToConstTensor(flatBufferInputParams->recurrentToInputWeights()); + cellToInputWeights = ToConstTensor(flatBufferInputParams->cellToInputWeights()); + inputGateBias = ToConstTensor(flatBufferInputParams->inputGateBias()); + + lstmInputParams.m_InputToInputWeights = &inputToInputWeights; + lstmInputParams.m_RecurrentToInputWeights = &recurrentToInputWeights; + lstmInputParams.m_CellToInputWeights = &cellToInputWeights; + lstmInputParams.m_InputGateBias = &inputGateBias; + } + + armnn::ConstTensor projectionWeights; + armnn::ConstTensor projectionBias; + if (lstmDescriptor.m_ProjectionEnabled) + { + projectionWeights = ToConstTensor(flatBufferInputParams->projectionWeights()); + projectionBias = ToConstTensor(flatBufferInputParams->projectionBias()); + + lstmInputParams.m_ProjectionWeights = &projectionWeights; + lstmInputParams.m_ProjectionBias = &projectionBias; + } + + armnn::ConstTensor cellToForgetWeights; + armnn::ConstTensor cellToOutputWeights; + if (lstmDescriptor.m_PeepholeEnabled) + { + cellToForgetWeights = ToConstTensor(flatBufferInputParams->cellToForgetWeights()); + cellToOutputWeights = ToConstTensor(flatBufferInputParams->cellToOutputWeights()); + + lstmInputParams.m_CellToForgetWeights = &cellToForgetWeights; + lstmInputParams.m_CellToOutputWeights = &cellToOutputWeights; + } + + IConnectableLayer* layer = m_Network->AddLstmLayer(lstmDescriptor, lstmInputParams, layerName.c_str()); + + armnn::TensorInfo outputTensorInfo1 = ToTensorInfo(outputs[0]); + layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo1); + + armnn::TensorInfo outputTensorInfo2 = ToTensorInfo(outputs[1]); + layer->GetOutputSlot(1).SetTensorInfo(outputTensorInfo2); + + armnn::TensorInfo outputTensorInfo3 = ToTensorInfo(outputs[2]); + layer->GetOutputSlot(2).SetTensorInfo(outputTensorInfo3); + + armnn::TensorInfo outputTensorInfo4 = ToTensorInfo(outputs[3]); + layer->GetOutputSlot(3).SetTensorInfo(outputTensorInfo4); + + RegisterInputSlots(graph, layerIndex, layer); + RegisterOutputSlots(graph, layerIndex, layer); +} + } // namespace armnnDeserializer diff --git a/src/armnnDeserializer/Deserializer.hpp b/src/armnnDeserializer/Deserializer.hpp index effc7ae144..6454643f98 100644 --- a/src/armnnDeserializer/Deserializer.hpp +++ b/src/armnnDeserializer/Deserializer.hpp @@ -22,6 +22,8 @@ public: using TensorRawPtr = const armnnSerializer::TensorInfo *; using PoolingDescriptor = const armnnSerializer::Pooling2dDescriptor *; using NormalizationDescriptorPtr = const armnnSerializer::NormalizationDescriptor *; + using LstmDescriptorPtr = const armnnSerializer::LstmDescriptor *; + using LstmInputParamsPtr = const armnnSerializer::LstmInputParams *; using TensorRawPtrVector = std::vector; using LayerRawPtr = const armnnSerializer::LayerBase *; using LayerBaseRawPtr = const armnnSerializer::LayerBase *; @@ -58,6 +60,9 @@ public: unsigned int layerIndex); static armnn::NormalizationDescriptor GetNormalizationDescriptor( NormalizationDescriptorPtr normalizationDescriptor, unsigned int layerIndex); + static armnn::LstmDescriptor GetLstmDescriptor(LstmDescriptorPtr lstmDescriptor); + static armnn::LstmInputParams GetLstmInputParams(LstmDescriptorPtr lstmDescriptor, + LstmInputParamsPtr lstmInputParams); static armnn::TensorInfo OutputShapeOfReshape(const armnn::TensorInfo & inputTensorInfo, const std::vector & targetDimsIn); @@ -94,6 +99,7 @@ private: void ParseMerger(GraphPtr graph, unsigned int layerIndex); void ParseMultiplication(GraphPtr graph, unsigned int layerIndex); void ParseNormalization(GraphPtr graph, unsigned int layerIndex); + void ParseLstm(GraphPtr graph, unsigned int layerIndex); void ParsePad(GraphPtr graph, unsigned int layerIndex); void ParsePermute(GraphPtr graph, unsigned int layerIndex); void ParsePooling2d(GraphPtr graph, unsigned int layerIndex); diff --git a/src/armnnDeserializer/DeserializerSupport.md b/src/armnnDeserializer/DeserializerSupport.md index 48b8c88103..d53252ec00 100644 --- a/src/armnnDeserializer/DeserializerSupport.md +++ b/src/armnnDeserializer/DeserializerSupport.md @@ -21,6 +21,7 @@ The Arm NN SDK Deserialize parser currently supports the following layers: * Gather * Greater * L2Normalization +* Lstm * Maximum * Mean * Merger diff --git a/src/armnnSerializer/ArmnnSchema.fbs b/src/armnnSerializer/ArmnnSchema.fbs index a11eeadf12..2cceaae031 100644 --- a/src/armnnSerializer/ArmnnSchema.fbs +++ b/src/armnnSerializer/ArmnnSchema.fbs @@ -115,7 +115,8 @@ enum LayerType : uint { Merger = 30, L2Normalization = 31, Splitter = 32, - DetectionPostProcess = 33 + DetectionPostProcess = 33, + Lstm = 34 } // Base layer table to be used as part of other layers @@ -475,6 +476,44 @@ table DetectionPostProcessDescriptor { scaleH:float; } +table LstmInputParams { + inputToForgetWeights:ConstTensor; + inputToCellWeights:ConstTensor; + inputToOutputWeights:ConstTensor; + recurrentToForgetWeights:ConstTensor; + recurrentToCellWeights:ConstTensor; + recurrentToOutputWeights:ConstTensor; + forgetGateBias:ConstTensor; + cellBias:ConstTensor; + outputGateBias:ConstTensor; + + inputToInputWeights:ConstTensor; + recurrentToInputWeights:ConstTensor; + cellToInputWeights:ConstTensor; + inputGateBias:ConstTensor; + + projectionWeights:ConstTensor; + projectionBias:ConstTensor; + + cellToForgetWeights:ConstTensor; + cellToOutputWeights:ConstTensor; +} + +table LstmDescriptor { + activationFunc:uint; + clippingThresCell:float; + clippingThresProj:float; + cifgEnabled:bool = true; + peepholeEnabled:bool = false; + projectionEnabled:bool = false; +} + +table LstmLayer { + base:LayerBase; + descriptor:LstmDescriptor; + inputParams:LstmInputParams; +} + union Layer { ActivationLayer, AdditionLayer, @@ -509,7 +548,8 @@ union Layer { MergerLayer, L2NormalizationLayer, SplitterLayer, - DetectionPostProcessLayer + DetectionPostProcessLayer, + LstmLayer } table AnyLayer { diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp index a27cbc03ba..2fd840258e 100644 --- a/src/armnnSerializer/Serializer.cpp +++ b/src/armnnSerializer/Serializer.cpp @@ -375,6 +375,90 @@ void SerializerVisitor::VisitL2NormalizationLayer(const armnn::IConnectableLayer CreateAnyLayer(fbLayer.o, serializer::Layer::Layer_L2NormalizationLayer); } +void SerializerVisitor::VisitLstmLayer(const armnn::IConnectableLayer* layer, const armnn::LstmDescriptor& descriptor, + const armnn::LstmInputParams& params, const char* name) +{ + auto fbLstmBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Lstm); + + auto fbLstmDescriptor = serializer::CreateLstmDescriptor( + m_flatBufferBuilder, + descriptor.m_ActivationFunc, + descriptor.m_ClippingThresCell, + descriptor.m_ClippingThresProj, + descriptor.m_CifgEnabled, + descriptor.m_PeepholeEnabled, + descriptor.m_ProjectionEnabled); + + // Get mandatory input parameters + auto inputToForgetWeights = CreateConstTensorInfo(*params.m_InputToForgetWeights); + auto inputToCellWeights = CreateConstTensorInfo(*params.m_InputToCellWeights); + auto inputToOutputWeights = CreateConstTensorInfo(*params.m_InputToOutputWeights); + auto recurrentToForgetWeights = CreateConstTensorInfo(*params.m_RecurrentToForgetWeights); + auto recurrentToCellWeights = CreateConstTensorInfo(*params.m_RecurrentToCellWeights); + auto recurrentToOutputWeights = CreateConstTensorInfo(*params.m_RecurrentToOutputWeights); + auto forgetGateBias = CreateConstTensorInfo(*params.m_ForgetGateBias); + auto cellBias = CreateConstTensorInfo(*params.m_CellBias); + auto outputGateBias = CreateConstTensorInfo(*params.m_OutputGateBias); + + //Define optional parameters, these will be set depending on configuration in Lstm descriptor + flatbuffers::Offset inputToInputWeights; + flatbuffers::Offset recurrentToInputWeights; + flatbuffers::Offset cellToInputWeights; + flatbuffers::Offset inputGateBias; + flatbuffers::Offset projectionWeights; + flatbuffers::Offset projectionBias; + flatbuffers::Offset cellToForgetWeights; + flatbuffers::Offset cellToOutputWeights; + + if (!descriptor.m_CifgEnabled) + { + inputToInputWeights = CreateConstTensorInfo(*params.m_InputToInputWeights); + recurrentToInputWeights = CreateConstTensorInfo(*params.m_RecurrentToInputWeights); + cellToInputWeights = CreateConstTensorInfo(*params.m_CellToInputWeights); + inputGateBias = CreateConstTensorInfo(*params.m_InputGateBias); + } + + if (descriptor.m_ProjectionEnabled) + { + projectionWeights = CreateConstTensorInfo(*params.m_ProjectionWeights); + projectionBias = CreateConstTensorInfo(*params.m_ProjectionBias); + } + + if (descriptor.m_PeepholeEnabled) + { + cellToForgetWeights = CreateConstTensorInfo(*params.m_CellToForgetWeights); + cellToOutputWeights = CreateConstTensorInfo(*params.m_CellToOutputWeights); + } + + auto fbLstmParams = serializer::CreateLstmInputParams( + m_flatBufferBuilder, + inputToForgetWeights, + inputToCellWeights, + inputToOutputWeights, + recurrentToForgetWeights, + recurrentToCellWeights, + recurrentToOutputWeights, + forgetGateBias, + cellBias, + outputGateBias, + inputToInputWeights, + recurrentToInputWeights, + cellToInputWeights, + inputGateBias, + projectionWeights, + projectionBias, + cellToForgetWeights, + cellToOutputWeights); + + auto fbLstmLayer = serializer::CreateLstmLayer( + m_flatBufferBuilder, + fbLstmBaseLayer, + fbLstmDescriptor, + fbLstmParams); + + CreateAnyLayer(fbLstmLayer.o, serializer::Layer::Layer_LstmLayer); +} + void SerializerVisitor::VisitMaximumLayer(const armnn::IConnectableLayer* layer, const char* name) { auto fbMaximumBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Maximum); diff --git a/src/armnnSerializer/Serializer.hpp b/src/armnnSerializer/Serializer.hpp index 71066d2699..4573bfdd11 100644 --- a/src/armnnSerializer/Serializer.hpp +++ b/src/armnnSerializer/Serializer.hpp @@ -111,6 +111,11 @@ public: const armnn::L2NormalizationDescriptor& l2NormalizationDescriptor, const char* name = nullptr) override; + void VisitLstmLayer(const armnn::IConnectableLayer* layer, + const armnn::LstmDescriptor& descriptor, + const armnn::LstmInputParams& params, + const char* name = nullptr) override; + void VisitMeanLayer(const armnn::IConnectableLayer* layer, const armnn::MeanDescriptor& descriptor, const char* name) override; diff --git a/src/armnnSerializer/SerializerSupport.md b/src/armnnSerializer/SerializerSupport.md index 4e127b3f9f..7686d5c24c 100644 --- a/src/armnnSerializer/SerializerSupport.md +++ b/src/armnnSerializer/SerializerSupport.md @@ -21,6 +21,7 @@ The Arm NN SDK Serializer currently supports the following layers: * Gather * Greater * L2Normalization +* Lstm * Maximum * Mean * Merger diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp index f40c02dfde..e3ce6d29d3 100644 --- a/src/armnnSerializer/test/SerializerTests.cpp +++ b/src/armnnSerializer/test/SerializerTests.cpp @@ -2047,4 +2047,379 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeNonLinearNetwork) deserializedNetwork->Accept(verifier); } +class VerifyLstmLayer : public LayerVerifierBase +{ +public: + VerifyLstmLayer(const std::string& layerName, + const std::vector& inputInfos, + const std::vector& outputInfos, + const armnn::LstmDescriptor& descriptor, + const armnn::LstmInputParams& inputParams) : + LayerVerifierBase(layerName, inputInfos, outputInfos), m_Descriptor(descriptor), m_InputParams(inputParams) + { + } + void VisitLstmLayer(const armnn::IConnectableLayer* layer, + const armnn::LstmDescriptor& descriptor, + const armnn::LstmInputParams& params, + const char* name) + { + VerifyNameAndConnections(layer, name); + VerifyDescriptor(descriptor); + VerifyInputParameters(params); + } +protected: + void VerifyDescriptor(const armnn::LstmDescriptor& descriptor) + { + BOOST_TEST(m_Descriptor.m_ActivationFunc == descriptor.m_ActivationFunc); + BOOST_TEST(m_Descriptor.m_ClippingThresCell == descriptor.m_ClippingThresCell); + BOOST_TEST(m_Descriptor.m_ClippingThresProj == descriptor.m_ClippingThresProj); + BOOST_TEST(m_Descriptor.m_CifgEnabled == descriptor.m_CifgEnabled); + BOOST_TEST(m_Descriptor.m_PeepholeEnabled = descriptor.m_PeepholeEnabled); + BOOST_TEST(m_Descriptor.m_ProjectionEnabled == descriptor.m_ProjectionEnabled); + } + void VerifyInputParameters(const armnn::LstmInputParams& params) + { + VerifyConstTensors( + "m_InputToInputWeights", m_InputParams.m_InputToInputWeights, params.m_InputToInputWeights); + VerifyConstTensors( + "m_InputToForgetWeights", m_InputParams.m_InputToForgetWeights, params.m_InputToForgetWeights); + VerifyConstTensors( + "m_InputToCellWeights", m_InputParams.m_InputToCellWeights, params.m_InputToCellWeights); + VerifyConstTensors( + "m_InputToOutputWeights", m_InputParams.m_InputToOutputWeights, params.m_InputToOutputWeights); + VerifyConstTensors( + "m_RecurrentToInputWeights", m_InputParams.m_RecurrentToInputWeights, params.m_RecurrentToInputWeights); + VerifyConstTensors( + "m_RecurrentToForgetWeights", m_InputParams.m_RecurrentToForgetWeights, params.m_RecurrentToForgetWeights); + VerifyConstTensors( + "m_RecurrentToCellWeights", m_InputParams.m_RecurrentToCellWeights, params.m_RecurrentToCellWeights); + VerifyConstTensors( + "m_RecurrentToOutputWeights", m_InputParams.m_RecurrentToOutputWeights, params.m_RecurrentToOutputWeights); + VerifyConstTensors( + "m_CellToInputWeights", m_InputParams.m_CellToInputWeights, params.m_CellToInputWeights); + VerifyConstTensors( + "m_CellToForgetWeights", m_InputParams.m_CellToForgetWeights, params.m_CellToForgetWeights); + VerifyConstTensors( + "m_CellToOutputWeights", m_InputParams.m_CellToOutputWeights, params.m_CellToOutputWeights); + VerifyConstTensors( + "m_InputGateBias", m_InputParams.m_InputGateBias, params.m_InputGateBias); + VerifyConstTensors( + "m_ForgetGateBias", m_InputParams.m_ForgetGateBias, params.m_ForgetGateBias); + VerifyConstTensors( + "m_CellBias", m_InputParams.m_CellBias, params.m_CellBias); + VerifyConstTensors( + "m_OutputGateBias", m_InputParams.m_OutputGateBias, params.m_OutputGateBias); + VerifyConstTensors( + "m_ProjectionWeights", m_InputParams.m_ProjectionWeights, params.m_ProjectionWeights); + VerifyConstTensors( + "m_ProjectionBias", m_InputParams.m_ProjectionBias, params.m_ProjectionBias); + } + void VerifyConstTensors(const std::string& tensorName, + const armnn::ConstTensor* expectedPtr, + const armnn::ConstTensor* actualPtr) + { + if (expectedPtr == nullptr) + { + BOOST_CHECK_MESSAGE(actualPtr == nullptr, tensorName + " should not exist"); + } + else + { + BOOST_CHECK_MESSAGE(actualPtr != nullptr, tensorName + " should have been set"); + if (actualPtr != nullptr) + { + const armnn::TensorInfo& expectedInfo = expectedPtr->GetInfo(); + const armnn::TensorInfo& actualInfo = actualPtr->GetInfo(); + + BOOST_CHECK_MESSAGE(expectedInfo.GetShape() == actualInfo.GetShape(), + tensorName + " shapes don't match"); + BOOST_CHECK_MESSAGE( + GetDataTypeName(expectedInfo.GetDataType()) == GetDataTypeName(actualInfo.GetDataType()), + tensorName + " data types don't match"); + + BOOST_CHECK_MESSAGE(expectedPtr->GetNumBytes() == actualPtr->GetNumBytes(), + tensorName + " (GetNumBytes) data sizes do not match"); + if (expectedPtr->GetNumBytes() == actualPtr->GetNumBytes()) + { + //check the data is identical + const char* expectedData = static_cast(expectedPtr->GetMemoryArea()); + const char* actualData = static_cast(actualPtr->GetMemoryArea()); + bool same = true; + for (unsigned int i = 0; i < expectedPtr->GetNumBytes(); ++i) + { + same = expectedData[i] == actualData[i]; + if (!same) + { + break; + } + } + BOOST_CHECK_MESSAGE(same, tensorName + " data does not match"); + } + } + } + } +private: + armnn::LstmDescriptor m_Descriptor; + armnn::LstmInputParams m_InputParams; +}; + +BOOST_AUTO_TEST_CASE(SerializeDeserializeLstmCifgPeepholeNoProjection) +{ + armnn::LstmDescriptor descriptor; + descriptor.m_ActivationFunc = 4; + descriptor.m_ClippingThresProj = 0.0f; + descriptor.m_ClippingThresCell = 0.0f; + descriptor.m_CifgEnabled = true; // if this is true then we DON'T need to set the OptCifgParams + descriptor.m_ProjectionEnabled = false; + descriptor.m_PeepholeEnabled = true; + + const uint32_t batchSize = 1; + const uint32_t inputSize = 2; + const uint32_t numUnits = 4; + const uint32_t outputSize = numUnits; + + armnn::TensorInfo inputWeightsInfo1({numUnits, inputSize}, armnn::DataType::Float32); + std::vector inputToForgetWeightsData = GenerateRandomData(inputWeightsInfo1.GetNumElements()); + armnn::ConstTensor inputToForgetWeights(inputWeightsInfo1, inputToForgetWeightsData); + + std::vector inputToCellWeightsData = GenerateRandomData(inputWeightsInfo1.GetNumElements()); + armnn::ConstTensor inputToCellWeights(inputWeightsInfo1, inputToCellWeightsData); + + std::vector inputToOutputWeightsData = GenerateRandomData(inputWeightsInfo1.GetNumElements()); + armnn::ConstTensor inputToOutputWeights(inputWeightsInfo1, inputToOutputWeightsData); + + armnn::TensorInfo inputWeightsInfo2({numUnits, outputSize}, armnn::DataType::Float32); + std::vector recurrentToForgetWeightsData = GenerateRandomData(inputWeightsInfo2.GetNumElements()); + armnn::ConstTensor recurrentToForgetWeights(inputWeightsInfo2, recurrentToForgetWeightsData); + + std::vector recurrentToCellWeightsData = GenerateRandomData(inputWeightsInfo2.GetNumElements()); + armnn::ConstTensor recurrentToCellWeights(inputWeightsInfo2, recurrentToCellWeightsData); + + std::vector recurrentToOutputWeightsData = GenerateRandomData(inputWeightsInfo2.GetNumElements()); + armnn::ConstTensor recurrentToOutputWeights(inputWeightsInfo2, recurrentToOutputWeightsData); + + armnn::TensorInfo inputWeightsInfo3({numUnits}, armnn::DataType::Float32); + std::vector cellToForgetWeightsData = GenerateRandomData(inputWeightsInfo3.GetNumElements()); + armnn::ConstTensor cellToForgetWeights(inputWeightsInfo3, cellToForgetWeightsData); + + std::vector cellToOutputWeightsData = GenerateRandomData(inputWeightsInfo3.GetNumElements()); + armnn::ConstTensor cellToOutputWeights(inputWeightsInfo3, cellToOutputWeightsData); + + std::vector forgetGateBiasData(numUnits, 1.0f); + armnn::ConstTensor forgetGateBias(inputWeightsInfo3, forgetGateBiasData); + + std::vector cellBiasData(numUnits, 0.0f); + armnn::ConstTensor cellBias(inputWeightsInfo3, cellBiasData); + + std::vector outputGateBiasData(numUnits, 0.0f); + armnn::ConstTensor outputGateBias(inputWeightsInfo3, outputGateBiasData); + + armnn::LstmInputParams params; + params.m_InputToForgetWeights = &inputToForgetWeights; + params.m_InputToCellWeights = &inputToCellWeights; + params.m_InputToOutputWeights = &inputToOutputWeights; + params.m_RecurrentToForgetWeights = &recurrentToForgetWeights; + params.m_RecurrentToCellWeights = &recurrentToCellWeights; + params.m_RecurrentToOutputWeights = &recurrentToOutputWeights; + params.m_ForgetGateBias = &forgetGateBias; + params.m_CellBias = &cellBias; + params.m_OutputGateBias = &outputGateBias; + params.m_CellToForgetWeights = &cellToForgetWeights; + params.m_CellToOutputWeights = &cellToOutputWeights; + + armnn::INetworkPtr network = armnn::INetwork::Create(); + armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0); + armnn::IConnectableLayer* const cellStateIn = network->AddInputLayer(1); + armnn::IConnectableLayer* const outputStateIn = network->AddInputLayer(2); + const std::string layerName("lstm"); + armnn::IConnectableLayer* const lstmLayer = network->AddLstmLayer(descriptor, params, layerName.c_str()); + armnn::IConnectableLayer* const scratchBuffer = network->AddOutputLayer(0); + armnn::IConnectableLayer* const outputStateOut = network->AddOutputLayer(1); + armnn::IConnectableLayer* const cellStateOut = network->AddOutputLayer(2); + armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(3); + + // connect up + armnn::TensorInfo inputTensorInfo({ batchSize, inputSize }, armnn::DataType::Float32); + armnn::TensorInfo cellStateTensorInfo({ batchSize, numUnits}, armnn::DataType::Float32); + armnn::TensorInfo outputStateTensorInfo({ batchSize, outputSize }, armnn::DataType::Float32); + armnn::TensorInfo lstmTensorInfoScratchBuff({ batchSize, numUnits * 3 }, armnn::DataType::Float32); + + inputLayer->GetOutputSlot(0).Connect(lstmLayer->GetInputSlot(0)); + inputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo); + + outputStateIn->GetOutputSlot(0).Connect(lstmLayer->GetInputSlot(1)); + outputStateIn->GetOutputSlot(0).SetTensorInfo(outputStateTensorInfo); + + cellStateIn->GetOutputSlot(0).Connect(lstmLayer->GetInputSlot(2)); + cellStateIn->GetOutputSlot(0).SetTensorInfo(cellStateTensorInfo); + + lstmLayer->GetOutputSlot(0).Connect(scratchBuffer->GetInputSlot(0)); + lstmLayer->GetOutputSlot(0).SetTensorInfo(lstmTensorInfoScratchBuff); + + lstmLayer->GetOutputSlot(1).Connect(outputStateOut->GetInputSlot(0)); + lstmLayer->GetOutputSlot(1).SetTensorInfo(outputStateTensorInfo); + + lstmLayer->GetOutputSlot(2).Connect(cellStateOut->GetInputSlot(0)); + lstmLayer->GetOutputSlot(2).SetTensorInfo(cellStateTensorInfo); + + lstmLayer->GetOutputSlot(3).Connect(outputLayer->GetInputSlot(0)); + lstmLayer->GetOutputSlot(3).SetTensorInfo(outputStateTensorInfo); + + armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); + BOOST_CHECK(deserializedNetwork); + + VerifyLstmLayer checker( + layerName, + {inputTensorInfo, outputStateTensorInfo, cellStateTensorInfo}, + {lstmTensorInfoScratchBuff, outputStateTensorInfo, cellStateTensorInfo, outputStateTensorInfo}, + descriptor, + params); + deserializedNetwork->Accept(checker); +} + +BOOST_AUTO_TEST_CASE(SerializeDeserializeLstmNoCifgWithPeepholeAndProjection) +{ + armnn::LstmDescriptor descriptor; + descriptor.m_ActivationFunc = 4; + descriptor.m_ClippingThresProj = 0.0f; + descriptor.m_ClippingThresCell = 0.0f; + descriptor.m_CifgEnabled = false; // if this is true then we DON'T need to set the OptCifgParams + descriptor.m_ProjectionEnabled = true; + descriptor.m_PeepholeEnabled = true; + + const uint32_t batchSize = 2; + const uint32_t inputSize = 5; + const uint32_t numUnits = 20; + const uint32_t outputSize = 16; + + armnn::TensorInfo tensorInfo20x5({numUnits, inputSize}, armnn::DataType::Float32); + std::vector inputToInputWeightsData = GenerateRandomData(tensorInfo20x5.GetNumElements()); + armnn::ConstTensor inputToInputWeights(tensorInfo20x5, inputToInputWeightsData); + + std::vector inputToForgetWeightsData = GenerateRandomData(tensorInfo20x5.GetNumElements()); + armnn::ConstTensor inputToForgetWeights(tensorInfo20x5, inputToForgetWeightsData); + + std::vector inputToCellWeightsData = GenerateRandomData(tensorInfo20x5.GetNumElements()); + armnn::ConstTensor inputToCellWeights(tensorInfo20x5, inputToCellWeightsData); + + std::vector inputToOutputWeightsData = GenerateRandomData(tensorInfo20x5.GetNumElements()); + armnn::ConstTensor inputToOutputWeights(tensorInfo20x5, inputToOutputWeightsData); + + armnn::TensorInfo tensorInfo20({numUnits}, armnn::DataType::Float32); + std::vector inputGateBiasData = GenerateRandomData(tensorInfo20.GetNumElements()); + armnn::ConstTensor inputGateBias(tensorInfo20, inputGateBiasData); + + std::vector forgetGateBiasData = GenerateRandomData(tensorInfo20.GetNumElements()); + armnn::ConstTensor forgetGateBias(tensorInfo20, forgetGateBiasData); + + std::vector cellBiasData = GenerateRandomData(tensorInfo20.GetNumElements()); + armnn::ConstTensor cellBias(tensorInfo20, cellBiasData); + + std::vector outputGateBiasData = GenerateRandomData(tensorInfo20.GetNumElements()); + armnn::ConstTensor outputGateBias(tensorInfo20, outputGateBiasData); + + armnn::TensorInfo tensorInfo20x16({numUnits, outputSize}, armnn::DataType::Float32); + std::vector recurrentToInputWeightsData = GenerateRandomData(tensorInfo20x16.GetNumElements()); + armnn::ConstTensor recurrentToInputWeights(tensorInfo20x16, recurrentToInputWeightsData); + + std::vector recurrentToForgetWeightsData = GenerateRandomData(tensorInfo20x16.GetNumElements()); + armnn::ConstTensor recurrentToForgetWeights(tensorInfo20x16, recurrentToForgetWeightsData); + + std::vector recurrentToCellWeightsData = GenerateRandomData(tensorInfo20x16.GetNumElements()); + armnn::ConstTensor recurrentToCellWeights(tensorInfo20x16, recurrentToCellWeightsData); + + std::vector recurrentToOutputWeightsData = GenerateRandomData(tensorInfo20x16.GetNumElements()); + armnn::ConstTensor recurrentToOutputWeights(tensorInfo20x16, recurrentToOutputWeightsData); + + std::vector cellToInputWeightsData = GenerateRandomData(tensorInfo20.GetNumElements()); + armnn::ConstTensor cellToInputWeights(tensorInfo20, cellToInputWeightsData); + + std::vector cellToForgetWeightsData = GenerateRandomData(tensorInfo20.GetNumElements()); + armnn::ConstTensor cellToForgetWeights(tensorInfo20, cellToForgetWeightsData); + + std::vector cellToOutputWeightsData = GenerateRandomData(tensorInfo20.GetNumElements()); + armnn::ConstTensor cellToOutputWeights(tensorInfo20, cellToOutputWeightsData); + + armnn::TensorInfo tensorInfo16x20({outputSize, numUnits}, armnn::DataType::Float32); + std::vector projectionWeightsData = GenerateRandomData(tensorInfo16x20.GetNumElements()); + armnn::ConstTensor projectionWeights(tensorInfo16x20, projectionWeightsData); + + armnn::TensorInfo tensorInfo16({outputSize}, armnn::DataType::Float32); + std::vector projectionBiasData(outputSize, 0.f); + armnn::ConstTensor projectionBias(tensorInfo16, projectionBiasData); + + armnn::LstmInputParams params; + params.m_InputToForgetWeights = &inputToForgetWeights; + params.m_InputToCellWeights = &inputToCellWeights; + params.m_InputToOutputWeights = &inputToOutputWeights; + params.m_RecurrentToForgetWeights = &recurrentToForgetWeights; + params.m_RecurrentToCellWeights = &recurrentToCellWeights; + params.m_RecurrentToOutputWeights = &recurrentToOutputWeights; + params.m_ForgetGateBias = &forgetGateBias; + params.m_CellBias = &cellBias; + params.m_OutputGateBias = &outputGateBias; + + // additional params because: descriptor.m_CifgEnabled = false + params.m_InputToInputWeights = &inputToInputWeights; + params.m_RecurrentToInputWeights = &recurrentToInputWeights; + params.m_CellToInputWeights = &cellToInputWeights; + params.m_InputGateBias = &inputGateBias; + + // additional params because: descriptor.m_ProjectionEnabled = true + params.m_ProjectionWeights = &projectionWeights; + params.m_ProjectionBias = &projectionBias; + + // additional params because: descriptor.m_PeepholeEnabled = true + params.m_CellToForgetWeights = &cellToForgetWeights; + params.m_CellToOutputWeights = &cellToOutputWeights; + + armnn::INetworkPtr network = armnn::INetwork::Create(); + armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0); + armnn::IConnectableLayer* const cellStateIn = network->AddInputLayer(1); + armnn::IConnectableLayer* const outputStateIn = network->AddInputLayer(2); + const std::string layerName("lstm"); + armnn::IConnectableLayer* const lstmLayer = network->AddLstmLayer(descriptor, params, layerName.c_str()); + armnn::IConnectableLayer* const scratchBuffer = network->AddOutputLayer(0); + armnn::IConnectableLayer* const outputStateOut = network->AddOutputLayer(1); + armnn::IConnectableLayer* const cellStateOut = network->AddOutputLayer(2); + armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(3); + + // connect up + armnn::TensorInfo inputTensorInfo({ batchSize, inputSize }, armnn::DataType::Float32); + armnn::TensorInfo cellStateTensorInfo({ batchSize, numUnits}, armnn::DataType::Float32); + armnn::TensorInfo outputStateTensorInfo({ batchSize, outputSize }, armnn::DataType::Float32); + armnn::TensorInfo lstmTensorInfoScratchBuff({ batchSize, numUnits * 4 }, armnn::DataType::Float32); + + inputLayer->GetOutputSlot(0).Connect(lstmLayer->GetInputSlot(0)); + inputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo); + + outputStateIn->GetOutputSlot(0).Connect(lstmLayer->GetInputSlot(1)); + outputStateIn->GetOutputSlot(0).SetTensorInfo(outputStateTensorInfo); + + cellStateIn->GetOutputSlot(0).Connect(lstmLayer->GetInputSlot(2)); + cellStateIn->GetOutputSlot(0).SetTensorInfo(cellStateTensorInfo); + + lstmLayer->GetOutputSlot(0).Connect(scratchBuffer->GetInputSlot(0)); + lstmLayer->GetOutputSlot(0).SetTensorInfo(lstmTensorInfoScratchBuff); + + lstmLayer->GetOutputSlot(1).Connect(outputStateOut->GetInputSlot(0)); + lstmLayer->GetOutputSlot(1).SetTensorInfo(outputStateTensorInfo); + + lstmLayer->GetOutputSlot(2).Connect(cellStateOut->GetInputSlot(0)); + lstmLayer->GetOutputSlot(2).SetTensorInfo(cellStateTensorInfo); + + lstmLayer->GetOutputSlot(3).Connect(outputLayer->GetInputSlot(0)); + lstmLayer->GetOutputSlot(3).SetTensorInfo(outputStateTensorInfo); + + armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); + BOOST_CHECK(deserializedNetwork); + + VerifyLstmLayer checker( + layerName, + {inputTensorInfo, outputStateTensorInfo, cellStateTensorInfo}, + {lstmTensorInfoScratchBuff, outputStateTensorInfo, cellStateTensorInfo, outputStateTensorInfo}, + descriptor, + params); + deserializedNetwork->Accept(checker); +} + BOOST_AUTO_TEST_SUITE_END() -- cgit v1.2.1