From 8d33318a7ac33d90ed79701ff717de8d9940cc67 Mon Sep 17 00:00:00 2001 From: James Conroy Date: Wed, 13 May 2020 10:27:58 +0100 Subject: IVGCVSW-4777 Add QLstm serialization support * Adds serialization/deserilization for QLstm. * 3 unit tests: basic, layer norm and advanced. Signed-off-by: James Conroy Change-Id: I97d825e06b0d4a1257713cdd71ff06afa10d4380 --- src/armnnDeserializer/Deserializer.cpp | 152 +++++++ src/armnnDeserializer/Deserializer.hpp | 3 + src/armnnDeserializer/DeserializerSupport.md | 1 + src/armnnSerializer/ArmnnSchema.fbs | 93 +++- src/armnnSerializer/Serializer.cpp | 119 ++++- src/armnnSerializer/SerializerSupport.md | 1 + src/armnnSerializer/test/SerializerTests.cpp | 657 +++++++++++++++++++++++++++ 7 files changed, 1008 insertions(+), 18 deletions(-) diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp index 42b0052b03..36beebc1cd 100644 --- a/src/armnnDeserializer/Deserializer.cpp +++ b/src/armnnDeserializer/Deserializer.cpp @@ -222,6 +222,7 @@ m_ParserFunctions(Layer_MAX+1, &Deserializer::ParseUnsupportedLayer) m_ParserFunctions[Layer_PermuteLayer] = &Deserializer::ParsePermute; m_ParserFunctions[Layer_Pooling2dLayer] = &Deserializer::ParsePooling2d; m_ParserFunctions[Layer_PreluLayer] = &Deserializer::ParsePrelu; + m_ParserFunctions[Layer_QLstmLayer] = &Deserializer::ParseQLstm; m_ParserFunctions[Layer_QuantizeLayer] = &Deserializer::ParseQuantize; m_ParserFunctions[Layer_QuantizedLstmLayer] = &Deserializer::ParseQuantizedLstm; m_ParserFunctions[Layer_ReshapeLayer] = &Deserializer::ParseReshape; @@ -322,6 +323,8 @@ Deserializer::LayerBaseRawPtr Deserializer::GetBaseLayer(const GraphPtr& graphPt return graphPtr->layers()->Get(layerIndex)->layer_as_Pooling2dLayer()->base(); case Layer::Layer_PreluLayer: return graphPtr->layers()->Get(layerIndex)->layer_as_PreluLayer()->base(); + case Layer::Layer_QLstmLayer: + return graphPtr->layers()->Get(layerIndex)->layer_as_QLstmLayer()->base(); case Layer::Layer_QuantizeLayer: return graphPtr->layers()->Get(layerIndex)->layer_as_QuantizeLayer()->base(); case Layer::Layer_QuantizedLstmLayer: @@ -2610,6 +2613,155 @@ void Deserializer::ParseLstm(GraphPtr graph, unsigned int layerIndex) RegisterOutputSlots(graph, layerIndex, layer); } +armnn::QLstmDescriptor Deserializer::GetQLstmDescriptor(Deserializer::QLstmDescriptorPtr qLstmDescriptor) +{ + armnn::QLstmDescriptor desc; + + desc.m_CifgEnabled = qLstmDescriptor->cifgEnabled(); + desc.m_PeepholeEnabled = qLstmDescriptor->peepholeEnabled(); + desc.m_ProjectionEnabled = qLstmDescriptor->projectionEnabled(); + desc.m_LayerNormEnabled = qLstmDescriptor->layerNormEnabled(); + + desc.m_CellClip = qLstmDescriptor->cellClip(); + desc.m_ProjectionClip = qLstmDescriptor->projectionClip(); + + desc.m_InputIntermediateScale = qLstmDescriptor->inputIntermediateScale(); + desc.m_ForgetIntermediateScale = qLstmDescriptor->forgetIntermediateScale(); + desc.m_CellIntermediateScale = qLstmDescriptor->cellIntermediateScale(); + desc.m_OutputIntermediateScale = qLstmDescriptor->outputIntermediateScale(); + + desc.m_HiddenStateScale = qLstmDescriptor->hiddenStateScale(); + desc.m_HiddenStateZeroPoint = qLstmDescriptor->hiddenStateZeroPoint(); + + return desc; +} + +void Deserializer::ParseQLstm(GraphPtr graph, unsigned int layerIndex) +{ + CHECK_LAYERS(graph, 0, layerIndex); + + auto inputs = GetInputs(graph, layerIndex); + CHECK_VALID_SIZE(inputs.size(), 3); + + auto outputs = GetOutputs(graph, layerIndex); + CHECK_VALID_SIZE(outputs.size(), 3); + + auto flatBufferLayer = graph->layers()->Get(layerIndex)->layer_as_QLstmLayer(); + auto layerName = GetLayerName(graph, layerIndex); + auto flatBufferDescriptor = flatBufferLayer->descriptor(); + auto flatBufferInputParams = flatBufferLayer->inputParams(); + + auto qLstmDescriptor = GetQLstmDescriptor(flatBufferDescriptor); + armnn::LstmInputParams qLstmInputParams; + + // Mandatory params + armnn::ConstTensor inputToForgetWeights = ToConstTensor(flatBufferInputParams->inputToForgetWeights()); + armnn::ConstTensor inputToCellWeights = ToConstTensor(flatBufferInputParams->inputToCellWeights()); + armnn::ConstTensor inputToOutputWeights = ToConstTensor(flatBufferInputParams->inputToOutputWeights()); + armnn::ConstTensor recurrentToForgetWeights = ToConstTensor(flatBufferInputParams->recurrentToForgetWeights()); + armnn::ConstTensor recurrentToCellWeights = ToConstTensor(flatBufferInputParams->recurrentToCellWeights()); + armnn::ConstTensor recurrentToOutputWeights = ToConstTensor(flatBufferInputParams->recurrentToOutputWeights()); + armnn::ConstTensor forgetGateBias = ToConstTensor(flatBufferInputParams->forgetGateBias()); + armnn::ConstTensor cellBias = ToConstTensor(flatBufferInputParams->cellBias()); + armnn::ConstTensor outputGateBias = ToConstTensor(flatBufferInputParams->outputGateBias()); + + qLstmInputParams.m_InputToForgetWeights = &inputToForgetWeights; + qLstmInputParams.m_InputToCellWeights = &inputToCellWeights; + qLstmInputParams.m_InputToOutputWeights = &inputToOutputWeights; + qLstmInputParams.m_RecurrentToForgetWeights = &recurrentToForgetWeights; + qLstmInputParams.m_RecurrentToCellWeights = &recurrentToCellWeights; + qLstmInputParams.m_RecurrentToOutputWeights = &recurrentToOutputWeights; + qLstmInputParams.m_ForgetGateBias = &forgetGateBias; + qLstmInputParams.m_CellBias = &cellBias; + qLstmInputParams.m_OutputGateBias = &outputGateBias; + + // Optional CIFG params + armnn::ConstTensor inputToInputWeights; + armnn::ConstTensor recurrentToInputWeights; + armnn::ConstTensor inputGateBias; + + if (!qLstmDescriptor.m_CifgEnabled) + { + inputToInputWeights = ToConstTensor(flatBufferInputParams->inputToInputWeights()); + recurrentToInputWeights = ToConstTensor(flatBufferInputParams->recurrentToInputWeights()); + inputGateBias = ToConstTensor(flatBufferInputParams->inputGateBias()); + + qLstmInputParams.m_InputToInputWeights = &inputToInputWeights; + qLstmInputParams.m_RecurrentToInputWeights = &recurrentToInputWeights; + qLstmInputParams.m_InputGateBias = &inputGateBias; + } + + // Optional projection params + armnn::ConstTensor projectionWeights; + armnn::ConstTensor projectionBias; + + if (qLstmDescriptor.m_ProjectionEnabled) + { + projectionWeights = ToConstTensor(flatBufferInputParams->projectionWeights()); + projectionBias = ToConstTensor(flatBufferInputParams->projectionBias()); + + qLstmInputParams.m_ProjectionWeights = &projectionWeights; + qLstmInputParams.m_ProjectionBias = &projectionBias; + } + + // Optional peephole params + armnn::ConstTensor cellToInputWeights; + armnn::ConstTensor cellToForgetWeights; + armnn::ConstTensor cellToOutputWeights; + + if (qLstmDescriptor.m_PeepholeEnabled) + { + if (!qLstmDescriptor.m_CifgEnabled) + { + cellToInputWeights = ToConstTensor(flatBufferInputParams->cellToInputWeights()); + qLstmInputParams.m_CellToInputWeights = &cellToInputWeights; + } + + cellToForgetWeights = ToConstTensor(flatBufferInputParams->cellToForgetWeights()); + cellToOutputWeights = ToConstTensor(flatBufferInputParams->cellToOutputWeights()); + + qLstmInputParams.m_CellToForgetWeights = &cellToForgetWeights; + qLstmInputParams.m_CellToOutputWeights = &cellToOutputWeights; + } + + // Optional layer norm params + armnn::ConstTensor inputLayerNormWeights; + armnn::ConstTensor forgetLayerNormWeights; + armnn::ConstTensor cellLayerNormWeights; + armnn::ConstTensor outputLayerNormWeights; + + if (qLstmDescriptor.m_LayerNormEnabled) + { + if (!qLstmDescriptor.m_CifgEnabled) + { + inputLayerNormWeights = ToConstTensor(flatBufferInputParams->inputLayerNormWeights()); + qLstmInputParams.m_InputLayerNormWeights = &inputLayerNormWeights; + } + + forgetLayerNormWeights = ToConstTensor(flatBufferInputParams->forgetLayerNormWeights()); + cellLayerNormWeights = ToConstTensor(flatBufferInputParams->cellLayerNormWeights()); + outputLayerNormWeights = ToConstTensor(flatBufferInputParams->outputLayerNormWeights()); + + qLstmInputParams.m_ForgetLayerNormWeights = &forgetLayerNormWeights; + qLstmInputParams.m_CellLayerNormWeights = &cellLayerNormWeights; + qLstmInputParams.m_OutputLayerNormWeights = &outputLayerNormWeights; + } + + IConnectableLayer* layer = m_Network->AddQLstmLayer(qLstmDescriptor, qLstmInputParams, layerName.c_str()); + + armnn::TensorInfo outputStateOutInfo = ToTensorInfo(outputs[0]); + layer->GetOutputSlot(0).SetTensorInfo(outputStateOutInfo); + + armnn::TensorInfo cellStateOutInfo = ToTensorInfo(outputs[1]); + layer->GetOutputSlot(1).SetTensorInfo(cellStateOutInfo); + + armnn::TensorInfo outputInfo = ToTensorInfo(outputs[2]); + layer->GetOutputSlot(2).SetTensorInfo(outputInfo); + + RegisterInputSlots(graph, layerIndex, layer); + RegisterOutputSlots(graph, layerIndex, layer); +} + void Deserializer::ParseQuantizedLstm(GraphPtr graph, unsigned int layerIndex) { CHECK_LAYERS(graph, 0, layerIndex); diff --git a/src/armnnDeserializer/Deserializer.hpp b/src/armnnDeserializer/Deserializer.hpp index f7e47cc8c2..d6ceced7c6 100644 --- a/src/armnnDeserializer/Deserializer.hpp +++ b/src/armnnDeserializer/Deserializer.hpp @@ -24,6 +24,7 @@ public: using NormalizationDescriptorPtr = const armnnSerializer::NormalizationDescriptor *; using LstmDescriptorPtr = const armnnSerializer::LstmDescriptor *; using LstmInputParamsPtr = const armnnSerializer::LstmInputParams *; + using QLstmDescriptorPtr = const armnnSerializer::QLstmDescriptor *; using QunatizedLstmInputParamsPtr = const armnnSerializer::QuantizedLstmInputParams *; using TensorRawPtrVector = std::vector; using LayerRawPtr = const armnnSerializer::LayerBase *; @@ -62,6 +63,7 @@ public: static armnn::LstmDescriptor GetLstmDescriptor(LstmDescriptorPtr lstmDescriptor); static armnn::LstmInputParams GetLstmInputParams(LstmDescriptorPtr lstmDescriptor, LstmInputParamsPtr lstmInputParams); + static armnn::QLstmDescriptor GetQLstmDescriptor(QLstmDescriptorPtr qLstmDescriptorPtr); static armnn::TensorInfo OutputShapeOfReshape(const armnn::TensorInfo & inputTensorInfo, const std::vector & targetDimsIn); @@ -113,6 +115,7 @@ private: void ParsePermute(GraphPtr graph, unsigned int layerIndex); void ParsePooling2d(GraphPtr graph, unsigned int layerIndex); void ParsePrelu(GraphPtr graph, unsigned int layerIndex); + void ParseQLstm(GraphPtr graph, unsigned int layerIndex); void ParseQuantize(GraphPtr graph, unsigned int layerIndex); void ParseReshape(GraphPtr graph, unsigned int layerIndex); void ParseResize(GraphPtr graph, unsigned int layerIndex); diff --git a/src/armnnDeserializer/DeserializerSupport.md b/src/armnnDeserializer/DeserializerSupport.md index 491cd1de42..c47810c11e 100644 --- a/src/armnnDeserializer/DeserializerSupport.md +++ b/src/armnnDeserializer/DeserializerSupport.md @@ -42,6 +42,7 @@ The Arm NN SDK Deserialize parser currently supports the following layers: * Pooling2d * Prelu * Quantize +* QLstm * QuantizedLstm * Reshape * Resize diff --git a/src/armnnSerializer/ArmnnSchema.fbs b/src/armnnSerializer/ArmnnSchema.fbs index ff79f6cffe..6e5ee3f3d3 100644 --- a/src/armnnSerializer/ArmnnSchema.fbs +++ b/src/armnnSerializer/ArmnnSchema.fbs @@ -155,7 +155,8 @@ enum LayerType : uint { Comparison = 52, StandIn = 53, ElementwiseUnary = 54, - Transpose = 55 + Transpose = 55, + QLstm = 56 } // Base layer table to be used as part of other layers @@ -666,37 +667,96 @@ table LstmInputParams { outputLayerNormWeights:ConstTensor; } -table QuantizedLstmInputParams { - inputToInputWeights:ConstTensor; +table LstmDescriptor { + activationFunc:uint; + clippingThresCell:float; + clippingThresProj:float; + cifgEnabled:bool = true; + peepholeEnabled:bool = false; + projectionEnabled:bool = false; + layerNormEnabled:bool = false; +} + +table LstmLayer { + base:LayerBase; + descriptor:LstmDescriptor; + inputParams:LstmInputParams; +} + +table QLstmInputParams { + // Mandatory inputToForgetWeights:ConstTensor; inputToCellWeights:ConstTensor; inputToOutputWeights:ConstTensor; - recurrentToInputWeights:ConstTensor; recurrentToForgetWeights:ConstTensor; recurrentToCellWeights:ConstTensor; recurrentToOutputWeights:ConstTensor; - inputGateBias:ConstTensor; forgetGateBias:ConstTensor; cellBias:ConstTensor; outputGateBias:ConstTensor; + + // CIFG + inputToInputWeights:ConstTensor; + recurrentToInputWeights:ConstTensor; + inputGateBias:ConstTensor; + + // Projection + projectionWeights:ConstTensor; + projectionBias:ConstTensor; + + // Peephole + cellToInputWeights:ConstTensor; + cellToForgetWeights:ConstTensor; + cellToOutputWeights:ConstTensor; + + // Layer norm + inputLayerNormWeights:ConstTensor; + forgetLayerNormWeights:ConstTensor; + cellLayerNormWeights:ConstTensor; + outputLayerNormWeights:ConstTensor; } -table LstmDescriptor { - activationFunc:uint; - clippingThresCell:float; - clippingThresProj:float; - cifgEnabled:bool = true; - peepholeEnabled:bool = false; +table QLstmDescriptor { + cifgEnabled:bool = true; + peepholeEnabled:bool = false; projectionEnabled:bool = false; - layerNormEnabled:bool = false; + layerNormEnabled:bool = false; + + cellClip:float; + projectionClip:float; + + inputIntermediateScale:float; + forgetIntermediateScale:float; + cellIntermediateScale:float; + outputIntermediateScale:float; + + hiddenStateZeroPoint:int; + hiddenStateScale:float; } -table LstmLayer { +table QLstmLayer { base:LayerBase; - descriptor:LstmDescriptor; - inputParams:LstmInputParams; + descriptor:QLstmDescriptor; + inputParams:QLstmInputParams; +} + +table QuantizedLstmInputParams { + inputToInputWeights:ConstTensor; + inputToForgetWeights:ConstTensor; + inputToCellWeights:ConstTensor; + inputToOutputWeights:ConstTensor; + + recurrentToInputWeights:ConstTensor; + recurrentToForgetWeights:ConstTensor; + recurrentToCellWeights:ConstTensor; + recurrentToOutputWeights:ConstTensor; + + inputGateBias:ConstTensor; + forgetGateBias:ConstTensor; + cellBias:ConstTensor; + outputGateBias:ConstTensor; } table QuantizedLstmLayer { @@ -836,7 +896,8 @@ union Layer { ComparisonLayer, StandInLayer, ElementwiseUnaryLayer, - TransposeLayer + TransposeLayer, + QLstmLayer } table AnyLayer { diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp index 355673697f..c4d3cfb5dd 100644 --- a/src/armnnSerializer/Serializer.cpp +++ b/src/armnnSerializer/Serializer.cpp @@ -1335,9 +1335,124 @@ void SerializerVisitor::VisitQLstmLayer(const armnn::IConnectableLayer* layer, const armnn::LstmInputParams& params, const char* name) { - IgnoreUnused(layer, descriptor, params, name); + IgnoreUnused(name); + + auto fbQLstmBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_QLstm); + + auto fbQLstmDescriptor = serializer::CreateQLstmDescriptor( + m_flatBufferBuilder, + descriptor.m_CifgEnabled, + descriptor.m_PeepholeEnabled, + descriptor.m_ProjectionEnabled, + descriptor.m_LayerNormEnabled, + descriptor.m_CellClip, + descriptor.m_ProjectionClip, + descriptor.m_InputIntermediateScale, + descriptor.m_ForgetIntermediateScale, + descriptor.m_CellIntermediateScale, + descriptor.m_OutputIntermediateScale, + descriptor.m_HiddenStateZeroPoint, + descriptor.m_HiddenStateScale + ); + + // Mandatory params + auto inputToForgetWeights = CreateConstTensorInfo(*params.m_InputToForgetWeights); + auto inputToCellWeights = CreateConstTensorInfo(*params.m_InputToCellWeights); + auto inputToOutputWeights = CreateConstTensorInfo(*params.m_InputToOutputWeights); + auto recurrentToForgetWeights = CreateConstTensorInfo(*params.m_RecurrentToForgetWeights); + auto recurrentToCellWeights = CreateConstTensorInfo(*params.m_RecurrentToCellWeights); + auto recurrentToOutputWeights = CreateConstTensorInfo(*params.m_RecurrentToOutputWeights); + auto forgetGateBias = CreateConstTensorInfo(*params.m_ForgetGateBias); + auto cellBias = CreateConstTensorInfo(*params.m_CellBias); + auto outputGateBias = CreateConstTensorInfo(*params.m_OutputGateBias); + + // CIFG + flatbuffers::Offset inputToInputWeights; + flatbuffers::Offset recurrentToInputWeights; + flatbuffers::Offset inputGateBias; + + if (!descriptor.m_CifgEnabled) + { + inputToInputWeights = CreateConstTensorInfo(*params.m_InputToInputWeights); + recurrentToInputWeights = CreateConstTensorInfo(*params.m_RecurrentToInputWeights); + inputGateBias = CreateConstTensorInfo(*params.m_InputGateBias); + } + + // Projectiom + flatbuffers::Offset projectionWeights; + flatbuffers::Offset projectionBias; + + if (descriptor.m_ProjectionEnabled) + { + projectionWeights = CreateConstTensorInfo(*params.m_ProjectionWeights); + projectionBias = CreateConstTensorInfo(*params.m_ProjectionBias); + } + + // Peephole + flatbuffers::Offset cellToInputWeights; + flatbuffers::Offset cellToForgetWeights; + flatbuffers::Offset cellToOutputWeights; + + if (descriptor.m_PeepholeEnabled) + { + if (!descriptor.m_CifgEnabled) + { + cellToInputWeights = CreateConstTensorInfo(*params.m_CellToInputWeights); + } + + cellToForgetWeights = CreateConstTensorInfo(*params.m_CellToForgetWeights); + cellToOutputWeights = CreateConstTensorInfo(*params.m_CellToOutputWeights); + } + + // Layer norm + flatbuffers::Offset inputLayerNormWeights; + flatbuffers::Offset forgetLayerNormWeights; + flatbuffers::Offset cellLayerNormWeights; + flatbuffers::Offset outputLayerNormWeights; + + if (descriptor.m_LayerNormEnabled) + { + if (!descriptor.m_CifgEnabled) + { + inputLayerNormWeights = CreateConstTensorInfo((*params.m_InputLayerNormWeights)); + } + + forgetLayerNormWeights = CreateConstTensorInfo(*params.m_ForgetLayerNormWeights); + cellLayerNormWeights = CreateConstTensorInfo(*params.m_CellLayerNormWeights); + outputLayerNormWeights = CreateConstTensorInfo(*params.m_OutputLayerNormWeights); + } + + auto fbQLstmParams = serializer::CreateQLstmInputParams( + m_flatBufferBuilder, + inputToForgetWeights, + inputToCellWeights, + inputToOutputWeights, + recurrentToForgetWeights, + recurrentToCellWeights, + recurrentToOutputWeights, + forgetGateBias, + cellBias, + outputGateBias, + inputToInputWeights, + recurrentToInputWeights, + inputGateBias, + projectionWeights, + projectionBias, + cellToInputWeights, + cellToForgetWeights, + cellToOutputWeights, + inputLayerNormWeights, + forgetLayerNormWeights, + cellLayerNormWeights, + outputLayerNormWeights); + + auto fbQLstmLayer = serializer::CreateQLstmLayer( + m_flatBufferBuilder, + fbQLstmBaseLayer, + fbQLstmDescriptor, + fbQLstmParams); - throw UnimplementedException("SerializerVisitor::VisitQLstmLayer not yet implemented"); + CreateAnyLayer(fbQLstmLayer.o, serializer::Layer::Layer_QLstmLayer); } void SerializerVisitor::VisitQuantizedLstmLayer(const armnn::IConnectableLayer* layer, diff --git a/src/armnnSerializer/SerializerSupport.md b/src/armnnSerializer/SerializerSupport.md index 79b551f03a..8ba164c91e 100644 --- a/src/armnnSerializer/SerializerSupport.md +++ b/src/armnnSerializer/SerializerSupport.md @@ -40,6 +40,7 @@ The Arm NN SDK Serializer currently supports the following layers: * Permute * Pooling2d * Prelu +* QLstm * Quantize * QuantizedLstm * Reshape diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp index db89430439..76ac5a4de2 100644 --- a/src/armnnSerializer/test/SerializerTests.cpp +++ b/src/armnnSerializer/test/SerializerTests.cpp @@ -4326,4 +4326,661 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeQuantizedLstm) deserializedNetwork->Accept(checker); } +class VerifyQLstmLayer : public LayerVerifierBaseWithDescriptor +{ +public: + VerifyQLstmLayer(const std::string& layerName, + const std::vector& inputInfos, + const std::vector& outputInfos, + const armnn::QLstmDescriptor& descriptor, + const armnn::LstmInputParams& inputParams) + : LayerVerifierBaseWithDescriptor(layerName, inputInfos, outputInfos, descriptor) + , m_InputParams(inputParams) {} + + void VisitQLstmLayer(const armnn::IConnectableLayer* layer, + const armnn::QLstmDescriptor& descriptor, + const armnn::LstmInputParams& params, + const char* name) + { + VerifyNameAndConnections(layer, name); + VerifyDescriptor(descriptor); + VerifyInputParameters(params); + } + +protected: + void VerifyInputParameters(const armnn::LstmInputParams& params) + { + VerifyConstTensors( + "m_InputToInputWeights", m_InputParams.m_InputToInputWeights, params.m_InputToInputWeights); + VerifyConstTensors( + "m_InputToForgetWeights", m_InputParams.m_InputToForgetWeights, params.m_InputToForgetWeights); + VerifyConstTensors( + "m_InputToCellWeights", m_InputParams.m_InputToCellWeights, params.m_InputToCellWeights); + VerifyConstTensors( + "m_InputToOutputWeights", m_InputParams.m_InputToOutputWeights, params.m_InputToOutputWeights); + VerifyConstTensors( + "m_RecurrentToInputWeights", m_InputParams.m_RecurrentToInputWeights, params.m_RecurrentToInputWeights); + VerifyConstTensors( + "m_RecurrentToForgetWeights", m_InputParams.m_RecurrentToForgetWeights, params.m_RecurrentToForgetWeights); + VerifyConstTensors( + "m_RecurrentToCellWeights", m_InputParams.m_RecurrentToCellWeights, params.m_RecurrentToCellWeights); + VerifyConstTensors( + "m_RecurrentToOutputWeights", m_InputParams.m_RecurrentToOutputWeights, params.m_RecurrentToOutputWeights); + VerifyConstTensors( + "m_CellToInputWeights", m_InputParams.m_CellToInputWeights, params.m_CellToInputWeights); + VerifyConstTensors( + "m_CellToForgetWeights", m_InputParams.m_CellToForgetWeights, params.m_CellToForgetWeights); + VerifyConstTensors( + "m_CellToOutputWeights", m_InputParams.m_CellToOutputWeights, params.m_CellToOutputWeights); + VerifyConstTensors( + "m_InputGateBias", m_InputParams.m_InputGateBias, params.m_InputGateBias); + VerifyConstTensors( + "m_ForgetGateBias", m_InputParams.m_ForgetGateBias, params.m_ForgetGateBias); + VerifyConstTensors( + "m_CellBias", m_InputParams.m_CellBias, params.m_CellBias); + VerifyConstTensors( + "m_OutputGateBias", m_InputParams.m_OutputGateBias, params.m_OutputGateBias); + VerifyConstTensors( + "m_ProjectionWeights", m_InputParams.m_ProjectionWeights, params.m_ProjectionWeights); + VerifyConstTensors( + "m_ProjectionBias", m_InputParams.m_ProjectionBias, params.m_ProjectionBias); + VerifyConstTensors( + "m_InputLayerNormWeights", m_InputParams.m_InputLayerNormWeights, params.m_InputLayerNormWeights); + VerifyConstTensors( + "m_ForgetLayerNormWeights", m_InputParams.m_ForgetLayerNormWeights, params.m_ForgetLayerNormWeights); + VerifyConstTensors( + "m_CellLayerNormWeights", m_InputParams.m_CellLayerNormWeights, params.m_CellLayerNormWeights); + VerifyConstTensors( + "m_OutputLayerNormWeights", m_InputParams.m_OutputLayerNormWeights, params.m_OutputLayerNormWeights); + } + +private: + armnn::LstmInputParams m_InputParams; +}; + +BOOST_AUTO_TEST_CASE(SerializeDeserializeQLstmBasic) +{ + armnn::QLstmDescriptor descriptor; + + descriptor.m_CifgEnabled = true; + descriptor.m_ProjectionEnabled = false; + descriptor.m_PeepholeEnabled = false; + descriptor.m_LayerNormEnabled = false; + + descriptor.m_CellClip = 0.0f; + descriptor.m_ProjectionClip = 0.0f; + + descriptor.m_InputIntermediateScale = 0.00001f; + descriptor.m_ForgetIntermediateScale = 0.00001f; + descriptor.m_CellIntermediateScale = 0.00001f; + descriptor.m_OutputIntermediateScale = 0.00001f; + + descriptor.m_HiddenStateScale = 0.07f; + descriptor.m_HiddenStateZeroPoint = 0; + + const unsigned int numBatches = 2; + const unsigned int inputSize = 5; + const unsigned int outputSize = 4; + const unsigned int numUnits = 4; + + // Scale/Offset quantization info + float inputScale = 0.0078f; + int32_t inputOffset = 0; + + float outputScale = 0.0078f; + int32_t outputOffset = 0; + + float cellStateScale = 3.5002e-05f; + int32_t cellStateOffset = 0; + + float weightsScale = 0.007f; + int32_t weightsOffset = 0; + + float biasScale = 3.5002e-05f / 1024; + int32_t biasOffset = 0; + + // Weights and bias tensor and quantization info + armnn::TensorInfo inputWeightsInfo({numUnits, inputSize}, + armnn::DataType::QSymmS8, + weightsScale, + weightsOffset); + + armnn::TensorInfo recurrentWeightsInfo({numUnits, outputSize}, + armnn::DataType::QSymmS8, + weightsScale, + weightsOffset); + + armnn::TensorInfo biasInfo({numUnits}, armnn::DataType::Signed32, biasScale, biasOffset); + + std::vector inputToForgetWeightsData = GenerateRandomData(inputWeightsInfo.GetNumElements()); + std::vector inputToCellWeightsData = GenerateRandomData(inputWeightsInfo.GetNumElements()); + std::vector inputToOutputWeightsData = GenerateRandomData(inputWeightsInfo.GetNumElements()); + + armnn::ConstTensor inputToForgetWeights(inputWeightsInfo, inputToForgetWeightsData); + armnn::ConstTensor inputToCellWeights(inputWeightsInfo, inputToCellWeightsData); + armnn::ConstTensor inputToOutputWeights(inputWeightsInfo, inputToOutputWeightsData); + + std::vector recurrentToForgetWeightsData = + GenerateRandomData(recurrentWeightsInfo.GetNumElements()); + std::vector recurrentToCellWeightsData = + GenerateRandomData(recurrentWeightsInfo.GetNumElements()); + std::vector recurrentToOutputWeightsData = + GenerateRandomData(recurrentWeightsInfo.GetNumElements()); + + armnn::ConstTensor recurrentToForgetWeights(recurrentWeightsInfo, recurrentToForgetWeightsData); + armnn::ConstTensor recurrentToCellWeights(recurrentWeightsInfo, recurrentToCellWeightsData); + armnn::ConstTensor recurrentToOutputWeights(recurrentWeightsInfo, recurrentToOutputWeightsData); + + std::vector forgetGateBiasData(numUnits, 1); + std::vector cellBiasData(numUnits, 0); + std::vector outputGateBiasData(numUnits, 0); + + armnn::ConstTensor forgetGateBias(biasInfo, forgetGateBiasData); + armnn::ConstTensor cellBias(biasInfo, cellBiasData); + armnn::ConstTensor outputGateBias(biasInfo, outputGateBiasData); + + // Set up params + armnn::LstmInputParams params; + params.m_InputToForgetWeights = &inputToForgetWeights; + params.m_InputToCellWeights = &inputToCellWeights; + params.m_InputToOutputWeights = &inputToOutputWeights; + + params.m_RecurrentToForgetWeights = &recurrentToForgetWeights; + params.m_RecurrentToCellWeights = &recurrentToCellWeights; + params.m_RecurrentToOutputWeights = &recurrentToOutputWeights; + + params.m_ForgetGateBias = &forgetGateBias; + params.m_CellBias = &cellBias; + params.m_OutputGateBias = &outputGateBias; + + // Create network + armnn::INetworkPtr network = armnn::INetwork::Create(); + const std::string layerName("qLstm"); + + armnn::IConnectableLayer* const input = network->AddInputLayer(0); + armnn::IConnectableLayer* const outputStateIn = network->AddInputLayer(1); + armnn::IConnectableLayer* const cellStateIn = network->AddInputLayer(2); + + armnn::IConnectableLayer* const qLstmLayer = network->AddQLstmLayer(descriptor, params, layerName.c_str()); + + armnn::IConnectableLayer* const outputStateOut = network->AddOutputLayer(0); + armnn::IConnectableLayer* const cellStateOut = network->AddOutputLayer(1); + armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(2); + + // Input/Output tensor info + armnn::TensorInfo inputInfo({numBatches , inputSize}, + armnn::DataType::QAsymmS8, + inputScale, + inputOffset); + + armnn::TensorInfo cellStateInfo({numBatches , numUnits}, + armnn::DataType::QSymmS16, + cellStateScale, + cellStateOffset); + + armnn::TensorInfo outputStateInfo({numBatches , outputSize}, + armnn::DataType::QAsymmS8, + outputScale, + outputOffset); + + // Connect input/output slots + input->GetOutputSlot(0).Connect(qLstmLayer->GetInputSlot(0)); + input->GetOutputSlot(0).SetTensorInfo(inputInfo); + + outputStateIn->GetOutputSlot(0).Connect(qLstmLayer->GetInputSlot(1)); + outputStateIn->GetOutputSlot(0).SetTensorInfo(cellStateInfo); + + cellStateIn->GetOutputSlot(0).Connect(qLstmLayer->GetInputSlot(2)); + cellStateIn->GetOutputSlot(0).SetTensorInfo(outputStateInfo); + + qLstmLayer->GetOutputSlot(0).Connect(outputStateOut->GetInputSlot(0)); + qLstmLayer->GetOutputSlot(0).SetTensorInfo(outputStateInfo); + + qLstmLayer->GetOutputSlot(1).Connect(cellStateOut->GetInputSlot(0)); + qLstmLayer->GetOutputSlot(1).SetTensorInfo(cellStateInfo); + + qLstmLayer->GetOutputSlot(2).Connect(outputLayer->GetInputSlot(0)); + qLstmLayer->GetOutputSlot(2).SetTensorInfo(outputStateInfo); + + armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); + BOOST_CHECK(deserializedNetwork); + + VerifyQLstmLayer checker(layerName, + {inputInfo, cellStateInfo, outputStateInfo}, + {outputStateInfo, cellStateInfo, outputStateInfo}, + descriptor, + params); + + deserializedNetwork->Accept(checker); +} + +BOOST_AUTO_TEST_CASE(SerializeDeserializeQLstmCifgLayerNorm) +{ + armnn::QLstmDescriptor descriptor; + + // CIFG params are used when CIFG is disabled + descriptor.m_CifgEnabled = true; + descriptor.m_ProjectionEnabled = false; + descriptor.m_PeepholeEnabled = false; + descriptor.m_LayerNormEnabled = true; + + descriptor.m_CellClip = 0.0f; + descriptor.m_ProjectionClip = 0.0f; + + descriptor.m_InputIntermediateScale = 0.00001f; + descriptor.m_ForgetIntermediateScale = 0.00001f; + descriptor.m_CellIntermediateScale = 0.00001f; + descriptor.m_OutputIntermediateScale = 0.00001f; + + descriptor.m_HiddenStateScale = 0.07f; + descriptor.m_HiddenStateZeroPoint = 0; + + const unsigned int numBatches = 2; + const unsigned int inputSize = 5; + const unsigned int outputSize = 4; + const unsigned int numUnits = 4; + + // Scale/Offset quantization info + float inputScale = 0.0078f; + int32_t inputOffset = 0; + + float outputScale = 0.0078f; + int32_t outputOffset = 0; + + float cellStateScale = 3.5002e-05f; + int32_t cellStateOffset = 0; + + float weightsScale = 0.007f; + int32_t weightsOffset = 0; + + float layerNormScale = 3.5002e-05f; + int32_t layerNormOffset = 0; + + float biasScale = layerNormScale / 1024; + int32_t biasOffset = 0; + + // Weights and bias tensor and quantization info + armnn::TensorInfo inputWeightsInfo({numUnits, inputSize}, + armnn::DataType::QSymmS8, + weightsScale, + weightsOffset); + + armnn::TensorInfo recurrentWeightsInfo({numUnits, outputSize}, + armnn::DataType::QSymmS8, + weightsScale, + weightsOffset); + + armnn::TensorInfo biasInfo({numUnits}, + armnn::DataType::Signed32, + biasScale, + biasOffset); + + armnn::TensorInfo layerNormWeightsInfo({numUnits}, + armnn::DataType::QSymmS16, + layerNormScale, + layerNormOffset); + + // Mandatory params + std::vector inputToForgetWeightsData = GenerateRandomData(inputWeightsInfo.GetNumElements()); + std::vector inputToCellWeightsData = GenerateRandomData(inputWeightsInfo.GetNumElements()); + std::vector inputToOutputWeightsData = GenerateRandomData(inputWeightsInfo.GetNumElements()); + + armnn::ConstTensor inputToForgetWeights(inputWeightsInfo, inputToForgetWeightsData); + armnn::ConstTensor inputToCellWeights(inputWeightsInfo, inputToCellWeightsData); + armnn::ConstTensor inputToOutputWeights(inputWeightsInfo, inputToOutputWeightsData); + + std::vector recurrentToForgetWeightsData = + GenerateRandomData(recurrentWeightsInfo.GetNumElements()); + std::vector recurrentToCellWeightsData = + GenerateRandomData(recurrentWeightsInfo.GetNumElements()); + std::vector recurrentToOutputWeightsData = + GenerateRandomData(recurrentWeightsInfo.GetNumElements()); + + armnn::ConstTensor recurrentToForgetWeights(recurrentWeightsInfo, recurrentToForgetWeightsData); + armnn::ConstTensor recurrentToCellWeights(recurrentWeightsInfo, recurrentToCellWeightsData); + armnn::ConstTensor recurrentToOutputWeights(recurrentWeightsInfo, recurrentToOutputWeightsData); + + std::vector forgetGateBiasData(numUnits, 1); + std::vector cellBiasData(numUnits, 0); + std::vector outputGateBiasData(numUnits, 0); + + armnn::ConstTensor forgetGateBias(biasInfo, forgetGateBiasData); + armnn::ConstTensor cellBias(biasInfo, cellBiasData); + armnn::ConstTensor outputGateBias(biasInfo, outputGateBiasData); + + // Layer Norm + std::vector forgetLayerNormWeightsData = + GenerateRandomData(layerNormWeightsInfo.GetNumElements()); + std::vector cellLayerNormWeightsData = + GenerateRandomData(layerNormWeightsInfo.GetNumElements()); + std::vector outputLayerNormWeightsData = + GenerateRandomData(layerNormWeightsInfo.GetNumElements()); + + armnn::ConstTensor forgetLayerNormWeights(layerNormWeightsInfo, forgetLayerNormWeightsData); + armnn::ConstTensor cellLayerNormWeights(layerNormWeightsInfo, cellLayerNormWeightsData); + armnn::ConstTensor outputLayerNormWeights(layerNormWeightsInfo, outputLayerNormWeightsData); + + // Set up params + armnn::LstmInputParams params; + + // Mandatory params + params.m_InputToForgetWeights = &inputToForgetWeights; + params.m_InputToCellWeights = &inputToCellWeights; + params.m_InputToOutputWeights = &inputToOutputWeights; + + params.m_RecurrentToForgetWeights = &recurrentToForgetWeights; + params.m_RecurrentToCellWeights = &recurrentToCellWeights; + params.m_RecurrentToOutputWeights = &recurrentToOutputWeights; + + params.m_ForgetGateBias = &forgetGateBias; + params.m_CellBias = &cellBias; + params.m_OutputGateBias = &outputGateBias; + + // Layer Norm + params.m_ForgetLayerNormWeights = &forgetLayerNormWeights; + params.m_CellLayerNormWeights = &cellLayerNormWeights; + params.m_OutputLayerNormWeights = &outputLayerNormWeights; + + // Create network + armnn::INetworkPtr network = armnn::INetwork::Create(); + const std::string layerName("qLstm"); + + armnn::IConnectableLayer* const input = network->AddInputLayer(0); + armnn::IConnectableLayer* const outputStateIn = network->AddInputLayer(1); + armnn::IConnectableLayer* const cellStateIn = network->AddInputLayer(2); + + armnn::IConnectableLayer* const qLstmLayer = network->AddQLstmLayer(descriptor, params, layerName.c_str()); + + armnn::IConnectableLayer* const outputStateOut = network->AddOutputLayer(0); + armnn::IConnectableLayer* const cellStateOut = network->AddOutputLayer(1); + armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(2); + + // Input/Output tensor info + armnn::TensorInfo inputInfo({numBatches , inputSize}, + armnn::DataType::QAsymmS8, + inputScale, + inputOffset); + + armnn::TensorInfo cellStateInfo({numBatches , numUnits}, + armnn::DataType::QSymmS16, + cellStateScale, + cellStateOffset); + + armnn::TensorInfo outputStateInfo({numBatches , outputSize}, + armnn::DataType::QAsymmS8, + outputScale, + outputOffset); + + // Connect input/output slots + input->GetOutputSlot(0).Connect(qLstmLayer->GetInputSlot(0)); + input->GetOutputSlot(0).SetTensorInfo(inputInfo); + + outputStateIn->GetOutputSlot(0).Connect(qLstmLayer->GetInputSlot(1)); + outputStateIn->GetOutputSlot(0).SetTensorInfo(cellStateInfo); + + cellStateIn->GetOutputSlot(0).Connect(qLstmLayer->GetInputSlot(2)); + cellStateIn->GetOutputSlot(0).SetTensorInfo(outputStateInfo); + + qLstmLayer->GetOutputSlot(0).Connect(outputStateOut->GetInputSlot(0)); + qLstmLayer->GetOutputSlot(0).SetTensorInfo(outputStateInfo); + + qLstmLayer->GetOutputSlot(1).Connect(cellStateOut->GetInputSlot(0)); + qLstmLayer->GetOutputSlot(1).SetTensorInfo(cellStateInfo); + + qLstmLayer->GetOutputSlot(2).Connect(outputLayer->GetInputSlot(0)); + qLstmLayer->GetOutputSlot(2).SetTensorInfo(outputStateInfo); + + armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); + BOOST_CHECK(deserializedNetwork); + + VerifyQLstmLayer checker(layerName, + {inputInfo, cellStateInfo, outputStateInfo}, + {outputStateInfo, cellStateInfo, outputStateInfo}, + descriptor, + params); + + deserializedNetwork->Accept(checker); +} + +BOOST_AUTO_TEST_CASE(SerializeDeserializeQLstmAdvanced) +{ + armnn::QLstmDescriptor descriptor; + + descriptor.m_CifgEnabled = false; + descriptor.m_ProjectionEnabled = true; + descriptor.m_PeepholeEnabled = true; + descriptor.m_LayerNormEnabled = true; + + descriptor.m_CellClip = 0.1f; + descriptor.m_ProjectionClip = 0.1f; + + descriptor.m_InputIntermediateScale = 0.00001f; + descriptor.m_ForgetIntermediateScale = 0.00001f; + descriptor.m_CellIntermediateScale = 0.00001f; + descriptor.m_OutputIntermediateScale = 0.00001f; + + descriptor.m_HiddenStateScale = 0.07f; + descriptor.m_HiddenStateZeroPoint = 0; + + const unsigned int numBatches = 2; + const unsigned int inputSize = 5; + const unsigned int outputSize = 4; + const unsigned int numUnits = 4; + + // Scale/Offset quantization info + float inputScale = 0.0078f; + int32_t inputOffset = 0; + + float outputScale = 0.0078f; + int32_t outputOffset = 0; + + float cellStateScale = 3.5002e-05f; + int32_t cellStateOffset = 0; + + float weightsScale = 0.007f; + int32_t weightsOffset = 0; + + float layerNormScale = 3.5002e-05f; + int32_t layerNormOffset = 0; + + float biasScale = layerNormScale / 1024; + int32_t biasOffset = 0; + + // Weights and bias tensor and quantization info + armnn::TensorInfo inputWeightsInfo({numUnits, inputSize}, + armnn::DataType::QSymmS8, + weightsScale, + weightsOffset); + + armnn::TensorInfo recurrentWeightsInfo({numUnits, outputSize}, + armnn::DataType::QSymmS8, + weightsScale, + weightsOffset); + + armnn::TensorInfo biasInfo({numUnits}, + armnn::DataType::Signed32, + biasScale, + biasOffset); + + armnn::TensorInfo peepholeWeightsInfo({numUnits}, + armnn::DataType::QSymmS16, + weightsScale, + weightsOffset); + + armnn::TensorInfo layerNormWeightsInfo({numUnits}, + armnn::DataType::QSymmS16, + layerNormScale, + layerNormOffset); + + armnn::TensorInfo projectionWeightsInfo({outputSize, numUnits}, + armnn::DataType::QSymmS8, + weightsScale, + weightsOffset); + + // Mandatory params + std::vector inputToForgetWeightsData = GenerateRandomData(inputWeightsInfo.GetNumElements()); + std::vector inputToCellWeightsData = GenerateRandomData(inputWeightsInfo.GetNumElements()); + std::vector inputToOutputWeightsData = GenerateRandomData(inputWeightsInfo.GetNumElements()); + + armnn::ConstTensor inputToForgetWeights(inputWeightsInfo, inputToForgetWeightsData); + armnn::ConstTensor inputToCellWeights(inputWeightsInfo, inputToCellWeightsData); + armnn::ConstTensor inputToOutputWeights(inputWeightsInfo, inputToOutputWeightsData); + + std::vector recurrentToForgetWeightsData = + GenerateRandomData(recurrentWeightsInfo.GetNumElements()); + std::vector recurrentToCellWeightsData = + GenerateRandomData(recurrentWeightsInfo.GetNumElements()); + std::vector recurrentToOutputWeightsData = + GenerateRandomData(recurrentWeightsInfo.GetNumElements()); + + armnn::ConstTensor recurrentToForgetWeights(recurrentWeightsInfo, recurrentToForgetWeightsData); + armnn::ConstTensor recurrentToCellWeights(recurrentWeightsInfo, recurrentToCellWeightsData); + armnn::ConstTensor recurrentToOutputWeights(recurrentWeightsInfo, recurrentToOutputWeightsData); + + std::vector forgetGateBiasData(numUnits, 1); + std::vector cellBiasData(numUnits, 0); + std::vector outputGateBiasData(numUnits, 0); + + armnn::ConstTensor forgetGateBias(biasInfo, forgetGateBiasData); + armnn::ConstTensor cellBias(biasInfo, cellBiasData); + armnn::ConstTensor outputGateBias(biasInfo, outputGateBiasData); + + // CIFG + std::vector inputToInputWeightsData = GenerateRandomData(inputWeightsInfo.GetNumElements()); + std::vector recurrentToInputWeightsData = + GenerateRandomData(recurrentWeightsInfo.GetNumElements()); + std::vector inputGateBiasData(numUnits, 1); + + armnn::ConstTensor inputToInputWeights(inputWeightsInfo, inputToInputWeightsData); + armnn::ConstTensor recurrentToInputWeights(recurrentWeightsInfo, recurrentToInputWeightsData); + armnn::ConstTensor inputGateBias(biasInfo, inputGateBiasData); + + // Peephole + std::vector cellToInputWeightsData = GenerateRandomData(peepholeWeightsInfo.GetNumElements()); + std::vector cellToForgetWeightsData = GenerateRandomData(peepholeWeightsInfo.GetNumElements()); + std::vector cellToOutputWeightsData = GenerateRandomData(peepholeWeightsInfo.GetNumElements()); + + armnn::ConstTensor cellToInputWeights(peepholeWeightsInfo, cellToInputWeightsData); + armnn::ConstTensor cellToForgetWeights(peepholeWeightsInfo, cellToForgetWeightsData); + armnn::ConstTensor cellToOutputWeights(peepholeWeightsInfo, cellToOutputWeightsData); + + // Projection + std::vector projectionWeightsData = GenerateRandomData(projectionWeightsInfo.GetNumElements()); + std::vector projectionBiasData(outputSize, 1); + + armnn::ConstTensor projectionWeights(projectionWeightsInfo, projectionWeightsData); + armnn::ConstTensor projectionBias(biasInfo, projectionBiasData); + + // Layer Norm + std::vector inputLayerNormWeightsData = + GenerateRandomData(layerNormWeightsInfo.GetNumElements()); + std::vector forgetLayerNormWeightsData = + GenerateRandomData(layerNormWeightsInfo.GetNumElements()); + std::vector cellLayerNormWeightsData = + GenerateRandomData(layerNormWeightsInfo.GetNumElements()); + std::vector outputLayerNormWeightsData = + GenerateRandomData(layerNormWeightsInfo.GetNumElements()); + + armnn::ConstTensor inputLayerNormWeights(layerNormWeightsInfo, inputLayerNormWeightsData); + armnn::ConstTensor forgetLayerNormWeights(layerNormWeightsInfo, forgetLayerNormWeightsData); + armnn::ConstTensor cellLayerNormWeights(layerNormWeightsInfo, cellLayerNormWeightsData); + armnn::ConstTensor outputLayerNormWeights(layerNormWeightsInfo, outputLayerNormWeightsData); + + // Set up params + armnn::LstmInputParams params; + + // Mandatory params + params.m_InputToForgetWeights = &inputToForgetWeights; + params.m_InputToCellWeights = &inputToCellWeights; + params.m_InputToOutputWeights = &inputToOutputWeights; + + params.m_RecurrentToForgetWeights = &recurrentToForgetWeights; + params.m_RecurrentToCellWeights = &recurrentToCellWeights; + params.m_RecurrentToOutputWeights = &recurrentToOutputWeights; + + params.m_ForgetGateBias = &forgetGateBias; + params.m_CellBias = &cellBias; + params.m_OutputGateBias = &outputGateBias; + + // CIFG + params.m_InputToInputWeights = &inputToInputWeights; + params.m_RecurrentToInputWeights = &recurrentToInputWeights; + params.m_InputGateBias = &inputGateBias; + + // Peephole + params.m_CellToInputWeights = &cellToInputWeights; + params.m_CellToForgetWeights = &cellToForgetWeights; + params.m_CellToOutputWeights = &cellToOutputWeights; + + // Projection + params.m_ProjectionWeights = &projectionWeights; + params.m_ProjectionBias = &projectionBias; + + // Layer Norm + params.m_InputLayerNormWeights = &inputLayerNormWeights; + params.m_ForgetLayerNormWeights = &forgetLayerNormWeights; + params.m_CellLayerNormWeights = &cellLayerNormWeights; + params.m_OutputLayerNormWeights = &outputLayerNormWeights; + + // Create network + armnn::INetworkPtr network = armnn::INetwork::Create(); + const std::string layerName("qLstm"); + + armnn::IConnectableLayer* const input = network->AddInputLayer(0); + armnn::IConnectableLayer* const outputStateIn = network->AddInputLayer(1); + armnn::IConnectableLayer* const cellStateIn = network->AddInputLayer(2); + + armnn::IConnectableLayer* const qLstmLayer = network->AddQLstmLayer(descriptor, params, layerName.c_str()); + + armnn::IConnectableLayer* const outputStateOut = network->AddOutputLayer(0); + armnn::IConnectableLayer* const cellStateOut = network->AddOutputLayer(1); + armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(2); + + // Input/Output tensor info + armnn::TensorInfo inputInfo({numBatches , inputSize}, + armnn::DataType::QAsymmS8, + inputScale, + inputOffset); + + armnn::TensorInfo cellStateInfo({numBatches , numUnits}, + armnn::DataType::QSymmS16, + cellStateScale, + cellStateOffset); + + armnn::TensorInfo outputStateInfo({numBatches , outputSize}, + armnn::DataType::QAsymmS8, + outputScale, + outputOffset); + + // Connect input/output slots + input->GetOutputSlot(0).Connect(qLstmLayer->GetInputSlot(0)); + input->GetOutputSlot(0).SetTensorInfo(inputInfo); + + outputStateIn->GetOutputSlot(0).Connect(qLstmLayer->GetInputSlot(1)); + outputStateIn->GetOutputSlot(0).SetTensorInfo(cellStateInfo); + + cellStateIn->GetOutputSlot(0).Connect(qLstmLayer->GetInputSlot(2)); + cellStateIn->GetOutputSlot(0).SetTensorInfo(outputStateInfo); + + qLstmLayer->GetOutputSlot(0).Connect(outputStateOut->GetInputSlot(0)); + qLstmLayer->GetOutputSlot(0).SetTensorInfo(outputStateInfo); + + qLstmLayer->GetOutputSlot(1).Connect(cellStateOut->GetInputSlot(0)); + qLstmLayer->GetOutputSlot(1).SetTensorInfo(cellStateInfo); + + qLstmLayer->GetOutputSlot(2).Connect(outputLayer->GetInputSlot(0)); + qLstmLayer->GetOutputSlot(2).SetTensorInfo(outputStateInfo); + + armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); + BOOST_CHECK(deserializedNetwork); + + VerifyQLstmLayer checker(layerName, + {inputInfo, cellStateInfo, outputStateInfo}, + {outputStateInfo, cellStateInfo, outputStateInfo}, + descriptor, + params); + + deserializedNetwork->Accept(checker); +} + BOOST_AUTO_TEST_SUITE_END() -- cgit v1.2.1