From 3166c3edeb64d834ba27031ddd39b5b1f940b604 Mon Sep 17 00:00:00 2001 From: Saoirse Stewart Date: Mon, 18 Feb 2019 15:24:53 +0000 Subject: IVGCVSW-2645 Add Serializer & Deserializer for Pooling2d Change-Id: Iba41da3cccd539a0175f2ed0ff9a8b6a23c5fb6f Signed-off-by: Aron Virginas-Tar Signed-off-by: Saoirse Stewart --- src/armnnSerializer/Serializer.cpp | 52 ++++++++++++++++++++++++-------------- 1 file changed, 33 insertions(+), 19 deletions(-) (limited to 'src/armnnSerializer/Serializer.cpp') diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp index ba4b36934c..57228c406e 100644 --- a/src/armnnSerializer/Serializer.cpp +++ b/src/armnnSerializer/Serializer.cpp @@ -4,9 +4,15 @@ // #include "Serializer.hpp" + +#include "SerializerUtils.hpp" + #include + #include + #include + #include using namespace armnn; @@ -16,25 +22,6 @@ namespace serializer = armnn::armnnSerializer; namespace armnnSerializer { -serializer::DataType GetFlatBufferDataType(DataType dataType) -{ - switch (dataType) - { - case DataType::Float32: - return serializer::DataType::DataType_Float32; - case DataType::Float16: - return serializer::DataType::DataType_Float16; - case DataType::Signed32: - return serializer::DataType::DataType_Signed32; - case DataType::QuantisedAsymm8: - return serializer::DataType::DataType_QuantisedAsymm8; - case DataType::Boolean: - return serializer::DataType::DataType_Boolean; - default: - return serializer::DataType::DataType_Float16; - } -} - uint32_t SerializerVisitor::GetSerializedId(unsigned int guid) { std::pair guidPair(guid, m_layerId); @@ -140,6 +127,33 @@ void SerializerVisitor::VisitSoftmaxLayer(const IConnectableLayer* layer, CreateAnyLayer(flatBufferSoftmaxLayer.o, serializer::Layer::Layer_SoftmaxLayer); } +void SerializerVisitor::VisitPooling2dLayer(const IConnectableLayer* layer, + const Pooling2dDescriptor& pooling2dDescriptor, + const char* name) +{ + auto fbPooling2dBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Pooling2d); + auto fbPooling2dDescriptor = serializer::CreatePooling2dDescriptor( + m_flatBufferBuilder, + GetFlatBufferPoolingAlgorithm(pooling2dDescriptor.m_PoolType), + pooling2dDescriptor.m_PadLeft, + pooling2dDescriptor.m_PadRight, + pooling2dDescriptor.m_PadTop, + pooling2dDescriptor.m_PadBottom, + pooling2dDescriptor.m_PoolWidth, + pooling2dDescriptor.m_PoolHeight, + pooling2dDescriptor.m_StrideX, + pooling2dDescriptor.m_StrideY, + GetFlatBufferOutputShapeRounding(pooling2dDescriptor.m_OutputShapeRounding), + GetFlatBufferPaddingMethod(pooling2dDescriptor.m_PaddingMethod), + GetFlatBufferDataLayout(pooling2dDescriptor.m_DataLayout)); + + auto fbPooling2dLayer = serializer::CreatePooling2dLayer(m_flatBufferBuilder, + fbPooling2dBaseLayer, + fbPooling2dDescriptor); + + CreateAnyLayer(fbPooling2dLayer.o, serializer::Layer::Layer_Pooling2dLayer); +} + fb::Offset SerializerVisitor::CreateLayerBase(const IConnectableLayer* layer, const serializer::LayerType layerType) { -- cgit v1.2.1