From dbb0c0ca0c8425886ee3a2095e0ced07099134f9 Mon Sep 17 00:00:00 2001 From: Sadik Armagan Date: Thu, 21 Feb 2019 09:01:41 +0000 Subject: IVGCVSW-2639 Add Serializer & Deserializer for Fully Connected * Added FullyConnectedLayer to Serializer Schema Schema.fbs * Added FullyConnected serialization and deserialization support * Added FullyConnected serialization and deserialization unit tests Change-Id: I8ef14f9728158f849fa4d1a8d05a1a4170cd5b41 Signed-off-by: Sadik Armagan Signed-off-by: Aron Virginas-Tar --- src/armnnSerializer/test/SerializerTests.cpp | 41 ++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) (limited to 'src/armnnSerializer/test') diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp index 822f9c7e00..ede24baf9e 100644 --- a/src/armnnSerializer/test/SerializerTests.cpp +++ b/src/armnnSerializer/test/SerializerTests.cpp @@ -404,4 +404,45 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializePermute) outputTensorInfo.GetShape()); } +BOOST_AUTO_TEST_CASE(SerializeDeserializeFullyConnected) +{ + armnn::TensorInfo inputInfo ({ 2, 5, 1, 1 }, armnn::DataType::Float32); + armnn::TensorInfo outputInfo({ 2, 3 }, armnn::DataType::Float32); + + armnn::TensorInfo weightsInfo({ 5, 3 }, armnn::DataType::Float32); + armnn::TensorInfo biasesInfo ({ 3 }, armnn::DataType::Float32); + + armnn::FullyConnectedDescriptor descriptor; + descriptor.m_BiasEnabled = true; + descriptor.m_TransposeWeightMatrix = false; + + std::vector weightsData = GenerateRandomData(weightsInfo.GetNumElements()); + std::vector biasesData = GenerateRandomData(biasesInfo.GetNumElements()); + + armnn::ConstTensor weights(weightsInfo, weightsData); + armnn::ConstTensor biases(biasesInfo, biasesData); + + armnn::INetworkPtr network = armnn::INetwork::Create(); + armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0, "input"); + armnn::IConnectableLayer* const fullyConnectedLayer = network->AddFullyConnectedLayer(descriptor, + weights, + biases, + "fully_connected"); + armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0, "output"); + + inputLayer->GetOutputSlot(0).Connect(fullyConnectedLayer->GetInputSlot(0)); + inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo); + + fullyConnectedLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); + fullyConnectedLayer->GetOutputSlot(0).SetTensorInfo(outputInfo); + + armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); + BOOST_CHECK(deserializedNetwork); + + CheckDeserializedNetworkAgainstOriginal(*network, + *deserializedNetwork, + inputInfo.GetShape(), + outputInfo.GetShape()); +} + BOOST_AUTO_TEST_SUITE_END() -- cgit v1.2.1