diff options
Diffstat (limited to 'src/armnnSerializer/test/SerializerTests.cpp')
-rw-r--r-- | src/armnnSerializer/test/SerializerTests.cpp | 375 |
1 files changed, 375 insertions, 0 deletions
diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp index f40c02dfde..e3ce6d29d3 100644 --- a/src/armnnSerializer/test/SerializerTests.cpp +++ b/src/armnnSerializer/test/SerializerTests.cpp @@ -2047,4 +2047,379 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeNonLinearNetwork) deserializedNetwork->Accept(verifier); } +class VerifyLstmLayer : public LayerVerifierBase +{ +public: + VerifyLstmLayer(const std::string& layerName, + const std::vector<armnn::TensorInfo>& inputInfos, + const std::vector<armnn::TensorInfo>& outputInfos, + const armnn::LstmDescriptor& descriptor, + const armnn::LstmInputParams& inputParams) : + LayerVerifierBase(layerName, inputInfos, outputInfos), m_Descriptor(descriptor), m_InputParams(inputParams) + { + } + void VisitLstmLayer(const armnn::IConnectableLayer* layer, + const armnn::LstmDescriptor& descriptor, + const armnn::LstmInputParams& params, + const char* name) + { + VerifyNameAndConnections(layer, name); + VerifyDescriptor(descriptor); + VerifyInputParameters(params); + } +protected: + void VerifyDescriptor(const armnn::LstmDescriptor& descriptor) + { + BOOST_TEST(m_Descriptor.m_ActivationFunc == descriptor.m_ActivationFunc); + BOOST_TEST(m_Descriptor.m_ClippingThresCell == descriptor.m_ClippingThresCell); + BOOST_TEST(m_Descriptor.m_ClippingThresProj == descriptor.m_ClippingThresProj); + BOOST_TEST(m_Descriptor.m_CifgEnabled == descriptor.m_CifgEnabled); + BOOST_TEST(m_Descriptor.m_PeepholeEnabled = descriptor.m_PeepholeEnabled); + BOOST_TEST(m_Descriptor.m_ProjectionEnabled == descriptor.m_ProjectionEnabled); + } + void VerifyInputParameters(const armnn::LstmInputParams& params) + { + VerifyConstTensors( + "m_InputToInputWeights", m_InputParams.m_InputToInputWeights, params.m_InputToInputWeights); + VerifyConstTensors( + "m_InputToForgetWeights", m_InputParams.m_InputToForgetWeights, params.m_InputToForgetWeights); + VerifyConstTensors( + "m_InputToCellWeights", m_InputParams.m_InputToCellWeights, params.m_InputToCellWeights); + VerifyConstTensors( + "m_InputToOutputWeights", m_InputParams.m_InputToOutputWeights, params.m_InputToOutputWeights); + VerifyConstTensors( + "m_RecurrentToInputWeights", m_InputParams.m_RecurrentToInputWeights, params.m_RecurrentToInputWeights); + VerifyConstTensors( + "m_RecurrentToForgetWeights", m_InputParams.m_RecurrentToForgetWeights, params.m_RecurrentToForgetWeights); + VerifyConstTensors( + "m_RecurrentToCellWeights", m_InputParams.m_RecurrentToCellWeights, params.m_RecurrentToCellWeights); + VerifyConstTensors( + "m_RecurrentToOutputWeights", m_InputParams.m_RecurrentToOutputWeights, params.m_RecurrentToOutputWeights); + VerifyConstTensors( + "m_CellToInputWeights", m_InputParams.m_CellToInputWeights, params.m_CellToInputWeights); + VerifyConstTensors( + "m_CellToForgetWeights", m_InputParams.m_CellToForgetWeights, params.m_CellToForgetWeights); + VerifyConstTensors( + "m_CellToOutputWeights", m_InputParams.m_CellToOutputWeights, params.m_CellToOutputWeights); + VerifyConstTensors( + "m_InputGateBias", m_InputParams.m_InputGateBias, params.m_InputGateBias); + VerifyConstTensors( + "m_ForgetGateBias", m_InputParams.m_ForgetGateBias, params.m_ForgetGateBias); + VerifyConstTensors( + "m_CellBias", m_InputParams.m_CellBias, params.m_CellBias); + VerifyConstTensors( + "m_OutputGateBias", m_InputParams.m_OutputGateBias, params.m_OutputGateBias); + VerifyConstTensors( + "m_ProjectionWeights", m_InputParams.m_ProjectionWeights, params.m_ProjectionWeights); + VerifyConstTensors( + "m_ProjectionBias", m_InputParams.m_ProjectionBias, params.m_ProjectionBias); + } + void VerifyConstTensors(const std::string& tensorName, + const armnn::ConstTensor* expectedPtr, + const armnn::ConstTensor* actualPtr) + { + if (expectedPtr == nullptr) + { + BOOST_CHECK_MESSAGE(actualPtr == nullptr, tensorName + " should not exist"); + } + else + { + BOOST_CHECK_MESSAGE(actualPtr != nullptr, tensorName + " should have been set"); + if (actualPtr != nullptr) + { + const armnn::TensorInfo& expectedInfo = expectedPtr->GetInfo(); + const armnn::TensorInfo& actualInfo = actualPtr->GetInfo(); + + BOOST_CHECK_MESSAGE(expectedInfo.GetShape() == actualInfo.GetShape(), + tensorName + " shapes don't match"); + BOOST_CHECK_MESSAGE( + GetDataTypeName(expectedInfo.GetDataType()) == GetDataTypeName(actualInfo.GetDataType()), + tensorName + " data types don't match"); + + BOOST_CHECK_MESSAGE(expectedPtr->GetNumBytes() == actualPtr->GetNumBytes(), + tensorName + " (GetNumBytes) data sizes do not match"); + if (expectedPtr->GetNumBytes() == actualPtr->GetNumBytes()) + { + //check the data is identical + const char* expectedData = static_cast<const char*>(expectedPtr->GetMemoryArea()); + const char* actualData = static_cast<const char*>(actualPtr->GetMemoryArea()); + bool same = true; + for (unsigned int i = 0; i < expectedPtr->GetNumBytes(); ++i) + { + same = expectedData[i] == actualData[i]; + if (!same) + { + break; + } + } + BOOST_CHECK_MESSAGE(same, tensorName + " data does not match"); + } + } + } + } +private: + armnn::LstmDescriptor m_Descriptor; + armnn::LstmInputParams m_InputParams; +}; + +BOOST_AUTO_TEST_CASE(SerializeDeserializeLstmCifgPeepholeNoProjection) +{ + armnn::LstmDescriptor descriptor; + descriptor.m_ActivationFunc = 4; + descriptor.m_ClippingThresProj = 0.0f; + descriptor.m_ClippingThresCell = 0.0f; + descriptor.m_CifgEnabled = true; // if this is true then we DON'T need to set the OptCifgParams + descriptor.m_ProjectionEnabled = false; + descriptor.m_PeepholeEnabled = true; + + const uint32_t batchSize = 1; + const uint32_t inputSize = 2; + const uint32_t numUnits = 4; + const uint32_t outputSize = numUnits; + + armnn::TensorInfo inputWeightsInfo1({numUnits, inputSize}, armnn::DataType::Float32); + std::vector<float> inputToForgetWeightsData = GenerateRandomData<float>(inputWeightsInfo1.GetNumElements()); + armnn::ConstTensor inputToForgetWeights(inputWeightsInfo1, inputToForgetWeightsData); + + std::vector<float> inputToCellWeightsData = GenerateRandomData<float>(inputWeightsInfo1.GetNumElements()); + armnn::ConstTensor inputToCellWeights(inputWeightsInfo1, inputToCellWeightsData); + + std::vector<float> inputToOutputWeightsData = GenerateRandomData<float>(inputWeightsInfo1.GetNumElements()); + armnn::ConstTensor inputToOutputWeights(inputWeightsInfo1, inputToOutputWeightsData); + + armnn::TensorInfo inputWeightsInfo2({numUnits, outputSize}, armnn::DataType::Float32); + std::vector<float> recurrentToForgetWeightsData = GenerateRandomData<float>(inputWeightsInfo2.GetNumElements()); + armnn::ConstTensor recurrentToForgetWeights(inputWeightsInfo2, recurrentToForgetWeightsData); + + std::vector<float> recurrentToCellWeightsData = GenerateRandomData<float>(inputWeightsInfo2.GetNumElements()); + armnn::ConstTensor recurrentToCellWeights(inputWeightsInfo2, recurrentToCellWeightsData); + + std::vector<float> recurrentToOutputWeightsData = GenerateRandomData<float>(inputWeightsInfo2.GetNumElements()); + armnn::ConstTensor recurrentToOutputWeights(inputWeightsInfo2, recurrentToOutputWeightsData); + + armnn::TensorInfo inputWeightsInfo3({numUnits}, armnn::DataType::Float32); + std::vector<float> cellToForgetWeightsData = GenerateRandomData<float>(inputWeightsInfo3.GetNumElements()); + armnn::ConstTensor cellToForgetWeights(inputWeightsInfo3, cellToForgetWeightsData); + + std::vector<float> cellToOutputWeightsData = GenerateRandomData<float>(inputWeightsInfo3.GetNumElements()); + armnn::ConstTensor cellToOutputWeights(inputWeightsInfo3, cellToOutputWeightsData); + + std::vector<float> forgetGateBiasData(numUnits, 1.0f); + armnn::ConstTensor forgetGateBias(inputWeightsInfo3, forgetGateBiasData); + + std::vector<float> cellBiasData(numUnits, 0.0f); + armnn::ConstTensor cellBias(inputWeightsInfo3, cellBiasData); + + std::vector<float> outputGateBiasData(numUnits, 0.0f); + armnn::ConstTensor outputGateBias(inputWeightsInfo3, outputGateBiasData); + + armnn::LstmInputParams params; + params.m_InputToForgetWeights = &inputToForgetWeights; + params.m_InputToCellWeights = &inputToCellWeights; + params.m_InputToOutputWeights = &inputToOutputWeights; + params.m_RecurrentToForgetWeights = &recurrentToForgetWeights; + params.m_RecurrentToCellWeights = &recurrentToCellWeights; + params.m_RecurrentToOutputWeights = &recurrentToOutputWeights; + params.m_ForgetGateBias = &forgetGateBias; + params.m_CellBias = &cellBias; + params.m_OutputGateBias = &outputGateBias; + params.m_CellToForgetWeights = &cellToForgetWeights; + params.m_CellToOutputWeights = &cellToOutputWeights; + + armnn::INetworkPtr network = armnn::INetwork::Create(); + armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0); + armnn::IConnectableLayer* const cellStateIn = network->AddInputLayer(1); + armnn::IConnectableLayer* const outputStateIn = network->AddInputLayer(2); + const std::string layerName("lstm"); + armnn::IConnectableLayer* const lstmLayer = network->AddLstmLayer(descriptor, params, layerName.c_str()); + armnn::IConnectableLayer* const scratchBuffer = network->AddOutputLayer(0); + armnn::IConnectableLayer* const outputStateOut = network->AddOutputLayer(1); + armnn::IConnectableLayer* const cellStateOut = network->AddOutputLayer(2); + armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(3); + + // connect up + armnn::TensorInfo inputTensorInfo({ batchSize, inputSize }, armnn::DataType::Float32); + armnn::TensorInfo cellStateTensorInfo({ batchSize, numUnits}, armnn::DataType::Float32); + armnn::TensorInfo outputStateTensorInfo({ batchSize, outputSize }, armnn::DataType::Float32); + armnn::TensorInfo lstmTensorInfoScratchBuff({ batchSize, numUnits * 3 }, armnn::DataType::Float32); + + inputLayer->GetOutputSlot(0).Connect(lstmLayer->GetInputSlot(0)); + inputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo); + + outputStateIn->GetOutputSlot(0).Connect(lstmLayer->GetInputSlot(1)); + outputStateIn->GetOutputSlot(0).SetTensorInfo(outputStateTensorInfo); + + cellStateIn->GetOutputSlot(0).Connect(lstmLayer->GetInputSlot(2)); + cellStateIn->GetOutputSlot(0).SetTensorInfo(cellStateTensorInfo); + + lstmLayer->GetOutputSlot(0).Connect(scratchBuffer->GetInputSlot(0)); + lstmLayer->GetOutputSlot(0).SetTensorInfo(lstmTensorInfoScratchBuff); + + lstmLayer->GetOutputSlot(1).Connect(outputStateOut->GetInputSlot(0)); + lstmLayer->GetOutputSlot(1).SetTensorInfo(outputStateTensorInfo); + + lstmLayer->GetOutputSlot(2).Connect(cellStateOut->GetInputSlot(0)); + lstmLayer->GetOutputSlot(2).SetTensorInfo(cellStateTensorInfo); + + lstmLayer->GetOutputSlot(3).Connect(outputLayer->GetInputSlot(0)); + lstmLayer->GetOutputSlot(3).SetTensorInfo(outputStateTensorInfo); + + armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); + BOOST_CHECK(deserializedNetwork); + + VerifyLstmLayer checker( + layerName, + {inputTensorInfo, outputStateTensorInfo, cellStateTensorInfo}, + {lstmTensorInfoScratchBuff, outputStateTensorInfo, cellStateTensorInfo, outputStateTensorInfo}, + descriptor, + params); + deserializedNetwork->Accept(checker); +} + +BOOST_AUTO_TEST_CASE(SerializeDeserializeLstmNoCifgWithPeepholeAndProjection) +{ + armnn::LstmDescriptor descriptor; + descriptor.m_ActivationFunc = 4; + descriptor.m_ClippingThresProj = 0.0f; + descriptor.m_ClippingThresCell = 0.0f; + descriptor.m_CifgEnabled = false; // if this is true then we DON'T need to set the OptCifgParams + descriptor.m_ProjectionEnabled = true; + descriptor.m_PeepholeEnabled = true; + + const uint32_t batchSize = 2; + const uint32_t inputSize = 5; + const uint32_t numUnits = 20; + const uint32_t outputSize = 16; + + armnn::TensorInfo tensorInfo20x5({numUnits, inputSize}, armnn::DataType::Float32); + std::vector<float> inputToInputWeightsData = GenerateRandomData<float>(tensorInfo20x5.GetNumElements()); + armnn::ConstTensor inputToInputWeights(tensorInfo20x5, inputToInputWeightsData); + + std::vector<float> inputToForgetWeightsData = GenerateRandomData<float>(tensorInfo20x5.GetNumElements()); + armnn::ConstTensor inputToForgetWeights(tensorInfo20x5, inputToForgetWeightsData); + + std::vector<float> inputToCellWeightsData = GenerateRandomData<float>(tensorInfo20x5.GetNumElements()); + armnn::ConstTensor inputToCellWeights(tensorInfo20x5, inputToCellWeightsData); + + std::vector<float> inputToOutputWeightsData = GenerateRandomData<float>(tensorInfo20x5.GetNumElements()); + armnn::ConstTensor inputToOutputWeights(tensorInfo20x5, inputToOutputWeightsData); + + armnn::TensorInfo tensorInfo20({numUnits}, armnn::DataType::Float32); + std::vector<float> inputGateBiasData = GenerateRandomData<float>(tensorInfo20.GetNumElements()); + armnn::ConstTensor inputGateBias(tensorInfo20, inputGateBiasData); + + std::vector<float> forgetGateBiasData = GenerateRandomData<float>(tensorInfo20.GetNumElements()); + armnn::ConstTensor forgetGateBias(tensorInfo20, forgetGateBiasData); + + std::vector<float> cellBiasData = GenerateRandomData<float>(tensorInfo20.GetNumElements()); + armnn::ConstTensor cellBias(tensorInfo20, cellBiasData); + + std::vector<float> outputGateBiasData = GenerateRandomData<float>(tensorInfo20.GetNumElements()); + armnn::ConstTensor outputGateBias(tensorInfo20, outputGateBiasData); + + armnn::TensorInfo tensorInfo20x16({numUnits, outputSize}, armnn::DataType::Float32); + std::vector<float> recurrentToInputWeightsData = GenerateRandomData<float>(tensorInfo20x16.GetNumElements()); + armnn::ConstTensor recurrentToInputWeights(tensorInfo20x16, recurrentToInputWeightsData); + + std::vector<float> recurrentToForgetWeightsData = GenerateRandomData<float>(tensorInfo20x16.GetNumElements()); + armnn::ConstTensor recurrentToForgetWeights(tensorInfo20x16, recurrentToForgetWeightsData); + + std::vector<float> recurrentToCellWeightsData = GenerateRandomData<float>(tensorInfo20x16.GetNumElements()); + armnn::ConstTensor recurrentToCellWeights(tensorInfo20x16, recurrentToCellWeightsData); + + std::vector<float> recurrentToOutputWeightsData = GenerateRandomData<float>(tensorInfo20x16.GetNumElements()); + armnn::ConstTensor recurrentToOutputWeights(tensorInfo20x16, recurrentToOutputWeightsData); + + std::vector<float> cellToInputWeightsData = GenerateRandomData<float>(tensorInfo20.GetNumElements()); + armnn::ConstTensor cellToInputWeights(tensorInfo20, cellToInputWeightsData); + + std::vector<float> cellToForgetWeightsData = GenerateRandomData<float>(tensorInfo20.GetNumElements()); + armnn::ConstTensor cellToForgetWeights(tensorInfo20, cellToForgetWeightsData); + + std::vector<float> cellToOutputWeightsData = GenerateRandomData<float>(tensorInfo20.GetNumElements()); + armnn::ConstTensor cellToOutputWeights(tensorInfo20, cellToOutputWeightsData); + + armnn::TensorInfo tensorInfo16x20({outputSize, numUnits}, armnn::DataType::Float32); + std::vector<float> projectionWeightsData = GenerateRandomData<float>(tensorInfo16x20.GetNumElements()); + armnn::ConstTensor projectionWeights(tensorInfo16x20, projectionWeightsData); + + armnn::TensorInfo tensorInfo16({outputSize}, armnn::DataType::Float32); + std::vector<float> projectionBiasData(outputSize, 0.f); + armnn::ConstTensor projectionBias(tensorInfo16, projectionBiasData); + + armnn::LstmInputParams params; + params.m_InputToForgetWeights = &inputToForgetWeights; + params.m_InputToCellWeights = &inputToCellWeights; + params.m_InputToOutputWeights = &inputToOutputWeights; + params.m_RecurrentToForgetWeights = &recurrentToForgetWeights; + params.m_RecurrentToCellWeights = &recurrentToCellWeights; + params.m_RecurrentToOutputWeights = &recurrentToOutputWeights; + params.m_ForgetGateBias = &forgetGateBias; + params.m_CellBias = &cellBias; + params.m_OutputGateBias = &outputGateBias; + + // additional params because: descriptor.m_CifgEnabled = false + params.m_InputToInputWeights = &inputToInputWeights; + params.m_RecurrentToInputWeights = &recurrentToInputWeights; + params.m_CellToInputWeights = &cellToInputWeights; + params.m_InputGateBias = &inputGateBias; + + // additional params because: descriptor.m_ProjectionEnabled = true + params.m_ProjectionWeights = &projectionWeights; + params.m_ProjectionBias = &projectionBias; + + // additional params because: descriptor.m_PeepholeEnabled = true + params.m_CellToForgetWeights = &cellToForgetWeights; + params.m_CellToOutputWeights = &cellToOutputWeights; + + armnn::INetworkPtr network = armnn::INetwork::Create(); + armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0); + armnn::IConnectableLayer* const cellStateIn = network->AddInputLayer(1); + armnn::IConnectableLayer* const outputStateIn = network->AddInputLayer(2); + const std::string layerName("lstm"); + armnn::IConnectableLayer* const lstmLayer = network->AddLstmLayer(descriptor, params, layerName.c_str()); + armnn::IConnectableLayer* const scratchBuffer = network->AddOutputLayer(0); + armnn::IConnectableLayer* const outputStateOut = network->AddOutputLayer(1); + armnn::IConnectableLayer* const cellStateOut = network->AddOutputLayer(2); + armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(3); + + // connect up + armnn::TensorInfo inputTensorInfo({ batchSize, inputSize }, armnn::DataType::Float32); + armnn::TensorInfo cellStateTensorInfo({ batchSize, numUnits}, armnn::DataType::Float32); + armnn::TensorInfo outputStateTensorInfo({ batchSize, outputSize }, armnn::DataType::Float32); + armnn::TensorInfo lstmTensorInfoScratchBuff({ batchSize, numUnits * 4 }, armnn::DataType::Float32); + + inputLayer->GetOutputSlot(0).Connect(lstmLayer->GetInputSlot(0)); + inputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo); + + outputStateIn->GetOutputSlot(0).Connect(lstmLayer->GetInputSlot(1)); + outputStateIn->GetOutputSlot(0).SetTensorInfo(outputStateTensorInfo); + + cellStateIn->GetOutputSlot(0).Connect(lstmLayer->GetInputSlot(2)); + cellStateIn->GetOutputSlot(0).SetTensorInfo(cellStateTensorInfo); + + lstmLayer->GetOutputSlot(0).Connect(scratchBuffer->GetInputSlot(0)); + lstmLayer->GetOutputSlot(0).SetTensorInfo(lstmTensorInfoScratchBuff); + + lstmLayer->GetOutputSlot(1).Connect(outputStateOut->GetInputSlot(0)); + lstmLayer->GetOutputSlot(1).SetTensorInfo(outputStateTensorInfo); + + lstmLayer->GetOutputSlot(2).Connect(cellStateOut->GetInputSlot(0)); + lstmLayer->GetOutputSlot(2).SetTensorInfo(cellStateTensorInfo); + + lstmLayer->GetOutputSlot(3).Connect(outputLayer->GetInputSlot(0)); + lstmLayer->GetOutputSlot(3).SetTensorInfo(outputStateTensorInfo); + + armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); + BOOST_CHECK(deserializedNetwork); + + VerifyLstmLayer checker( + layerName, + {inputTensorInfo, outputStateTensorInfo, cellStateTensorInfo}, + {lstmTensorInfoScratchBuff, outputStateTensorInfo, cellStateTensorInfo, outputStateTensorInfo}, + descriptor, + params); + deserializedNetwork->Accept(checker); +} + BOOST_AUTO_TEST_SUITE_END() |