From 919c14ef132986aa1514b2070ce6d19b5579a6ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89anna=20=C3=93=20Cath=C3=A1in?= Date: Mon, 14 Sep 2020 17:36:49 +0100 Subject: MLECO-929 Add Object Detection sample application using the public ArmNN C++ API MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Change-Id: I14aa1b4b726212cffbefd6687203f93f936fa872 Signed-off-by: Éanna Ó Catháin --- samples/ObjectDetection/test/PipelineTest.cpp | 60 +++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 samples/ObjectDetection/test/PipelineTest.cpp (limited to 'samples/ObjectDetection/test/PipelineTest.cpp') diff --git a/samples/ObjectDetection/test/PipelineTest.cpp b/samples/ObjectDetection/test/PipelineTest.cpp new file mode 100644 index 0000000000..289f44f5e9 --- /dev/null +++ b/samples/ObjectDetection/test/PipelineTest.cpp @@ -0,0 +1,60 @@ +// +// Copyright © 2020 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// +#include +#include +#include +#include "Types.hpp" + +static std::string GetResourceFilePath(const std::string& filename) +{ + std::string testResources = TEST_RESOURCE_DIR; + if (0 == testResources.size()) + { + throw "Invalid test resources directory provided"; + } + else + { + if(testResources.back() != '/') + { + return testResources + "/" + filename; + } + else + { + return testResources + filename; + } + } +} + +TEST_CASE("Test Network Execution SSD_MOBILE") +{ + std::string testResources = TEST_RESOURCE_DIR; + REQUIRE(testResources != ""); + // Create the network options + od::ODPipelineOptions options; + options.m_ModelFilePath = GetResourceFilePath("detect.tflite"); + options.m_ModelName = "SSD_MOBILE"; + options.m_backends = {"CpuAcc", "CpuRef"}; + + od::IPipelinePtr objectDetectionPipeline = od::CreatePipeline(options); + + od::InferenceResults results; + cv::Mat processed; + cv::Mat inputFrame = cv::imread(GetResourceFilePath("basketball1.png"), cv::IMREAD_COLOR); + cv::cvtColor(inputFrame, inputFrame, cv::COLOR_BGR2RGB); + + objectDetectionPipeline->PreProcessing(inputFrame, processed); + + CHECK(processed.type() == CV_8UC3); + CHECK(processed.cols == 300); + CHECK(processed.rows == 300); + + objectDetectionPipeline->Inference(processed, results); + objectDetectionPipeline->PostProcessing(results, + [](od::DetectedObjects detects) -> void { + CHECK(detects.size() == 2); + CHECK(detects[0].GetLabel() == "0"); + }); + +} -- cgit v1.2.1