diff options
author | Mike Kelly <mike.kelly@arm.com> | 2019-02-20 16:53:11 +0000 |
---|---|---|
committer | Mike Kelly <mike.kelly@arm.com> | 2019-02-20 16:53:11 +0000 |
commit | af484013329a8ca5b3c4c9d16395fb79dd19b1b2 (patch) | |
tree | 90a6e08d99a3856403c79395cd4b58bad8755e4a /src/armnnSerializer/test/ActivationSerializationTests.cpp | |
parent | 0028d1b0ce5f4c2c6a6eb3c66f38111c21eb47a3 (diff) | |
download | armnn-af484013329a8ca5b3c4c9d16395fb79dd19b1b2.tar.gz |
IVGCVSW-2643 Add Serializer & Deserializer for Activation
* Added ActivationLayer to Schema.fbs
* Added Activation serialization and deserialization support
* Added serialization and deserialization unit tests
Change-Id: Ib5df45f123674988b994ffe3f111d3fb57864912
Signed-off-by: Mike Kelly <mike.kelly@arm.com>
Diffstat (limited to 'src/armnnSerializer/test/ActivationSerializationTests.cpp')
-rw-r--r-- | src/armnnSerializer/test/ActivationSerializationTests.cpp | 78 |
1 files changed, 78 insertions, 0 deletions
diff --git a/src/armnnSerializer/test/ActivationSerializationTests.cpp b/src/armnnSerializer/test/ActivationSerializationTests.cpp new file mode 100644 index 0000000000..c20f2864f9 --- /dev/null +++ b/src/armnnSerializer/test/ActivationSerializationTests.cpp @@ -0,0 +1,78 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include <armnnDeserializer/IDeserializer.hpp> +#include <armnn/ArmNN.hpp> +#include <armnn/INetwork.hpp> +#include "../Serializer.hpp" +#include <sstream> +#include <boost/test/unit_test.hpp> + +BOOST_AUTO_TEST_SUITE(SerializerTests) + +BOOST_AUTO_TEST_CASE(ActivationSerialization) +{ + armnnDeserializer::IDeserializerPtr parser = armnnDeserializer::IDeserializer::Create(); + + armnn::TensorInfo inputInfo(armnn::TensorShape({1, 2, 2, 1}), armnn::DataType::Float32, 1.0f, 0); + armnn::TensorInfo outputInfo(armnn::TensorShape({1, 2, 2, 1}), armnn::DataType::Float32, 4.0f, 0); + + // Construct network + armnn::INetworkPtr network = armnn::INetwork::Create(); + + armnn::ActivationDescriptor descriptor; + descriptor.m_Function = armnn::ActivationFunction::ReLu; + descriptor.m_A = 0; + descriptor.m_B = 0; + + armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0, "input"); + armnn::IConnectableLayer* const activationLayer = network->AddActivationLayer(descriptor, "activation"); + armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0, "output"); + + inputLayer->GetOutputSlot(0).Connect(activationLayer->GetInputSlot(0)); + inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo); + + activationLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); + activationLayer->GetOutputSlot(0).SetTensorInfo(outputInfo); + + armnnSerializer::Serializer serializer; + serializer.Serialize(*network); + + std::stringstream stream; + serializer.SaveSerializedToStream(stream); + + std::string const serializerString{stream.str()}; + std::vector<std::uint8_t> const serializerVector{serializerString.begin(), serializerString.end()}; + + armnn::INetworkPtr deserializedNetwork = parser->CreateNetworkFromBinary(serializerVector); + + armnn::IRuntime::CreationOptions options; // default options + armnn::IRuntimePtr run = armnn::IRuntime::Create(options); + auto deserializedOptimized = Optimize(*deserializedNetwork, { armnn::Compute::CpuRef }, run->GetDeviceSpec()); + + armnn::NetworkId networkIdentifier; + + // Load graph into runtime + run->LoadNetwork(networkIdentifier, std::move(deserializedOptimized)); + + std::vector<float> inputData {0.0f, -5.3f, 42.0f, -42.0f}; + armnn::InputTensors inputTensors + { + {0, armnn::ConstTensor(run->GetInputTensorInfo(networkIdentifier, 0), inputData.data())} + }; + + std::vector<float> expectedOutputData {0.0f, 0.0f, 42.0f, 0.0f}; + + std::vector<float> outputData(4); + armnn::OutputTensors outputTensors + { + {0, armnn::Tensor(run->GetOutputTensorInfo(networkIdentifier, 0), outputData.data())} + }; + run->EnqueueWorkload(networkIdentifier, inputTensors, outputTensors); + BOOST_CHECK_EQUAL_COLLECTIONS(outputData.begin(), outputData.end(), + expectedOutputData.begin(), expectedOutputData.end()); +} + +BOOST_AUTO_TEST_SUITE_END() |