aboutsummaryrefslogtreecommitdiff
path: root/samples/ObjectDetection/test/PipelineTest.cpp
diff options
context:
space:
mode:
authorÉanna Ó Catháin <eanna.ocathain@arm.com>2020-09-14 17:36:49 +0100
committerJim Flynn <jim.flynn@arm.com>2020-09-14 18:40:01 +0000
commit919c14ef132986aa1514b2070ce6d19b5579a6ab (patch)
tree5c281e02a083768f65871cb861ab9b32ac7d8767 /samples/ObjectDetection/test/PipelineTest.cpp
parent589e3e81a86c83456580e112978bf7a0ed5f43ac (diff)
downloadarmnn-919c14ef132986aa1514b2070ce6d19b5579a6ab.tar.gz
MLECO-929 Add Object Detection sample application using the public ArmNN C++ API
Change-Id: I14aa1b4b726212cffbefd6687203f93f936fa872 Signed-off-by: Éanna Ó Catháin <eanna.ocathain@arm.com>
Diffstat (limited to 'samples/ObjectDetection/test/PipelineTest.cpp')
-rw-r--r--samples/ObjectDetection/test/PipelineTest.cpp60
1 files changed, 60 insertions, 0 deletions
diff --git a/samples/ObjectDetection/test/PipelineTest.cpp b/samples/ObjectDetection/test/PipelineTest.cpp
new file mode 100644
index 0000000000..289f44f5e9
--- /dev/null
+++ b/samples/ObjectDetection/test/PipelineTest.cpp
@@ -0,0 +1,60 @@
+//
+// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+#include <catch.hpp>
+#include <opencv2/opencv.hpp>
+#include <NetworkPipeline.hpp>
+#include "Types.hpp"
+
+static std::string GetResourceFilePath(const std::string& filename)
+{
+ std::string testResources = TEST_RESOURCE_DIR;
+ if (0 == testResources.size())
+ {
+ throw "Invalid test resources directory provided";
+ }
+ else
+ {
+ if(testResources.back() != '/')
+ {
+ return testResources + "/" + filename;
+ }
+ else
+ {
+ return testResources + filename;
+ }
+ }
+}
+
+TEST_CASE("Test Network Execution SSD_MOBILE")
+{
+ std::string testResources = TEST_RESOURCE_DIR;
+ REQUIRE(testResources != "");
+ // Create the network options
+ od::ODPipelineOptions options;
+ options.m_ModelFilePath = GetResourceFilePath("detect.tflite");
+ options.m_ModelName = "SSD_MOBILE";
+ options.m_backends = {"CpuAcc", "CpuRef"};
+
+ od::IPipelinePtr objectDetectionPipeline = od::CreatePipeline(options);
+
+ od::InferenceResults results;
+ cv::Mat processed;
+ cv::Mat inputFrame = cv::imread(GetResourceFilePath("basketball1.png"), cv::IMREAD_COLOR);
+ cv::cvtColor(inputFrame, inputFrame, cv::COLOR_BGR2RGB);
+
+ objectDetectionPipeline->PreProcessing(inputFrame, processed);
+
+ CHECK(processed.type() == CV_8UC3);
+ CHECK(processed.cols == 300);
+ CHECK(processed.rows == 300);
+
+ objectDetectionPipeline->Inference(processed, results);
+ objectDetectionPipeline->PostProcessing(results,
+ [](od::DetectedObjects detects) -> void {
+ CHECK(detects.size() == 2);
+ CHECK(detects[0].GetLabel() == "0");
+ });
+
+}