aboutsummaryrefslogtreecommitdiff
path: root/src/armnnSerializer/Serializer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnSerializer/Serializer.cpp')
-rw-r--r--src/armnnSerializer/Serializer.cpp33
1 files changed, 20 insertions, 13 deletions
diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp
index 865ed7af51..c49f6f9227 100644
--- a/src/armnnSerializer/Serializer.cpp
+++ b/src/armnnSerializer/Serializer.cpp
@@ -515,17 +515,24 @@ void SerializerVisitor::VisitMergeLayer(const armnn::IConnectableLayer* layer, c
}
void SerializerVisitor::VisitMergerLayer(const armnn::IConnectableLayer* layer,
- const armnn::OriginsDescriptor& mergerDescriptor,
+ const armnn::MergerDescriptor& mergerDescriptor,
const char* name)
{
- auto flatBufferMergerBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Merger);
+ VisitConcatLayer(layer, mergerDescriptor, name);
+}
+
+void SerializerVisitor::VisitConcatLayer(const armnn::IConnectableLayer* layer,
+ const armnn::ConcatDescriptor& concatDescriptor,
+ const char* name)
+{
+ auto flatBufferConcatBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Concat);
std::vector<flatbuffers::Offset<UintVector>> views;
- for (unsigned int v = 0; v < mergerDescriptor.GetNumViews(); ++v)
+ for (unsigned int v = 0; v < concatDescriptor.GetNumViews(); ++v)
{
- const uint32_t* origin = mergerDescriptor.GetViewOrigin(v);
+ const uint32_t* origin = concatDescriptor.GetViewOrigin(v);
std::vector<uint32_t> origins;
- for (unsigned int d = 0; d < mergerDescriptor.GetNumDimensions(); ++d)
+ for (unsigned int d = 0; d < concatDescriptor.GetNumDimensions(); ++d)
{
origins.push_back(origin[d]);
}
@@ -534,17 +541,17 @@ void SerializerVisitor::VisitMergerLayer(const armnn::IConnectableLayer* layer,
views.push_back(uintVector);
}
- auto flatBufferMergerDescriptor = CreateOriginsDescriptor(m_flatBufferBuilder,
- mergerDescriptor.GetConcatAxis(),
- mergerDescriptor.GetNumViews(),
- mergerDescriptor.GetNumDimensions(),
+ auto flatBufferConcatDescriptor = CreateOriginsDescriptor(m_flatBufferBuilder,
+ concatDescriptor.GetConcatAxis(),
+ concatDescriptor.GetNumViews(),
+ concatDescriptor.GetNumDimensions(),
m_flatBufferBuilder.CreateVector(views));
- auto flatBufferLayer = CreateMergerLayer(m_flatBufferBuilder,
- flatBufferMergerBaseLayer,
- flatBufferMergerDescriptor);
+ auto flatBufferLayer = CreateConcatLayer(m_flatBufferBuilder,
+ flatBufferConcatBaseLayer,
+ flatBufferConcatDescriptor);
- CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_MergerLayer);
+ CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_ConcatLayer);
}
void SerializerVisitor::VisitMultiplicationLayer(const armnn::IConnectableLayer* layer, const char* name)