diff options
Diffstat (limited to 'src/armnnSerializer')
-rw-r--r-- | src/armnnSerializer/ArmnnSchema.fbs | 17 | ||||
-rw-r--r-- | src/armnnSerializer/Serializer.cpp | 13 | ||||
-rw-r--r-- | src/armnnSerializer/SerializerSupport.md | 1 | ||||
-rw-r--r-- | src/armnnSerializer/test/SerializerTests.cpp | 29 |
4 files changed, 53 insertions, 7 deletions
diff --git a/src/armnnSerializer/ArmnnSchema.fbs b/src/armnnSerializer/ArmnnSchema.fbs index 532c12c706..18415ce785 100644 --- a/src/armnnSerializer/ArmnnSchema.fbs +++ b/src/armnnSerializer/ArmnnSchema.fbs @@ -156,7 +156,8 @@ enum LayerType : uint { StandIn = 53, ElementwiseUnary = 54, Transpose = 55, - QLstm = 56 + QLstm = 56, + Fill = 57 } // Base layer table to be used as part of other layers @@ -284,6 +285,15 @@ table EqualLayer { base:LayerBase; } +table FillLayer { + base:LayerBase; + descriptor:FillDescriptor; +} + +table FillDescriptor { + value:float; +} + table FloorLayer{ base:LayerBase; } @@ -901,7 +911,8 @@ union Layer { StandInLayer, ElementwiseUnaryLayer, TransposeLayer, - QLstmLayer + QLstmLayer, + FillLayer } table AnyLayer { @@ -920,4 +931,4 @@ table SerializedGraph { featureVersions:FeatureCompatibilityVersions; } -root_type SerializedGraph; +root_type SerializedGraph;
\ No newline at end of file diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp index ddd38e18ef..17076c62ab 100644 --- a/src/armnnSerializer/Serializer.cpp +++ b/src/armnnSerializer/Serializer.cpp @@ -472,10 +472,15 @@ void SerializerVisitor::VisitFillLayer(const armnn::IConnectableLayer* layer, const armnn::FillDescriptor& fillDescriptor, const char* name) { - throw UnimplementedException("SerializerVisitor::VisitFillLayer is not implemented"); IgnoreUnused(name); - IgnoreUnused(layer); - IgnoreUnused(fillDescriptor); + + auto fbFillBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Fill); + + auto fbDescriptor = serializer::CreateFillDescriptor(m_flatBufferBuilder, fillDescriptor.m_Value); + + auto fbFillLayer = serializer::CreateFillLayer(m_flatBufferBuilder, fbFillBaseLayer, fbDescriptor); + + CreateAnyLayer(fbFillLayer.o, serializer::Layer::Layer_FillLayer); } void SerializerVisitor::VisitFloorLayer(const armnn::IConnectableLayer *layer, const char *name) @@ -1726,4 +1731,4 @@ bool Serializer::SaveSerializedToStream(std::ostream& stream) return !stream.bad(); } -} // namespace armnnSerializer +} // namespace armnnSerializer
\ No newline at end of file diff --git a/src/armnnSerializer/SerializerSupport.md b/src/armnnSerializer/SerializerSupport.md index f8d86551ae..4f7868bee7 100644 --- a/src/armnnSerializer/SerializerSupport.md +++ b/src/armnnSerializer/SerializerSupport.md @@ -21,6 +21,7 @@ The Arm NN SDK Serializer currently supports the following layers: * DetectionPostProcess * Division * ElementwiseUnary +* Fill * Floor * FullyConnected * Gather diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp index e7f93c6740..fa43e09647 100644 --- a/src/armnnSerializer/test/SerializerTests.cpp +++ b/src/armnnSerializer/test/SerializerTests.cpp @@ -1199,6 +1199,35 @@ BOOST_AUTO_TEST_CASE(EnsureEqualBackwardCompatibility) deserializedNetwork->Accept(verifier); } +BOOST_AUTO_TEST_CASE(SerializeFill) +{ + DECLARE_LAYER_VERIFIER_CLASS_WITH_DESCRIPTOR(Fill) + + const std::string layerName("fill"); + const armnn::TensorInfo inputInfo({4}, armnn::DataType::Float32); + const armnn::TensorInfo outputInfo({1, 3, 3, 1}, armnn::DataType::Float32); + + armnn::FillDescriptor descriptor(1.0f); + + armnn::INetworkPtr network = armnn::INetwork::Create(); + armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0); + armnn::IConnectableLayer* const fillLayer = network->AddFillLayer(descriptor, layerName.c_str()); + armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0); + + inputLayer->GetOutputSlot(0).Connect(fillLayer->GetInputSlot(0)); + fillLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); + + inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo); + fillLayer->GetOutputSlot(0).SetTensorInfo(outputInfo); + + armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); + BOOST_CHECK(deserializedNetwork); + + FillLayerVerifier verifier(layerName, {inputInfo}, {outputInfo}, descriptor); + + deserializedNetwork->Accept(verifier); +} + BOOST_AUTO_TEST_CASE(SerializeFloor) { DECLARE_LAYER_VERIFIER_CLASS(Floor) |