diff options
Diffstat (limited to 'src/armnnSerializer')
-rw-r--r-- | src/armnnSerializer/Schema.fbs | 16 | ||||
-rw-r--r-- | src/armnnSerializer/Serializer.cpp | 39 | ||||
-rw-r--r-- | src/armnnSerializer/Serializer.hpp | 6 | ||||
-rw-r--r-- | src/armnnSerializer/test/SerializerTests.cpp | 41 |
4 files changed, 100 insertions, 2 deletions
diff --git a/src/armnnSerializer/Schema.fbs b/src/armnnSerializer/Schema.fbs index 94ca23b0cd..dc14069798 100644 --- a/src/armnnSerializer/Schema.fbs +++ b/src/armnnSerializer/Schema.fbs @@ -91,7 +91,8 @@ enum LayerType : uint { Convolution2d = 7, DepthwiseConvolution2d = 8, Activation = 9, - Permute = 10 + Permute = 10, + FullyConnected = 11 } // Base layer table to be used as part of other layers @@ -142,6 +143,18 @@ table Convolution2dDescriptor { dataLayout:DataLayout = NCHW; } +table FullyConnectedLayer { + base:LayerBase; + descriptor:FullyConnectedDescriptor; + weights:ConstTensor; + biases:ConstTensor; +} + +table FullyConnectedDescriptor { + biasEnabled:bool = false; + transposeWeightsMatrix:bool = false; +} + table InputLayer { base:BindableLayerBase; } @@ -240,6 +253,7 @@ union Layer { AdditionLayer, Convolution2dLayer, DepthwiseConvolution2dLayer, + FullyConnectedLayer, InputLayer, MultiplicationLayer, OutputLayer, diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp index e1d22ec406..b4afd37b99 100644 --- a/src/armnnSerializer/Serializer.cpp +++ b/src/armnnSerializer/Serializer.cpp @@ -324,7 +324,44 @@ void SerializerVisitor::VisitPooling2dLayer(const armnn::IConnectableLayer* laye CreateAnyLayer(fbPooling2dLayer.o, serializer::Layer::Layer_Pooling2dLayer); } -fb::Offset<serializer::LayerBase> SerializerVisitor::CreateLayerBase(const armnn::IConnectableLayer* layer, +// Build FlatBuffer for FullyConnected Layer +void SerializerVisitor::VisitFullyConnectedLayer(const armnn::IConnectableLayer* layer, + const armnn::FullyConnectedDescriptor& fullyConnectedDescriptor, + const armnn::ConstTensor& weights, + const armnn::Optional<armnn::ConstTensor>& biases, + const char* name) +{ + // Create FlatBuffer BaseLayer + auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_FullyConnected); + + // Create FlatBuffer FullyConnectedDescriptor + auto flatBufferDescriptor = + serializer::CreateFullyConnectedDescriptor(m_flatBufferBuilder, + fullyConnectedDescriptor.m_BiasEnabled, + fullyConnectedDescriptor.m_TransposeWeightMatrix); + + // Create FlatBuffer weights data + auto flatBufferWeights = CreateConstTensorInfo(weights); + + // Create FlatBuffer bias data + flatbuffers::Offset<serializer::ConstTensor> flatBufferBiases; + if (fullyConnectedDescriptor.m_BiasEnabled) + { + flatBufferBiases = CreateConstTensorInfo(biases.value()); + } + + // Create FlatBuffer FullyConnectedLayer + auto flatBufferLayer = serializer::CreateFullyConnectedLayer(m_flatBufferBuilder, + flatBufferBaseLayer, + flatBufferDescriptor, + flatBufferWeights, + flatBufferBiases); + + // Add created FullyConnectedLayer to the FlatBufferLayers + CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_FullyConnectedLayer); +} + +fb::Offset<serializer::LayerBase> SerializerVisitor::CreateLayerBase(const IConnectableLayer* layer, const serializer::LayerType layerType) { std::vector<fb::Offset<serializer::InputSlot>> inputSlots = CreateInputSlots(layer); diff --git a/src/armnnSerializer/Serializer.hpp b/src/armnnSerializer/Serializer.hpp index 329b005624..0a62732ef2 100644 --- a/src/armnnSerializer/Serializer.hpp +++ b/src/armnnSerializer/Serializer.hpp @@ -61,6 +61,12 @@ public: const armnn::Optional<armnn::ConstTensor>& biases, const char* name = nullptr) override; + void VisitFullyConnectedLayer(const armnn::IConnectableLayer* layer, + const armnn::FullyConnectedDescriptor& fullyConnectedDescriptor, + const armnn::ConstTensor& weights, + const armnn::Optional<armnn::ConstTensor>& biases, + const char* name = nullptr) override; + void VisitInputLayer(const armnn::IConnectableLayer* layer, armnn::LayerBindingId id, const char* name = nullptr) override; diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp index 822f9c7e00..ede24baf9e 100644 --- a/src/armnnSerializer/test/SerializerTests.cpp +++ b/src/armnnSerializer/test/SerializerTests.cpp @@ -404,4 +404,45 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializePermute) outputTensorInfo.GetShape()); } +BOOST_AUTO_TEST_CASE(SerializeDeserializeFullyConnected) +{ + armnn::TensorInfo inputInfo ({ 2, 5, 1, 1 }, armnn::DataType::Float32); + armnn::TensorInfo outputInfo({ 2, 3 }, armnn::DataType::Float32); + + armnn::TensorInfo weightsInfo({ 5, 3 }, armnn::DataType::Float32); + armnn::TensorInfo biasesInfo ({ 3 }, armnn::DataType::Float32); + + armnn::FullyConnectedDescriptor descriptor; + descriptor.m_BiasEnabled = true; + descriptor.m_TransposeWeightMatrix = false; + + std::vector<float> weightsData = GenerateRandomData<float>(weightsInfo.GetNumElements()); + std::vector<float> biasesData = GenerateRandomData<float>(biasesInfo.GetNumElements()); + + armnn::ConstTensor weights(weightsInfo, weightsData); + armnn::ConstTensor biases(biasesInfo, biasesData); + + armnn::INetworkPtr network = armnn::INetwork::Create(); + armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0, "input"); + armnn::IConnectableLayer* const fullyConnectedLayer = network->AddFullyConnectedLayer(descriptor, + weights, + biases, + "fully_connected"); + armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0, "output"); + + inputLayer->GetOutputSlot(0).Connect(fullyConnectedLayer->GetInputSlot(0)); + inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo); + + fullyConnectedLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); + fullyConnectedLayer->GetOutputSlot(0).SetTensorInfo(outputInfo); + + armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); + BOOST_CHECK(deserializedNetwork); + + CheckDeserializedNetworkAgainstOriginal(*network, + *deserializedNetwork, + inputInfo.GetShape(), + outputInfo.GetShape()); +} + BOOST_AUTO_TEST_SUITE_END() |