diff options
author | Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> | 2019-06-21 13:53:38 +0100 |
---|---|---|
committer | Áron Virginás-Tar <aron.virginas-tar@arm.com> | 2019-06-21 14:30:08 +0000 |
commit | cb549301bc4c5a405e02c1f433537557423d2e36 (patch) | |
tree | 8f9bf05911c05e673e0b90b68c29164e47e8609c /src/armnnSerializer | |
parent | 0dcffec80292cd2e0e7c2736fd3db63abd7c3f64 (diff) | |
download | armnn-cb549301bc4c5a405e02c1f433537557423d2e36.tar.gz |
IVGCVSW-3321 Add serialization support for TransposeConvolution2dLayer
Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com>
Change-Id: If0c8f3662d5e03696f97040abed784c0fbcdbc6f
Diffstat (limited to 'src/armnnSerializer')
-rw-r--r-- | src/armnnSerializer/ArmnnSchema.fbs | 24 | ||||
-rw-r--r-- | src/armnnSerializer/Serializer.cpp | 27 | ||||
-rw-r--r-- | src/armnnSerializer/SerializerSupport.md | 1 | ||||
-rw-r--r-- | src/armnnSerializer/test/SerializerTests.cpp | 104 |
4 files changed, 153 insertions, 3 deletions
diff --git a/src/armnnSerializer/ArmnnSchema.fbs b/src/armnnSerializer/ArmnnSchema.fbs index 794789390a..83275ca248 100644 --- a/src/armnnSerializer/ArmnnSchema.fbs +++ b/src/armnnSerializer/ArmnnSchema.fbs @@ -123,7 +123,8 @@ enum LayerType : uint { Switch = 38, Concat = 39, SpaceToDepth = 40, - Prelu = 41 + Prelu = 41, + TransposeConvolution2d = 42 } // Base layer table to be used as part of other layers @@ -561,6 +562,24 @@ table PreluLayer { base:LayerBase; } +table TransposeConvolution2dLayer { + base:LayerBase; + descriptor:TransposeConvolution2dDescriptor; + weights:ConstTensor; + biases:ConstTensor; +} + +table TransposeConvolution2dDescriptor { + padLeft:uint; + padRight:uint; + padTop:uint; + padBottom:uint; + strideX:uint; + strideY:uint; + biasEnabled:bool = false; + dataLayout:DataLayout = NCHW; +} + union Layer { ActivationLayer, AdditionLayer, @@ -603,7 +622,8 @@ union Layer { SwitchLayer, ConcatLayer, SpaceToDepthLayer, - PreluLayer + PreluLayer, + TransposeConvolution2dLayer } table AnyLayer { diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp index efadbb38ca..126247bb8c 100644 --- a/src/armnnSerializer/Serializer.cpp +++ b/src/armnnSerializer/Serializer.cpp @@ -952,7 +952,32 @@ void SerializerVisitor::VisitTransposeConvolution2dLayer( const armnn::Optional<armnn::ConstTensor>& biases, const char* name) { - throw UnimplementedException("SerializerVisitor::VisitTransposeConvolution2dLayer is not implemented"); + auto fbBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Convolution2d); + auto fbDescriptor = CreateTransposeConvolution2dDescriptor(m_flatBufferBuilder, + descriptor.m_PadLeft, + descriptor.m_PadRight, + descriptor.m_PadTop, + descriptor.m_PadBottom, + descriptor.m_StrideX, + descriptor.m_StrideY, + descriptor.m_BiasEnabled, + GetFlatBufferDataLayout(descriptor.m_DataLayout)); + + // weights & biases + auto fbWeightsConstTensorInfo = CreateConstTensorInfo(weights); + flatbuffers::Offset<serializer::ConstTensor> fbBiasesConstTensorInfo; + if (biases.has_value()) + { + fbBiasesConstTensorInfo = CreateConstTensorInfo(biases.value()); + } + + auto fbLayer = CreateTransposeConvolution2dLayer(m_flatBufferBuilder, + fbBaseLayer, + fbDescriptor, + fbWeightsConstTensorInfo, + fbBiasesConstTensorInfo); + + CreateAnyLayer(fbLayer.o, serializer::Layer::Layer_TransposeConvolution2dLayer); } fb::Offset<serializer::LayerBase> SerializerVisitor::CreateLayerBase(const IConnectableLayer* layer, diff --git a/src/armnnSerializer/SerializerSupport.md b/src/armnnSerializer/SerializerSupport.md index e19eb32639..99bc332ac8 100644 --- a/src/armnnSerializer/SerializerSupport.md +++ b/src/armnnSerializer/SerializerSupport.md @@ -45,6 +45,7 @@ The Arm NN SDK Serializer currently supports the following layers: * StridedSlice * Subtraction * Switch +* TransposeConvolution2d More machine learning layers will be supported in future releases. diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp index ddebd1435c..448778b118 100644 --- a/src/armnnSerializer/test/SerializerTests.cpp +++ b/src/armnnSerializer/test/SerializerTests.cpp @@ -2446,6 +2446,110 @@ BOOST_AUTO_TEST_CASE(SerializeSwitch) deserializedNetwork->Accept(verifier); } +BOOST_AUTO_TEST_CASE(SerializeTransposeConvolution2d) +{ + class TransposeConvolution2dLayerVerifier : public LayerVerifierBase + { + public: + TransposeConvolution2dLayerVerifier(const std::string& layerName, + const std::vector<armnn::TensorInfo>& inputInfos, + const std::vector<armnn::TensorInfo>& outputInfos, + const armnn::TransposeConvolution2dDescriptor& descriptor, + const armnn::ConstTensor& weights, + const armnn::Optional<armnn::ConstTensor>& biases) : + LayerVerifierBase(layerName, inputInfos, outputInfos), + m_Descriptor(descriptor), + m_Weights(weights), + m_Biases(biases) + {} + + void VisitTransposeConvolution2dLayer(const armnn::IConnectableLayer* layer, + const armnn::TransposeConvolution2dDescriptor& descriptor, + const armnn::ConstTensor& weights, + const armnn::Optional<armnn::ConstTensor>& biases, + const char* name) override + { + VerifyNameAndConnections(layer, name); + VerifyDescriptor(descriptor); + + // check weights + CompareConstTensor(weights, m_Weights); + + // check biases + BOOST_CHECK(biases.has_value() == descriptor.m_BiasEnabled); + BOOST_CHECK(m_Biases.has_value() == m_Descriptor.m_BiasEnabled); + + BOOST_CHECK(biases.has_value() == m_Biases.has_value()); + + if (biases.has_value() && m_Biases.has_value()) + { + CompareConstTensor(biases.value(), m_Biases.value()); + } + } + + private: + void VerifyDescriptor(const armnn::TransposeConvolution2dDescriptor& descriptor) + { + BOOST_CHECK(descriptor.m_PadLeft == m_Descriptor.m_PadLeft); + BOOST_CHECK(descriptor.m_PadRight == m_Descriptor.m_PadRight); + BOOST_CHECK(descriptor.m_PadTop == m_Descriptor.m_PadTop); + BOOST_CHECK(descriptor.m_PadBottom == m_Descriptor.m_PadBottom); + BOOST_CHECK(descriptor.m_StrideX == m_Descriptor.m_StrideX); + BOOST_CHECK(descriptor.m_StrideY == m_Descriptor.m_StrideY); + BOOST_CHECK(descriptor.m_BiasEnabled == m_Descriptor.m_BiasEnabled); + BOOST_CHECK(descriptor.m_DataLayout == m_Descriptor.m_DataLayout); + } + + armnn::TransposeConvolution2dDescriptor m_Descriptor; + armnn::ConstTensor m_Weights; + armnn::Optional<armnn::ConstTensor> m_Biases; + }; + + const std::string layerName("transposeConvolution2d"); + const armnn::TensorInfo inputInfo ({ 1, 7, 7, 1 }, armnn::DataType::Float32); + const armnn::TensorInfo outputInfo({ 1, 9, 9, 1 }, armnn::DataType::Float32); + + const armnn::TensorInfo weightsInfo({ 1, 3, 3, 1 }, armnn::DataType::Float32); + const armnn::TensorInfo biasesInfo ({ 1 }, armnn::DataType::Float32); + + std::vector<float> weightsData = GenerateRandomData<float>(weightsInfo.GetNumElements()); + armnn::ConstTensor weights(weightsInfo, weightsData); + + std::vector<float> biasesData = GenerateRandomData<float>(biasesInfo.GetNumElements()); + armnn::ConstTensor biases(biasesInfo, biasesData); + + armnn::TransposeConvolution2dDescriptor descriptor; + descriptor.m_PadLeft = 1; + descriptor.m_PadRight = 1; + descriptor.m_PadTop = 1; + descriptor.m_PadBottom = 1; + descriptor.m_StrideX = 1; + descriptor.m_StrideY = 1; + descriptor.m_BiasEnabled = true; + descriptor.m_DataLayout = armnn::DataLayout::NHWC; + + armnn::INetworkPtr network = armnn::INetwork::Create(); + armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0); + armnn::IConnectableLayer* const convLayer = + network->AddTransposeConvolution2dLayer(descriptor, + weights, + armnn::Optional<armnn::ConstTensor>(biases), + layerName.c_str()); + armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0); + + inputLayer->GetOutputSlot(0).Connect(convLayer->GetInputSlot(0)); + convLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); + + inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo); + convLayer->GetOutputSlot(0).SetTensorInfo(outputInfo); + + armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); + BOOST_CHECK(deserializedNetwork); + + TransposeConvolution2dLayerVerifier verifier(layerName, {inputInfo}, {outputInfo}, descriptor, weights, biases); + deserializedNetwork->Accept(verifier); +} + BOOST_AUTO_TEST_CASE(SerializeDeserializeNonLinearNetwork) { class ConstantLayerVerifier : public LayerVerifierBase |