diff options
Diffstat (limited to 'src/armnnSerializer/test/SerializerTests.cpp')
-rw-r--r-- | src/armnnSerializer/test/SerializerTests.cpp | 79 |
1 files changed, 77 insertions, 2 deletions
diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp index 069b9d699c..41f5d14ce3 100644 --- a/src/armnnSerializer/test/SerializerTests.cpp +++ b/src/armnnSerializer/test/SerializerTests.cpp @@ -473,8 +473,8 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeL2Normalization) VerifyL2NormalizationName nameChecker(l2NormLayerName); deserializedNetwork->Accept(nameChecker); - CheckDeserializedNetworkAgainstOriginal<float>(*network, - *deserializedNetwork, + CheckDeserializedNetworkAgainstOriginal<float>(*deserializedNetwork, + *network, { info.GetShape() }, { info.GetShape() }); } @@ -1520,4 +1520,79 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeMerger) {0, 1}); } +BOOST_AUTO_TEST_CASE(SerializeDeserializeSplitter) +{ + class VerifySplitterName : public armnn::LayerVisitorBase<armnn::VisitorNoThrowPolicy> + { + public: + void VisitSplitterLayer(const armnn::IConnectableLayer*, + const armnn::ViewsDescriptor& viewsDescriptor, + const char* name) override + { + BOOST_TEST(name == "splitter"); + } + }; + + unsigned int numViews = 3; + unsigned int numDimensions = 4; + unsigned int inputShape[] = {1,18, 4, 4}; + unsigned int outputShape[] = {1, 6, 4, 4}; + + auto inputTensorInfo = armnn::TensorInfo(numDimensions, inputShape, armnn::DataType::Float32); + auto outputTensorInfo = armnn::TensorInfo(numDimensions, outputShape, armnn::DataType::Float32); + + // This is modelled on how the caffe parser sets up a splitter layer to partition an input + // along dimension one. + unsigned int splitterDimSizes[4] = {static_cast<unsigned int>(inputShape[0]), + static_cast<unsigned int>(inputShape[1]), + static_cast<unsigned int>(inputShape[2]), + static_cast<unsigned int>(inputShape[3])}; + splitterDimSizes[1] /= numViews; + armnn::ViewsDescriptor desc(numViews, numDimensions); + + for (unsigned int g = 0; g < numViews; ++g) + { + desc.SetViewOriginCoord(g, 1, splitterDimSizes[1] * g); + + // Set the size of the views. + for (unsigned int dimIdx=0; dimIdx < 4; dimIdx++) + { + desc.SetViewSize(g, dimIdx, splitterDimSizes[dimIdx]); + } + } + + const char* splitterLayerName = "splitter"; + + armnn::INetworkPtr network = armnn::INetwork::Create(); + armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0); + armnn::IConnectableLayer* const splitterLayer = network->AddSplitterLayer(desc, splitterLayerName); + armnn::IConnectableLayer* const outputLayer0 = network->AddOutputLayer(0); + armnn::IConnectableLayer* const outputLayer1 = network->AddOutputLayer(1); + armnn::IConnectableLayer* const outputLayer2 = network->AddOutputLayer(2); + + inputLayer->GetOutputSlot(0).Connect(splitterLayer->GetInputSlot(0)); + inputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo); + splitterLayer->GetOutputSlot(0).Connect(outputLayer0->GetInputSlot(0)); + splitterLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + splitterLayer->GetOutputSlot(1).Connect(outputLayer1->GetInputSlot(0)); + splitterLayer->GetOutputSlot(1).SetTensorInfo(outputTensorInfo); + splitterLayer->GetOutputSlot(2).Connect(outputLayer2->GetInputSlot(0)); + splitterLayer->GetOutputSlot(2).SetTensorInfo(outputTensorInfo); + + armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); + BOOST_CHECK(deserializedNetwork); + + VerifySplitterName nameChecker; + deserializedNetwork->Accept(nameChecker); + + CheckDeserializedNetworkAgainstOriginal<float>(*deserializedNetwork, + *network, + {inputTensorInfo.GetShape()}, + {outputTensorInfo.GetShape(), + outputTensorInfo.GetShape(), + outputTensorInfo.GetShape()}, + {0}, + {0, 1, 2}); +} + BOOST_AUTO_TEST_SUITE_END() |