diff options
Diffstat (limited to 'src/armnnSerializer')
-rw-r--r-- | src/armnnSerializer/ArmnnSchema.fbs | 10 | ||||
-rw-r--r-- | src/armnnSerializer/Serializer.cpp | 8 | ||||
-rw-r--r-- | src/armnnSerializer/Serializer.hpp | 3 | ||||
-rw-r--r-- | src/armnnSerializer/SerializerSupport.md | 1 | ||||
-rw-r--r-- | src/armnnSerializer/test/SerializerTests.cpp | 40 |
5 files changed, 60 insertions, 2 deletions
diff --git a/src/armnnSerializer/ArmnnSchema.fbs b/src/armnnSerializer/ArmnnSchema.fbs index 3aa644dbe5..8b275b6f17 100644 --- a/src/armnnSerializer/ArmnnSchema.fbs +++ b/src/armnnSerializer/ArmnnSchema.fbs @@ -118,7 +118,8 @@ enum LayerType : uint { DetectionPostProcess = 33, Lstm = 34, Quantize = 35, - Dequantize = 36 + Dequantize = 36, + Merge = 37 } // Base layer table to be used as part of other layers @@ -524,6 +525,10 @@ table DequantizeLayer { base:LayerBase; } +table MergeLayer { + base:LayerBase; +} + union Layer { ActivationLayer, AdditionLayer, @@ -561,7 +566,8 @@ union Layer { DetectionPostProcessLayer, LstmLayer, QuantizeLayer, - DequantizeLayer + DequantizeLayer, + MergeLayer } table AnyLayer { diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp index 7181f01e6b..fe30c3eee5 100644 --- a/src/armnnSerializer/Serializer.cpp +++ b/src/armnnSerializer/Serializer.cpp @@ -500,6 +500,14 @@ void SerializerVisitor::VisitMinimumLayer(const armnn::IConnectableLayer* layer, CreateAnyLayer(fbMinimumLayer.o, serializer::Layer::Layer_MinimumLayer); } +void SerializerVisitor::VisitMergeLayer(const armnn::IConnectableLayer* layer, const char* name) +{ + auto fbMergeBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Merge); + auto fbMergeLayer = serializer::CreateMergeLayer(m_flatBufferBuilder, fbMergeBaseLayer); + + CreateAnyLayer(fbMergeLayer.o, serializer::Layer::Layer_MergeLayer); +} + void SerializerVisitor::VisitMergerLayer(const armnn::IConnectableLayer* layer, const armnn::OriginsDescriptor& mergerDescriptor, const char* name) diff --git a/src/armnnSerializer/Serializer.hpp b/src/armnnSerializer/Serializer.hpp index 5c3e48a695..775df83966 100644 --- a/src/armnnSerializer/Serializer.hpp +++ b/src/armnnSerializer/Serializer.hpp @@ -129,6 +129,9 @@ public: void VisitMaximumLayer(const armnn::IConnectableLayer* layer, const char* name = nullptr) override; + void VisitMergeLayer(const armnn::IConnectableLayer* layer, + const char* name = nullptr) override; + void VisitMergerLayer(const armnn::IConnectableLayer* layer, const armnn::OriginsDescriptor& mergerDescriptor, const char* name = nullptr) override; diff --git a/src/armnnSerializer/SerializerSupport.md b/src/armnnSerializer/SerializerSupport.md index a3c5852bd2..a8335e1e68 100644 --- a/src/armnnSerializer/SerializerSupport.md +++ b/src/armnnSerializer/SerializerSupport.md @@ -25,6 +25,7 @@ The Arm NN SDK Serializer currently supports the following layers: * Lstm * Maximum * Mean +* Merge * Merger * Minimum * Multiplication diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp index 0979076476..a1ef9eef59 100644 --- a/src/armnnSerializer/test/SerializerTests.cpp +++ b/src/armnnSerializer/test/SerializerTests.cpp @@ -1185,6 +1185,46 @@ BOOST_AUTO_TEST_CASE(SerializeMean) deserializedNetwork->Accept(verifier); } +BOOST_AUTO_TEST_CASE(SerializeMerge) +{ + class MergeLayerVerifier : public LayerVerifierBase + { + public: + MergeLayerVerifier(const std::string& layerName, + const std::vector<armnn::TensorInfo>& inputInfos, + const std::vector<armnn::TensorInfo>& outputInfos) + : LayerVerifierBase(layerName, inputInfos, outputInfos) {} + + void VisitMergeLayer(const armnn::IConnectableLayer* layer, const char* name) override + { + VerifyNameAndConnections(layer, name); + } + }; + + const std::string layerName("merge"); + const armnn::TensorInfo info({ 1, 2, 2, 3 }, armnn::DataType::Float32); + + armnn::INetworkPtr network = armnn::INetwork::Create(); + armnn::IConnectableLayer* const inputLayer0 = network->AddInputLayer(0); + armnn::IConnectableLayer* const inputLayer1 = network->AddInputLayer(1); + armnn::IConnectableLayer* const mergeLayer = network->AddMergeLayer(layerName.c_str()); + armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0); + + inputLayer0->GetOutputSlot(0).Connect(mergeLayer->GetInputSlot(0)); + inputLayer1->GetOutputSlot(0).Connect(mergeLayer->GetInputSlot(1)); + mergeLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); + + inputLayer0->GetOutputSlot(0).SetTensorInfo(info); + inputLayer1->GetOutputSlot(0).SetTensorInfo(info); + mergeLayer->GetOutputSlot(0).SetTensorInfo(info); + + armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); + BOOST_CHECK(deserializedNetwork); + + MergeLayerVerifier verifier(layerName, {info, info}, {info}); + deserializedNetwork->Accept(verifier); +} + BOOST_AUTO_TEST_CASE(SerializeMerger) { class MergerLayerVerifier : public LayerVerifierBase |