diff options
Diffstat (limited to 'src/armnnSerializer')
-rw-r--r-- | src/armnnSerializer/ArmnnSchema.fbs | 31 | ||||
-rw-r--r-- | src/armnnSerializer/Serializer.cpp | 23 | ||||
-rw-r--r-- | src/armnnSerializer/Serializer.hpp | 3 | ||||
-rw-r--r-- | src/armnnSerializer/SerializerSupport.md | 1 | ||||
-rw-r--r-- | src/armnnSerializer/SerializerUtils.cpp | 28 | ||||
-rw-r--r-- | src/armnnSerializer/SerializerUtils.hpp | 6 | ||||
-rw-r--r-- | src/armnnSerializer/test/SerializerTests.cpp | 49 |
7 files changed, 138 insertions, 3 deletions
diff --git a/src/armnnSerializer/ArmnnSchema.fbs b/src/armnnSerializer/ArmnnSchema.fbs index b59adcf82b..cde0087d6f 100644 --- a/src/armnnSerializer/ArmnnSchema.fbs +++ b/src/armnnSerializer/ArmnnSchema.fbs @@ -99,7 +99,8 @@ enum LayerType : uint { Division = 15, Minimum = 16, Equal = 17, - Maximum = 18 + Maximum = 18, + Normalization = 19 } // Base layer table to be used as part of other layers @@ -298,6 +299,31 @@ table BatchToSpaceNdDescriptor { dataLayout:DataLayout; } +enum NormalizationAlgorithmChannel : byte { + Across = 0, + Within = 1 +} + +enum NormalizationAlgorithmMethod : byte { + LocalBrightness = 0, + LocalContrast = 1 +} + +table NormalizationLayer { + base:LayerBase; + descriptor:NormalizationDescriptor; +} + +table NormalizationDescriptor { + normChannelType:NormalizationAlgorithmChannel = Across; + normMethodType:NormalizationAlgorithmMethod = LocalBrightness; + normSize:uint; + alpha:float; + beta:float; + k:float; + dataLayout:DataLayout = NCHW; +} + union Layer { ActivationLayer, AdditionLayer, @@ -317,7 +343,8 @@ union Layer { DivisionLayer, MinimumLayer, EqualLayer, - MaximumLayer + MaximumLayer, + NormalizationLayer } table AnyLayer { diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp index a94a319a4c..2000726526 100644 --- a/src/armnnSerializer/Serializer.cpp +++ b/src/armnnSerializer/Serializer.cpp @@ -470,6 +470,29 @@ void SerializerVisitor::VisitSpaceToBatchNdLayer(const armnn::IConnectableLayer* CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_SpaceToBatchNdLayer); } +void SerializerVisitor::VisitNormalizationLayer(const armnn::IConnectableLayer* layer, + const armnn::NormalizationDescriptor& descriptor, + const char* name) +{ + auto fbNormalizationBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Normalization); + + auto fbNormalizationDescriptor = serializer::CreateNormalizationDescriptor( + m_flatBufferBuilder, + GetFlatBufferNormalizationAlgorithmChannel(descriptor.m_NormChannelType), + GetFlatBufferNormalizationAlgorithmMethod(descriptor.m_NormMethodType), + descriptor.m_NormSize, + descriptor.m_Alpha, + descriptor.m_Beta, + descriptor.m_K, + GetFlatBufferDataLayout(descriptor.m_DataLayout)); + + auto flatBufferLayer = serializer::CreateNormalizationLayer(m_flatBufferBuilder, + fbNormalizationBaseLayer, + fbNormalizationDescriptor); + + CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_NormalizationLayer); +} + fb::Offset<serializer::LayerBase> SerializerVisitor::CreateLayerBase(const IConnectableLayer* layer, const serializer::LayerType layerType) { diff --git a/src/armnnSerializer/Serializer.hpp b/src/armnnSerializer/Serializer.hpp index 3d6f1b5700..7e6097c465 100644 --- a/src/armnnSerializer/Serializer.hpp +++ b/src/armnnSerializer/Serializer.hpp @@ -118,6 +118,9 @@ public: const armnn::SpaceToBatchNdDescriptor& spaceToBatchNdDescriptor, const char* name = nullptr) override; + void VisitNormalizationLayer(const armnn::IConnectableLayer* layer, + const armnn::NormalizationDescriptor& normalizationDescriptor, + const char* name = nullptr) override; private: /// Creates the Input Slots and Output Slots and LayerBase for the layer. diff --git a/src/armnnSerializer/SerializerSupport.md b/src/armnnSerializer/SerializerSupport.md index 83987f98d7..d018a35c3a 100644 --- a/src/armnnSerializer/SerializerSupport.md +++ b/src/armnnSerializer/SerializerSupport.md @@ -18,6 +18,7 @@ The Arm NN SDK Serializer currently supports the following layers: * Maximum * Minimum * Multiplication +* Normalization * Permute * Pooling2d * Reshape diff --git a/src/armnnSerializer/SerializerUtils.cpp b/src/armnnSerializer/SerializerUtils.cpp index 592f258b81..bfe795c8c4 100644 --- a/src/armnnSerializer/SerializerUtils.cpp +++ b/src/armnnSerializer/SerializerUtils.cpp @@ -96,4 +96,32 @@ armnnSerializer::PaddingMethod GetFlatBufferPaddingMethod(armnn::PaddingMethod p } } +armnnSerializer::NormalizationAlgorithmChannel GetFlatBufferNormalizationAlgorithmChannel( + armnn::NormalizationAlgorithmChannel normalizationAlgorithmChannel) +{ + switch (normalizationAlgorithmChannel) + { + case armnn::NormalizationAlgorithmChannel::Across: + return armnnSerializer::NormalizationAlgorithmChannel::NormalizationAlgorithmChannel_Across; + case armnn::NormalizationAlgorithmChannel::Within: + return armnnSerializer::NormalizationAlgorithmChannel::NormalizationAlgorithmChannel_Within; + default: + return armnnSerializer::NormalizationAlgorithmChannel::NormalizationAlgorithmChannel_Across; + } +} + +armnnSerializer::NormalizationAlgorithmMethod GetFlatBufferNormalizationAlgorithmMethod( + armnn::NormalizationAlgorithmMethod normalizationAlgorithmMethod) +{ + switch (normalizationAlgorithmMethod) + { + case armnn::NormalizationAlgorithmMethod::LocalBrightness: + return armnnSerializer::NormalizationAlgorithmMethod::NormalizationAlgorithmMethod_LocalBrightness; + case armnn::NormalizationAlgorithmMethod::LocalContrast: + return armnnSerializer::NormalizationAlgorithmMethod::NormalizationAlgorithmMethod_LocalContrast; + default: + return armnnSerializer::NormalizationAlgorithmMethod::NormalizationAlgorithmMethod_LocalBrightness; + } +} + } // namespace armnnSerializer
\ No newline at end of file diff --git a/src/armnnSerializer/SerializerUtils.hpp b/src/armnnSerializer/SerializerUtils.hpp index 9b1dff9112..29cda0d629 100644 --- a/src/armnnSerializer/SerializerUtils.hpp +++ b/src/armnnSerializer/SerializerUtils.hpp @@ -24,4 +24,10 @@ armnnSerializer::OutputShapeRounding GetFlatBufferOutputShapeRounding( armnnSerializer::PaddingMethod GetFlatBufferPaddingMethod(armnn::PaddingMethod paddingMethod); +armnnSerializer::NormalizationAlgorithmChannel GetFlatBufferNormalizationAlgorithmChannel( + armnn::NormalizationAlgorithmChannel normalizationAlgorithmChannel); + +armnnSerializer::NormalizationAlgorithmMethod GetFlatBufferNormalizationAlgorithmMethod( + armnn::NormalizationAlgorithmMethod normalizationAlgorithmMethod); + } // namespace armnnSerializer diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp index 7e4ff8c614..271b3e71bd 100644 --- a/src/armnnSerializer/test/SerializerTests.cpp +++ b/src/armnnSerializer/test/SerializerTests.cpp @@ -891,6 +891,54 @@ BOOST_AUTO_TEST_CASE(SerializeDivision) deserializedNetwork->Accept(nameChecker); } +BOOST_AUTO_TEST_CASE(SerializeDeserializeNormalization) +{ + class VerifyNormalizationName : public armnn::LayerVisitorBase<armnn::VisitorNoThrowPolicy> + { + public: + void VisitNormalizationLayer(const armnn::IConnectableLayer*, + const armnn::NormalizationDescriptor& normalizationDescriptor, + const char* name) override + { + BOOST_TEST(name == "NormalizationLayer"); + } + }; + + unsigned int inputShape[] = {2, 1, 2, 2}; + unsigned int outputShape[] = {2, 1, 2, 2}; + + armnn::NormalizationDescriptor desc; + desc.m_DataLayout = armnn::DataLayout::NCHW; + desc.m_NormSize = 3; + desc.m_Alpha = 1; + desc.m_Beta = 1; + desc.m_K = 1; + + auto inputTensorInfo = armnn::TensorInfo(4, inputShape, armnn::DataType::Float32); + auto outputTensorInfo = armnn::TensorInfo(4, outputShape, armnn::DataType::Float32); + + armnn::INetworkPtr network = armnn::INetwork::Create(); + armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0); + armnn::IConnectableLayer* const normalizationLayer = network->AddNormalizationLayer(desc, "NormalizationLayer"); + armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0); + + inputLayer->GetOutputSlot(0).Connect(normalizationLayer->GetInputSlot(0)); + inputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo); + normalizationLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); + normalizationLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + + armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); + BOOST_CHECK(deserializedNetwork); + + VerifyNormalizationName nameChecker; + deserializedNetwork->Accept(nameChecker); + + CheckDeserializedNetworkAgainstOriginal(*network, + *deserializedNetwork, + {inputTensorInfo.GetShape()}, + {outputTensorInfo.GetShape()}); +} + BOOST_AUTO_TEST_CASE(SerializeDeserializeEqual) { class VerifyEqualName : public armnn::LayerVisitorBase<armnn::VisitorNoThrowPolicy> @@ -932,5 +980,4 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeEqual) {outputTensorInfo.GetShape()}, {0, 1}); } - BOOST_AUTO_TEST_SUITE_END() |