diff options
author | James Conroy <james.conroy@arm.com> | 2020-05-13 10:27:58 +0100 |
---|---|---|
committer | James Conroy <james.conroy@arm.com> | 2020-05-13 23:06:38 +0000 |
commit | 8d33318a7ac33d90ed79701ff717de8d9940cc67 (patch) | |
tree | 2cf4140ec37b5b0a43b9618bab7f4f8076b5f4ab /src/armnnSerializer/test/SerializerTests.cpp | |
parent | 5061601fb6833dda20a6097af6a92e5e07310f25 (diff) | |
download | armnn-8d33318a7ac33d90ed79701ff717de8d9940cc67.tar.gz |
IVGCVSW-4777 Add QLstm serialization support
* Adds serialization/deserilization for QLstm.
* 3 unit tests: basic, layer norm and advanced.
Signed-off-by: James Conroy <james.conroy@arm.com>
Change-Id: I97d825e06b0d4a1257713cdd71ff06afa10d4380
Diffstat (limited to 'src/armnnSerializer/test/SerializerTests.cpp')
-rw-r--r-- | src/armnnSerializer/test/SerializerTests.cpp | 657 |
1 files changed, 657 insertions, 0 deletions
diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp index db89430439..76ac5a4de2 100644 --- a/src/armnnSerializer/test/SerializerTests.cpp +++ b/src/armnnSerializer/test/SerializerTests.cpp @@ -4326,4 +4326,661 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeQuantizedLstm) deserializedNetwork->Accept(checker); } +class VerifyQLstmLayer : public LayerVerifierBaseWithDescriptor<armnn::QLstmDescriptor> +{ +public: + VerifyQLstmLayer(const std::string& layerName, + const std::vector<armnn::TensorInfo>& inputInfos, + const std::vector<armnn::TensorInfo>& outputInfos, + const armnn::QLstmDescriptor& descriptor, + const armnn::LstmInputParams& inputParams) + : LayerVerifierBaseWithDescriptor<armnn::QLstmDescriptor>(layerName, inputInfos, outputInfos, descriptor) + , m_InputParams(inputParams) {} + + void VisitQLstmLayer(const armnn::IConnectableLayer* layer, + const armnn::QLstmDescriptor& descriptor, + const armnn::LstmInputParams& params, + const char* name) + { + VerifyNameAndConnections(layer, name); + VerifyDescriptor(descriptor); + VerifyInputParameters(params); + } + +protected: + void VerifyInputParameters(const armnn::LstmInputParams& params) + { + VerifyConstTensors( + "m_InputToInputWeights", m_InputParams.m_InputToInputWeights, params.m_InputToInputWeights); + VerifyConstTensors( + "m_InputToForgetWeights", m_InputParams.m_InputToForgetWeights, params.m_InputToForgetWeights); + VerifyConstTensors( + "m_InputToCellWeights", m_InputParams.m_InputToCellWeights, params.m_InputToCellWeights); + VerifyConstTensors( + "m_InputToOutputWeights", m_InputParams.m_InputToOutputWeights, params.m_InputToOutputWeights); + VerifyConstTensors( + "m_RecurrentToInputWeights", m_InputParams.m_RecurrentToInputWeights, params.m_RecurrentToInputWeights); + VerifyConstTensors( + "m_RecurrentToForgetWeights", m_InputParams.m_RecurrentToForgetWeights, params.m_RecurrentToForgetWeights); + VerifyConstTensors( + "m_RecurrentToCellWeights", m_InputParams.m_RecurrentToCellWeights, params.m_RecurrentToCellWeights); + VerifyConstTensors( + "m_RecurrentToOutputWeights", m_InputParams.m_RecurrentToOutputWeights, params.m_RecurrentToOutputWeights); + VerifyConstTensors( + "m_CellToInputWeights", m_InputParams.m_CellToInputWeights, params.m_CellToInputWeights); + VerifyConstTensors( + "m_CellToForgetWeights", m_InputParams.m_CellToForgetWeights, params.m_CellToForgetWeights); + VerifyConstTensors( + "m_CellToOutputWeights", m_InputParams.m_CellToOutputWeights, params.m_CellToOutputWeights); + VerifyConstTensors( + "m_InputGateBias", m_InputParams.m_InputGateBias, params.m_InputGateBias); + VerifyConstTensors( + "m_ForgetGateBias", m_InputParams.m_ForgetGateBias, params.m_ForgetGateBias); + VerifyConstTensors( + "m_CellBias", m_InputParams.m_CellBias, params.m_CellBias); + VerifyConstTensors( + "m_OutputGateBias", m_InputParams.m_OutputGateBias, params.m_OutputGateBias); + VerifyConstTensors( + "m_ProjectionWeights", m_InputParams.m_ProjectionWeights, params.m_ProjectionWeights); + VerifyConstTensors( + "m_ProjectionBias", m_InputParams.m_ProjectionBias, params.m_ProjectionBias); + VerifyConstTensors( + "m_InputLayerNormWeights", m_InputParams.m_InputLayerNormWeights, params.m_InputLayerNormWeights); + VerifyConstTensors( + "m_ForgetLayerNormWeights", m_InputParams.m_ForgetLayerNormWeights, params.m_ForgetLayerNormWeights); + VerifyConstTensors( + "m_CellLayerNormWeights", m_InputParams.m_CellLayerNormWeights, params.m_CellLayerNormWeights); + VerifyConstTensors( + "m_OutputLayerNormWeights", m_InputParams.m_OutputLayerNormWeights, params.m_OutputLayerNormWeights); + } + +private: + armnn::LstmInputParams m_InputParams; +}; + +BOOST_AUTO_TEST_CASE(SerializeDeserializeQLstmBasic) +{ + armnn::QLstmDescriptor descriptor; + + descriptor.m_CifgEnabled = true; + descriptor.m_ProjectionEnabled = false; + descriptor.m_PeepholeEnabled = false; + descriptor.m_LayerNormEnabled = false; + + descriptor.m_CellClip = 0.0f; + descriptor.m_ProjectionClip = 0.0f; + + descriptor.m_InputIntermediateScale = 0.00001f; + descriptor.m_ForgetIntermediateScale = 0.00001f; + descriptor.m_CellIntermediateScale = 0.00001f; + descriptor.m_OutputIntermediateScale = 0.00001f; + + descriptor.m_HiddenStateScale = 0.07f; + descriptor.m_HiddenStateZeroPoint = 0; + + const unsigned int numBatches = 2; + const unsigned int inputSize = 5; + const unsigned int outputSize = 4; + const unsigned int numUnits = 4; + + // Scale/Offset quantization info + float inputScale = 0.0078f; + int32_t inputOffset = 0; + + float outputScale = 0.0078f; + int32_t outputOffset = 0; + + float cellStateScale = 3.5002e-05f; + int32_t cellStateOffset = 0; + + float weightsScale = 0.007f; + int32_t weightsOffset = 0; + + float biasScale = 3.5002e-05f / 1024; + int32_t biasOffset = 0; + + // Weights and bias tensor and quantization info + armnn::TensorInfo inputWeightsInfo({numUnits, inputSize}, + armnn::DataType::QSymmS8, + weightsScale, + weightsOffset); + + armnn::TensorInfo recurrentWeightsInfo({numUnits, outputSize}, + armnn::DataType::QSymmS8, + weightsScale, + weightsOffset); + + armnn::TensorInfo biasInfo({numUnits}, armnn::DataType::Signed32, biasScale, biasOffset); + + std::vector<int8_t> inputToForgetWeightsData = GenerateRandomData<int8_t>(inputWeightsInfo.GetNumElements()); + std::vector<int8_t> inputToCellWeightsData = GenerateRandomData<int8_t>(inputWeightsInfo.GetNumElements()); + std::vector<int8_t> inputToOutputWeightsData = GenerateRandomData<int8_t>(inputWeightsInfo.GetNumElements()); + + armnn::ConstTensor inputToForgetWeights(inputWeightsInfo, inputToForgetWeightsData); + armnn::ConstTensor inputToCellWeights(inputWeightsInfo, inputToCellWeightsData); + armnn::ConstTensor inputToOutputWeights(inputWeightsInfo, inputToOutputWeightsData); + + std::vector<int8_t> recurrentToForgetWeightsData = + GenerateRandomData<int8_t>(recurrentWeightsInfo.GetNumElements()); + std::vector<int8_t> recurrentToCellWeightsData = + GenerateRandomData<int8_t>(recurrentWeightsInfo.GetNumElements()); + std::vector<int8_t> recurrentToOutputWeightsData = + GenerateRandomData<int8_t>(recurrentWeightsInfo.GetNumElements()); + + armnn::ConstTensor recurrentToForgetWeights(recurrentWeightsInfo, recurrentToForgetWeightsData); + armnn::ConstTensor recurrentToCellWeights(recurrentWeightsInfo, recurrentToCellWeightsData); + armnn::ConstTensor recurrentToOutputWeights(recurrentWeightsInfo, recurrentToOutputWeightsData); + + std::vector<int32_t> forgetGateBiasData(numUnits, 1); + std::vector<int32_t> cellBiasData(numUnits, 0); + std::vector<int32_t> outputGateBiasData(numUnits, 0); + + armnn::ConstTensor forgetGateBias(biasInfo, forgetGateBiasData); + armnn::ConstTensor cellBias(biasInfo, cellBiasData); + armnn::ConstTensor outputGateBias(biasInfo, outputGateBiasData); + + // Set up params + armnn::LstmInputParams params; + params.m_InputToForgetWeights = &inputToForgetWeights; + params.m_InputToCellWeights = &inputToCellWeights; + params.m_InputToOutputWeights = &inputToOutputWeights; + + params.m_RecurrentToForgetWeights = &recurrentToForgetWeights; + params.m_RecurrentToCellWeights = &recurrentToCellWeights; + params.m_RecurrentToOutputWeights = &recurrentToOutputWeights; + + params.m_ForgetGateBias = &forgetGateBias; + params.m_CellBias = &cellBias; + params.m_OutputGateBias = &outputGateBias; + + // Create network + armnn::INetworkPtr network = armnn::INetwork::Create(); + const std::string layerName("qLstm"); + + armnn::IConnectableLayer* const input = network->AddInputLayer(0); + armnn::IConnectableLayer* const outputStateIn = network->AddInputLayer(1); + armnn::IConnectableLayer* const cellStateIn = network->AddInputLayer(2); + + armnn::IConnectableLayer* const qLstmLayer = network->AddQLstmLayer(descriptor, params, layerName.c_str()); + + armnn::IConnectableLayer* const outputStateOut = network->AddOutputLayer(0); + armnn::IConnectableLayer* const cellStateOut = network->AddOutputLayer(1); + armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(2); + + // Input/Output tensor info + armnn::TensorInfo inputInfo({numBatches , inputSize}, + armnn::DataType::QAsymmS8, + inputScale, + inputOffset); + + armnn::TensorInfo cellStateInfo({numBatches , numUnits}, + armnn::DataType::QSymmS16, + cellStateScale, + cellStateOffset); + + armnn::TensorInfo outputStateInfo({numBatches , outputSize}, + armnn::DataType::QAsymmS8, + outputScale, + outputOffset); + + // Connect input/output slots + input->GetOutputSlot(0).Connect(qLstmLayer->GetInputSlot(0)); + input->GetOutputSlot(0).SetTensorInfo(inputInfo); + + outputStateIn->GetOutputSlot(0).Connect(qLstmLayer->GetInputSlot(1)); + outputStateIn->GetOutputSlot(0).SetTensorInfo(cellStateInfo); + + cellStateIn->GetOutputSlot(0).Connect(qLstmLayer->GetInputSlot(2)); + cellStateIn->GetOutputSlot(0).SetTensorInfo(outputStateInfo); + + qLstmLayer->GetOutputSlot(0).Connect(outputStateOut->GetInputSlot(0)); + qLstmLayer->GetOutputSlot(0).SetTensorInfo(outputStateInfo); + + qLstmLayer->GetOutputSlot(1).Connect(cellStateOut->GetInputSlot(0)); + qLstmLayer->GetOutputSlot(1).SetTensorInfo(cellStateInfo); + + qLstmLayer->GetOutputSlot(2).Connect(outputLayer->GetInputSlot(0)); + qLstmLayer->GetOutputSlot(2).SetTensorInfo(outputStateInfo); + + armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); + BOOST_CHECK(deserializedNetwork); + + VerifyQLstmLayer checker(layerName, + {inputInfo, cellStateInfo, outputStateInfo}, + {outputStateInfo, cellStateInfo, outputStateInfo}, + descriptor, + params); + + deserializedNetwork->Accept(checker); +} + +BOOST_AUTO_TEST_CASE(SerializeDeserializeQLstmCifgLayerNorm) +{ + armnn::QLstmDescriptor descriptor; + + // CIFG params are used when CIFG is disabled + descriptor.m_CifgEnabled = true; + descriptor.m_ProjectionEnabled = false; + descriptor.m_PeepholeEnabled = false; + descriptor.m_LayerNormEnabled = true; + + descriptor.m_CellClip = 0.0f; + descriptor.m_ProjectionClip = 0.0f; + + descriptor.m_InputIntermediateScale = 0.00001f; + descriptor.m_ForgetIntermediateScale = 0.00001f; + descriptor.m_CellIntermediateScale = 0.00001f; + descriptor.m_OutputIntermediateScale = 0.00001f; + + descriptor.m_HiddenStateScale = 0.07f; + descriptor.m_HiddenStateZeroPoint = 0; + + const unsigned int numBatches = 2; + const unsigned int inputSize = 5; + const unsigned int outputSize = 4; + const unsigned int numUnits = 4; + + // Scale/Offset quantization info + float inputScale = 0.0078f; + int32_t inputOffset = 0; + + float outputScale = 0.0078f; + int32_t outputOffset = 0; + + float cellStateScale = 3.5002e-05f; + int32_t cellStateOffset = 0; + + float weightsScale = 0.007f; + int32_t weightsOffset = 0; + + float layerNormScale = 3.5002e-05f; + int32_t layerNormOffset = 0; + + float biasScale = layerNormScale / 1024; + int32_t biasOffset = 0; + + // Weights and bias tensor and quantization info + armnn::TensorInfo inputWeightsInfo({numUnits, inputSize}, + armnn::DataType::QSymmS8, + weightsScale, + weightsOffset); + + armnn::TensorInfo recurrentWeightsInfo({numUnits, outputSize}, + armnn::DataType::QSymmS8, + weightsScale, + weightsOffset); + + armnn::TensorInfo biasInfo({numUnits}, + armnn::DataType::Signed32, + biasScale, + biasOffset); + + armnn::TensorInfo layerNormWeightsInfo({numUnits}, + armnn::DataType::QSymmS16, + layerNormScale, + layerNormOffset); + + // Mandatory params + std::vector<int8_t> inputToForgetWeightsData = GenerateRandomData<int8_t>(inputWeightsInfo.GetNumElements()); + std::vector<int8_t> inputToCellWeightsData = GenerateRandomData<int8_t>(inputWeightsInfo.GetNumElements()); + std::vector<int8_t> inputToOutputWeightsData = GenerateRandomData<int8_t>(inputWeightsInfo.GetNumElements()); + + armnn::ConstTensor inputToForgetWeights(inputWeightsInfo, inputToForgetWeightsData); + armnn::ConstTensor inputToCellWeights(inputWeightsInfo, inputToCellWeightsData); + armnn::ConstTensor inputToOutputWeights(inputWeightsInfo, inputToOutputWeightsData); + + std::vector<int8_t> recurrentToForgetWeightsData = + GenerateRandomData<int8_t>(recurrentWeightsInfo.GetNumElements()); + std::vector<int8_t> recurrentToCellWeightsData = + GenerateRandomData<int8_t>(recurrentWeightsInfo.GetNumElements()); + std::vector<int8_t> recurrentToOutputWeightsData = + GenerateRandomData<int8_t>(recurrentWeightsInfo.GetNumElements()); + + armnn::ConstTensor recurrentToForgetWeights(recurrentWeightsInfo, recurrentToForgetWeightsData); + armnn::ConstTensor recurrentToCellWeights(recurrentWeightsInfo, recurrentToCellWeightsData); + armnn::ConstTensor recurrentToOutputWeights(recurrentWeightsInfo, recurrentToOutputWeightsData); + + std::vector<int32_t> forgetGateBiasData(numUnits, 1); + std::vector<int32_t> cellBiasData(numUnits, 0); + std::vector<int32_t> outputGateBiasData(numUnits, 0); + + armnn::ConstTensor forgetGateBias(biasInfo, forgetGateBiasData); + armnn::ConstTensor cellBias(biasInfo, cellBiasData); + armnn::ConstTensor outputGateBias(biasInfo, outputGateBiasData); + + // Layer Norm + std::vector<int16_t> forgetLayerNormWeightsData = + GenerateRandomData<int16_t>(layerNormWeightsInfo.GetNumElements()); + std::vector<int16_t> cellLayerNormWeightsData = + GenerateRandomData<int16_t>(layerNormWeightsInfo.GetNumElements()); + std::vector<int16_t> outputLayerNormWeightsData = + GenerateRandomData<int16_t>(layerNormWeightsInfo.GetNumElements()); + + armnn::ConstTensor forgetLayerNormWeights(layerNormWeightsInfo, forgetLayerNormWeightsData); + armnn::ConstTensor cellLayerNormWeights(layerNormWeightsInfo, cellLayerNormWeightsData); + armnn::ConstTensor outputLayerNormWeights(layerNormWeightsInfo, outputLayerNormWeightsData); + + // Set up params + armnn::LstmInputParams params; + + // Mandatory params + params.m_InputToForgetWeights = &inputToForgetWeights; + params.m_InputToCellWeights = &inputToCellWeights; + params.m_InputToOutputWeights = &inputToOutputWeights; + + params.m_RecurrentToForgetWeights = &recurrentToForgetWeights; + params.m_RecurrentToCellWeights = &recurrentToCellWeights; + params.m_RecurrentToOutputWeights = &recurrentToOutputWeights; + + params.m_ForgetGateBias = &forgetGateBias; + params.m_CellBias = &cellBias; + params.m_OutputGateBias = &outputGateBias; + + // Layer Norm + params.m_ForgetLayerNormWeights = &forgetLayerNormWeights; + params.m_CellLayerNormWeights = &cellLayerNormWeights; + params.m_OutputLayerNormWeights = &outputLayerNormWeights; + + // Create network + armnn::INetworkPtr network = armnn::INetwork::Create(); + const std::string layerName("qLstm"); + + armnn::IConnectableLayer* const input = network->AddInputLayer(0); + armnn::IConnectableLayer* const outputStateIn = network->AddInputLayer(1); + armnn::IConnectableLayer* const cellStateIn = network->AddInputLayer(2); + + armnn::IConnectableLayer* const qLstmLayer = network->AddQLstmLayer(descriptor, params, layerName.c_str()); + + armnn::IConnectableLayer* const outputStateOut = network->AddOutputLayer(0); + armnn::IConnectableLayer* const cellStateOut = network->AddOutputLayer(1); + armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(2); + + // Input/Output tensor info + armnn::TensorInfo inputInfo({numBatches , inputSize}, + armnn::DataType::QAsymmS8, + inputScale, + inputOffset); + + armnn::TensorInfo cellStateInfo({numBatches , numUnits}, + armnn::DataType::QSymmS16, + cellStateScale, + cellStateOffset); + + armnn::TensorInfo outputStateInfo({numBatches , outputSize}, + armnn::DataType::QAsymmS8, + outputScale, + outputOffset); + + // Connect input/output slots + input->GetOutputSlot(0).Connect(qLstmLayer->GetInputSlot(0)); + input->GetOutputSlot(0).SetTensorInfo(inputInfo); + + outputStateIn->GetOutputSlot(0).Connect(qLstmLayer->GetInputSlot(1)); + outputStateIn->GetOutputSlot(0).SetTensorInfo(cellStateInfo); + + cellStateIn->GetOutputSlot(0).Connect(qLstmLayer->GetInputSlot(2)); + cellStateIn->GetOutputSlot(0).SetTensorInfo(outputStateInfo); + + qLstmLayer->GetOutputSlot(0).Connect(outputStateOut->GetInputSlot(0)); + qLstmLayer->GetOutputSlot(0).SetTensorInfo(outputStateInfo); + + qLstmLayer->GetOutputSlot(1).Connect(cellStateOut->GetInputSlot(0)); + qLstmLayer->GetOutputSlot(1).SetTensorInfo(cellStateInfo); + + qLstmLayer->GetOutputSlot(2).Connect(outputLayer->GetInputSlot(0)); + qLstmLayer->GetOutputSlot(2).SetTensorInfo(outputStateInfo); + + armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); + BOOST_CHECK(deserializedNetwork); + + VerifyQLstmLayer checker(layerName, + {inputInfo, cellStateInfo, outputStateInfo}, + {outputStateInfo, cellStateInfo, outputStateInfo}, + descriptor, + params); + + deserializedNetwork->Accept(checker); +} + +BOOST_AUTO_TEST_CASE(SerializeDeserializeQLstmAdvanced) +{ + armnn::QLstmDescriptor descriptor; + + descriptor.m_CifgEnabled = false; + descriptor.m_ProjectionEnabled = true; + descriptor.m_PeepholeEnabled = true; + descriptor.m_LayerNormEnabled = true; + + descriptor.m_CellClip = 0.1f; + descriptor.m_ProjectionClip = 0.1f; + + descriptor.m_InputIntermediateScale = 0.00001f; + descriptor.m_ForgetIntermediateScale = 0.00001f; + descriptor.m_CellIntermediateScale = 0.00001f; + descriptor.m_OutputIntermediateScale = 0.00001f; + + descriptor.m_HiddenStateScale = 0.07f; + descriptor.m_HiddenStateZeroPoint = 0; + + const unsigned int numBatches = 2; + const unsigned int inputSize = 5; + const unsigned int outputSize = 4; + const unsigned int numUnits = 4; + + // Scale/Offset quantization info + float inputScale = 0.0078f; + int32_t inputOffset = 0; + + float outputScale = 0.0078f; + int32_t outputOffset = 0; + + float cellStateScale = 3.5002e-05f; + int32_t cellStateOffset = 0; + + float weightsScale = 0.007f; + int32_t weightsOffset = 0; + + float layerNormScale = 3.5002e-05f; + int32_t layerNormOffset = 0; + + float biasScale = layerNormScale / 1024; + int32_t biasOffset = 0; + + // Weights and bias tensor and quantization info + armnn::TensorInfo inputWeightsInfo({numUnits, inputSize}, + armnn::DataType::QSymmS8, + weightsScale, + weightsOffset); + + armnn::TensorInfo recurrentWeightsInfo({numUnits, outputSize}, + armnn::DataType::QSymmS8, + weightsScale, + weightsOffset); + + armnn::TensorInfo biasInfo({numUnits}, + armnn::DataType::Signed32, + biasScale, + biasOffset); + + armnn::TensorInfo peepholeWeightsInfo({numUnits}, + armnn::DataType::QSymmS16, + weightsScale, + weightsOffset); + + armnn::TensorInfo layerNormWeightsInfo({numUnits}, + armnn::DataType::QSymmS16, + layerNormScale, + layerNormOffset); + + armnn::TensorInfo projectionWeightsInfo({outputSize, numUnits}, + armnn::DataType::QSymmS8, + weightsScale, + weightsOffset); + + // Mandatory params + std::vector<int8_t> inputToForgetWeightsData = GenerateRandomData<int8_t>(inputWeightsInfo.GetNumElements()); + std::vector<int8_t> inputToCellWeightsData = GenerateRandomData<int8_t>(inputWeightsInfo.GetNumElements()); + std::vector<int8_t> inputToOutputWeightsData = GenerateRandomData<int8_t>(inputWeightsInfo.GetNumElements()); + + armnn::ConstTensor inputToForgetWeights(inputWeightsInfo, inputToForgetWeightsData); + armnn::ConstTensor inputToCellWeights(inputWeightsInfo, inputToCellWeightsData); + armnn::ConstTensor inputToOutputWeights(inputWeightsInfo, inputToOutputWeightsData); + + std::vector<int8_t> recurrentToForgetWeightsData = + GenerateRandomData<int8_t>(recurrentWeightsInfo.GetNumElements()); + std::vector<int8_t> recurrentToCellWeightsData = + GenerateRandomData<int8_t>(recurrentWeightsInfo.GetNumElements()); + std::vector<int8_t> recurrentToOutputWeightsData = + GenerateRandomData<int8_t>(recurrentWeightsInfo.GetNumElements()); + + armnn::ConstTensor recurrentToForgetWeights(recurrentWeightsInfo, recurrentToForgetWeightsData); + armnn::ConstTensor recurrentToCellWeights(recurrentWeightsInfo, recurrentToCellWeightsData); + armnn::ConstTensor recurrentToOutputWeights(recurrentWeightsInfo, recurrentToOutputWeightsData); + + std::vector<int32_t> forgetGateBiasData(numUnits, 1); + std::vector<int32_t> cellBiasData(numUnits, 0); + std::vector<int32_t> outputGateBiasData(numUnits, 0); + + armnn::ConstTensor forgetGateBias(biasInfo, forgetGateBiasData); + armnn::ConstTensor cellBias(biasInfo, cellBiasData); + armnn::ConstTensor outputGateBias(biasInfo, outputGateBiasData); + + // CIFG + std::vector<int8_t> inputToInputWeightsData = GenerateRandomData<int8_t>(inputWeightsInfo.GetNumElements()); + std::vector<int8_t> recurrentToInputWeightsData = + GenerateRandomData<int8_t>(recurrentWeightsInfo.GetNumElements()); + std::vector<int32_t> inputGateBiasData(numUnits, 1); + + armnn::ConstTensor inputToInputWeights(inputWeightsInfo, inputToInputWeightsData); + armnn::ConstTensor recurrentToInputWeights(recurrentWeightsInfo, recurrentToInputWeightsData); + armnn::ConstTensor inputGateBias(biasInfo, inputGateBiasData); + + // Peephole + std::vector<int16_t> cellToInputWeightsData = GenerateRandomData<int16_t>(peepholeWeightsInfo.GetNumElements()); + std::vector<int16_t> cellToForgetWeightsData = GenerateRandomData<int16_t>(peepholeWeightsInfo.GetNumElements()); + std::vector<int16_t> cellToOutputWeightsData = GenerateRandomData<int16_t>(peepholeWeightsInfo.GetNumElements()); + + armnn::ConstTensor cellToInputWeights(peepholeWeightsInfo, cellToInputWeightsData); + armnn::ConstTensor cellToForgetWeights(peepholeWeightsInfo, cellToForgetWeightsData); + armnn::ConstTensor cellToOutputWeights(peepholeWeightsInfo, cellToOutputWeightsData); + + // Projection + std::vector<int8_t> projectionWeightsData = GenerateRandomData<int8_t>(projectionWeightsInfo.GetNumElements()); + std::vector<int32_t> projectionBiasData(outputSize, 1); + + armnn::ConstTensor projectionWeights(projectionWeightsInfo, projectionWeightsData); + armnn::ConstTensor projectionBias(biasInfo, projectionBiasData); + + // Layer Norm + std::vector<int16_t> inputLayerNormWeightsData = + GenerateRandomData<int16_t>(layerNormWeightsInfo.GetNumElements()); + std::vector<int16_t> forgetLayerNormWeightsData = + GenerateRandomData<int16_t>(layerNormWeightsInfo.GetNumElements()); + std::vector<int16_t> cellLayerNormWeightsData = + GenerateRandomData<int16_t>(layerNormWeightsInfo.GetNumElements()); + std::vector<int16_t> outputLayerNormWeightsData = + GenerateRandomData<int16_t>(layerNormWeightsInfo.GetNumElements()); + + armnn::ConstTensor inputLayerNormWeights(layerNormWeightsInfo, inputLayerNormWeightsData); + armnn::ConstTensor forgetLayerNormWeights(layerNormWeightsInfo, forgetLayerNormWeightsData); + armnn::ConstTensor cellLayerNormWeights(layerNormWeightsInfo, cellLayerNormWeightsData); + armnn::ConstTensor outputLayerNormWeights(layerNormWeightsInfo, outputLayerNormWeightsData); + + // Set up params + armnn::LstmInputParams params; + + // Mandatory params + params.m_InputToForgetWeights = &inputToForgetWeights; + params.m_InputToCellWeights = &inputToCellWeights; + params.m_InputToOutputWeights = &inputToOutputWeights; + + params.m_RecurrentToForgetWeights = &recurrentToForgetWeights; + params.m_RecurrentToCellWeights = &recurrentToCellWeights; + params.m_RecurrentToOutputWeights = &recurrentToOutputWeights; + + params.m_ForgetGateBias = &forgetGateBias; + params.m_CellBias = &cellBias; + params.m_OutputGateBias = &outputGateBias; + + // CIFG + params.m_InputToInputWeights = &inputToInputWeights; + params.m_RecurrentToInputWeights = &recurrentToInputWeights; + params.m_InputGateBias = &inputGateBias; + + // Peephole + params.m_CellToInputWeights = &cellToInputWeights; + params.m_CellToForgetWeights = &cellToForgetWeights; + params.m_CellToOutputWeights = &cellToOutputWeights; + + // Projection + params.m_ProjectionWeights = &projectionWeights; + params.m_ProjectionBias = &projectionBias; + + // Layer Norm + params.m_InputLayerNormWeights = &inputLayerNormWeights; + params.m_ForgetLayerNormWeights = &forgetLayerNormWeights; + params.m_CellLayerNormWeights = &cellLayerNormWeights; + params.m_OutputLayerNormWeights = &outputLayerNormWeights; + + // Create network + armnn::INetworkPtr network = armnn::INetwork::Create(); + const std::string layerName("qLstm"); + + armnn::IConnectableLayer* const input = network->AddInputLayer(0); + armnn::IConnectableLayer* const outputStateIn = network->AddInputLayer(1); + armnn::IConnectableLayer* const cellStateIn = network->AddInputLayer(2); + + armnn::IConnectableLayer* const qLstmLayer = network->AddQLstmLayer(descriptor, params, layerName.c_str()); + + armnn::IConnectableLayer* const outputStateOut = network->AddOutputLayer(0); + armnn::IConnectableLayer* const cellStateOut = network->AddOutputLayer(1); + armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(2); + + // Input/Output tensor info + armnn::TensorInfo inputInfo({numBatches , inputSize}, + armnn::DataType::QAsymmS8, + inputScale, + inputOffset); + + armnn::TensorInfo cellStateInfo({numBatches , numUnits}, + armnn::DataType::QSymmS16, + cellStateScale, + cellStateOffset); + + armnn::TensorInfo outputStateInfo({numBatches , outputSize}, + armnn::DataType::QAsymmS8, + outputScale, + outputOffset); + + // Connect input/output slots + input->GetOutputSlot(0).Connect(qLstmLayer->GetInputSlot(0)); + input->GetOutputSlot(0).SetTensorInfo(inputInfo); + + outputStateIn->GetOutputSlot(0).Connect(qLstmLayer->GetInputSlot(1)); + outputStateIn->GetOutputSlot(0).SetTensorInfo(cellStateInfo); + + cellStateIn->GetOutputSlot(0).Connect(qLstmLayer->GetInputSlot(2)); + cellStateIn->GetOutputSlot(0).SetTensorInfo(outputStateInfo); + + qLstmLayer->GetOutputSlot(0).Connect(outputStateOut->GetInputSlot(0)); + qLstmLayer->GetOutputSlot(0).SetTensorInfo(outputStateInfo); + + qLstmLayer->GetOutputSlot(1).Connect(cellStateOut->GetInputSlot(0)); + qLstmLayer->GetOutputSlot(1).SetTensorInfo(cellStateInfo); + + qLstmLayer->GetOutputSlot(2).Connect(outputLayer->GetInputSlot(0)); + qLstmLayer->GetOutputSlot(2).SetTensorInfo(outputStateInfo); + + armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); + BOOST_CHECK(deserializedNetwork); + + VerifyQLstmLayer checker(layerName, + {inputInfo, cellStateInfo, outputStateInfo}, + {outputStateInfo, cellStateInfo, outputStateInfo}, + descriptor, + params); + + deserializedNetwork->Accept(checker); +} + BOOST_AUTO_TEST_SUITE_END() |