From a0162e17c56538ee6d72ecce4c3e0836cbb34c56 Mon Sep 17 00:00:00 2001 From: Narumol Prangnawarat Date: Fri, 23 Jul 2021 14:47:49 +0100 Subject: MLCE-530 Add Serializer and Deserializer for UnidirectionalSequenceLstm Signed-off-by: Narumol Prangnawarat Change-Id: Ic1c56a57941ebede19ab8b9032e7f9df1221be7a --- .../test/LstmSerializationTests.cpp | 506 ++++++++++++++++++++- 1 file changed, 505 insertions(+), 1 deletion(-) (limited to 'src/armnnSerializer/test') diff --git a/src/armnnSerializer/test/LstmSerializationTests.cpp b/src/armnnSerializer/test/LstmSerializationTests.cpp index c2bc8737b4..bdc37877f7 100644 --- a/src/armnnSerializer/test/LstmSerializationTests.cpp +++ b/src/armnnSerializer/test/LstmSerializationTests.cpp @@ -74,7 +74,7 @@ armnn::LstmInputParams ConstantVector2LstmInputParams(const std::vector class VerifyLstmLayer : public LayerVerifierBaseWithDescriptor { @@ -99,6 +99,7 @@ public: case armnn::LayerType::Input: break; case armnn::LayerType::Output: break; case armnn::LayerType::Lstm: + case armnn::LayerType::UnidirectionalSequenceLstm: { this->VerifyNameAndConnections(layer, name); const Descriptor& internalDescriptor = static_cast(descriptor); @@ -2195,4 +2196,507 @@ TEST_CASE("SerializeDeserializeQLstmAdvanced") deserializedNetwork->ExecuteStrategy(checker); } +TEST_CASE("SerializeDeserializeUnidirectionalSequenceLstmCifgPeepholeNoProjection") +{ + armnn::UnidirectionalSequenceLstmDescriptor descriptor; + descriptor.m_ActivationFunc = 4; + descriptor.m_ClippingThresProj = 0.0f; + descriptor.m_ClippingThresCell = 0.0f; + descriptor.m_CifgEnabled = true; // if this is true then we DON'T need to set the OptCifgParams + descriptor.m_ProjectionEnabled = false; + descriptor.m_PeepholeEnabled = true; + descriptor.m_TimeMajor = false; + + const uint32_t batchSize = 1; + const uint32_t timeSize = 2; + const uint32_t inputSize = 2; + const uint32_t numUnits = 4; + const uint32_t outputSize = numUnits; + + armnn::TensorInfo inputWeightsInfo1({numUnits, inputSize}, armnn::DataType::Float32); + std::vector inputToForgetWeightsData = GenerateRandomData(inputWeightsInfo1.GetNumElements()); + armnn::ConstTensor inputToForgetWeights(inputWeightsInfo1, inputToForgetWeightsData); + + std::vector inputToCellWeightsData = GenerateRandomData(inputWeightsInfo1.GetNumElements()); + armnn::ConstTensor inputToCellWeights(inputWeightsInfo1, inputToCellWeightsData); + + std::vector inputToOutputWeightsData = GenerateRandomData(inputWeightsInfo1.GetNumElements()); + armnn::ConstTensor inputToOutputWeights(inputWeightsInfo1, inputToOutputWeightsData); + + armnn::TensorInfo inputWeightsInfo2({numUnits, outputSize}, armnn::DataType::Float32); + std::vector recurrentToForgetWeightsData = GenerateRandomData(inputWeightsInfo2.GetNumElements()); + armnn::ConstTensor recurrentToForgetWeights(inputWeightsInfo2, recurrentToForgetWeightsData); + + std::vector recurrentToCellWeightsData = GenerateRandomData(inputWeightsInfo2.GetNumElements()); + armnn::ConstTensor recurrentToCellWeights(inputWeightsInfo2, recurrentToCellWeightsData); + + std::vector recurrentToOutputWeightsData = GenerateRandomData(inputWeightsInfo2.GetNumElements()); + armnn::ConstTensor recurrentToOutputWeights(inputWeightsInfo2, recurrentToOutputWeightsData); + + armnn::TensorInfo inputWeightsInfo3({numUnits}, armnn::DataType::Float32); + std::vector cellToForgetWeightsData = GenerateRandomData(inputWeightsInfo3.GetNumElements()); + armnn::ConstTensor cellToForgetWeights(inputWeightsInfo3, cellToForgetWeightsData); + + std::vector cellToOutputWeightsData = GenerateRandomData(inputWeightsInfo3.GetNumElements()); + armnn::ConstTensor cellToOutputWeights(inputWeightsInfo3, cellToOutputWeightsData); + + std::vector forgetGateBiasData(numUnits, 1.0f); + armnn::ConstTensor forgetGateBias(inputWeightsInfo3, forgetGateBiasData); + + std::vector cellBiasData(numUnits, 0.0f); + armnn::ConstTensor cellBias(inputWeightsInfo3, cellBiasData); + + std::vector outputGateBiasData(numUnits, 0.0f); + armnn::ConstTensor outputGateBias(inputWeightsInfo3, outputGateBiasData); + + armnn::LstmInputParams params; + params.m_InputToForgetWeights = &inputToForgetWeights; + params.m_InputToCellWeights = &inputToCellWeights; + params.m_InputToOutputWeights = &inputToOutputWeights; + params.m_RecurrentToForgetWeights = &recurrentToForgetWeights; + params.m_RecurrentToCellWeights = &recurrentToCellWeights; + params.m_RecurrentToOutputWeights = &recurrentToOutputWeights; + params.m_ForgetGateBias = &forgetGateBias; + params.m_CellBias = &cellBias; + params.m_OutputGateBias = &outputGateBias; + params.m_CellToForgetWeights = &cellToForgetWeights; + params.m_CellToOutputWeights = &cellToOutputWeights; + + armnn::INetworkPtr network = armnn::INetwork::Create(); + armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0); + armnn::IConnectableLayer* const cellStateIn = network->AddInputLayer(1); + armnn::IConnectableLayer* const outputStateIn = network->AddInputLayer(2); + const std::string layerName("UnidirectionalSequenceLstm"); + armnn::IConnectableLayer* const unidirectionalSequenceLstmLayer = + network->AddUnidirectionalSequenceLstmLayer(descriptor, params, layerName.c_str()); + armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0); + + // connect up + armnn::TensorInfo inputTensorInfo({ batchSize, timeSize, inputSize }, armnn::DataType::Float32); + armnn::TensorInfo cellStateTensorInfo({ batchSize, numUnits}, armnn::DataType::Float32); + armnn::TensorInfo outputStateTensorInfo({ batchSize, outputSize }, armnn::DataType::Float32); + armnn::TensorInfo outputTensorInfo({ batchSize, timeSize, outputSize }, armnn::DataType::Float32); + + inputLayer->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(0)); + inputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo); + + outputStateIn->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(1)); + outputStateIn->GetOutputSlot(0).SetTensorInfo(outputStateTensorInfo); + + cellStateIn->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(2)); + cellStateIn->GetOutputSlot(0).SetTensorInfo(cellStateTensorInfo); + + unidirectionalSequenceLstmLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); + unidirectionalSequenceLstmLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + + armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); + CHECK(deserializedNetwork); + + VerifyLstmLayer checker( + layerName, + {inputTensorInfo, outputStateTensorInfo, cellStateTensorInfo}, + {outputTensorInfo}, + descriptor, + params); + deserializedNetwork->ExecuteStrategy(checker); +} + +TEST_CASE("SerializeDeserializeUnidirectionalSequenceLstmNoCifgWithPeepholeAndProjection") +{ + armnn::UnidirectionalSequenceLstmDescriptor descriptor; + descriptor.m_ActivationFunc = 4; + descriptor.m_ClippingThresProj = 0.0f; + descriptor.m_ClippingThresCell = 0.0f; + descriptor.m_CifgEnabled = false; // if this is true then we DON'T need to set the OptCifgParams + descriptor.m_ProjectionEnabled = true; + descriptor.m_PeepholeEnabled = true; + descriptor.m_TimeMajor = false; + + const uint32_t batchSize = 2; + const uint32_t timeSize = 2; + const uint32_t inputSize = 5; + const uint32_t numUnits = 20; + const uint32_t outputSize = 16; + + armnn::TensorInfo tensorInfo20x5({numUnits, inputSize}, armnn::DataType::Float32); + std::vector inputToInputWeightsData = GenerateRandomData(tensorInfo20x5.GetNumElements()); + armnn::ConstTensor inputToInputWeights(tensorInfo20x5, inputToInputWeightsData); + + std::vector inputToForgetWeightsData = GenerateRandomData(tensorInfo20x5.GetNumElements()); + armnn::ConstTensor inputToForgetWeights(tensorInfo20x5, inputToForgetWeightsData); + + std::vector inputToCellWeightsData = GenerateRandomData(tensorInfo20x5.GetNumElements()); + armnn::ConstTensor inputToCellWeights(tensorInfo20x5, inputToCellWeightsData); + + std::vector inputToOutputWeightsData = GenerateRandomData(tensorInfo20x5.GetNumElements()); + armnn::ConstTensor inputToOutputWeights(tensorInfo20x5, inputToOutputWeightsData); + + armnn::TensorInfo tensorInfo20({numUnits}, armnn::DataType::Float32); + std::vector inputGateBiasData = GenerateRandomData(tensorInfo20.GetNumElements()); + armnn::ConstTensor inputGateBias(tensorInfo20, inputGateBiasData); + + std::vector forgetGateBiasData = GenerateRandomData(tensorInfo20.GetNumElements()); + armnn::ConstTensor forgetGateBias(tensorInfo20, forgetGateBiasData); + + std::vector cellBiasData = GenerateRandomData(tensorInfo20.GetNumElements()); + armnn::ConstTensor cellBias(tensorInfo20, cellBiasData); + + std::vector outputGateBiasData = GenerateRandomData(tensorInfo20.GetNumElements()); + armnn::ConstTensor outputGateBias(tensorInfo20, outputGateBiasData); + + armnn::TensorInfo tensorInfo20x16({numUnits, outputSize}, armnn::DataType::Float32); + std::vector recurrentToInputWeightsData = GenerateRandomData(tensorInfo20x16.GetNumElements()); + armnn::ConstTensor recurrentToInputWeights(tensorInfo20x16, recurrentToInputWeightsData); + + std::vector recurrentToForgetWeightsData = GenerateRandomData(tensorInfo20x16.GetNumElements()); + armnn::ConstTensor recurrentToForgetWeights(tensorInfo20x16, recurrentToForgetWeightsData); + + std::vector recurrentToCellWeightsData = GenerateRandomData(tensorInfo20x16.GetNumElements()); + armnn::ConstTensor recurrentToCellWeights(tensorInfo20x16, recurrentToCellWeightsData); + + std::vector recurrentToOutputWeightsData = GenerateRandomData(tensorInfo20x16.GetNumElements()); + armnn::ConstTensor recurrentToOutputWeights(tensorInfo20x16, recurrentToOutputWeightsData); + + std::vector cellToInputWeightsData = GenerateRandomData(tensorInfo20.GetNumElements()); + armnn::ConstTensor cellToInputWeights(tensorInfo20, cellToInputWeightsData); + + std::vector cellToForgetWeightsData = GenerateRandomData(tensorInfo20.GetNumElements()); + armnn::ConstTensor cellToForgetWeights(tensorInfo20, cellToForgetWeightsData); + + std::vector cellToOutputWeightsData = GenerateRandomData(tensorInfo20.GetNumElements()); + armnn::ConstTensor cellToOutputWeights(tensorInfo20, cellToOutputWeightsData); + + armnn::TensorInfo tensorInfo16x20({outputSize, numUnits}, armnn::DataType::Float32); + std::vector projectionWeightsData = GenerateRandomData(tensorInfo16x20.GetNumElements()); + armnn::ConstTensor projectionWeights(tensorInfo16x20, projectionWeightsData); + + armnn::TensorInfo tensorInfo16({outputSize}, armnn::DataType::Float32); + std::vector projectionBiasData(outputSize, 0.f); + armnn::ConstTensor projectionBias(tensorInfo16, projectionBiasData); + + armnn::LstmInputParams params; + params.m_InputToForgetWeights = &inputToForgetWeights; + params.m_InputToCellWeights = &inputToCellWeights; + params.m_InputToOutputWeights = &inputToOutputWeights; + params.m_RecurrentToForgetWeights = &recurrentToForgetWeights; + params.m_RecurrentToCellWeights = &recurrentToCellWeights; + params.m_RecurrentToOutputWeights = &recurrentToOutputWeights; + params.m_ForgetGateBias = &forgetGateBias; + params.m_CellBias = &cellBias; + params.m_OutputGateBias = &outputGateBias; + + // additional params because: descriptor.m_CifgEnabled = false + params.m_InputToInputWeights = &inputToInputWeights; + params.m_RecurrentToInputWeights = &recurrentToInputWeights; + params.m_CellToInputWeights = &cellToInputWeights; + params.m_InputGateBias = &inputGateBias; + + // additional params because: descriptor.m_ProjectionEnabled = true + params.m_ProjectionWeights = &projectionWeights; + params.m_ProjectionBias = &projectionBias; + + // additional params because: descriptor.m_PeepholeEnabled = true + params.m_CellToForgetWeights = &cellToForgetWeights; + params.m_CellToOutputWeights = &cellToOutputWeights; + + armnn::INetworkPtr network = armnn::INetwork::Create(); + armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0); + armnn::IConnectableLayer* const cellStateIn = network->AddInputLayer(1); + armnn::IConnectableLayer* const outputStateIn = network->AddInputLayer(2); + const std::string layerName("unidirectionalSequenceLstm"); + armnn::IConnectableLayer* const unidirectionalSequenceLstmLayer = + network->AddUnidirectionalSequenceLstmLayer(descriptor, params, layerName.c_str()); + armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0); + + // connect up + armnn::TensorInfo inputTensorInfo({ batchSize, timeSize, inputSize }, armnn::DataType::Float32); + armnn::TensorInfo cellStateTensorInfo({ batchSize, numUnits}, armnn::DataType::Float32); + armnn::TensorInfo outputStateTensorInfo({ batchSize, outputSize }, armnn::DataType::Float32); + armnn::TensorInfo outputTensorInfo({ batchSize, timeSize, outputSize }, armnn::DataType::Float32); + + inputLayer->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(0)); + inputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo); + + outputStateIn->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(1)); + outputStateIn->GetOutputSlot(0).SetTensorInfo(outputStateTensorInfo); + + cellStateIn->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(2)); + cellStateIn->GetOutputSlot(0).SetTensorInfo(cellStateTensorInfo); + + unidirectionalSequenceLstmLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); + unidirectionalSequenceLstmLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + + armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); + CHECK(deserializedNetwork); + + VerifyLstmLayer checker( + layerName, + {inputTensorInfo, outputStateTensorInfo, cellStateTensorInfo}, + {outputTensorInfo}, + descriptor, + params); + deserializedNetwork->ExecuteStrategy(checker); +} + +TEST_CASE("SerializeDeserializeUnidirectionalSequenceLstmNoCifgWithPeepholeWithProjectionWithLayerNorm") +{ + armnn::UnidirectionalSequenceLstmDescriptor descriptor; + descriptor.m_ActivationFunc = 4; + descriptor.m_ClippingThresProj = 0.0f; + descriptor.m_ClippingThresCell = 0.0f; + descriptor.m_CifgEnabled = false; // if this is true then we DON'T need to set the OptCifgParams + descriptor.m_ProjectionEnabled = true; + descriptor.m_PeepholeEnabled = true; + descriptor.m_LayerNormEnabled = true; + descriptor.m_TimeMajor = false; + + const uint32_t batchSize = 2; + const uint32_t timeSize = 2; + const uint32_t inputSize = 5; + const uint32_t numUnits = 20; + const uint32_t outputSize = 16; + + armnn::TensorInfo tensorInfo20x5({numUnits, inputSize}, armnn::DataType::Float32); + std::vector inputToInputWeightsData = GenerateRandomData(tensorInfo20x5.GetNumElements()); + armnn::ConstTensor inputToInputWeights(tensorInfo20x5, inputToInputWeightsData); + + std::vector inputToForgetWeightsData = GenerateRandomData(tensorInfo20x5.GetNumElements()); + armnn::ConstTensor inputToForgetWeights(tensorInfo20x5, inputToForgetWeightsData); + + std::vector inputToCellWeightsData = GenerateRandomData(tensorInfo20x5.GetNumElements()); + armnn::ConstTensor inputToCellWeights(tensorInfo20x5, inputToCellWeightsData); + + std::vector inputToOutputWeightsData = GenerateRandomData(tensorInfo20x5.GetNumElements()); + armnn::ConstTensor inputToOutputWeights(tensorInfo20x5, inputToOutputWeightsData); + + armnn::TensorInfo tensorInfo20({numUnits}, armnn::DataType::Float32); + std::vector inputGateBiasData = GenerateRandomData(tensorInfo20.GetNumElements()); + armnn::ConstTensor inputGateBias(tensorInfo20, inputGateBiasData); + + std::vector forgetGateBiasData = GenerateRandomData(tensorInfo20.GetNumElements()); + armnn::ConstTensor forgetGateBias(tensorInfo20, forgetGateBiasData); + + std::vector cellBiasData = GenerateRandomData(tensorInfo20.GetNumElements()); + armnn::ConstTensor cellBias(tensorInfo20, cellBiasData); + + std::vector outputGateBiasData = GenerateRandomData(tensorInfo20.GetNumElements()); + armnn::ConstTensor outputGateBias(tensorInfo20, outputGateBiasData); + + armnn::TensorInfo tensorInfo20x16({numUnits, outputSize}, armnn::DataType::Float32); + std::vector recurrentToInputWeightsData = GenerateRandomData(tensorInfo20x16.GetNumElements()); + armnn::ConstTensor recurrentToInputWeights(tensorInfo20x16, recurrentToInputWeightsData); + + std::vector recurrentToForgetWeightsData = GenerateRandomData(tensorInfo20x16.GetNumElements()); + armnn::ConstTensor recurrentToForgetWeights(tensorInfo20x16, recurrentToForgetWeightsData); + + std::vector recurrentToCellWeightsData = GenerateRandomData(tensorInfo20x16.GetNumElements()); + armnn::ConstTensor recurrentToCellWeights(tensorInfo20x16, recurrentToCellWeightsData); + + std::vector recurrentToOutputWeightsData = GenerateRandomData(tensorInfo20x16.GetNumElements()); + armnn::ConstTensor recurrentToOutputWeights(tensorInfo20x16, recurrentToOutputWeightsData); + + std::vector cellToInputWeightsData = GenerateRandomData(tensorInfo20.GetNumElements()); + armnn::ConstTensor cellToInputWeights(tensorInfo20, cellToInputWeightsData); + + std::vector cellToForgetWeightsData = GenerateRandomData(tensorInfo20.GetNumElements()); + armnn::ConstTensor cellToForgetWeights(tensorInfo20, cellToForgetWeightsData); + + std::vector cellToOutputWeightsData = GenerateRandomData(tensorInfo20.GetNumElements()); + armnn::ConstTensor cellToOutputWeights(tensorInfo20, cellToOutputWeightsData); + + armnn::TensorInfo tensorInfo16x20({outputSize, numUnits}, armnn::DataType::Float32); + std::vector projectionWeightsData = GenerateRandomData(tensorInfo16x20.GetNumElements()); + armnn::ConstTensor projectionWeights(tensorInfo16x20, projectionWeightsData); + + armnn::TensorInfo tensorInfo16({outputSize}, armnn::DataType::Float32); + std::vector projectionBiasData(outputSize, 0.f); + armnn::ConstTensor projectionBias(tensorInfo16, projectionBiasData); + + std::vector inputLayerNormWeightsData = GenerateRandomData(tensorInfo20.GetNumElements()); + armnn::ConstTensor inputLayerNormWeights(tensorInfo20, forgetGateBiasData); + + std::vector forgetLayerNormWeightsData = GenerateRandomData(tensorInfo20.GetNumElements()); + armnn::ConstTensor forgetLayerNormWeights(tensorInfo20, forgetGateBiasData); + + std::vector cellLayerNormWeightsData = GenerateRandomData(tensorInfo20.GetNumElements()); + armnn::ConstTensor cellLayerNormWeights(tensorInfo20, forgetGateBiasData); + + std::vector outLayerNormWeightsData = GenerateRandomData(tensorInfo20.GetNumElements()); + armnn::ConstTensor outLayerNormWeights(tensorInfo20, forgetGateBiasData); + + armnn::LstmInputParams params; + params.m_InputToForgetWeights = &inputToForgetWeights; + params.m_InputToCellWeights = &inputToCellWeights; + params.m_InputToOutputWeights = &inputToOutputWeights; + params.m_RecurrentToForgetWeights = &recurrentToForgetWeights; + params.m_RecurrentToCellWeights = &recurrentToCellWeights; + params.m_RecurrentToOutputWeights = &recurrentToOutputWeights; + params.m_ForgetGateBias = &forgetGateBias; + params.m_CellBias = &cellBias; + params.m_OutputGateBias = &outputGateBias; + + // additional params because: descriptor.m_CifgEnabled = false + params.m_InputToInputWeights = &inputToInputWeights; + params.m_RecurrentToInputWeights = &recurrentToInputWeights; + params.m_CellToInputWeights = &cellToInputWeights; + params.m_InputGateBias = &inputGateBias; + + // additional params because: descriptor.m_ProjectionEnabled = true + params.m_ProjectionWeights = &projectionWeights; + params.m_ProjectionBias = &projectionBias; + + // additional params because: descriptor.m_PeepholeEnabled = true + params.m_CellToForgetWeights = &cellToForgetWeights; + params.m_CellToOutputWeights = &cellToOutputWeights; + + // additional params because: despriptor.m_LayerNormEnabled = true + params.m_InputLayerNormWeights = &inputLayerNormWeights; + params.m_ForgetLayerNormWeights = &forgetLayerNormWeights; + params.m_CellLayerNormWeights = &cellLayerNormWeights; + params.m_OutputLayerNormWeights = &outLayerNormWeights; + + armnn::INetworkPtr network = armnn::INetwork::Create(); + armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0); + armnn::IConnectableLayer* const cellStateIn = network->AddInputLayer(1); + armnn::IConnectableLayer* const outputStateIn = network->AddInputLayer(2); + const std::string layerName("unidirectionalSequenceLstm"); + armnn::IConnectableLayer* const unidirectionalSequenceLstmLayer = + network->AddUnidirectionalSequenceLstmLayer(descriptor, params, layerName.c_str()); + armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0); + + // connect up + armnn::TensorInfo inputTensorInfo({ batchSize, timeSize, inputSize }, armnn::DataType::Float32); + armnn::TensorInfo cellStateTensorInfo({ batchSize, numUnits}, armnn::DataType::Float32); + armnn::TensorInfo outputStateTensorInfo({ batchSize, outputSize }, armnn::DataType::Float32); + armnn::TensorInfo outputTensorInfo({ batchSize, timeSize, outputSize }, armnn::DataType::Float32); + + inputLayer->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(0)); + inputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo); + + outputStateIn->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(1)); + outputStateIn->GetOutputSlot(0).SetTensorInfo(outputStateTensorInfo); + + cellStateIn->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(2)); + cellStateIn->GetOutputSlot(0).SetTensorInfo(cellStateTensorInfo); + + unidirectionalSequenceLstmLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); + unidirectionalSequenceLstmLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + + armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); + CHECK(deserializedNetwork); + + VerifyLstmLayer checker( + layerName, + {inputTensorInfo, outputStateTensorInfo, cellStateTensorInfo}, + {outputTensorInfo}, + descriptor, + params); + deserializedNetwork->ExecuteStrategy(checker); +} + +TEST_CASE("SerializeDeserializeUnidirectionalSequenceLstmCifgPeepholeNoProjectionTimeMajor") +{ + armnn::UnidirectionalSequenceLstmDescriptor descriptor; + descriptor.m_ActivationFunc = 4; + descriptor.m_ClippingThresProj = 0.0f; + descriptor.m_ClippingThresCell = 0.0f; + descriptor.m_CifgEnabled = true; // if this is true then we DON'T need to set the OptCifgParams + descriptor.m_ProjectionEnabled = false; + descriptor.m_PeepholeEnabled = true; + descriptor.m_TimeMajor = true; + + const uint32_t batchSize = 1; + const uint32_t timeSize = 2; + const uint32_t inputSize = 2; + const uint32_t numUnits = 4; + const uint32_t outputSize = numUnits; + + armnn::TensorInfo inputWeightsInfo1({numUnits, inputSize}, armnn::DataType::Float32); + std::vector inputToForgetWeightsData = GenerateRandomData(inputWeightsInfo1.GetNumElements()); + armnn::ConstTensor inputToForgetWeights(inputWeightsInfo1, inputToForgetWeightsData); + + std::vector inputToCellWeightsData = GenerateRandomData(inputWeightsInfo1.GetNumElements()); + armnn::ConstTensor inputToCellWeights(inputWeightsInfo1, inputToCellWeightsData); + + std::vector inputToOutputWeightsData = GenerateRandomData(inputWeightsInfo1.GetNumElements()); + armnn::ConstTensor inputToOutputWeights(inputWeightsInfo1, inputToOutputWeightsData); + + armnn::TensorInfo inputWeightsInfo2({numUnits, outputSize}, armnn::DataType::Float32); + std::vector recurrentToForgetWeightsData = GenerateRandomData(inputWeightsInfo2.GetNumElements()); + armnn::ConstTensor recurrentToForgetWeights(inputWeightsInfo2, recurrentToForgetWeightsData); + + std::vector recurrentToCellWeightsData = GenerateRandomData(inputWeightsInfo2.GetNumElements()); + armnn::ConstTensor recurrentToCellWeights(inputWeightsInfo2, recurrentToCellWeightsData); + + std::vector recurrentToOutputWeightsData = GenerateRandomData(inputWeightsInfo2.GetNumElements()); + armnn::ConstTensor recurrentToOutputWeights(inputWeightsInfo2, recurrentToOutputWeightsData); + + armnn::TensorInfo inputWeightsInfo3({numUnits}, armnn::DataType::Float32); + std::vector cellToForgetWeightsData = GenerateRandomData(inputWeightsInfo3.GetNumElements()); + armnn::ConstTensor cellToForgetWeights(inputWeightsInfo3, cellToForgetWeightsData); + + std::vector cellToOutputWeightsData = GenerateRandomData(inputWeightsInfo3.GetNumElements()); + armnn::ConstTensor cellToOutputWeights(inputWeightsInfo3, cellToOutputWeightsData); + + std::vector forgetGateBiasData(numUnits, 1.0f); + armnn::ConstTensor forgetGateBias(inputWeightsInfo3, forgetGateBiasData); + + std::vector cellBiasData(numUnits, 0.0f); + armnn::ConstTensor cellBias(inputWeightsInfo3, cellBiasData); + + std::vector outputGateBiasData(numUnits, 0.0f); + armnn::ConstTensor outputGateBias(inputWeightsInfo3, outputGateBiasData); + + armnn::LstmInputParams params; + params.m_InputToForgetWeights = &inputToForgetWeights; + params.m_InputToCellWeights = &inputToCellWeights; + params.m_InputToOutputWeights = &inputToOutputWeights; + params.m_RecurrentToForgetWeights = &recurrentToForgetWeights; + params.m_RecurrentToCellWeights = &recurrentToCellWeights; + params.m_RecurrentToOutputWeights = &recurrentToOutputWeights; + params.m_ForgetGateBias = &forgetGateBias; + params.m_CellBias = &cellBias; + params.m_OutputGateBias = &outputGateBias; + params.m_CellToForgetWeights = &cellToForgetWeights; + params.m_CellToOutputWeights = &cellToOutputWeights; + + armnn::INetworkPtr network = armnn::INetwork::Create(); + armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0); + armnn::IConnectableLayer* const cellStateIn = network->AddInputLayer(1); + armnn::IConnectableLayer* const outputStateIn = network->AddInputLayer(2); + const std::string layerName("UnidirectionalSequenceLstm"); + armnn::IConnectableLayer* const unidirectionalSequenceLstmLayer = + network->AddUnidirectionalSequenceLstmLayer(descriptor, params, layerName.c_str()); + armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0); + + // connect up + armnn::TensorInfo inputTensorInfo({ timeSize, batchSize, inputSize }, armnn::DataType::Float32); + armnn::TensorInfo cellStateTensorInfo({ batchSize, numUnits}, armnn::DataType::Float32); + armnn::TensorInfo outputStateTensorInfo({ batchSize, outputSize }, armnn::DataType::Float32); + armnn::TensorInfo outputTensorInfo({ timeSize, batchSize, outputSize }, armnn::DataType::Float32); + + inputLayer->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(0)); + inputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo); + + outputStateIn->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(1)); + outputStateIn->GetOutputSlot(0).SetTensorInfo(outputStateTensorInfo); + + cellStateIn->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(2)); + cellStateIn->GetOutputSlot(0).SetTensorInfo(cellStateTensorInfo); + + unidirectionalSequenceLstmLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); + unidirectionalSequenceLstmLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + + armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); + CHECK(deserializedNetwork); + + VerifyLstmLayer checker( + layerName, + {inputTensorInfo, outputStateTensorInfo, cellStateTensorInfo}, + {outputTensorInfo}, + descriptor, + params); + deserializedNetwork->ExecuteStrategy(checker); +} + } -- cgit v1.2.1