diff options
Diffstat (limited to 'src/armnnSerializer/Serializer.cpp')
-rw-r--r-- | src/armnnSerializer/Serializer.cpp | 52 |
1 files changed, 33 insertions, 19 deletions
diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp index ba4b36934c..57228c406e 100644 --- a/src/armnnSerializer/Serializer.cpp +++ b/src/armnnSerializer/Serializer.cpp @@ -4,9 +4,15 @@ // #include "Serializer.hpp" + +#include "SerializerUtils.hpp" + #include <armnn/ArmNN.hpp> + #include <iostream> + #include <Schema_generated.h> + #include <flatbuffers/util.h> using namespace armnn; @@ -16,25 +22,6 @@ namespace serializer = armnn::armnnSerializer; namespace armnnSerializer { -serializer::DataType GetFlatBufferDataType(DataType dataType) -{ - switch (dataType) - { - case DataType::Float32: - return serializer::DataType::DataType_Float32; - case DataType::Float16: - return serializer::DataType::DataType_Float16; - case DataType::Signed32: - return serializer::DataType::DataType_Signed32; - case DataType::QuantisedAsymm8: - return serializer::DataType::DataType_QuantisedAsymm8; - case DataType::Boolean: - return serializer::DataType::DataType_Boolean; - default: - return serializer::DataType::DataType_Float16; - } -} - uint32_t SerializerVisitor::GetSerializedId(unsigned int guid) { std::pair<unsigned int, uint32_t> guidPair(guid, m_layerId); @@ -140,6 +127,33 @@ void SerializerVisitor::VisitSoftmaxLayer(const IConnectableLayer* layer, CreateAnyLayer(flatBufferSoftmaxLayer.o, serializer::Layer::Layer_SoftmaxLayer); } +void SerializerVisitor::VisitPooling2dLayer(const IConnectableLayer* layer, + const Pooling2dDescriptor& pooling2dDescriptor, + const char* name) +{ + auto fbPooling2dBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Pooling2d); + auto fbPooling2dDescriptor = serializer::CreatePooling2dDescriptor( + m_flatBufferBuilder, + GetFlatBufferPoolingAlgorithm(pooling2dDescriptor.m_PoolType), + pooling2dDescriptor.m_PadLeft, + pooling2dDescriptor.m_PadRight, + pooling2dDescriptor.m_PadTop, + pooling2dDescriptor.m_PadBottom, + pooling2dDescriptor.m_PoolWidth, + pooling2dDescriptor.m_PoolHeight, + pooling2dDescriptor.m_StrideX, + pooling2dDescriptor.m_StrideY, + GetFlatBufferOutputShapeRounding(pooling2dDescriptor.m_OutputShapeRounding), + GetFlatBufferPaddingMethod(pooling2dDescriptor.m_PaddingMethod), + GetFlatBufferDataLayout(pooling2dDescriptor.m_DataLayout)); + + auto fbPooling2dLayer = serializer::CreatePooling2dLayer(m_flatBufferBuilder, + fbPooling2dBaseLayer, + fbPooling2dDescriptor); + + CreateAnyLayer(fbPooling2dLayer.o, serializer::Layer::Layer_Pooling2dLayer); +} + fb::Offset<serializer::LayerBase> SerializerVisitor::CreateLayerBase(const IConnectableLayer* layer, const serializer::LayerType layerType) { |