From fc413c0c977e6c9680a2aa6546e977be0a2efdb9 Mon Sep 17 00:00:00 2001 From: Aron Virginas-Tar Date: Wed, 13 Feb 2019 15:41:52 +0000 Subject: IVGCVSW-2644 Add Serializer & Deserializer for Softmax Change-Id: Ifea2108e173d2b602162fe53b880a68e1c715510 Signed-off-by: Aron Virginas-Tar --- src/armnnSerializer/test/SerializerTests.cpp | 91 ++++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) (limited to 'src/armnnSerializer/test') diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp index ab4bc0fe0b..5b55682dfa 100644 --- a/src/armnnSerializer/test/SerializerTests.cpp +++ b/src/armnnSerializer/test/SerializerTests.cpp @@ -5,12 +5,23 @@ #include #include + #include "../Serializer.hpp" + +#include + +#include #include +#include + #include +#include + BOOST_AUTO_TEST_SUITE(SerializerTests) +armnnDeserializeParser::IDeserializeParserPtr g_Parser = armnnDeserializeParser::IDeserializeParser::Create(); + BOOST_AUTO_TEST_CASE(SimpleNetworkSerialization) { armnn::INetworkPtr network = armnn::INetwork::Create(); @@ -58,4 +69,84 @@ BOOST_AUTO_TEST_CASE(SimpleNetworkWithMultiplicationSerialization) BOOST_TEST(stream.str().find(multLayerName) != stream.str().npos); } +BOOST_AUTO_TEST_CASE(SimpleSoftmaxIntegration) +{ + armnn::TensorInfo tensorInfo({1, 10}, armnn::DataType::Float32); + + armnn::SoftmaxDescriptor descriptor; + descriptor.m_Beta = 1.0f; + + // Create test network + armnn::INetworkPtr network = armnn::INetwork::Create(); + armnn::IConnectableLayer *const inputLayer = network->AddInputLayer(0); + armnn::IConnectableLayer *const softmaxLayer = network->AddSoftmaxLayer(descriptor, "softmax"); + armnn::IConnectableLayer *const outputLayer = network->AddOutputLayer(0); + + inputLayer->GetOutputSlot(0).Connect(softmaxLayer->GetInputSlot(0)); + inputLayer->GetOutputSlot(0).SetTensorInfo(tensorInfo); + softmaxLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); + softmaxLayer->GetOutputSlot(0).SetTensorInfo(tensorInfo); + + // Serialize + armnnSerializer::Serializer serializer; + serializer.Serialize(*network); + std::stringstream stream; + serializer.SaveSerializedToStream(stream); + const std::string serializerString{stream.str()}; + + // Deserialize + armnn::INetworkPtr deserializedNetwork = + g_Parser->CreateNetworkFromBinary({serializerString.begin(), serializerString.end()}); + BOOST_CHECK(deserializedNetwork); + + armnn::IRuntime::CreationOptions options; + armnn::IRuntimePtr run = armnn::IRuntime::Create(options); + + armnn::IOptimizedNetworkPtr optimizedNetwork = + armnn::Optimize(*network, {armnn::Compute::CpuRef}, run->GetDeviceSpec()); + BOOST_CHECK(optimizedNetwork); + + armnn::IOptimizedNetworkPtr deserializedOptimizedNetwork = + armnn::Optimize(*deserializedNetwork, {armnn::Compute::CpuRef}, run->GetDeviceSpec()); + BOOST_CHECK(deserializedOptimizedNetwork); + + armnn::NetworkId networkId1; + armnn::NetworkId networkId2; + + run->LoadNetwork(networkId1, std::move(optimizedNetwork)); + run->LoadNetwork(networkId2, std::move(deserializedOptimizedNetwork)); + + std::vector inputData(tensorInfo.GetNumElements()); + std::iota(inputData.begin(), inputData.end(), 0); + + armnn::InputTensors inputTensors1 + { + {0, armnn::ConstTensor(run->GetInputTensorInfo(networkId1, 0), inputData.data())} + }; + + armnn::InputTensors inputTensors2 + { + {0, armnn::ConstTensor(run->GetInputTensorInfo(networkId2, 0), inputData.data())} + }; + + std::vector outputData1(inputData.size()); + std::vector outputData2(inputData.size()); + + armnn::OutputTensors outputTensors1 + { + {0, armnn::Tensor(run->GetOutputTensorInfo(networkId1, 0), outputData1.data())} + }; + + armnn::OutputTensors outputTensors2 + { + {0, armnn::Tensor(run->GetOutputTensorInfo(networkId2, 0), outputData2.data())} + }; + + run->EnqueueWorkload(networkId1, inputTensors1, outputTensors1); + run->EnqueueWorkload(networkId2, inputTensors2, outputTensors2); + + BOOST_CHECK_EQUAL_COLLECTIONS(outputData1.begin(), outputData1.end(), + outputData2.begin(), outputData2.end()); +} + BOOST_AUTO_TEST_SUITE_END() -- cgit v1.2.1