aboutsummaryrefslogtreecommitdiff
path: root/src/armnnSerializer/test/SerializerTests.cpp
diff options
context:
space:
mode:
authorJim Flynn <jim.flynn@arm.com>2019-03-19 17:22:29 +0000
committerJim Flynn <jim.flynn@arm.com>2019-03-21 16:09:19 +0000
commit11af375a5a6bf88b4f3b933a86d53000b0d91ed0 (patch)
treef4f4db5192b275be44d96d96c7f3c8c10f15b3f1 /src/armnnSerializer/test/SerializerTests.cpp
parentdb059fd50f9afb398b8b12cd4592323fc8f60d7f (diff)
downloadarmnn-11af375a5a6bf88b4f3b933a86d53000b0d91ed0.tar.gz
IVGCVSW-2694: serialize/deserialize LSTM
* added serialize/deserialize methods for LSTM and tests Change-Id: Ic59557f03001c496008c4bef92c2e0406e1fbc6c Signed-off-by: Nina Drozd <nina.drozd@arm.com> Signed-off-by: Jim Flynn <jim.flynn@arm.com>
Diffstat (limited to 'src/armnnSerializer/test/SerializerTests.cpp')
-rw-r--r--src/armnnSerializer/test/SerializerTests.cpp375
1 files changed, 375 insertions, 0 deletions
diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp
index f40c02dfde..e3ce6d29d3 100644
--- a/src/armnnSerializer/test/SerializerTests.cpp
+++ b/src/armnnSerializer/test/SerializerTests.cpp
@@ -2047,4 +2047,379 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeNonLinearNetwork)
deserializedNetwork->Accept(verifier);
}
+class VerifyLstmLayer : public LayerVerifierBase
+{
+public:
+ VerifyLstmLayer(const std::string& layerName,
+ const std::vector<armnn::TensorInfo>& inputInfos,
+ const std::vector<armnn::TensorInfo>& outputInfos,
+ const armnn::LstmDescriptor& descriptor,
+ const armnn::LstmInputParams& inputParams) :
+ LayerVerifierBase(layerName, inputInfos, outputInfos), m_Descriptor(descriptor), m_InputParams(inputParams)
+ {
+ }
+ void VisitLstmLayer(const armnn::IConnectableLayer* layer,
+ const armnn::LstmDescriptor& descriptor,
+ const armnn::LstmInputParams& params,
+ const char* name)
+ {
+ VerifyNameAndConnections(layer, name);
+ VerifyDescriptor(descriptor);
+ VerifyInputParameters(params);
+ }
+protected:
+ void VerifyDescriptor(const armnn::LstmDescriptor& descriptor)
+ {
+ BOOST_TEST(m_Descriptor.m_ActivationFunc == descriptor.m_ActivationFunc);
+ BOOST_TEST(m_Descriptor.m_ClippingThresCell == descriptor.m_ClippingThresCell);
+ BOOST_TEST(m_Descriptor.m_ClippingThresProj == descriptor.m_ClippingThresProj);
+ BOOST_TEST(m_Descriptor.m_CifgEnabled == descriptor.m_CifgEnabled);
+ BOOST_TEST(m_Descriptor.m_PeepholeEnabled = descriptor.m_PeepholeEnabled);
+ BOOST_TEST(m_Descriptor.m_ProjectionEnabled == descriptor.m_ProjectionEnabled);
+ }
+ void VerifyInputParameters(const armnn::LstmInputParams& params)
+ {
+ VerifyConstTensors(
+ "m_InputToInputWeights", m_InputParams.m_InputToInputWeights, params.m_InputToInputWeights);
+ VerifyConstTensors(
+ "m_InputToForgetWeights", m_InputParams.m_InputToForgetWeights, params.m_InputToForgetWeights);
+ VerifyConstTensors(
+ "m_InputToCellWeights", m_InputParams.m_InputToCellWeights, params.m_InputToCellWeights);
+ VerifyConstTensors(
+ "m_InputToOutputWeights", m_InputParams.m_InputToOutputWeights, params.m_InputToOutputWeights);
+ VerifyConstTensors(
+ "m_RecurrentToInputWeights", m_InputParams.m_RecurrentToInputWeights, params.m_RecurrentToInputWeights);
+ VerifyConstTensors(
+ "m_RecurrentToForgetWeights", m_InputParams.m_RecurrentToForgetWeights, params.m_RecurrentToForgetWeights);
+ VerifyConstTensors(
+ "m_RecurrentToCellWeights", m_InputParams.m_RecurrentToCellWeights, params.m_RecurrentToCellWeights);
+ VerifyConstTensors(
+ "m_RecurrentToOutputWeights", m_InputParams.m_RecurrentToOutputWeights, params.m_RecurrentToOutputWeights);
+ VerifyConstTensors(
+ "m_CellToInputWeights", m_InputParams.m_CellToInputWeights, params.m_CellToInputWeights);
+ VerifyConstTensors(
+ "m_CellToForgetWeights", m_InputParams.m_CellToForgetWeights, params.m_CellToForgetWeights);
+ VerifyConstTensors(
+ "m_CellToOutputWeights", m_InputParams.m_CellToOutputWeights, params.m_CellToOutputWeights);
+ VerifyConstTensors(
+ "m_InputGateBias", m_InputParams.m_InputGateBias, params.m_InputGateBias);
+ VerifyConstTensors(
+ "m_ForgetGateBias", m_InputParams.m_ForgetGateBias, params.m_ForgetGateBias);
+ VerifyConstTensors(
+ "m_CellBias", m_InputParams.m_CellBias, params.m_CellBias);
+ VerifyConstTensors(
+ "m_OutputGateBias", m_InputParams.m_OutputGateBias, params.m_OutputGateBias);
+ VerifyConstTensors(
+ "m_ProjectionWeights", m_InputParams.m_ProjectionWeights, params.m_ProjectionWeights);
+ VerifyConstTensors(
+ "m_ProjectionBias", m_InputParams.m_ProjectionBias, params.m_ProjectionBias);
+ }
+ void VerifyConstTensors(const std::string& tensorName,
+ const armnn::ConstTensor* expectedPtr,
+ const armnn::ConstTensor* actualPtr)
+ {
+ if (expectedPtr == nullptr)
+ {
+ BOOST_CHECK_MESSAGE(actualPtr == nullptr, tensorName + " should not exist");
+ }
+ else
+ {
+ BOOST_CHECK_MESSAGE(actualPtr != nullptr, tensorName + " should have been set");
+ if (actualPtr != nullptr)
+ {
+ const armnn::TensorInfo& expectedInfo = expectedPtr->GetInfo();
+ const armnn::TensorInfo& actualInfo = actualPtr->GetInfo();
+
+ BOOST_CHECK_MESSAGE(expectedInfo.GetShape() == actualInfo.GetShape(),
+ tensorName + " shapes don't match");
+ BOOST_CHECK_MESSAGE(
+ GetDataTypeName(expectedInfo.GetDataType()) == GetDataTypeName(actualInfo.GetDataType()),
+ tensorName + " data types don't match");
+
+ BOOST_CHECK_MESSAGE(expectedPtr->GetNumBytes() == actualPtr->GetNumBytes(),
+ tensorName + " (GetNumBytes) data sizes do not match");
+ if (expectedPtr->GetNumBytes() == actualPtr->GetNumBytes())
+ {
+ //check the data is identical
+ const char* expectedData = static_cast<const char*>(expectedPtr->GetMemoryArea());
+ const char* actualData = static_cast<const char*>(actualPtr->GetMemoryArea());
+ bool same = true;
+ for (unsigned int i = 0; i < expectedPtr->GetNumBytes(); ++i)
+ {
+ same = expectedData[i] == actualData[i];
+ if (!same)
+ {
+ break;
+ }
+ }
+ BOOST_CHECK_MESSAGE(same, tensorName + " data does not match");
+ }
+ }
+ }
+ }
+private:
+ armnn::LstmDescriptor m_Descriptor;
+ armnn::LstmInputParams m_InputParams;
+};
+
+BOOST_AUTO_TEST_CASE(SerializeDeserializeLstmCifgPeepholeNoProjection)
+{
+ armnn::LstmDescriptor descriptor;
+ descriptor.m_ActivationFunc = 4;
+ descriptor.m_ClippingThresProj = 0.0f;
+ descriptor.m_ClippingThresCell = 0.0f;
+ descriptor.m_CifgEnabled = true; // if this is true then we DON'T need to set the OptCifgParams
+ descriptor.m_ProjectionEnabled = false;
+ descriptor.m_PeepholeEnabled = true;
+
+ const uint32_t batchSize = 1;
+ const uint32_t inputSize = 2;
+ const uint32_t numUnits = 4;
+ const uint32_t outputSize = numUnits;
+
+ armnn::TensorInfo inputWeightsInfo1({numUnits, inputSize}, armnn::DataType::Float32);
+ std::vector<float> inputToForgetWeightsData = GenerateRandomData<float>(inputWeightsInfo1.GetNumElements());
+ armnn::ConstTensor inputToForgetWeights(inputWeightsInfo1, inputToForgetWeightsData);
+
+ std::vector<float> inputToCellWeightsData = GenerateRandomData<float>(inputWeightsInfo1.GetNumElements());
+ armnn::ConstTensor inputToCellWeights(inputWeightsInfo1, inputToCellWeightsData);
+
+ std::vector<float> inputToOutputWeightsData = GenerateRandomData<float>(inputWeightsInfo1.GetNumElements());
+ armnn::ConstTensor inputToOutputWeights(inputWeightsInfo1, inputToOutputWeightsData);
+
+ armnn::TensorInfo inputWeightsInfo2({numUnits, outputSize}, armnn::DataType::Float32);
+ std::vector<float> recurrentToForgetWeightsData = GenerateRandomData<float>(inputWeightsInfo2.GetNumElements());
+ armnn::ConstTensor recurrentToForgetWeights(inputWeightsInfo2, recurrentToForgetWeightsData);
+
+ std::vector<float> recurrentToCellWeightsData = GenerateRandomData<float>(inputWeightsInfo2.GetNumElements());
+ armnn::ConstTensor recurrentToCellWeights(inputWeightsInfo2, recurrentToCellWeightsData);
+
+ std::vector<float> recurrentToOutputWeightsData = GenerateRandomData<float>(inputWeightsInfo2.GetNumElements());
+ armnn::ConstTensor recurrentToOutputWeights(inputWeightsInfo2, recurrentToOutputWeightsData);
+
+ armnn::TensorInfo inputWeightsInfo3({numUnits}, armnn::DataType::Float32);
+ std::vector<float> cellToForgetWeightsData = GenerateRandomData<float>(inputWeightsInfo3.GetNumElements());
+ armnn::ConstTensor cellToForgetWeights(inputWeightsInfo3, cellToForgetWeightsData);
+
+ std::vector<float> cellToOutputWeightsData = GenerateRandomData<float>(inputWeightsInfo3.GetNumElements());
+ armnn::ConstTensor cellToOutputWeights(inputWeightsInfo3, cellToOutputWeightsData);
+
+ std::vector<float> forgetGateBiasData(numUnits, 1.0f);
+ armnn::ConstTensor forgetGateBias(inputWeightsInfo3, forgetGateBiasData);
+
+ std::vector<float> cellBiasData(numUnits, 0.0f);
+ armnn::ConstTensor cellBias(inputWeightsInfo3, cellBiasData);
+
+ std::vector<float> outputGateBiasData(numUnits, 0.0f);
+ armnn::ConstTensor outputGateBias(inputWeightsInfo3, outputGateBiasData);
+
+ armnn::LstmInputParams params;
+ params.m_InputToForgetWeights = &inputToForgetWeights;
+ params.m_InputToCellWeights = &inputToCellWeights;
+ params.m_InputToOutputWeights = &inputToOutputWeights;
+ params.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
+ params.m_RecurrentToCellWeights = &recurrentToCellWeights;
+ params.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
+ params.m_ForgetGateBias = &forgetGateBias;
+ params.m_CellBias = &cellBias;
+ params.m_OutputGateBias = &outputGateBias;
+ params.m_CellToForgetWeights = &cellToForgetWeights;
+ params.m_CellToOutputWeights = &cellToOutputWeights;
+
+ armnn::INetworkPtr network = armnn::INetwork::Create();
+ armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0);
+ armnn::IConnectableLayer* const cellStateIn = network->AddInputLayer(1);
+ armnn::IConnectableLayer* const outputStateIn = network->AddInputLayer(2);
+ const std::string layerName("lstm");
+ armnn::IConnectableLayer* const lstmLayer = network->AddLstmLayer(descriptor, params, layerName.c_str());
+ armnn::IConnectableLayer* const scratchBuffer = network->AddOutputLayer(0);
+ armnn::IConnectableLayer* const outputStateOut = network->AddOutputLayer(1);
+ armnn::IConnectableLayer* const cellStateOut = network->AddOutputLayer(2);
+ armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(3);
+
+ // connect up
+ armnn::TensorInfo inputTensorInfo({ batchSize, inputSize }, armnn::DataType::Float32);
+ armnn::TensorInfo cellStateTensorInfo({ batchSize, numUnits}, armnn::DataType::Float32);
+ armnn::TensorInfo outputStateTensorInfo({ batchSize, outputSize }, armnn::DataType::Float32);
+ armnn::TensorInfo lstmTensorInfoScratchBuff({ batchSize, numUnits * 3 }, armnn::DataType::Float32);
+
+ inputLayer->GetOutputSlot(0).Connect(lstmLayer->GetInputSlot(0));
+ inputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo);
+
+ outputStateIn->GetOutputSlot(0).Connect(lstmLayer->GetInputSlot(1));
+ outputStateIn->GetOutputSlot(0).SetTensorInfo(outputStateTensorInfo);
+
+ cellStateIn->GetOutputSlot(0).Connect(lstmLayer->GetInputSlot(2));
+ cellStateIn->GetOutputSlot(0).SetTensorInfo(cellStateTensorInfo);
+
+ lstmLayer->GetOutputSlot(0).Connect(scratchBuffer->GetInputSlot(0));
+ lstmLayer->GetOutputSlot(0).SetTensorInfo(lstmTensorInfoScratchBuff);
+
+ lstmLayer->GetOutputSlot(1).Connect(outputStateOut->GetInputSlot(0));
+ lstmLayer->GetOutputSlot(1).SetTensorInfo(outputStateTensorInfo);
+
+ lstmLayer->GetOutputSlot(2).Connect(cellStateOut->GetInputSlot(0));
+ lstmLayer->GetOutputSlot(2).SetTensorInfo(cellStateTensorInfo);
+
+ lstmLayer->GetOutputSlot(3).Connect(outputLayer->GetInputSlot(0));
+ lstmLayer->GetOutputSlot(3).SetTensorInfo(outputStateTensorInfo);
+
+ armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
+ BOOST_CHECK(deserializedNetwork);
+
+ VerifyLstmLayer checker(
+ layerName,
+ {inputTensorInfo, outputStateTensorInfo, cellStateTensorInfo},
+ {lstmTensorInfoScratchBuff, outputStateTensorInfo, cellStateTensorInfo, outputStateTensorInfo},
+ descriptor,
+ params);
+ deserializedNetwork->Accept(checker);
+}
+
+BOOST_AUTO_TEST_CASE(SerializeDeserializeLstmNoCifgWithPeepholeAndProjection)
+{
+ armnn::LstmDescriptor descriptor;
+ descriptor.m_ActivationFunc = 4;
+ descriptor.m_ClippingThresProj = 0.0f;
+ descriptor.m_ClippingThresCell = 0.0f;
+ descriptor.m_CifgEnabled = false; // if this is true then we DON'T need to set the OptCifgParams
+ descriptor.m_ProjectionEnabled = true;
+ descriptor.m_PeepholeEnabled = true;
+
+ const uint32_t batchSize = 2;
+ const uint32_t inputSize = 5;
+ const uint32_t numUnits = 20;
+ const uint32_t outputSize = 16;
+
+ armnn::TensorInfo tensorInfo20x5({numUnits, inputSize}, armnn::DataType::Float32);
+ std::vector<float> inputToInputWeightsData = GenerateRandomData<float>(tensorInfo20x5.GetNumElements());
+ armnn::ConstTensor inputToInputWeights(tensorInfo20x5, inputToInputWeightsData);
+
+ std::vector<float> inputToForgetWeightsData = GenerateRandomData<float>(tensorInfo20x5.GetNumElements());
+ armnn::ConstTensor inputToForgetWeights(tensorInfo20x5, inputToForgetWeightsData);
+
+ std::vector<float> inputToCellWeightsData = GenerateRandomData<float>(tensorInfo20x5.GetNumElements());
+ armnn::ConstTensor inputToCellWeights(tensorInfo20x5, inputToCellWeightsData);
+
+ std::vector<float> inputToOutputWeightsData = GenerateRandomData<float>(tensorInfo20x5.GetNumElements());
+ armnn::ConstTensor inputToOutputWeights(tensorInfo20x5, inputToOutputWeightsData);
+
+ armnn::TensorInfo tensorInfo20({numUnits}, armnn::DataType::Float32);
+ std::vector<float> inputGateBiasData = GenerateRandomData<float>(tensorInfo20.GetNumElements());
+ armnn::ConstTensor inputGateBias(tensorInfo20, inputGateBiasData);
+
+ std::vector<float> forgetGateBiasData = GenerateRandomData<float>(tensorInfo20.GetNumElements());
+ armnn::ConstTensor forgetGateBias(tensorInfo20, forgetGateBiasData);
+
+ std::vector<float> cellBiasData = GenerateRandomData<float>(tensorInfo20.GetNumElements());
+ armnn::ConstTensor cellBias(tensorInfo20, cellBiasData);
+
+ std::vector<float> outputGateBiasData = GenerateRandomData<float>(tensorInfo20.GetNumElements());
+ armnn::ConstTensor outputGateBias(tensorInfo20, outputGateBiasData);
+
+ armnn::TensorInfo tensorInfo20x16({numUnits, outputSize}, armnn::DataType::Float32);
+ std::vector<float> recurrentToInputWeightsData = GenerateRandomData<float>(tensorInfo20x16.GetNumElements());
+ armnn::ConstTensor recurrentToInputWeights(tensorInfo20x16, recurrentToInputWeightsData);
+
+ std::vector<float> recurrentToForgetWeightsData = GenerateRandomData<float>(tensorInfo20x16.GetNumElements());
+ armnn::ConstTensor recurrentToForgetWeights(tensorInfo20x16, recurrentToForgetWeightsData);
+
+ std::vector<float> recurrentToCellWeightsData = GenerateRandomData<float>(tensorInfo20x16.GetNumElements());
+ armnn::ConstTensor recurrentToCellWeights(tensorInfo20x16, recurrentToCellWeightsData);
+
+ std::vector<float> recurrentToOutputWeightsData = GenerateRandomData<float>(tensorInfo20x16.GetNumElements());
+ armnn::ConstTensor recurrentToOutputWeights(tensorInfo20x16, recurrentToOutputWeightsData);
+
+ std::vector<float> cellToInputWeightsData = GenerateRandomData<float>(tensorInfo20.GetNumElements());
+ armnn::ConstTensor cellToInputWeights(tensorInfo20, cellToInputWeightsData);
+
+ std::vector<float> cellToForgetWeightsData = GenerateRandomData<float>(tensorInfo20.GetNumElements());
+ armnn::ConstTensor cellToForgetWeights(tensorInfo20, cellToForgetWeightsData);
+
+ std::vector<float> cellToOutputWeightsData = GenerateRandomData<float>(tensorInfo20.GetNumElements());
+ armnn::ConstTensor cellToOutputWeights(tensorInfo20, cellToOutputWeightsData);
+
+ armnn::TensorInfo tensorInfo16x20({outputSize, numUnits}, armnn::DataType::Float32);
+ std::vector<float> projectionWeightsData = GenerateRandomData<float>(tensorInfo16x20.GetNumElements());
+ armnn::ConstTensor projectionWeights(tensorInfo16x20, projectionWeightsData);
+
+ armnn::TensorInfo tensorInfo16({outputSize}, armnn::DataType::Float32);
+ std::vector<float> projectionBiasData(outputSize, 0.f);
+ armnn::ConstTensor projectionBias(tensorInfo16, projectionBiasData);
+
+ armnn::LstmInputParams params;
+ params.m_InputToForgetWeights = &inputToForgetWeights;
+ params.m_InputToCellWeights = &inputToCellWeights;
+ params.m_InputToOutputWeights = &inputToOutputWeights;
+ params.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
+ params.m_RecurrentToCellWeights = &recurrentToCellWeights;
+ params.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
+ params.m_ForgetGateBias = &forgetGateBias;
+ params.m_CellBias = &cellBias;
+ params.m_OutputGateBias = &outputGateBias;
+
+ // additional params because: descriptor.m_CifgEnabled = false
+ params.m_InputToInputWeights = &inputToInputWeights;
+ params.m_RecurrentToInputWeights = &recurrentToInputWeights;
+ params.m_CellToInputWeights = &cellToInputWeights;
+ params.m_InputGateBias = &inputGateBias;
+
+ // additional params because: descriptor.m_ProjectionEnabled = true
+ params.m_ProjectionWeights = &projectionWeights;
+ params.m_ProjectionBias = &projectionBias;
+
+ // additional params because: descriptor.m_PeepholeEnabled = true
+ params.m_CellToForgetWeights = &cellToForgetWeights;
+ params.m_CellToOutputWeights = &cellToOutputWeights;
+
+ armnn::INetworkPtr network = armnn::INetwork::Create();
+ armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0);
+ armnn::IConnectableLayer* const cellStateIn = network->AddInputLayer(1);
+ armnn::IConnectableLayer* const outputStateIn = network->AddInputLayer(2);
+ const std::string layerName("lstm");
+ armnn::IConnectableLayer* const lstmLayer = network->AddLstmLayer(descriptor, params, layerName.c_str());
+ armnn::IConnectableLayer* const scratchBuffer = network->AddOutputLayer(0);
+ armnn::IConnectableLayer* const outputStateOut = network->AddOutputLayer(1);
+ armnn::IConnectableLayer* const cellStateOut = network->AddOutputLayer(2);
+ armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(3);
+
+ // connect up
+ armnn::TensorInfo inputTensorInfo({ batchSize, inputSize }, armnn::DataType::Float32);
+ armnn::TensorInfo cellStateTensorInfo({ batchSize, numUnits}, armnn::DataType::Float32);
+ armnn::TensorInfo outputStateTensorInfo({ batchSize, outputSize }, armnn::DataType::Float32);
+ armnn::TensorInfo lstmTensorInfoScratchBuff({ batchSize, numUnits * 4 }, armnn::DataType::Float32);
+
+ inputLayer->GetOutputSlot(0).Connect(lstmLayer->GetInputSlot(0));
+ inputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo);
+
+ outputStateIn->GetOutputSlot(0).Connect(lstmLayer->GetInputSlot(1));
+ outputStateIn->GetOutputSlot(0).SetTensorInfo(outputStateTensorInfo);
+
+ cellStateIn->GetOutputSlot(0).Connect(lstmLayer->GetInputSlot(2));
+ cellStateIn->GetOutputSlot(0).SetTensorInfo(cellStateTensorInfo);
+
+ lstmLayer->GetOutputSlot(0).Connect(scratchBuffer->GetInputSlot(0));
+ lstmLayer->GetOutputSlot(0).SetTensorInfo(lstmTensorInfoScratchBuff);
+
+ lstmLayer->GetOutputSlot(1).Connect(outputStateOut->GetInputSlot(0));
+ lstmLayer->GetOutputSlot(1).SetTensorInfo(outputStateTensorInfo);
+
+ lstmLayer->GetOutputSlot(2).Connect(cellStateOut->GetInputSlot(0));
+ lstmLayer->GetOutputSlot(2).SetTensorInfo(cellStateTensorInfo);
+
+ lstmLayer->GetOutputSlot(3).Connect(outputLayer->GetInputSlot(0));
+ lstmLayer->GetOutputSlot(3).SetTensorInfo(outputStateTensorInfo);
+
+ armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
+ BOOST_CHECK(deserializedNetwork);
+
+ VerifyLstmLayer checker(
+ layerName,
+ {inputTensorInfo, outputStateTensorInfo, cellStateTensorInfo},
+ {lstmTensorInfoScratchBuff, outputStateTensorInfo, cellStateTensorInfo, outputStateTensorInfo},
+ descriptor,
+ params);
+ deserializedNetwork->Accept(checker);
+}
+
BOOST_AUTO_TEST_SUITE_END()