diff options
Diffstat (limited to 'src/armnnSerializer/Serializer.cpp')
-rw-r--r-- | src/armnnSerializer/Serializer.cpp | 50 |
1 files changed, 50 insertions, 0 deletions
diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp index b85c45aa10..bee1a3cdb5 100644 --- a/src/armnnSerializer/Serializer.cpp +++ b/src/armnnSerializer/Serializer.cpp @@ -22,6 +22,33 @@ namespace serializer = armnnSerializer; namespace armnnSerializer { +serializer::ActivationFunction GetFlatBufferActivationFunction(armnn::ActivationFunction function) +{ + switch (function) + { + case armnn::ActivationFunction::Sigmoid: + return serializer::ActivationFunction::ActivationFunction_Sigmoid; + case armnn::ActivationFunction::TanH: + return serializer::ActivationFunction::ActivationFunction_TanH; + case armnn::ActivationFunction::Linear: + return serializer::ActivationFunction::ActivationFunction_Linear; + case armnn::ActivationFunction::ReLu: + return serializer::ActivationFunction::ActivationFunction_ReLu; + case armnn::ActivationFunction::BoundedReLu: + return serializer::ActivationFunction::ActivationFunction_BoundedReLu; + case armnn::ActivationFunction::LeakyReLu: + return serializer::ActivationFunction::ActivationFunction_LeakyReLu; + case armnn::ActivationFunction::Abs: + return serializer::ActivationFunction::ActivationFunction_Abs; + case armnn::ActivationFunction::Sqrt: + return serializer::ActivationFunction::ActivationFunction_Sqrt; + case armnn::ActivationFunction::Square: + return serializer::ActivationFunction::ActivationFunction_Square; + default: + return serializer::ActivationFunction::ActivationFunction_Sigmoid; + } +} + uint32_t SerializerVisitor::GetSerializedId(unsigned int guid) { std::pair<unsigned int, uint32_t> guidPair(guid, m_layerId); @@ -78,6 +105,29 @@ void SerializerVisitor::VisitOutputLayer(const armnn::IConnectableLayer* layer, CreateAnyLayer(flatBufferOutputLayer.o, serializer::Layer::Layer_OutputLayer); } +// Build FlatBuffer for Activation Layer +void SerializerVisitor::VisitActivationLayer(const armnn::IConnectableLayer* layer, + const armnn::ActivationDescriptor& descriptor, + const char* name) +{ + // Create FlatBuffer BaseLayer + auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Activation); + + // Create the FlatBuffer ActivationDescriptor + auto flatBufferDescriptor = CreateActivationDescriptor(m_flatBufferBuilder, + GetFlatBufferActivationFunction(descriptor.m_Function), + descriptor.m_A, + descriptor.m_B); + + // Create the FlatBuffer ActivationLayer + auto flatBufferAdditionLayer = CreateActivationLayer(m_flatBufferBuilder, + flatBufferBaseLayer, + flatBufferDescriptor); + + // Add the AnyLayer to the FlatBufferLayers + CreateAnyLayer(flatBufferAdditionLayer.o, serializer::Layer::Layer_ActivationLayer); +} + // Build FlatBuffer for Addition Layer void SerializerVisitor::VisitAdditionLayer(const armnn::IConnectableLayer* layer, const char* name) { |