From 5b01a8994caea2857f3b991dc69a814f12ab7743 Mon Sep 17 00:00:00 2001 From: Jan Eilers Date: Tue, 23 Jul 2019 09:47:43 +0100 Subject: IVGCVSW-3471 Add Serialization support for Quantized_LSTM * Adds serialization/deserialization support * Adds related Unit test Signed-off-by: Jan Eilers Change-Id: Iaf271aa7d848bc3a69dbbf182389f2241c0ced5f --- src/armnnDeserializer/Deserializer.cpp | 57 ++++++ src/armnnDeserializer/Deserializer.hpp | 2 + src/armnnDeserializer/DeserializerSupport.md | 1 + src/armnnSerializer/ArmnnSchema.fbs | 26 ++- src/armnnSerializer/Serializer.cpp | 40 +++- src/armnnSerializer/SerializerSupport.md | 1 + src/armnnSerializer/test/SerializerTests.cpp | 275 ++++++++++++++++++++++----- 7 files changed, 357 insertions(+), 45 deletions(-) diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp index 47ed3a65ed..ef1235745c 100644 --- a/src/armnnDeserializer/Deserializer.cpp +++ b/src/armnnDeserializer/Deserializer.cpp @@ -215,6 +215,7 @@ m_ParserFunctions(Layer_MAX+1, &Deserializer::ParseUnsupportedLayer) m_ParserFunctions[Layer_Pooling2dLayer] = &Deserializer::ParsePooling2d; m_ParserFunctions[Layer_PreluLayer] = &Deserializer::ParsePrelu; m_ParserFunctions[Layer_QuantizeLayer] = &Deserializer::ParseQuantize; + m_ParserFunctions[Layer_QuantizedLstmLayer] = &Deserializer::ParseQuantizedLstm; m_ParserFunctions[Layer_ReshapeLayer] = &Deserializer::ParseReshape; m_ParserFunctions[Layer_ResizeBilinearLayer] = &Deserializer::ParseResizeBilinear; m_ParserFunctions[Layer_ResizeLayer] = &Deserializer::ParseResize; @@ -300,6 +301,8 @@ Deserializer::LayerBaseRawPtr Deserializer::GetBaseLayer(const GraphPtr& graphPt return graphPtr->layers()->Get(layerIndex)->layer_as_PreluLayer()->base(); case Layer::Layer_QuantizeLayer: return graphPtr->layers()->Get(layerIndex)->layer_as_QuantizeLayer()->base(); + case Layer::Layer_QuantizedLstmLayer: + return graphPtr->layers()->Get(layerIndex)->layer_as_QuantizedLstmLayer()->base(); case Layer::Layer_ReshapeLayer: return graphPtr->layers()->Get(layerIndex)->layer_as_ReshapeLayer()->base(); case Layer::Layer_ResizeBilinearLayer: @@ -2210,6 +2213,60 @@ void Deserializer::ParseLstm(GraphPtr graph, unsigned int layerIndex) RegisterOutputSlots(graph, layerIndex, layer); } +void Deserializer::ParseQuantizedLstm(GraphPtr graph, unsigned int layerIndex) +{ + CHECK_LAYERS(graph, 0, layerIndex); + + auto inputs = GetInputs(graph, layerIndex); + CHECK_VALID_SIZE(inputs.size(), 3); + + auto outputs = GetOutputs(graph, layerIndex); + CHECK_VALID_SIZE(outputs.size(), 2); + + auto flatBufferLayer = graph->layers()->Get(layerIndex)->layer_as_QuantizedLstmLayer(); + auto layerName = GetLayerName(graph, layerIndex); + auto flatBufferInputParams = flatBufferLayer->inputParams(); + + armnn::QuantizedLstmInputParams lstmInputParams; + + armnn::ConstTensor inputToInputWeights = ToConstTensor(flatBufferInputParams->inputToInputWeights()); + armnn::ConstTensor inputToForgetWeights = ToConstTensor(flatBufferInputParams->inputToForgetWeights()); + armnn::ConstTensor inputToCellWeights = ToConstTensor(flatBufferInputParams->inputToCellWeights()); + armnn::ConstTensor inputToOutputWeights = ToConstTensor(flatBufferInputParams->inputToOutputWeights()); + armnn::ConstTensor recurrentToInputWeights = ToConstTensor(flatBufferInputParams->recurrentToInputWeights()); + armnn::ConstTensor recurrentToForgetWeights = ToConstTensor(flatBufferInputParams->recurrentToForgetWeights()); + armnn::ConstTensor recurrentToCellWeights = ToConstTensor(flatBufferInputParams->recurrentToCellWeights()); + armnn::ConstTensor recurrentToOutputWeights = ToConstTensor(flatBufferInputParams->recurrentToOutputWeights()); + armnn::ConstTensor inputGateBias = ToConstTensor(flatBufferInputParams->inputGateBias()); + armnn::ConstTensor forgetGateBias = ToConstTensor(flatBufferInputParams->forgetGateBias()); + armnn::ConstTensor cellBias = ToConstTensor(flatBufferInputParams->cellBias()); + armnn::ConstTensor outputGateBias = ToConstTensor(flatBufferInputParams->outputGateBias()); + + lstmInputParams.m_InputToInputWeights = &inputToInputWeights; + lstmInputParams.m_InputToForgetWeights = &inputToForgetWeights; + lstmInputParams.m_InputToCellWeights = &inputToCellWeights; + lstmInputParams.m_InputToOutputWeights = &inputToOutputWeights; + lstmInputParams.m_RecurrentToInputWeights = &recurrentToInputWeights; + lstmInputParams.m_RecurrentToForgetWeights = &recurrentToForgetWeights; + lstmInputParams.m_RecurrentToCellWeights = &recurrentToCellWeights; + lstmInputParams.m_RecurrentToOutputWeights = &recurrentToOutputWeights; + lstmInputParams.m_InputGateBias = &inputGateBias; + lstmInputParams.m_ForgetGateBias = &forgetGateBias; + lstmInputParams.m_CellBias = &cellBias; + lstmInputParams.m_OutputGateBias = &outputGateBias; + + IConnectableLayer* layer = m_Network->AddQuantizedLstmLayer(lstmInputParams, layerName.c_str()); + + armnn::TensorInfo outputTensorInfo1 = ToTensorInfo(outputs[0]); + layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo1); + + armnn::TensorInfo outputTensorInfo2 = ToTensorInfo(outputs[1]); + layer->GetOutputSlot(1).SetTensorInfo(outputTensorInfo2); + + RegisterInputSlots(graph, layerIndex, layer); + RegisterOutputSlots(graph, layerIndex, layer); +} + void Deserializer::ParseDequantize(GraphPtr graph, unsigned int layerIndex) { CHECK_LAYERS(graph, 0, layerIndex); diff --git a/src/armnnDeserializer/Deserializer.hpp b/src/armnnDeserializer/Deserializer.hpp index b9d6a170a1..591447de21 100644 --- a/src/armnnDeserializer/Deserializer.hpp +++ b/src/armnnDeserializer/Deserializer.hpp @@ -24,6 +24,7 @@ public: using NormalizationDescriptorPtr = const armnnSerializer::NormalizationDescriptor *; using LstmDescriptorPtr = const armnnSerializer::LstmDescriptor *; using LstmInputParamsPtr = const armnnSerializer::LstmInputParams *; + using QunatizedLstmInputParamsPtr = const armnnSerializer::QuantizedLstmInputParams *; using TensorRawPtrVector = std::vector; using LayerRawPtr = const armnnSerializer::LayerBase *; using LayerBaseRawPtr = const armnnSerializer::LayerBase *; @@ -100,6 +101,7 @@ private: void ParseMultiplication(GraphPtr graph, unsigned int layerIndex); void ParseNormalization(GraphPtr graph, unsigned int layerIndex); void ParseLstm(GraphPtr graph, unsigned int layerIndex); + void ParseQuantizedLstm(GraphPtr graph, unsigned int layerIndex); void ParsePad(GraphPtr graph, unsigned int layerIndex); void ParsePermute(GraphPtr graph, unsigned int layerIndex); void ParsePooling2d(GraphPtr graph, unsigned int layerIndex); diff --git a/src/armnnDeserializer/DeserializerSupport.md b/src/armnnDeserializer/DeserializerSupport.md index 698340bd31..1bda123284 100644 --- a/src/armnnDeserializer/DeserializerSupport.md +++ b/src/armnnDeserializer/DeserializerSupport.md @@ -35,6 +35,7 @@ The Arm NN SDK Deserialize parser currently supports the following layers: * Pooling2d * Prelu * Quantize +* QuantizedLstm * Reshape * ResizeBilinear * Rsqrt diff --git a/src/armnnSerializer/ArmnnSchema.fbs b/src/armnnSerializer/ArmnnSchema.fbs index 0fd8da7e8f..513c74e82d 100644 --- a/src/armnnSerializer/ArmnnSchema.fbs +++ b/src/armnnSerializer/ArmnnSchema.fbs @@ -131,7 +131,8 @@ enum LayerType : uint { Prelu = 41, TransposeConvolution2d = 42, Resize = 43, - Stack = 44 + Stack = 44, + QuantizedLstm = 45 } // Base layer table to be used as part of other layers @@ -544,6 +545,23 @@ table LstmInputParams { outputLayerNormWeights:ConstTensor; } +table QuantizedLstmInputParams { + inputToInputWeights:ConstTensor; + inputToForgetWeights:ConstTensor; + inputToCellWeights:ConstTensor; + inputToOutputWeights:ConstTensor; + + recurrentToInputWeights:ConstTensor; + recurrentToForgetWeights:ConstTensor; + recurrentToCellWeights:ConstTensor; + recurrentToOutputWeights:ConstTensor; + + inputGateBias:ConstTensor; + forgetGateBias:ConstTensor; + cellBias:ConstTensor; + outputGateBias:ConstTensor; +} + table LstmDescriptor { activationFunc:uint; clippingThresCell:float; @@ -560,6 +578,11 @@ table LstmLayer { inputParams:LstmInputParams; } +table QuantizedLstmLayer { + base:LayerBase; + inputParams:QuantizedLstmInputParams; +} + table DequantizeLayer { base:LayerBase; } @@ -653,6 +676,7 @@ union Layer { SplitterLayer, DetectionPostProcessLayer, LstmLayer, + QuantizedLstmLayer, QuantizeLayer, DequantizeLayer, MergeLayer, diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp index 67c2f053e6..af4dc7a926 100644 --- a/src/armnnSerializer/Serializer.cpp +++ b/src/armnnSerializer/Serializer.cpp @@ -1046,7 +1046,45 @@ void SerializerVisitor::VisitQuantizedLstmLayer(const armnn::IConnectableLayer* const armnn::QuantizedLstmInputParams& params, const char* name) { - throw UnimplementedException("SerializerVisitor::VisitQuantizedLstmLayer not yet implemented"); + auto fbQuantizedLstmBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_QuantizedLstm); + + // Get input parameters + auto inputToInputWeights = CreateConstTensorInfo(params.get_InputToInputWeights()); + auto inputToForgetWeights = CreateConstTensorInfo(params.get_InputToForgetWeights()); + auto inputToCellWeights = CreateConstTensorInfo(params.get_InputToCellWeights()); + auto inputToOutputWeights = CreateConstTensorInfo(params.get_InputToOutputWeights()); + + auto recurrentToInputWeights = CreateConstTensorInfo(params.get_RecurrentToInputWeights()); + auto recurrentToForgetWeights = CreateConstTensorInfo(params.get_RecurrentToForgetWeights()); + auto recurrentToCellWeights = CreateConstTensorInfo(params.get_RecurrentToCellWeights()); + auto recurrentToOutputWeights = CreateConstTensorInfo(params.get_RecurrentToOutputWeights()); + + auto inputGateBias = CreateConstTensorInfo(params.get_InputGateBias()); + auto forgetGateBias = CreateConstTensorInfo(params.get_ForgetGateBias()); + auto cellBias = CreateConstTensorInfo(params.get_CellBias()); + auto outputGateBias = CreateConstTensorInfo(params.get_OutputGateBias()); + + auto fbQuantizedLstmParams = serializer::CreateQuantizedLstmInputParams( + m_flatBufferBuilder, + inputToInputWeights, + inputToForgetWeights, + inputToCellWeights, + inputToOutputWeights, + recurrentToInputWeights, + recurrentToForgetWeights, + recurrentToCellWeights, + recurrentToOutputWeights, + inputGateBias, + forgetGateBias, + cellBias, + outputGateBias); + + auto fbQuantizedLstmLayer = serializer::CreateQuantizedLstmLayer( + m_flatBufferBuilder, + fbQuantizedLstmBaseLayer, + fbQuantizedLstmParams); + + CreateAnyLayer(fbQuantizedLstmLayer.o, serializer::Layer::Layer_QuantizedLstmLayer); } fb::Offset SerializerVisitor::CreateLayerBase(const IConnectableLayer* layer, diff --git a/src/armnnSerializer/SerializerSupport.md b/src/armnnSerializer/SerializerSupport.md index 1d47399376..39d96f79a9 100644 --- a/src/armnnSerializer/SerializerSupport.md +++ b/src/armnnSerializer/SerializerSupport.md @@ -35,6 +35,7 @@ The Arm NN SDK Serializer currently supports the following layers: * Pooling2d * Prelu * Quantize +* QuantizedLstm * Reshape * ResizeBilinear * Rsqrt diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp index 74ad2db4e5..79d83f054f 100644 --- a/src/armnnSerializer/test/SerializerTests.cpp +++ b/src/armnnSerializer/test/SerializerTests.cpp @@ -76,6 +76,50 @@ protected: } } + void VerifyConstTensors(const std::string& tensorName, + const armnn::ConstTensor* expectedPtr, + const armnn::ConstTensor* actualPtr) + { + if (expectedPtr == nullptr) + { + BOOST_CHECK_MESSAGE(actualPtr == nullptr, tensorName + " should not exist"); + } + else + { + BOOST_CHECK_MESSAGE(actualPtr != nullptr, tensorName + " should have been set"); + if (actualPtr != nullptr) + { + const armnn::TensorInfo& expectedInfo = expectedPtr->GetInfo(); + const armnn::TensorInfo& actualInfo = actualPtr->GetInfo(); + + BOOST_CHECK_MESSAGE(expectedInfo.GetShape() == actualInfo.GetShape(), + tensorName + " shapes don't match"); + BOOST_CHECK_MESSAGE( + GetDataTypeName(expectedInfo.GetDataType()) == GetDataTypeName(actualInfo.GetDataType()), + tensorName + " data types don't match"); + + BOOST_CHECK_MESSAGE(expectedPtr->GetNumBytes() == actualPtr->GetNumBytes(), + tensorName + " (GetNumBytes) data sizes do not match"); + if (expectedPtr->GetNumBytes() == actualPtr->GetNumBytes()) + { + //check the data is identical + const char* expectedData = static_cast(expectedPtr->GetMemoryArea()); + const char* actualData = static_cast(actualPtr->GetMemoryArea()); + bool same = true; + for (unsigned int i = 0; i < expectedPtr->GetNumBytes(); ++i) + { + same = expectedData[i] == actualData[i]; + if (!same) + { + break; + } + } + BOOST_CHECK_MESSAGE(same, tensorName + " data does not match"); + } + } + } + } + private: std::string m_LayerName; std::vector m_InputTensorInfos; @@ -2825,49 +2869,6 @@ protected: VerifyConstTensors( "m_OutputLayerNormWeights", m_InputParams.m_OutputLayerNormWeights, params.m_OutputLayerNormWeights); } - void VerifyConstTensors(const std::string& tensorName, - const armnn::ConstTensor* expectedPtr, - const armnn::ConstTensor* actualPtr) - { - if (expectedPtr == nullptr) - { - BOOST_CHECK_MESSAGE(actualPtr == nullptr, tensorName + " should not exist"); - } - else - { - BOOST_CHECK_MESSAGE(actualPtr != nullptr, tensorName + " should have been set"); - if (actualPtr != nullptr) - { - const armnn::TensorInfo& expectedInfo = expectedPtr->GetInfo(); - const armnn::TensorInfo& actualInfo = actualPtr->GetInfo(); - - BOOST_CHECK_MESSAGE(expectedInfo.GetShape() == actualInfo.GetShape(), - tensorName + " shapes don't match"); - BOOST_CHECK_MESSAGE( - GetDataTypeName(expectedInfo.GetDataType()) == GetDataTypeName(actualInfo.GetDataType()), - tensorName + " data types don't match"); - - BOOST_CHECK_MESSAGE(expectedPtr->GetNumBytes() == actualPtr->GetNumBytes(), - tensorName + " (GetNumBytes) data sizes do not match"); - if (expectedPtr->GetNumBytes() == actualPtr->GetNumBytes()) - { - //check the data is identical - const char* expectedData = static_cast(expectedPtr->GetMemoryArea()); - const char* actualData = static_cast(actualPtr->GetMemoryArea()); - bool same = true; - for (unsigned int i = 0; i < expectedPtr->GetNumBytes(); ++i) - { - same = expectedData[i] == actualData[i]; - if (!same) - { - break; - } - } - BOOST_CHECK_MESSAGE(same, tensorName + " data does not match"); - } - } - } - } private: armnn::LstmDescriptor m_Descriptor; armnn::LstmInputParams m_InputParams; @@ -3972,4 +3973,192 @@ BOOST_AUTO_TEST_CASE(EnsureLstmLayersBackwardCompatibility) deserializedNetwork->Accept(checker); } +class VerifyQuantizedLstmLayer : public LayerVerifierBase +{ + +public: + VerifyQuantizedLstmLayer(const std::string& layerName, + const std::vector& inputInfos, + const std::vector& outputInfos, + const armnn::QuantizedLstmInputParams& inputParams) : + LayerVerifierBase(layerName, inputInfos, outputInfos), m_InputParams(inputParams) + { + } + + void VisitQuantizedLstmLayer(const armnn::IConnectableLayer* layer, + const armnn::QuantizedLstmInputParams& params, + const char* name) + { + VerifyNameAndConnections(layer, name); + VerifyInputParameters(params); + } + +protected: + void VerifyInputParameters(const armnn::QuantizedLstmInputParams& params) + { + VerifyConstTensors("m_InputToInputWeights", + m_InputParams.m_InputToInputWeights, params.m_InputToInputWeights); + VerifyConstTensors("m_InputToForgetWeights", + m_InputParams.m_InputToForgetWeights, params.m_InputToForgetWeights); + VerifyConstTensors("m_InputToCellWeights", + m_InputParams.m_InputToCellWeights, params.m_InputToCellWeights); + VerifyConstTensors("m_InputToOutputWeights", + m_InputParams.m_InputToOutputWeights, params.m_InputToOutputWeights); + VerifyConstTensors("m_RecurrentToInputWeights", + m_InputParams.m_RecurrentToInputWeights, params.m_RecurrentToInputWeights); + VerifyConstTensors("m_RecurrentToForgetWeights", + m_InputParams.m_RecurrentToForgetWeights, params.m_RecurrentToForgetWeights); + VerifyConstTensors("m_RecurrentToCellWeights", + m_InputParams.m_RecurrentToCellWeights, params.m_RecurrentToCellWeights); + VerifyConstTensors("m_RecurrentToOutputWeights", + m_InputParams.m_RecurrentToOutputWeights, params.m_RecurrentToOutputWeights); + VerifyConstTensors("m_InputGateBias", + m_InputParams.m_InputGateBias, params.m_InputGateBias); + VerifyConstTensors("m_ForgetGateBias", + m_InputParams.m_ForgetGateBias, params.m_ForgetGateBias); + VerifyConstTensors("m_CellBias", + m_InputParams.m_CellBias, params.m_CellBias); + VerifyConstTensors("m_OutputGateBias", + m_InputParams.m_OutputGateBias, params.m_OutputGateBias); + } + +private: + armnn::QuantizedLstmInputParams m_InputParams; +}; + +BOOST_AUTO_TEST_CASE(SerializeDeserializeQuantizedLstm) +{ + const uint32_t batchSize = 1; + const uint32_t inputSize = 2; + const uint32_t numUnits = 4; + const uint32_t outputSize = numUnits; + + std::vector inputToInputWeightsData = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + std::vector inputToInputWeightsDimensions = {1, 1, 3, 3}; + armnn::ConstTensor inputToInputWeights(armnn::TensorInfo( + 4, inputToInputWeightsDimensions.data(), + armnn::DataType::QuantisedAsymm8), inputToInputWeightsData); + + std::vector inputToForgetWeightsData = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + std::vector inputToForgetWeightsDimensions = {1, 1, 3, 3}; + armnn::ConstTensor inputToForgetWeights(armnn::TensorInfo( + 4, inputToForgetWeightsDimensions.data(), + armnn::DataType::QuantisedAsymm8), inputToForgetWeightsData); + + std::vector inputToCellWeightsData = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + std::vector inputToCellWeightsDimensions = {1, 1, 3, 3}; + armnn::ConstTensor inputToCellWeights(armnn::TensorInfo( + 4, inputToCellWeightsDimensions.data(), + armnn::DataType::QuantisedAsymm8), inputToCellWeightsData); + + std::vector inputToOutputWeightsData = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + std::vector inputToOutputWeightsDimensions = {1, 1, 3, 3}; + armnn::ConstTensor inputToOutputWeights(armnn::TensorInfo( + 4, inputToOutputWeightsDimensions.data(), + armnn::DataType::QuantisedAsymm8), inputToOutputWeightsData); + + std::vector recurrentToInputWeightsData = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + std::vector recurrentToInputWeightsDimensions = {1, 1, 3, 3}; + armnn::ConstTensor recurrentToInputWeights(armnn::TensorInfo( + 4, recurrentToInputWeightsDimensions.data(), + armnn::DataType::QuantisedAsymm8), recurrentToInputWeightsData); + + std::vector recurrentToForgetWeightsData = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + std::vector recurrentToForgetWeightsDimensions = {1, 1, 3, 3}; + armnn::ConstTensor recurrentToForgetWeights(armnn::TensorInfo( + 4, recurrentToForgetWeightsDimensions.data(), + armnn::DataType::QuantisedAsymm8), recurrentToForgetWeightsData); + + std::vector recurrentToCellWeightsData = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + std::vector recurrentToCellWeightsDimensions = {1, 1, 3, 3}; + armnn::ConstTensor recurrentToCellWeights(armnn::TensorInfo( + 4, recurrentToCellWeightsDimensions.data(), + armnn::DataType::QuantisedAsymm8), recurrentToCellWeightsData); + + std::vector recurrentToOutputWeightsData = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + std::vector recurrentToOutputWeightsDimensions = {1, 1, 3, 3}; + armnn::ConstTensor recurrentToOutputWeights(armnn::TensorInfo( + 4, recurrentToOutputWeightsDimensions.data(), + armnn::DataType::QuantisedAsymm8), recurrentToOutputWeightsData); + + + std::vector inputGateBiasData = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + std::vector inputGateBiasDimensions = {1, 1, 3, 3}; + armnn::ConstTensor inputGateBias(armnn::TensorInfo( + 4, inputGateBiasDimensions.data(), + armnn::DataType::Signed32), inputGateBiasData); + + std::vector forgetGateBiasData = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + std::vector forgetGateBiasDimensions = {1, 1, 3, 3}; + armnn::ConstTensor forgetGateBias(armnn::TensorInfo( + 4, forgetGateBiasDimensions.data(), + armnn::DataType::Signed32), forgetGateBiasData); + + std::vector cellBiasData = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + std::vector cellBiasDimensions = {1, 1, 3, 3}; + armnn::ConstTensor cellBias(armnn::TensorInfo( + 4, cellBiasDimensions.data(), + armnn::DataType::Signed32), cellBiasData); + + std::vector outputGateBiasData = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + std::vector outputGateBiasDimensions = {1, 1, 3, 3}; + armnn::ConstTensor outputGateBias(armnn::TensorInfo( + 4, outputGateBiasDimensions.data(), + armnn::DataType::Signed32), outputGateBiasData); + + armnn::QuantizedLstmInputParams params; + params.m_InputToInputWeights = &inputToInputWeights; + params.m_InputToForgetWeights = &inputToForgetWeights; + params.m_InputToCellWeights = &inputToCellWeights; + params.m_InputToOutputWeights = &inputToOutputWeights; + params.m_RecurrentToInputWeights = &recurrentToInputWeights; + params.m_RecurrentToForgetWeights = &recurrentToForgetWeights; + params.m_RecurrentToCellWeights = &recurrentToCellWeights; + params.m_RecurrentToOutputWeights = &recurrentToOutputWeights; + params.m_InputGateBias = &inputGateBias; + params.m_ForgetGateBias = &forgetGateBias; + params.m_CellBias = &cellBias; + params.m_OutputGateBias = &outputGateBias; + + armnn::INetworkPtr network = armnn::INetwork::Create(); + armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0); + armnn::IConnectableLayer* const cellStateIn = network->AddInputLayer(1); + armnn::IConnectableLayer* const outputStateIn = network->AddInputLayer(2); + const std::string layerName("QuantizedLstm"); + armnn::IConnectableLayer* const quantizedLstmLayer = network->AddQuantizedLstmLayer(params, layerName.c_str()); + armnn::IConnectableLayer* const cellStateOut = network->AddOutputLayer(0); + armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(1); + + // connect up + armnn::TensorInfo inputTensorInfo({ batchSize, inputSize }, armnn::DataType::QuantisedAsymm8); + armnn::TensorInfo cellStateTensorInfo({ batchSize, numUnits}, armnn::DataType::Signed32); + armnn::TensorInfo outputStateTensorInfo({ batchSize, outputSize }, armnn::DataType::QuantisedAsymm8); + + inputLayer->GetOutputSlot(0).Connect(quantizedLstmLayer->GetInputSlot(0)); + inputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo); + + cellStateIn->GetOutputSlot(0).Connect(quantizedLstmLayer->GetInputSlot(1)); + cellStateIn->GetOutputSlot(0).SetTensorInfo(cellStateTensorInfo); + + outputStateIn->GetOutputSlot(0).Connect(quantizedLstmLayer->GetInputSlot(2)); + outputStateIn->GetOutputSlot(0).SetTensorInfo(outputStateTensorInfo); + + quantizedLstmLayer->GetOutputSlot(0).Connect(cellStateOut->GetInputSlot(0)); + quantizedLstmLayer->GetOutputSlot(0).SetTensorInfo(cellStateTensorInfo); + + quantizedLstmLayer->GetOutputSlot(1).Connect(outputLayer->GetInputSlot(0)); + quantizedLstmLayer->GetOutputSlot(1).SetTensorInfo(outputStateTensorInfo); + + armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); + BOOST_CHECK(deserializedNetwork); + + VerifyQuantizedLstmLayer checker( + layerName, + {inputTensorInfo, cellStateTensorInfo, outputStateTensorInfo}, + {cellStateTensorInfo, outputStateTensorInfo}, + params); + + deserializedNetwork->Accept(checker); +} + BOOST_AUTO_TEST_SUITE_END() -- cgit v1.2.1