From 18ce338711fc3ea44a7731eac795964256beac6c Mon Sep 17 00:00:00 2001 From: Jim Flynn Date: Fri, 8 Mar 2019 11:08:30 +0000 Subject: IVGCVSW-2709 Serialize / de-serialize the Splitter layer * fixed typo in Ref Merger Workload comment * fixed typo in ViewsDescriptor comment * made the origins descriptor accessable in the ViewsDescriptor (needed for serialization) * based the unit test on the use of the splitter in the CaffeParser Change-Id: I3e716839adb4eee5a695633377b49e7e18ec2aa9 Signed-off-by: Ferran Balaguer Signed-off-by: Francis Murtagh Signed-off-by: Jim Flynn --- src/armnnSerializer/test/SerializerTests.cpp | 79 +++++++++++++++++++++++++++- 1 file changed, 77 insertions(+), 2 deletions(-) (limited to 'src/armnnSerializer/test') diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp index 069b9d699c..41f5d14ce3 100644 --- a/src/armnnSerializer/test/SerializerTests.cpp +++ b/src/armnnSerializer/test/SerializerTests.cpp @@ -473,8 +473,8 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeL2Normalization) VerifyL2NormalizationName nameChecker(l2NormLayerName); deserializedNetwork->Accept(nameChecker); - CheckDeserializedNetworkAgainstOriginal(*network, - *deserializedNetwork, + CheckDeserializedNetworkAgainstOriginal(*deserializedNetwork, + *network, { info.GetShape() }, { info.GetShape() }); } @@ -1520,4 +1520,79 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeMerger) {0, 1}); } +BOOST_AUTO_TEST_CASE(SerializeDeserializeSplitter) +{ + class VerifySplitterName : public armnn::LayerVisitorBase + { + public: + void VisitSplitterLayer(const armnn::IConnectableLayer*, + const armnn::ViewsDescriptor& viewsDescriptor, + const char* name) override + { + BOOST_TEST(name == "splitter"); + } + }; + + unsigned int numViews = 3; + unsigned int numDimensions = 4; + unsigned int inputShape[] = {1,18, 4, 4}; + unsigned int outputShape[] = {1, 6, 4, 4}; + + auto inputTensorInfo = armnn::TensorInfo(numDimensions, inputShape, armnn::DataType::Float32); + auto outputTensorInfo = armnn::TensorInfo(numDimensions, outputShape, armnn::DataType::Float32); + + // This is modelled on how the caffe parser sets up a splitter layer to partition an input + // along dimension one. + unsigned int splitterDimSizes[4] = {static_cast(inputShape[0]), + static_cast(inputShape[1]), + static_cast(inputShape[2]), + static_cast(inputShape[3])}; + splitterDimSizes[1] /= numViews; + armnn::ViewsDescriptor desc(numViews, numDimensions); + + for (unsigned int g = 0; g < numViews; ++g) + { + desc.SetViewOriginCoord(g, 1, splitterDimSizes[1] * g); + + // Set the size of the views. + for (unsigned int dimIdx=0; dimIdx < 4; dimIdx++) + { + desc.SetViewSize(g, dimIdx, splitterDimSizes[dimIdx]); + } + } + + const char* splitterLayerName = "splitter"; + + armnn::INetworkPtr network = armnn::INetwork::Create(); + armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0); + armnn::IConnectableLayer* const splitterLayer = network->AddSplitterLayer(desc, splitterLayerName); + armnn::IConnectableLayer* const outputLayer0 = network->AddOutputLayer(0); + armnn::IConnectableLayer* const outputLayer1 = network->AddOutputLayer(1); + armnn::IConnectableLayer* const outputLayer2 = network->AddOutputLayer(2); + + inputLayer->GetOutputSlot(0).Connect(splitterLayer->GetInputSlot(0)); + inputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo); + splitterLayer->GetOutputSlot(0).Connect(outputLayer0->GetInputSlot(0)); + splitterLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + splitterLayer->GetOutputSlot(1).Connect(outputLayer1->GetInputSlot(0)); + splitterLayer->GetOutputSlot(1).SetTensorInfo(outputTensorInfo); + splitterLayer->GetOutputSlot(2).Connect(outputLayer2->GetInputSlot(0)); + splitterLayer->GetOutputSlot(2).SetTensorInfo(outputTensorInfo); + + armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); + BOOST_CHECK(deserializedNetwork); + + VerifySplitterName nameChecker; + deserializedNetwork->Accept(nameChecker); + + CheckDeserializedNetworkAgainstOriginal(*deserializedNetwork, + *network, + {inputTensorInfo.GetShape()}, + {outputTensorInfo.GetShape(), + outputTensorInfo.GetShape(), + outputTensorInfo.GetShape()}, + {0}, + {0, 1, 2}); +} + BOOST_AUTO_TEST_SUITE_END() -- cgit v1.2.1