aboutsummaryrefslogtreecommitdiff
path: root/src/armnnSerializer/Serializer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnSerializer/Serializer.cpp')
-rw-r--r--src/armnnSerializer/Serializer.cpp52
1 files changed, 33 insertions, 19 deletions
diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp
index ba4b36934c..57228c406e 100644
--- a/src/armnnSerializer/Serializer.cpp
+++ b/src/armnnSerializer/Serializer.cpp
@@ -4,9 +4,15 @@
//
#include "Serializer.hpp"
+
+#include "SerializerUtils.hpp"
+
#include <armnn/ArmNN.hpp>
+
#include <iostream>
+
#include <Schema_generated.h>
+
#include <flatbuffers/util.h>
using namespace armnn;
@@ -16,25 +22,6 @@ namespace serializer = armnn::armnnSerializer;
namespace armnnSerializer
{
-serializer::DataType GetFlatBufferDataType(DataType dataType)
-{
- switch (dataType)
- {
- case DataType::Float32:
- return serializer::DataType::DataType_Float32;
- case DataType::Float16:
- return serializer::DataType::DataType_Float16;
- case DataType::Signed32:
- return serializer::DataType::DataType_Signed32;
- case DataType::QuantisedAsymm8:
- return serializer::DataType::DataType_QuantisedAsymm8;
- case DataType::Boolean:
- return serializer::DataType::DataType_Boolean;
- default:
- return serializer::DataType::DataType_Float16;
- }
-}
-
uint32_t SerializerVisitor::GetSerializedId(unsigned int guid)
{
std::pair<unsigned int, uint32_t> guidPair(guid, m_layerId);
@@ -140,6 +127,33 @@ void SerializerVisitor::VisitSoftmaxLayer(const IConnectableLayer* layer,
CreateAnyLayer(flatBufferSoftmaxLayer.o, serializer::Layer::Layer_SoftmaxLayer);
}
+void SerializerVisitor::VisitPooling2dLayer(const IConnectableLayer* layer,
+ const Pooling2dDescriptor& pooling2dDescriptor,
+ const char* name)
+{
+ auto fbPooling2dBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Pooling2d);
+ auto fbPooling2dDescriptor = serializer::CreatePooling2dDescriptor(
+ m_flatBufferBuilder,
+ GetFlatBufferPoolingAlgorithm(pooling2dDescriptor.m_PoolType),
+ pooling2dDescriptor.m_PadLeft,
+ pooling2dDescriptor.m_PadRight,
+ pooling2dDescriptor.m_PadTop,
+ pooling2dDescriptor.m_PadBottom,
+ pooling2dDescriptor.m_PoolWidth,
+ pooling2dDescriptor.m_PoolHeight,
+ pooling2dDescriptor.m_StrideX,
+ pooling2dDescriptor.m_StrideY,
+ GetFlatBufferOutputShapeRounding(pooling2dDescriptor.m_OutputShapeRounding),
+ GetFlatBufferPaddingMethod(pooling2dDescriptor.m_PaddingMethod),
+ GetFlatBufferDataLayout(pooling2dDescriptor.m_DataLayout));
+
+ auto fbPooling2dLayer = serializer::CreatePooling2dLayer(m_flatBufferBuilder,
+ fbPooling2dBaseLayer,
+ fbPooling2dDescriptor);
+
+ CreateAnyLayer(fbPooling2dLayer.o, serializer::Layer::Layer_Pooling2dLayer);
+}
+
fb::Offset<serializer::LayerBase> SerializerVisitor::CreateLayerBase(const IConnectableLayer* layer,
const serializer::LayerType layerType)
{