diff options
author | Mike Kelly <mike.kelly@arm.com> | 2019-02-11 17:01:27 +0000 |
---|---|---|
committer | Mike Kelly <mike.kelly@arm.com> | 2019-02-11 17:01:27 +0000 |
commit | 8c1701a2d9c1da0e1decb2afdc2093aa88810242 (patch) | |
tree | 870ee9af506bb468c513214ab539f41aeb1e34dc /src | |
parent | a40521a70e73d20a060fa2df0e83b02c4f1c6139 (diff) | |
download | armnn-8c1701a2d9c1da0e1decb2afdc2093aa88810242.tar.gz |
IVGCVSW-2531 Serialize a simple ArmNN Network
Change-Id: I68cf5072aca6e3a8b3b8c57e19b6d417cd5813fc
Signed-off-by: Mike Kelly <mike.kelly@arm.com>
Diffstat (limited to 'src')
-rw-r--r-- | src/armnn/Layer.cpp | 6 | ||||
-rw-r--r-- | src/armnn/Layer.hpp | 4 | ||||
-rw-r--r-- | src/armnn/LayerVisitorBase.hpp | 175 | ||||
-rw-r--r-- | src/armnn/Network.cpp | 8 | ||||
-rw-r--r-- | src/armnn/Network.hpp | 2 | ||||
-rw-r--r-- | src/armnn/OverrideInputRangeVisitor.hpp | 2 | ||||
-rw-r--r-- | src/armnn/QuantizerVisitor.hpp | 2 | ||||
-rw-r--r-- | src/armnn/StaticRangeVisitor.hpp | 2 | ||||
-rw-r--r-- | src/armnn/test/QuantizerTest.cpp | 2 | ||||
-rw-r--r-- | src/armnnDeserializeParser/DeserializeParser.cpp | 4 | ||||
-rw-r--r-- | src/armnnSerializer/README.md | 6 | ||||
-rw-r--r-- | src/armnnSerializer/Schema.fbs | 4 | ||||
-rw-r--r-- | src/armnnSerializer/SeralizerSupport.md | 11 | ||||
-rw-r--r-- | src/armnnSerializer/Serializer.cpp | 186 | ||||
-rw-r--r-- | src/armnnSerializer/Serializer.hpp | 72 | ||||
-rw-r--r-- | src/armnnSerializer/test/SerializerTests.cpp | 35 |
16 files changed, 336 insertions, 185 deletions
diff --git a/src/armnn/Layer.cpp b/src/armnn/Layer.cpp index c49dd61786..0a6328ba3d 100644 --- a/src/armnn/Layer.cpp +++ b/src/armnn/Layer.cpp @@ -143,6 +143,11 @@ void OutputSlot::ValidateConnectionIndex(unsigned int index) const } } +LayerGuid OutputSlot::GetOwningLayerGuid() const +{ + return GetOwningLayer().GetGuid(); +} + namespace { LayerGuid GenerateLayerGuid() { @@ -335,5 +340,4 @@ std::vector<TensorShape> Layer::InferOutputShapes(const std::vector<TensorShape> } return inputShapes; } - } // namespace armnn diff --git a/src/armnn/Layer.hpp b/src/armnn/Layer.hpp index c08c6b0631..507b37bf95 100644 --- a/src/armnn/Layer.hpp +++ b/src/armnn/Layer.hpp @@ -105,6 +105,8 @@ public: Layer& GetOwningLayer() const { return m_OwningLayer; } + LayerGuid GetOwningLayerGuid() const override; + const OutputHandler& GetOutputHandler() const { return m_OutputHandler; } OutputHandler& GetOutputHandler() { return m_OutputHandler; } @@ -141,7 +143,7 @@ public: return Disconnect(*boost::polymorphic_downcast<InputSlot*>(&slot)); } - unsigned int CalculateIndexOnOwner() const; + unsigned int CalculateIndexOnOwner() const override; bool operator==(const OutputSlot& other) const; diff --git a/src/armnn/LayerVisitorBase.hpp b/src/armnn/LayerVisitorBase.hpp deleted file mode 100644 index 2c37a21786..0000000000 --- a/src/armnn/LayerVisitorBase.hpp +++ /dev/null @@ -1,175 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#pragma once - -#include <armnn/ILayerVisitor.hpp> - -namespace armnn -{ - -struct VisitorThrowingPolicy -{ - static void Apply() { throw UnimplementedException(); } -}; - -struct VisitorNoThrowPolicy -{ - static void Apply() {} -}; - -// Visitor base class with empty implementations. -template<typename DefaultPolicy> -class LayerVisitorBase : public ILayerVisitor -{ -protected: - LayerVisitorBase() {} - virtual ~LayerVisitorBase() {} - -public: - void VisitInputLayer(const IConnectableLayer*, - LayerBindingId, - const char*) override { DefaultPolicy::Apply(); } - - void VisitConvolution2dLayer(const IConnectableLayer*, - const Convolution2dDescriptor&, - const ConstTensor&, - const Optional<ConstTensor>&, - const char*) override { DefaultPolicy::Apply(); } - - void VisitDepthwiseConvolution2dLayer(const IConnectableLayer*, - const DepthwiseConvolution2dDescriptor&, - const ConstTensor&, - const Optional<ConstTensor>&, - const char*) override { DefaultPolicy::Apply(); } - - void VisitDetectionPostProcessLayer(const IConnectableLayer*, - const DetectionPostProcessDescriptor&, - const ConstTensor&, - const char*) override { DefaultPolicy::Apply(); } - - void VisitFullyConnectedLayer(const IConnectableLayer*, - const FullyConnectedDescriptor&, - const ConstTensor&, - const Optional<ConstTensor>&, - const char*) override { DefaultPolicy::Apply(); } - - void VisitPermuteLayer(const IConnectableLayer*, - const PermuteDescriptor&, - const char*) override { DefaultPolicy::Apply(); } - - void VisitBatchToSpaceNdLayer(const IConnectableLayer*, - const BatchToSpaceNdDescriptor&, - const char*) override { DefaultPolicy::Apply(); } - - void VisitPooling2dLayer(const IConnectableLayer*, - const Pooling2dDescriptor&, - const char*) override { DefaultPolicy::Apply(); } - - void VisitActivationLayer(const IConnectableLayer*, - const ActivationDescriptor&, - const char*) override { DefaultPolicy::Apply(); } - - void VisitNormalizationLayer(const IConnectableLayer*, - const NormalizationDescriptor&, - const char*) override { DefaultPolicy::Apply(); } - - void VisitSoftmaxLayer(const IConnectableLayer*, - const SoftmaxDescriptor&, - const char*) override { DefaultPolicy::Apply(); } - - void VisitSplitterLayer(const IConnectableLayer*, - const ViewsDescriptor&, - const char*) override { DefaultPolicy::Apply(); } - - void VisitMergerLayer(const IConnectableLayer*, - const OriginsDescriptor&, - const char*) override { DefaultPolicy::Apply(); } - - void VisitAdditionLayer(const IConnectableLayer*, - const char*) override { DefaultPolicy::Apply(); } - - void VisitMultiplicationLayer(const IConnectableLayer*, - const char*) override { DefaultPolicy::Apply(); } - - void VisitBatchNormalizationLayer(const IConnectableLayer*, - const BatchNormalizationDescriptor&, - const ConstTensor&, - const ConstTensor&, - const ConstTensor&, - const ConstTensor&, - const char*) override { DefaultPolicy::Apply(); } - - void VisitResizeBilinearLayer(const IConnectableLayer*, - const ResizeBilinearDescriptor&, - const char*) override { DefaultPolicy::Apply(); } - - void VisitL2NormalizationLayer(const IConnectableLayer*, - const L2NormalizationDescriptor&, - const char*) override { DefaultPolicy::Apply(); } - - void VisitConstantLayer(const IConnectableLayer*, - const ConstTensor&, - const char*) override { DefaultPolicy::Apply(); } - - void VisitReshapeLayer(const IConnectableLayer*, - const ReshapeDescriptor&, - const char*) override { DefaultPolicy::Apply(); } - - void VisitSpaceToBatchNdLayer(const IConnectableLayer*, - const SpaceToBatchNdDescriptor&, - const char*) override { DefaultPolicy::Apply(); } - - void VisitFloorLayer(const IConnectableLayer*, - const char*) override { DefaultPolicy::Apply(); } - - void VisitOutputLayer(const IConnectableLayer*, - LayerBindingId id, - const char*) override { DefaultPolicy::Apply(); } - - void VisitLstmLayer(const IConnectableLayer*, - const LstmDescriptor&, - const LstmInputParams&, - const char*) override { DefaultPolicy::Apply(); } - - void VisitDivisionLayer(const IConnectableLayer*, - const char*) override { DefaultPolicy::Apply(); } - - void VisitSubtractionLayer(const IConnectableLayer*, - const char*) override { DefaultPolicy::Apply(); } - - void VisitMaximumLayer(const IConnectableLayer*, - const char*) override { DefaultPolicy::Apply(); } - - void VisitMeanLayer(const IConnectableLayer*, - const MeanDescriptor&, - const char*) override { DefaultPolicy::Apply(); } - - void VisitPadLayer(const IConnectableLayer*, - const PadDescriptor&, - const char*) override { DefaultPolicy::Apply(); } - - void VisitStridedSliceLayer(const IConnectableLayer*, - const StridedSliceDescriptor&, - const char*) override { DefaultPolicy::Apply(); } - - void VisitMinimumLayer(const IConnectableLayer*, - const char*) override { DefaultPolicy::Apply(); } - - void VisitGreaterLayer(const IConnectableLayer*, - const char*) override { DefaultPolicy::Apply(); } - - void VisitEqualLayer(const IConnectableLayer*, - const char*) override { DefaultPolicy::Apply(); } - - void VisitRsqrtLayer(const IConnectableLayer*, - const char*) override { DefaultPolicy::Apply(); } - - void VisitGatherLayer(const IConnectableLayer*, - const char*) override { DefaultPolicy::Apply(); } -}; - -} //namespace armnn - diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp index 5c70003785..cad1690cbd 100644 --- a/src/armnn/Network.cpp +++ b/src/armnn/Network.cpp @@ -917,6 +917,14 @@ IConnectableLayer* Network::AddGatherLayer(const char* name) return m_Graph->AddLayer<GatherLayer>(name); } +void Network::Accept(ILayerVisitor& visitor) const +{ + for (auto layer : GetGraph()) + { + layer->Accept(visitor); + }; +} + OptimizedNetwork::OptimizedNetwork(std::unique_ptr<Graph> graph) : m_Graph(std::move(graph)) { diff --git a/src/armnn/Network.hpp b/src/armnn/Network.hpp index 66fb240979..3754c2e6d1 100644 --- a/src/armnn/Network.hpp +++ b/src/armnn/Network.hpp @@ -148,6 +148,8 @@ public: IConnectableLayer* AddRsqrtLayer(const char* name = nullptr) override; + void Accept(ILayerVisitor& visitor) const override; + private: IConnectableLayer* AddFullyConnectedLayerImpl(const FullyConnectedDescriptor& fullyConnectedDescriptor, const ConstTensor& weights, diff --git a/src/armnn/OverrideInputRangeVisitor.hpp b/src/armnn/OverrideInputRangeVisitor.hpp index 0b1999f1f8..72396b4652 100644 --- a/src/armnn/OverrideInputRangeVisitor.hpp +++ b/src/armnn/OverrideInputRangeVisitor.hpp @@ -6,7 +6,7 @@ #pragma once #include "NetworkQuantizer.hpp" -#include "LayerVisitorBase.hpp" +#include "armnn/LayerVisitorBase.hpp" #include <unordered_map> diff --git a/src/armnn/QuantizerVisitor.hpp b/src/armnn/QuantizerVisitor.hpp index 44ebc052b6..cf151baf3c 100644 --- a/src/armnn/QuantizerVisitor.hpp +++ b/src/armnn/QuantizerVisitor.hpp @@ -5,7 +5,7 @@ #pragma once -#include "LayerVisitorBase.hpp" +#include "armnn/LayerVisitorBase.hpp" #include "StaticRangeVisitor.hpp" #include <armnn/INetwork.hpp> diff --git a/src/armnn/StaticRangeVisitor.hpp b/src/armnn/StaticRangeVisitor.hpp index 9c3a4f32c1..d834d0449d 100644 --- a/src/armnn/StaticRangeVisitor.hpp +++ b/src/armnn/StaticRangeVisitor.hpp @@ -5,7 +5,7 @@ #pragma once -#include "LayerVisitorBase.hpp" +#include "armnn/LayerVisitorBase.hpp" #include <armnn/INetwork.hpp> #include <armnn/INetworkQuantizer.hpp> diff --git a/src/armnn/test/QuantizerTest.cpp b/src/armnn/test/QuantizerTest.cpp index c83d179961..24c130c372 100644 --- a/src/armnn/test/QuantizerTest.cpp +++ b/src/armnn/test/QuantizerTest.cpp @@ -8,7 +8,7 @@ #include <armnn/INetworkQuantizer.hpp> #include <armnn/Types.hpp> -#include "../LayerVisitorBase.hpp" +#include "armnn/LayerVisitorBase.hpp" #include "../Network.hpp" #include "../Graph.hpp" #include "../NetworkQuantizerUtils.hpp" diff --git a/src/armnnDeserializeParser/DeserializeParser.cpp b/src/armnnDeserializeParser/DeserializeParser.cpp index ca2e7e3167..5ba92d51e2 100644 --- a/src/armnnDeserializeParser/DeserializeParser.cpp +++ b/src/armnnDeserializeParser/DeserializeParser.cpp @@ -226,7 +226,7 @@ DeserializeParser::LayerBaseRawPtrVector DeserializeParser::GetGraphInputs(const for (unsigned int i=0; i<numInputs; ++i) { - uint32_t inputId = CHECKED_NON_NEGATIVE(graphPtr->inputIds()->Get(i)); + uint32_t inputId = graphPtr->inputIds()->Get(i); result[i] = GetBaseLayer(graphPtr, static_cast<uint32_t>(inputId)); } return result; @@ -241,7 +241,7 @@ DeserializeParser::LayerBaseRawPtrVector DeserializeParser::GetGraphOutputs(cons for (unsigned int i=0; i<numOutputs; ++i) { - uint32_t outputId = CHECKED_NON_NEGATIVE(graphPtr->outputIds()->Get(i)); + uint32_t outputId = graphPtr->outputIds()->Get(i); result[i] = GetBaseLayer(graphPtr, static_cast<uint32_t>(outputId)); } return result; diff --git a/src/armnnSerializer/README.md b/src/armnnSerializer/README.md new file mode 100644 index 0000000000..61478b1470 --- /dev/null +++ b/src/armnnSerializer/README.md @@ -0,0 +1,6 @@ +# The Arm NN Serializer + +The `armnnSerializer` is a library for serializing an Arm NN network to a stream. + +For more information about the layers that are supported, and the networks that have been tested, +see [SerializerSupport.md](./SerializerSupport.md) diff --git a/src/armnnSerializer/Schema.fbs b/src/armnnSerializer/Schema.fbs index 2a5fbcd2ad..2527f6d0f6 100644 --- a/src/armnnSerializer/Schema.fbs +++ b/src/armnnSerializer/Schema.fbs @@ -108,8 +108,8 @@ table AnyLayer { // Root type for serialized data is the graph of the network table SerializedGraph { layers:[AnyLayer]; - inputIds:[int]; - outputIds:[int]; + inputIds:[uint]; + outputIds:[uint]; } root_type SerializedGraph; diff --git a/src/armnnSerializer/SeralizerSupport.md b/src/armnnSerializer/SeralizerSupport.md new file mode 100644 index 0000000000..16d1940be0 --- /dev/null +++ b/src/armnnSerializer/SeralizerSupport.md @@ -0,0 +1,11 @@ +# The layers that ArmNN SDK Serializer currently supports. + +This reference guide provides a list of layers which can be serialized currently by the Arm NN SDK. + +## Fully supported + +The Arm NN SDK Serializer currently supports the following layers: + +* Addition + +More machine learning layers will be supported in future releases.
\ No newline at end of file diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp new file mode 100644 index 0000000000..57baf0e28c --- /dev/null +++ b/src/armnnSerializer/Serializer.cpp @@ -0,0 +1,186 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "Serializer.hpp" +#include <armnn/ArmNN.hpp> +#include <iostream> +#include <Schema_generated.h> +#include <flatbuffers/util.h> + +using namespace armnn; +namespace fb = flatbuffers; +namespace serializer = armnn::armnnSerializer; + +namespace armnnSerializer +{ + +serializer::DataType GetFlatBufferDataType(DataType dataType) +{ + switch (dataType) + { + case DataType::Float32: + return serializer::DataType::DataType_Float32; + case DataType::Float16: + return serializer::DataType::DataType_Float16; + case DataType::Signed32: + return serializer::DataType::DataType_Signed32; + case DataType::QuantisedAsymm8: + return serializer::DataType::DataType_QuantisedAsymm8; + case DataType::Boolean: + return serializer::DataType::DataType_Boolean; + default: + return serializer::DataType::DataType_Float16; + } +} + +// Build FlatBuffer for Input Layer +void Serializer::VisitInputLayer(const IConnectableLayer* layer, LayerBindingId id, const char* name) +{ + // Create FlatBuffer BaseLayer + auto flatBufferInputBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Input); + + // Create FlatBuffer BindableBaseLayer + auto flatBufferInputBindableBaseLayer = serializer::CreateBindableLayerBase(m_flatBufferBuilder, + flatBufferInputBaseLayer, + id); + + // Push layer Guid to outputIds. + m_inputIds.push_back(layer->GetGuid()); + + // Create the FlatBuffer InputLayer + auto flatBufferInputLayer = serializer::CreateInputLayer(m_flatBufferBuilder, flatBufferInputBindableBaseLayer); + + // Add the AnyLayer to the FlatBufferLayers + CreateAnyLayer(flatBufferInputLayer.o, serializer::Layer::Layer_InputLayer); +} + +// Build FlatBuffer for Output Layer +void Serializer::VisitOutputLayer(const IConnectableLayer* layer, LayerBindingId id, const char* name) +{ + // Create FlatBuffer BaseLayer + auto flatBufferOutputBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Output); + + // Create FlatBuffer BindableBaseLayer + auto flatBufferOutputBindableBaseLayer = serializer::CreateBindableLayerBase(m_flatBufferBuilder, + flatBufferOutputBaseLayer, + id); + // Push layer Guid to outputIds. + m_outputIds.push_back(layer->GetGuid()); + + // Create the FlatBuffer OutputLayer + auto flatBufferOutputLayer = serializer::CreateOutputLayer(m_flatBufferBuilder, flatBufferOutputBindableBaseLayer); + // Add the AnyLayer to the FlatBufferLayers + CreateAnyLayer(flatBufferOutputLayer.o, serializer::Layer::Layer_OutputLayer); +} + +// Build FlatBuffer for Addition Layer +void Serializer::VisitAdditionLayer(const IConnectableLayer* layer, const char* name) +{ + // Create FlatBuffer BaseLayer + auto flatBufferAdditionBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Addition); + + // Create the FlatBuffer AdditionLayer + auto flatBufferAdditionLayer = serializer::CreateAdditionLayer(m_flatBufferBuilder, flatBufferAdditionBaseLayer); + + // Add the AnyLayer to the FlatBufferLayers + CreateAnyLayer(flatBufferAdditionLayer.o, serializer::Layer::Layer_AdditionLayer); +} + +void Serializer::Serialize(const INetwork& inNetwork) +{ + // Iterate through to network + inNetwork.Accept(*this); + + // Create FlatBuffer SerializedGraph + auto serializedGraph = serializer::CreateSerializedGraph(m_flatBufferBuilder, + m_flatBufferBuilder.CreateVector(m_serializedLayers), + m_flatBufferBuilder.CreateVector(m_inputIds), + m_flatBufferBuilder.CreateVector(m_outputIds)); + + // Serialize the graph + m_flatBufferBuilder.Finish(serializedGraph); +} + +bool Serializer::SaveSerializedToStream(std::ostream& stream) +{ + stream.write(reinterpret_cast<const char*>(m_flatBufferBuilder.GetBufferPointer()), m_flatBufferBuilder.GetSize()); + return !stream.bad(); +} + +fb::Offset<serializer::LayerBase> Serializer::CreateLayerBase(const IConnectableLayer* layer, + const serializer::LayerType layerType) +{ + std::vector<fb::Offset<serializer::InputSlot>> inputSlots = CreateInputSlots(layer); + std::vector<fb::Offset<serializer::OutputSlot>> outputSlots = CreateOutputSlots(layer); + + return serializer::CreateLayerBase(m_flatBufferBuilder, + layer->GetGuid(), + m_flatBufferBuilder.CreateString(layer->GetName()), + layerType, + m_flatBufferBuilder.CreateVector(inputSlots), + m_flatBufferBuilder.CreateVector(outputSlots)); +} + +void Serializer::CreateAnyLayer(const flatbuffers::Offset<void>& layer, const serializer::Layer serializerLayer) +{ + auto anyLayer = armnn::armnnSerializer::CreateAnyLayer(m_flatBufferBuilder, + serializerLayer, + layer); + m_serializedLayers.push_back(anyLayer); +} + +std::vector<fb::Offset<serializer::InputSlot>> Serializer::CreateInputSlots(const IConnectableLayer* layer) +{ + std::vector<fb::Offset <serializer::InputSlot>> inputSlots; + + // Get the InputSlots + for (unsigned int slotIndex = 0; slotIndex<layer->GetNumInputSlots(); ++slotIndex) + { + const IInputSlot& inputSlot = layer->GetInputSlot(slotIndex); + + // Get the Connection for the InputSlot + const IOutputSlot* connection = inputSlot.GetConnection(); + + // Create FlatBuffer Connection + serializer::Connection conn(connection->GetOwningLayerGuid(), connection->CalculateIndexOnOwner()); + // Create FlatBuffer InputSlot + inputSlots.push_back(serializer::CreateInputSlot(m_flatBufferBuilder, slotIndex, &conn)); + } + return inputSlots; +} + +std::vector<fb::Offset<serializer::OutputSlot>> Serializer::CreateOutputSlots(const IConnectableLayer* layer) +{ + std::vector<fb::Offset<serializer::OutputSlot>> outputSlots; + + // Get the OutputSlots + for (unsigned int slotIndex = 0; slotIndex < layer->GetNumOutputSlots(); ++slotIndex) + { + const IOutputSlot& outputSlot = layer->GetOutputSlot(slotIndex); + const TensorInfo& tensorInfo = outputSlot.GetTensorInfo(); + + // Get the dimensions + std::vector<unsigned int> shape; + for(unsigned int dim = 0; dim < tensorInfo.GetShape().GetNumDimensions(); ++dim) + { + shape.push_back(tensorInfo.GetShape()[dim]); + } + + // Create FlatBuffer TensorInfo + auto flatBufferTensorInfo = serializer::CreateTensorInfo(m_flatBufferBuilder, + m_flatBufferBuilder.CreateVector(shape), + GetFlatBufferDataType(tensorInfo.GetDataType()), + tensorInfo.GetQuantizationScale(), + tensorInfo.GetQuantizationOffset()); + + // Create FlatBuffer Outputslot + outputSlots.push_back(serializer::CreateOutputSlot(m_flatBufferBuilder, + slotIndex, + flatBufferTensorInfo)); + } + return outputSlots; +} + +} //namespace armnnSerializer diff --git a/src/armnnSerializer/Serializer.hpp b/src/armnnSerializer/Serializer.hpp new file mode 100644 index 0000000000..697e5cfaa7 --- /dev/null +++ b/src/armnnSerializer/Serializer.hpp @@ -0,0 +1,72 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// +#pragma once + +#include <armnn/ILayerVisitor.hpp> +#include <armnn/LayerVisitorBase.hpp> +#include <iostream> +#include <Schema_generated.h> + +namespace armnnSerializer +{ + +class Serializer : public armnn::LayerVisitorBase<armnn::VisitorNoThrowPolicy> +{ +public: + Serializer() {}; + ~Serializer() {}; + + void VisitAdditionLayer(const armnn::IConnectableLayer* layer, + const char* name = nullptr) override; + + void VisitInputLayer(const armnn::IConnectableLayer* layer, + armnn::LayerBindingId id, + const char* name = nullptr) override; + + void VisitOutputLayer(const armnn::IConnectableLayer* layer, + armnn::LayerBindingId id, + const char* name = nullptr) override; + + /// Serializes the network to ArmNN SerializedGraph. + /// @param [in] inNetwork The network to be serialized. + void Serialize(const armnn::INetwork& inNetwork); + + /// Serializes the SerializedGraph to the stream. + /// @param [stream] the stream to save to + /// @return true if graph is Serialized to the Stream, false otherwise + bool SaveSerializedToStream(std::ostream& stream); + +private: + + /// Creates the Input Slots and Output Slots and LayerBase for the layer. + flatbuffers::Offset<armnn::armnnSerializer::LayerBase> CreateLayerBase( + const armnn::IConnectableLayer* layer, + const armnn::armnnSerializer::LayerType layerType); + + /// Creates the serializer AnyLayer for the layer and adds it to m_serializedLayers. + void CreateAnyLayer(const flatbuffers::Offset<void>& layer, const armnn::armnnSerializer::Layer serializerLayer); + + /// Creates the serializer InputSlots for the layer. + std::vector<flatbuffers::Offset<armnn::armnnSerializer::InputSlot>> CreateInputSlots( + const armnn::IConnectableLayer* layer); + + /// Creates the serializer OutputSlots for the layer. + std::vector<flatbuffers::Offset<armnn::armnnSerializer::OutputSlot>> CreateOutputSlots( + const armnn::IConnectableLayer* layer); + + /// FlatBufferBuilder to create our layers' FlatBuffers. + flatbuffers::FlatBufferBuilder m_flatBufferBuilder; + + /// AnyLayers required by the SerializedGraph. + std::vector<flatbuffers::Offset<armnn::armnnSerializer::AnyLayer>> m_serializedLayers; + + /// Guids of all Input Layers required by the SerializedGraph. + std::vector<unsigned int> m_inputIds; + + /// Guids of all Output Layers required by the SerializedGraph. + std::vector<unsigned int> m_outputIds; +}; + +} //namespace armnnSerializer diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp new file mode 100644 index 0000000000..17ad6e3695 --- /dev/null +++ b/src/armnnSerializer/test/SerializerTests.cpp @@ -0,0 +1,35 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include <armnn/ArmNN.hpp> +#include <armnn/INetwork.hpp> +#include "../Serializer.hpp" +#include <sstream> +#include <boost/test/unit_test.hpp> + +BOOST_AUTO_TEST_SUITE(SerializerTests) + +BOOST_AUTO_TEST_CASE(SimpleNetworkSerialization) +{ + armnn::INetworkPtr network = armnn::INetwork::Create(); + armnn::IConnectableLayer* const inputLayer0 = network->AddInputLayer(0); + armnn::IConnectableLayer* const inputLayer1 = network->AddInputLayer(1); + + armnn::IConnectableLayer* const additionLayer0 = network->AddAdditionLayer(); + inputLayer0->GetOutputSlot(0).Connect(additionLayer0->GetInputSlot(0)); + inputLayer1->GetOutputSlot(0).Connect(additionLayer0->GetInputSlot(1)); + + armnn::IConnectableLayer* const outputLayer0 = network->AddOutputLayer(0); + additionLayer0->GetOutputSlot(0).Connect(outputLayer0->GetInputSlot(0)); + + armnnSerializer::Serializer serializer; + serializer.Serialize(*network); + + std::stringstream stream; + serializer.SaveSerializedToStream(stream); + BOOST_TEST(stream.str().length() > 0); +} + +BOOST_AUTO_TEST_SUITE_END() |