diff options
author | Matthew Jackson <matthew.jackson@arm.com> | 2019-07-11 15:54:20 +0100 |
---|---|---|
committer | Matthew Jackson <matthew.jackson@arm.com> | 2019-07-16 09:05:14 +0000 |
commit | b5433ee34fd9d38c1453dc062b36348d65677002 (patch) | |
tree | f48ed9f6ffb25aed908c5aa775b49bef78fd906f /src/armnnSerializer/test/SerializerTests.cpp | |
parent | 15a9a8f8f5e05c9967c3a52ecbfb7e173e9e61dd (diff) | |
download | armnn-b5433ee34fd9d38c1453dc062b36348d65677002.tar.gz |
IVGCVSW-3420 Add Serialization support for the new Stack layer
* Adds serialization/deserialization support
* Adds related unit test
Signed-off-by: Matthew Jackson <matthew.jackson@arm.com>
Change-Id: I69deb5397b8a06c679715e24971e9bb1c282140d
Diffstat (limited to 'src/armnnSerializer/test/SerializerTests.cpp')
-rw-r--r-- | src/armnnSerializer/test/SerializerTests.cpp | 59 |
1 files changed, 59 insertions, 0 deletions
diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp index 33f10ef435..3d74d88e30 100644 --- a/src/armnnSerializer/test/SerializerTests.cpp +++ b/src/armnnSerializer/test/SerializerTests.cpp @@ -2379,6 +2379,65 @@ BOOST_AUTO_TEST_CASE(SerializeSplitter) deserializedNetwork->Accept(verifier); } +BOOST_AUTO_TEST_CASE(SerializeStack) +{ + class StackLayerVerifier : public LayerVerifierBase + { + public: + StackLayerVerifier(const std::string& layerName, + const std::vector<armnn::TensorInfo>& inputInfos, + const std::vector<armnn::TensorInfo>& outputInfos, + const armnn::StackDescriptor& descriptor) + : LayerVerifierBase(layerName, inputInfos, outputInfos) + , m_Descriptor(descriptor) {} + + void VisitStackLayer(const armnn::IConnectableLayer* layer, + const armnn::StackDescriptor& descriptor, + const char* name) override + { + VerifyNameAndConnections(layer, name); + VerifyDescriptor(descriptor); + } + + private: + void VerifyDescriptor(const armnn::StackDescriptor& descriptor) + { + BOOST_TEST(descriptor.m_Axis == m_Descriptor.m_Axis); + BOOST_TEST(descriptor.m_InputShape == m_Descriptor.m_InputShape); + BOOST_TEST(descriptor.m_NumInputs == m_Descriptor.m_NumInputs); + } + + armnn::StackDescriptor m_Descriptor; + }; + + const std::string layerName("stack"); + + armnn::TensorInfo inputTensorInfo ({4, 3, 5}, armnn::DataType::Float32); + armnn::TensorInfo outputTensorInfo({4, 3, 2, 5}, armnn::DataType::Float32); + + armnn::StackDescriptor descriptor(2, 2, {4, 3, 5}); + + armnn::INetworkPtr network = armnn::INetwork::Create(); + armnn::IConnectableLayer* const inputLayer1 = network->AddInputLayer(0); + armnn::IConnectableLayer* const inputLayer2 = network->AddInputLayer(1); + armnn::IConnectableLayer* const stackLayer = network->AddStackLayer(descriptor, layerName.c_str()); + armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0); + + inputLayer1->GetOutputSlot(0).Connect(stackLayer->GetInputSlot(0)); + inputLayer2->GetOutputSlot(0).Connect(stackLayer->GetInputSlot(1)); + stackLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); + + inputLayer1->GetOutputSlot(0).SetTensorInfo(inputTensorInfo); + inputLayer2->GetOutputSlot(0).SetTensorInfo(inputTensorInfo); + stackLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + + armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); + BOOST_CHECK(deserializedNetwork); + + StackLayerVerifier verifier(layerName, {inputTensorInfo, inputTensorInfo}, {outputTensorInfo}, descriptor); + deserializedNetwork->Accept(verifier); +} + BOOST_AUTO_TEST_CASE(SerializeStridedSlice) { class StridedSliceLayerVerifier : public LayerVerifierBase |