diff options
Diffstat (limited to 'src/armnnSerializer/Serializer.cpp')
-rw-r--r-- | src/armnnSerializer/Serializer.cpp | 33 |
1 files changed, 20 insertions, 13 deletions
diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp index 865ed7af51..c49f6f9227 100644 --- a/src/armnnSerializer/Serializer.cpp +++ b/src/armnnSerializer/Serializer.cpp @@ -515,17 +515,24 @@ void SerializerVisitor::VisitMergeLayer(const armnn::IConnectableLayer* layer, c } void SerializerVisitor::VisitMergerLayer(const armnn::IConnectableLayer* layer, - const armnn::OriginsDescriptor& mergerDescriptor, + const armnn::MergerDescriptor& mergerDescriptor, const char* name) { - auto flatBufferMergerBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Merger); + VisitConcatLayer(layer, mergerDescriptor, name); +} + +void SerializerVisitor::VisitConcatLayer(const armnn::IConnectableLayer* layer, + const armnn::ConcatDescriptor& concatDescriptor, + const char* name) +{ + auto flatBufferConcatBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Concat); std::vector<flatbuffers::Offset<UintVector>> views; - for (unsigned int v = 0; v < mergerDescriptor.GetNumViews(); ++v) + for (unsigned int v = 0; v < concatDescriptor.GetNumViews(); ++v) { - const uint32_t* origin = mergerDescriptor.GetViewOrigin(v); + const uint32_t* origin = concatDescriptor.GetViewOrigin(v); std::vector<uint32_t> origins; - for (unsigned int d = 0; d < mergerDescriptor.GetNumDimensions(); ++d) + for (unsigned int d = 0; d < concatDescriptor.GetNumDimensions(); ++d) { origins.push_back(origin[d]); } @@ -534,17 +541,17 @@ void SerializerVisitor::VisitMergerLayer(const armnn::IConnectableLayer* layer, views.push_back(uintVector); } - auto flatBufferMergerDescriptor = CreateOriginsDescriptor(m_flatBufferBuilder, - mergerDescriptor.GetConcatAxis(), - mergerDescriptor.GetNumViews(), - mergerDescriptor.GetNumDimensions(), + auto flatBufferConcatDescriptor = CreateOriginsDescriptor(m_flatBufferBuilder, + concatDescriptor.GetConcatAxis(), + concatDescriptor.GetNumViews(), + concatDescriptor.GetNumDimensions(), m_flatBufferBuilder.CreateVector(views)); - auto flatBufferLayer = CreateMergerLayer(m_flatBufferBuilder, - flatBufferMergerBaseLayer, - flatBufferMergerDescriptor); + auto flatBufferLayer = CreateConcatLayer(m_flatBufferBuilder, + flatBufferConcatBaseLayer, + flatBufferConcatDescriptor); - CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_MergerLayer); + CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_ConcatLayer); } void SerializerVisitor::VisitMultiplicationLayer(const armnn::IConnectableLayer* layer, const char* name) |