diff options
Diffstat (limited to 'src/armnnSerializer/test')
-rw-r--r-- | src/armnnSerializer/test/SerializerTests.cpp | 84 |
1 files changed, 84 insertions, 0 deletions
diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp index 41f5d14ce3..d0586c988c 100644 --- a/src/armnnSerializer/test/SerializerTests.cpp +++ b/src/armnnSerializer/test/SerializerTests.cpp @@ -1595,4 +1595,88 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializeSplitter) {0, 1, 2}); } +BOOST_AUTO_TEST_CASE(SerializeDeserializeDetectionPostProcess) +{ + class VerifyDetectionPostProcessName : public armnn::LayerVisitorBase<armnn::VisitorNoThrowPolicy> + { + public: + void VisitDetectionPostProcessLayer(const armnn::IConnectableLayer* layer, + const armnn::DetectionPostProcessDescriptor& descriptor, + const armnn::ConstTensor& anchors, + const char* name) override + { + BOOST_TEST(name == "DetectionPostProcessLayer"); + } + }; + + const armnn::TensorInfo inputInfos[] = { + armnn::TensorInfo({ 1, 6, 4 }, armnn::DataType::Float32), + armnn::TensorInfo({ 1, 6, 3}, armnn::DataType::Float32) + }; + + const armnn::TensorInfo outputInfos[] = { + armnn::TensorInfo({ 1, 3, 4 }, armnn::DataType::Float32), + armnn::TensorInfo({ 1, 3 }, armnn::DataType::Float32), + armnn::TensorInfo({ 1, 3 }, armnn::DataType::Float32), + armnn::TensorInfo({ 1 }, armnn::DataType::Float32) + }; + + armnn::DetectionPostProcessDescriptor desc; + desc.m_UseRegularNms = true; + desc.m_MaxDetections = 3; + desc.m_MaxClassesPerDetection = 1; + desc.m_DetectionsPerClass =1; + desc.m_NmsScoreThreshold = 0.0; + desc.m_NmsIouThreshold = 0.5; + desc.m_NumClasses = 2; + desc.m_ScaleY = 10.0; + desc.m_ScaleX = 10.0; + desc.m_ScaleH = 5.0; + desc.m_ScaleW = 5.0; + + const armnn::TensorInfo anchorsInfo({ 6, 4 }, armnn::DataType::Float32); + const std::vector<float> anchorsData({ + 0.5f, 0.5f, 1.0f, 1.0f, + 0.5f, 0.5f, 1.0f, 1.0f, + 0.5f, 0.5f, 1.0f, 1.0f, + 0.5f, 10.5f, 1.0f, 1.0f, + 0.5f, 10.5f, 1.0f, 1.0f, + 0.5f, 100.5f, 1.0f, 1.0f + }); + armnn::ConstTensor anchors(anchorsInfo, anchorsData); + + armnn::INetworkPtr network = armnn::INetwork::Create(); + + armnn::IConnectableLayer* const detectionLayer = + network->AddDetectionPostProcessLayer(desc, anchors, "DetectionPostProcessLayer"); + + for (unsigned int i = 0; i < 2; i++) + { + armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(static_cast<int>(i)); + inputLayer->GetOutputSlot(0).Connect(detectionLayer->GetInputSlot(i)); + inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfos[i]); + } + + for (unsigned int i = 0; i < 4; i++) + { + armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(static_cast<int>(i)); + detectionLayer->GetOutputSlot(i).Connect(outputLayer->GetInputSlot(0)); + detectionLayer->GetOutputSlot(i).SetTensorInfo(outputInfos[i]); + } + + armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); + BOOST_CHECK(deserializedNetwork); + + VerifyDetectionPostProcessName nameChecker; + deserializedNetwork->Accept(nameChecker); + + CheckDeserializedNetworkAgainstOriginal<float>( + *network, + *deserializedNetwork, + {inputInfos[0].GetShape(), inputInfos[1].GetShape()}, + {outputInfos[0].GetShape(), outputInfos[1].GetShape(), outputInfos[2].GetShape(), outputInfos[3].GetShape()}, + {0, 1}, + {0, 1, 2, 3}); +} + BOOST_AUTO_TEST_SUITE_END() |