aboutsummaryrefslogtreecommitdiff
path: root/src/armnnSerializer/Serializer.cpp
diff options
context:
space:
mode:
authorSaoirse Stewart <saoirse.stewart@arm.com>2019-02-18 15:24:53 +0000
committerAron Virginas-Tar <aron.virginas-tar@arm.com>2019-02-19 11:52:27 +0000
commit3166c3edeb64d834ba27031ddd39b5b1f940b604 (patch)
tree2789010d0878d64442f51ba0edbd8f159d1a32a0 /src/armnnSerializer/Serializer.cpp
parenta6b504a8925174739f5a064cf77d1563cca38708 (diff)
downloadarmnn-3166c3edeb64d834ba27031ddd39b5b1f940b604.tar.gz
IVGCVSW-2645 Add Serializer & Deserializer for Pooling2d
Change-Id: Iba41da3cccd539a0175f2ed0ff9a8b6a23c5fb6f Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> Signed-off-by: Saoirse Stewart <saoirse.stewart@arm.com>
Diffstat (limited to 'src/armnnSerializer/Serializer.cpp')
-rw-r--r--src/armnnSerializer/Serializer.cpp52
1 files changed, 33 insertions, 19 deletions
diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp
index ba4b36934c..57228c406e 100644
--- a/src/armnnSerializer/Serializer.cpp
+++ b/src/armnnSerializer/Serializer.cpp
@@ -4,9 +4,15 @@
//
#include "Serializer.hpp"
+
+#include "SerializerUtils.hpp"
+
#include <armnn/ArmNN.hpp>
+
#include <iostream>
+
#include <Schema_generated.h>
+
#include <flatbuffers/util.h>
using namespace armnn;
@@ -16,25 +22,6 @@ namespace serializer = armnn::armnnSerializer;
namespace armnnSerializer
{
-serializer::DataType GetFlatBufferDataType(DataType dataType)
-{
- switch (dataType)
- {
- case DataType::Float32:
- return serializer::DataType::DataType_Float32;
- case DataType::Float16:
- return serializer::DataType::DataType_Float16;
- case DataType::Signed32:
- return serializer::DataType::DataType_Signed32;
- case DataType::QuantisedAsymm8:
- return serializer::DataType::DataType_QuantisedAsymm8;
- case DataType::Boolean:
- return serializer::DataType::DataType_Boolean;
- default:
- return serializer::DataType::DataType_Float16;
- }
-}
-
uint32_t SerializerVisitor::GetSerializedId(unsigned int guid)
{
std::pair<unsigned int, uint32_t> guidPair(guid, m_layerId);
@@ -140,6 +127,33 @@ void SerializerVisitor::VisitSoftmaxLayer(const IConnectableLayer* layer,
CreateAnyLayer(flatBufferSoftmaxLayer.o, serializer::Layer::Layer_SoftmaxLayer);
}
+void SerializerVisitor::VisitPooling2dLayer(const IConnectableLayer* layer,
+ const Pooling2dDescriptor& pooling2dDescriptor,
+ const char* name)
+{
+ auto fbPooling2dBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Pooling2d);
+ auto fbPooling2dDescriptor = serializer::CreatePooling2dDescriptor(
+ m_flatBufferBuilder,
+ GetFlatBufferPoolingAlgorithm(pooling2dDescriptor.m_PoolType),
+ pooling2dDescriptor.m_PadLeft,
+ pooling2dDescriptor.m_PadRight,
+ pooling2dDescriptor.m_PadTop,
+ pooling2dDescriptor.m_PadBottom,
+ pooling2dDescriptor.m_PoolWidth,
+ pooling2dDescriptor.m_PoolHeight,
+ pooling2dDescriptor.m_StrideX,
+ pooling2dDescriptor.m_StrideY,
+ GetFlatBufferOutputShapeRounding(pooling2dDescriptor.m_OutputShapeRounding),
+ GetFlatBufferPaddingMethod(pooling2dDescriptor.m_PaddingMethod),
+ GetFlatBufferDataLayout(pooling2dDescriptor.m_DataLayout));
+
+ auto fbPooling2dLayer = serializer::CreatePooling2dLayer(m_flatBufferBuilder,
+ fbPooling2dBaseLayer,
+ fbPooling2dDescriptor);
+
+ CreateAnyLayer(fbPooling2dLayer.o, serializer::Layer::Layer_Pooling2dLayer);
+}
+
fb::Offset<serializer::LayerBase> SerializerVisitor::CreateLayerBase(const IConnectableLayer* layer,
const serializer::LayerType layerType)
{