diff options
Diffstat (limited to 'src/armnnSerializer')
-rw-r--r-- | src/armnnSerializer/ArmnnSchema.fbs | 11 | ||||
-rw-r--r-- | src/armnnSerializer/Serializer.cpp | 33 | ||||
-rw-r--r-- | src/armnnSerializer/Serializer.hpp | 7 | ||||
-rw-r--r-- | src/armnnSerializer/SerializerSupport.md | 2 | ||||
-rw-r--r-- | src/armnnSerializer/test/SerializerTests.cpp | 58 |
5 files changed, 84 insertions, 27 deletions
diff --git a/src/armnnSerializer/ArmnnSchema.fbs b/src/armnnSerializer/ArmnnSchema.fbs index 0419c4b883..5a001de545 100644 --- a/src/armnnSerializer/ArmnnSchema.fbs +++ b/src/armnnSerializer/ArmnnSchema.fbs @@ -120,7 +120,8 @@ enum LayerType : uint { Quantize = 35, Dequantize = 36, Merge = 37, - Switch = 38 + Switch = 38, + Concat = 39 } // Base layer table to be used as part of other layers @@ -442,6 +443,11 @@ table StridedSliceDescriptor { dataLayout:DataLayout; } +table ConcatLayer { + base:LayerBase; + descriptor:OriginsDescriptor; +} + table MergerLayer { base:LayerBase; descriptor:OriginsDescriptor; @@ -577,7 +583,8 @@ union Layer { QuantizeLayer, DequantizeLayer, MergeLayer, - SwitchLayer + SwitchLayer, + ConcatLayer } table AnyLayer { diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp index 865ed7af51..c49f6f9227 100644 --- a/src/armnnSerializer/Serializer.cpp +++ b/src/armnnSerializer/Serializer.cpp @@ -515,17 +515,24 @@ void SerializerVisitor::VisitMergeLayer(const armnn::IConnectableLayer* layer, c } void SerializerVisitor::VisitMergerLayer(const armnn::IConnectableLayer* layer, - const armnn::OriginsDescriptor& mergerDescriptor, + const armnn::MergerDescriptor& mergerDescriptor, const char* name) { - auto flatBufferMergerBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Merger); + VisitConcatLayer(layer, mergerDescriptor, name); +} + +void SerializerVisitor::VisitConcatLayer(const armnn::IConnectableLayer* layer, + const armnn::ConcatDescriptor& concatDescriptor, + const char* name) +{ + auto flatBufferConcatBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Concat); std::vector<flatbuffers::Offset<UintVector>> views; - for (unsigned int v = 0; v < mergerDescriptor.GetNumViews(); ++v) + for (unsigned int v = 0; v < concatDescriptor.GetNumViews(); ++v) { - const uint32_t* origin = mergerDescriptor.GetViewOrigin(v); + const uint32_t* origin = concatDescriptor.GetViewOrigin(v); std::vector<uint32_t> origins; - for (unsigned int d = 0; d < mergerDescriptor.GetNumDimensions(); ++d) + for (unsigned int d = 0; d < concatDescriptor.GetNumDimensions(); ++d) { origins.push_back(origin[d]); } @@ -534,17 +541,17 @@ void SerializerVisitor::VisitMergerLayer(const armnn::IConnectableLayer* layer, views.push_back(uintVector); } - auto flatBufferMergerDescriptor = CreateOriginsDescriptor(m_flatBufferBuilder, - mergerDescriptor.GetConcatAxis(), - mergerDescriptor.GetNumViews(), - mergerDescriptor.GetNumDimensions(), + auto flatBufferConcatDescriptor = CreateOriginsDescriptor(m_flatBufferBuilder, + concatDescriptor.GetConcatAxis(), + concatDescriptor.GetNumViews(), + concatDescriptor.GetNumDimensions(), m_flatBufferBuilder.CreateVector(views)); - auto flatBufferLayer = CreateMergerLayer(m_flatBufferBuilder, - flatBufferMergerBaseLayer, - flatBufferMergerDescriptor); + auto flatBufferLayer = CreateConcatLayer(m_flatBufferBuilder, + flatBufferConcatBaseLayer, + flatBufferConcatDescriptor); - CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_MergerLayer); + CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_ConcatLayer); } void SerializerVisitor::VisitMultiplicationLayer(const armnn::IConnectableLayer* layer, const char* name) diff --git a/src/armnnSerializer/Serializer.hpp b/src/armnnSerializer/Serializer.hpp index 4a718378b5..2e2816a182 100644 --- a/src/armnnSerializer/Serializer.hpp +++ b/src/armnnSerializer/Serializer.hpp @@ -61,6 +61,10 @@ public: const armnn::ConstTensor& gamma, const char* name = nullptr) override; + void VisitConcatLayer(const armnn::IConnectableLayer* layer, + const armnn::ConcatDescriptor& concatDescriptor, + const char* name = nullptr) override; + void VisitConstantLayer(const armnn::IConnectableLayer* layer, const armnn::ConstTensor& input, const char* = nullptr) override; @@ -132,8 +136,9 @@ public: void VisitMergeLayer(const armnn::IConnectableLayer* layer, const char* name = nullptr) override; + ARMNN_DEPRECATED_MSG("Use VisitConcatLayer instead") void VisitMergerLayer(const armnn::IConnectableLayer* layer, - const armnn::OriginsDescriptor& mergerDescriptor, + const armnn::MergerDescriptor& mergerDescriptor, const char* name = nullptr) override; void VisitMultiplicationLayer(const armnn::IConnectableLayer* layer, diff --git a/src/armnnSerializer/SerializerSupport.md b/src/armnnSerializer/SerializerSupport.md index f1b3365aca..832c1a7cca 100644 --- a/src/armnnSerializer/SerializerSupport.md +++ b/src/armnnSerializer/SerializerSupport.md @@ -26,7 +26,7 @@ The Arm NN SDK Serializer currently supports the following layers: * Maximum * Mean * Merge -* Merger +* Concat * Minimum * Multiplication * Normalization diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp index b21ae5841d..752cf0c27a 100644 --- a/src/armnnSerializer/test/SerializerTests.cpp +++ b/src/armnnSerializer/test/SerializerTests.cpp @@ -1248,6 +1248,13 @@ public: const armnn::OriginsDescriptor& descriptor, const char* name) override { + throw armnn::Exception("MergerLayer should have translated to ConcatLayer"); + } + + void VisitConcatLayer(const armnn::IConnectableLayer* layer, + const armnn::OriginsDescriptor& descriptor, + const char* name) override + { VerifyNameAndConnections(layer, name); VerifyDescriptor(descriptor); } @@ -1271,6 +1278,9 @@ private: armnn::OriginsDescriptor m_Descriptor; }; +// NOTE: until the deprecated AddMergerLayer disappears this test checks that calling +// AddMergerLayer places a ConcatLayer into the serialized format and that +// when this deserialises we have a ConcatLayer BOOST_AUTO_TEST_CASE(SerializeMerger) { const std::string layerName("merger"); @@ -1309,17 +1319,10 @@ BOOST_AUTO_TEST_CASE(SerializeMerger) BOOST_AUTO_TEST_CASE(EnsureMergerLayerBackwardCompatibility) { // The hex array below is a flat buffer containing a simple network with two inputs - // a merger layer (soon to be a thing of the past) and an output layer with dimensions - // as per the tensor infos below. - // The intention is that this test will be repurposed as soon as the MergerLayer - // is replaced by a ConcatLayer to verify that we can still read back these old style + // a merger layer (now deprecated) and an output layer with dimensions as per the tensor infos below. + // + // This test verifies that we can still read back these old style // models replacing the MergerLayers with ConcatLayers with the same parameters. - // To do this the MergerLayerVerifier will be changed to have a VisitConcatLayer - // which will do the work that the VisitMergerLayer currently does and the VisitMergerLayer - // so long as it remains (public API will drop Merger Layer at some future point) - // will throw an error if invoked because none of the graphs we create should contain - // Merger layers now regardless of whether we attempt to insert the Merger layer via - // the INetwork.AddMergerLayer call or by deserializing an old style flatbuffer file. unsigned int size = 760; const unsigned char mergerModel[] = { 0x10,0x00,0x00,0x00,0x00,0x00,0x0A,0x00,0x10,0x00,0x04,0x00,0x08,0x00,0x0C,0x00,0x0A,0x00,0x00,0x00, @@ -1381,6 +1384,41 @@ BOOST_AUTO_TEST_CASE(EnsureMergerLayerBackwardCompatibility) deserializedNetwork->Accept(verifier); } +BOOST_AUTO_TEST_CASE(SerializeConcat) +{ + const std::string layerName("concat"); + const armnn::TensorInfo inputInfo = armnn::TensorInfo({2, 3, 2, 2}, armnn::DataType::Float32); + const armnn::TensorInfo outputInfo = armnn::TensorInfo({4, 3, 2, 2}, armnn::DataType::Float32); + + const std::vector<armnn::TensorShape> shapes({inputInfo.GetShape(), inputInfo.GetShape()}); + + armnn::OriginsDescriptor descriptor = + armnn::CreateDescriptorForConcatenation(shapes.begin(), shapes.end(), 0); + + armnn::INetworkPtr network = armnn::INetwork::Create(); + armnn::IConnectableLayer* const inputLayerOne = network->AddInputLayer(0); + armnn::IConnectableLayer* const inputLayerTwo = network->AddInputLayer(1); + armnn::IConnectableLayer* const concatLayer = network->AddConcatLayer(descriptor, layerName.c_str()); + armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0); + + inputLayerOne->GetOutputSlot(0).Connect(concatLayer->GetInputSlot(0)); + inputLayerTwo->GetOutputSlot(0).Connect(concatLayer->GetInputSlot(1)); + concatLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); + + inputLayerOne->GetOutputSlot(0).SetTensorInfo(inputInfo); + inputLayerTwo->GetOutputSlot(0).SetTensorInfo(inputInfo); + concatLayer->GetOutputSlot(0).SetTensorInfo(outputInfo); + + std::string concatLayerNetwork = SerializeNetwork(*network); + armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(concatLayerNetwork); + BOOST_CHECK(deserializedNetwork); + + // NOTE: using the MergerLayerVerifier to ensure that it is a concat layer and not a + // merger layer that gets placed into the graph. + MergerLayerVerifier verifier(layerName, {inputInfo, inputInfo}, {outputInfo}, descriptor); + deserializedNetwork->Accept(verifier); +} + BOOST_AUTO_TEST_CASE(SerializeMinimum) { class MinimumLayerVerifier : public LayerVerifierBase |