aboutsummaryrefslogtreecommitdiff
path: root/src/armnnSerializer
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnSerializer')
-rw-r--r--src/armnnSerializer/ArmnnSchema.fbs11
-rw-r--r--src/armnnSerializer/Serializer.cpp33
-rw-r--r--src/armnnSerializer/Serializer.hpp7
-rw-r--r--src/armnnSerializer/SerializerSupport.md2
-rw-r--r--src/armnnSerializer/test/SerializerTests.cpp58
5 files changed, 84 insertions, 27 deletions
diff --git a/src/armnnSerializer/ArmnnSchema.fbs b/src/armnnSerializer/ArmnnSchema.fbs
index 0419c4b883..5a001de545 100644
--- a/src/armnnSerializer/ArmnnSchema.fbs
+++ b/src/armnnSerializer/ArmnnSchema.fbs
@@ -120,7 +120,8 @@ enum LayerType : uint {
Quantize = 35,
Dequantize = 36,
Merge = 37,
- Switch = 38
+ Switch = 38,
+ Concat = 39
}
// Base layer table to be used as part of other layers
@@ -442,6 +443,11 @@ table StridedSliceDescriptor {
dataLayout:DataLayout;
}
+table ConcatLayer {
+ base:LayerBase;
+ descriptor:OriginsDescriptor;
+}
+
table MergerLayer {
base:LayerBase;
descriptor:OriginsDescriptor;
@@ -577,7 +583,8 @@ union Layer {
QuantizeLayer,
DequantizeLayer,
MergeLayer,
- SwitchLayer
+ SwitchLayer,
+ ConcatLayer
}
table AnyLayer {
diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp
index 865ed7af51..c49f6f9227 100644
--- a/src/armnnSerializer/Serializer.cpp
+++ b/src/armnnSerializer/Serializer.cpp
@@ -515,17 +515,24 @@ void SerializerVisitor::VisitMergeLayer(const armnn::IConnectableLayer* layer, c
}
void SerializerVisitor::VisitMergerLayer(const armnn::IConnectableLayer* layer,
- const armnn::OriginsDescriptor& mergerDescriptor,
+ const armnn::MergerDescriptor& mergerDescriptor,
const char* name)
{
- auto flatBufferMergerBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Merger);
+ VisitConcatLayer(layer, mergerDescriptor, name);
+}
+
+void SerializerVisitor::VisitConcatLayer(const armnn::IConnectableLayer* layer,
+ const armnn::ConcatDescriptor& concatDescriptor,
+ const char* name)
+{
+ auto flatBufferConcatBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Concat);
std::vector<flatbuffers::Offset<UintVector>> views;
- for (unsigned int v = 0; v < mergerDescriptor.GetNumViews(); ++v)
+ for (unsigned int v = 0; v < concatDescriptor.GetNumViews(); ++v)
{
- const uint32_t* origin = mergerDescriptor.GetViewOrigin(v);
+ const uint32_t* origin = concatDescriptor.GetViewOrigin(v);
std::vector<uint32_t> origins;
- for (unsigned int d = 0; d < mergerDescriptor.GetNumDimensions(); ++d)
+ for (unsigned int d = 0; d < concatDescriptor.GetNumDimensions(); ++d)
{
origins.push_back(origin[d]);
}
@@ -534,17 +541,17 @@ void SerializerVisitor::VisitMergerLayer(const armnn::IConnectableLayer* layer,
views.push_back(uintVector);
}
- auto flatBufferMergerDescriptor = CreateOriginsDescriptor(m_flatBufferBuilder,
- mergerDescriptor.GetConcatAxis(),
- mergerDescriptor.GetNumViews(),
- mergerDescriptor.GetNumDimensions(),
+ auto flatBufferConcatDescriptor = CreateOriginsDescriptor(m_flatBufferBuilder,
+ concatDescriptor.GetConcatAxis(),
+ concatDescriptor.GetNumViews(),
+ concatDescriptor.GetNumDimensions(),
m_flatBufferBuilder.CreateVector(views));
- auto flatBufferLayer = CreateMergerLayer(m_flatBufferBuilder,
- flatBufferMergerBaseLayer,
- flatBufferMergerDescriptor);
+ auto flatBufferLayer = CreateConcatLayer(m_flatBufferBuilder,
+ flatBufferConcatBaseLayer,
+ flatBufferConcatDescriptor);
- CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_MergerLayer);
+ CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_ConcatLayer);
}
void SerializerVisitor::VisitMultiplicationLayer(const armnn::IConnectableLayer* layer, const char* name)
diff --git a/src/armnnSerializer/Serializer.hpp b/src/armnnSerializer/Serializer.hpp
index 4a718378b5..2e2816a182 100644
--- a/src/armnnSerializer/Serializer.hpp
+++ b/src/armnnSerializer/Serializer.hpp
@@ -61,6 +61,10 @@ public:
const armnn::ConstTensor& gamma,
const char* name = nullptr) override;
+ void VisitConcatLayer(const armnn::IConnectableLayer* layer,
+ const armnn::ConcatDescriptor& concatDescriptor,
+ const char* name = nullptr) override;
+
void VisitConstantLayer(const armnn::IConnectableLayer* layer,
const armnn::ConstTensor& input,
const char* = nullptr) override;
@@ -132,8 +136,9 @@ public:
void VisitMergeLayer(const armnn::IConnectableLayer* layer,
const char* name = nullptr) override;
+ ARMNN_DEPRECATED_MSG("Use VisitConcatLayer instead")
void VisitMergerLayer(const armnn::IConnectableLayer* layer,
- const armnn::OriginsDescriptor& mergerDescriptor,
+ const armnn::MergerDescriptor& mergerDescriptor,
const char* name = nullptr) override;
void VisitMultiplicationLayer(const armnn::IConnectableLayer* layer,
diff --git a/src/armnnSerializer/SerializerSupport.md b/src/armnnSerializer/SerializerSupport.md
index f1b3365aca..832c1a7cca 100644
--- a/src/armnnSerializer/SerializerSupport.md
+++ b/src/armnnSerializer/SerializerSupport.md
@@ -26,7 +26,7 @@ The Arm NN SDK Serializer currently supports the following layers:
* Maximum
* Mean
* Merge
-* Merger
+* Concat
* Minimum
* Multiplication
* Normalization
diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp
index b21ae5841d..752cf0c27a 100644
--- a/src/armnnSerializer/test/SerializerTests.cpp
+++ b/src/armnnSerializer/test/SerializerTests.cpp
@@ -1248,6 +1248,13 @@ public:
const armnn::OriginsDescriptor& descriptor,
const char* name) override
{
+ throw armnn::Exception("MergerLayer should have translated to ConcatLayer");
+ }
+
+ void VisitConcatLayer(const armnn::IConnectableLayer* layer,
+ const armnn::OriginsDescriptor& descriptor,
+ const char* name) override
+ {
VerifyNameAndConnections(layer, name);
VerifyDescriptor(descriptor);
}
@@ -1271,6 +1278,9 @@ private:
armnn::OriginsDescriptor m_Descriptor;
};
+// NOTE: until the deprecated AddMergerLayer disappears this test checks that calling
+// AddMergerLayer places a ConcatLayer into the serialized format and that
+// when this deserialises we have a ConcatLayer
BOOST_AUTO_TEST_CASE(SerializeMerger)
{
const std::string layerName("merger");
@@ -1309,17 +1319,10 @@ BOOST_AUTO_TEST_CASE(SerializeMerger)
BOOST_AUTO_TEST_CASE(EnsureMergerLayerBackwardCompatibility)
{
// The hex array below is a flat buffer containing a simple network with two inputs
- // a merger layer (soon to be a thing of the past) and an output layer with dimensions
- // as per the tensor infos below.
- // The intention is that this test will be repurposed as soon as the MergerLayer
- // is replaced by a ConcatLayer to verify that we can still read back these old style
+ // a merger layer (now deprecated) and an output layer with dimensions as per the tensor infos below.
+ //
+ // This test verifies that we can still read back these old style
// models replacing the MergerLayers with ConcatLayers with the same parameters.
- // To do this the MergerLayerVerifier will be changed to have a VisitConcatLayer
- // which will do the work that the VisitMergerLayer currently does and the VisitMergerLayer
- // so long as it remains (public API will drop Merger Layer at some future point)
- // will throw an error if invoked because none of the graphs we create should contain
- // Merger layers now regardless of whether we attempt to insert the Merger layer via
- // the INetwork.AddMergerLayer call or by deserializing an old style flatbuffer file.
unsigned int size = 760;
const unsigned char mergerModel[] = {
0x10,0x00,0x00,0x00,0x00,0x00,0x0A,0x00,0x10,0x00,0x04,0x00,0x08,0x00,0x0C,0x00,0x0A,0x00,0x00,0x00,
@@ -1381,6 +1384,41 @@ BOOST_AUTO_TEST_CASE(EnsureMergerLayerBackwardCompatibility)
deserializedNetwork->Accept(verifier);
}
+BOOST_AUTO_TEST_CASE(SerializeConcat)
+{
+ const std::string layerName("concat");
+ const armnn::TensorInfo inputInfo = armnn::TensorInfo({2, 3, 2, 2}, armnn::DataType::Float32);
+ const armnn::TensorInfo outputInfo = armnn::TensorInfo({4, 3, 2, 2}, armnn::DataType::Float32);
+
+ const std::vector<armnn::TensorShape> shapes({inputInfo.GetShape(), inputInfo.GetShape()});
+
+ armnn::OriginsDescriptor descriptor =
+ armnn::CreateDescriptorForConcatenation(shapes.begin(), shapes.end(), 0);
+
+ armnn::INetworkPtr network = armnn::INetwork::Create();
+ armnn::IConnectableLayer* const inputLayerOne = network->AddInputLayer(0);
+ armnn::IConnectableLayer* const inputLayerTwo = network->AddInputLayer(1);
+ armnn::IConnectableLayer* const concatLayer = network->AddConcatLayer(descriptor, layerName.c_str());
+ armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
+
+ inputLayerOne->GetOutputSlot(0).Connect(concatLayer->GetInputSlot(0));
+ inputLayerTwo->GetOutputSlot(0).Connect(concatLayer->GetInputSlot(1));
+ concatLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
+
+ inputLayerOne->GetOutputSlot(0).SetTensorInfo(inputInfo);
+ inputLayerTwo->GetOutputSlot(0).SetTensorInfo(inputInfo);
+ concatLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
+
+ std::string concatLayerNetwork = SerializeNetwork(*network);
+ armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(concatLayerNetwork);
+ BOOST_CHECK(deserializedNetwork);
+
+ // NOTE: using the MergerLayerVerifier to ensure that it is a concat layer and not a
+ // merger layer that gets placed into the graph.
+ MergerLayerVerifier verifier(layerName, {inputInfo, inputInfo}, {outputInfo}, descriptor);
+ deserializedNetwork->Accept(verifier);
+}
+
BOOST_AUTO_TEST_CASE(SerializeMinimum)
{
class MinimumLayerVerifier : public LayerVerifierBase