aboutsummaryrefslogtreecommitdiff
path: root/samples/ObjectDetection/test
diff options
context:
space:
mode:
Diffstat (limited to 'samples/ObjectDetection/test')
-rw-r--r--samples/ObjectDetection/test/BoundingBoxTests.cpp177
-rw-r--r--samples/ObjectDetection/test/FrameReaderTest.cpp103
-rw-r--r--samples/ObjectDetection/test/ImageUtilsTest.cpp128
-rw-r--r--samples/ObjectDetection/test/NMSTests.cpp90
-rw-r--r--samples/ObjectDetection/test/PipelineTest.cpp60
5 files changed, 558 insertions, 0 deletions
diff --git a/samples/ObjectDetection/test/BoundingBoxTests.cpp b/samples/ObjectDetection/test/BoundingBoxTests.cpp
new file mode 100644
index 0000000000..a8ed29a977
--- /dev/null
+++ b/samples/ObjectDetection/test/BoundingBoxTests.cpp
@@ -0,0 +1,177 @@
+//
+// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include <catch.hpp>
+#include "BoundingBox.hpp"
+
+namespace
+{
+ static constexpr unsigned int s_X = 100u;
+ static constexpr unsigned int s_Y = 200u;
+ static constexpr unsigned int s_W = 300u;
+ static constexpr unsigned int s_H = 400u;
+} // anonymous namespace
+
+TEST_CASE("BoundingBoxTest_Default")
+{
+ od::BoundingBox boundingBox;
+
+ REQUIRE(boundingBox.GetX() == 0u);
+ REQUIRE(boundingBox.GetY() == 0u);
+ REQUIRE(boundingBox.GetWidth() == 0u);
+ REQUIRE(boundingBox.GetHeight() == 0u);
+}
+
+TEST_CASE("BoundingBoxTest_Custom")
+{
+ od::BoundingBox boundingBox(s_X, s_Y, s_W, s_H);
+
+ REQUIRE(boundingBox.GetX() == s_X);
+ REQUIRE(boundingBox.GetY() == s_Y);
+ REQUIRE(boundingBox.GetWidth() == s_W);
+ REQUIRE(boundingBox.GetHeight() == s_H);
+}
+
+TEST_CASE("BoundingBoxTest_Setters")
+{
+ od::BoundingBox boundingBox;
+
+ boundingBox.SetX(s_X);
+ boundingBox.SetY(s_Y);
+ boundingBox.SetWidth(s_W);
+ boundingBox.SetHeight(s_H);
+
+ REQUIRE(boundingBox.GetX() == s_X);
+ REQUIRE(boundingBox.GetY() == s_Y);
+ REQUIRE(boundingBox.GetWidth() == s_W);
+ REQUIRE(boundingBox.GetHeight() == s_H);
+}
+
+static inline bool AreBoxesEqual(od::BoundingBox& b1, od::BoundingBox& b2)
+{
+ return (b1.GetX() == b2.GetX() && b1.GetY() == b2.GetY() &&
+ b1.GetWidth() == b2.GetWidth() && b1.GetHeight() == b2.GetHeight());
+}
+
+TEST_CASE("BoundingBoxTest_GetValidBoundingBox")
+{
+ od::BoundingBox boxIn { 0, 0, 10, 20 };
+ od::BoundingBox boxOut;
+
+ WHEN("Limiting box is completely within the input box")
+ {
+ od::BoundingBox boxLmt{ 1, 1, 9, 18 };
+ GetValidBoundingBox(boxIn, boxOut, boxLmt);
+ REQUIRE(AreBoxesEqual(boxLmt,boxOut));
+ }
+
+ WHEN("Limiting box cuts off the top and left")
+ {
+ od::BoundingBox boxLmt{ 1, 1, 10, 20 };
+ od::BoundingBox boxExp{ 1, 1, 9, 19 };
+ GetValidBoundingBox(boxIn, boxOut, boxLmt);
+ REQUIRE(AreBoxesEqual(boxExp, boxOut));
+ }
+
+ WHEN("Limiting box cuts off the bottom")
+ {
+ od::BoundingBox boxLmt{ 0, 0, 10, 19 };
+ GetValidBoundingBox(boxIn, boxOut, boxLmt);
+ REQUIRE(AreBoxesEqual(boxLmt, boxOut));
+ }
+
+ WHEN("Limiting box cuts off the right")
+ {
+ od::BoundingBox boxLmt{ 0, 0, 9, 20 };
+ GetValidBoundingBox(boxIn, boxOut, boxLmt);
+ REQUIRE(AreBoxesEqual(boxLmt, boxOut));
+ }
+
+ WHEN("Limiting box cuts off the bottom and right")
+ {
+ od::BoundingBox boxLmt{ 0, 0, 9, 19 };
+ GetValidBoundingBox(boxIn, boxOut, boxLmt);
+ REQUIRE(AreBoxesEqual(boxLmt, boxOut));
+ }
+
+ WHEN("Limiting box cuts off the bottom and left")
+ {
+ od::BoundingBox boxLmt{ 1, 0, 10, 19 };
+ od::BoundingBox boxExp{ 1, 0, 9, 19 };
+ GetValidBoundingBox(boxIn, boxOut, boxLmt);
+ REQUIRE(AreBoxesEqual(boxExp, boxOut));
+ }
+
+ WHEN("Limiting box does not impose any limit")
+ {
+ od::BoundingBox boxLmt{ 0, 0, 10, 20 };
+ GetValidBoundingBox(boxIn, boxOut, boxLmt);
+ REQUIRE(AreBoxesEqual(boxIn, boxOut));
+ }
+
+ WHEN("Limiting box zeros out the width")
+ {
+ od::BoundingBox boxLmt{ 0, 0, 0, 20 };
+ od::BoundingBox boxExp{ 0, 0, 0, 0 };
+ GetValidBoundingBox(boxIn, boxOut, boxLmt);
+ REQUIRE(AreBoxesEqual(boxExp, boxOut));
+ }
+
+ WHEN("Limiting box zeros out the height")
+ {
+ od::BoundingBox boxLmt{ 0, 0, 10, 0 };
+ od::BoundingBox boxExp{ 0, 0, 0, 0 };
+ GetValidBoundingBox(boxIn, boxOut, boxLmt);
+ REQUIRE(AreBoxesEqual(boxExp, boxOut));
+ }
+
+ WHEN("Limiting box with negative starts - top and left with 1 sq pixel cut-off")
+ {
+ od::BoundingBox boxLmt{ -1, -1, 10, 20 };
+ od::BoundingBox boxExp{ 0, 0, 9, 19 };
+ GetValidBoundingBox(boxIn, boxOut, boxLmt);
+ REQUIRE(AreBoxesEqual(boxExp, boxOut));
+ }
+
+ WHEN("Limiting box with negative starts - top and left with full overlap")
+ {
+ od::BoundingBox boxLmt{ -1, -1, 11, 21 };
+ GetValidBoundingBox(boxIn, boxOut, boxLmt);
+ REQUIRE(AreBoxesEqual(boxIn, boxOut));
+ }
+
+ WHEN("Limiting box with zero overlap")
+ {
+ od::BoundingBox boxLmt{-10,-20, 10, 20 };
+ od::BoundingBox boxExp{ 0, 0, 0, 0 };
+ GetValidBoundingBox(boxIn, boxOut, boxLmt);
+ REQUIRE(AreBoxesEqual(boxExp, boxOut));
+ }
+
+ WHEN("Limiting box with one square pixel overlap")
+ {
+ od::BoundingBox boxLmt{-9,-19, 10, 20 };
+ od::BoundingBox boxExp{ 0, 0, 1, 1 };
+ GetValidBoundingBox(boxIn, boxOut, boxLmt);
+ REQUIRE(AreBoxesEqual(boxExp, boxOut));
+ }
+
+ WHEN("Limiting box with unrealistically high values in positive quadrant")
+ {
+ od::BoundingBox boxLmt{INT32_MAX, INT32_MAX, UINT32_MAX, UINT32_MAX };
+ od::BoundingBox boxExp{ 0, 0, 0, 0 };
+ GetValidBoundingBox(boxIn, boxOut, boxLmt);
+ REQUIRE(AreBoxesEqual(boxExp, boxOut));
+ }
+
+ /* This should actually return a valid bounding box, currently not handled. */
+ WHEN("Limiting box with unrealistic values spanning 32 bit space")
+ {
+ od::BoundingBox boxLmt{-(INT32_MAX), -(INT32_MAX), UINT32_MAX, UINT32_MAX};
+ od::BoundingBox boxExp{ 0, 0, 0, 0 };
+ GetValidBoundingBox(boxIn, boxOut, boxLmt);
+ REQUIRE(AreBoxesEqual(boxExp, boxOut));
+ }
+} \ No newline at end of file
diff --git a/samples/ObjectDetection/test/FrameReaderTest.cpp b/samples/ObjectDetection/test/FrameReaderTest.cpp
new file mode 100644
index 0000000000..a4bda227b3
--- /dev/null
+++ b/samples/ObjectDetection/test/FrameReaderTest.cpp
@@ -0,0 +1,103 @@
+//
+// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#define CATCH_CONFIG_MAIN
+
+#include <catch.hpp>
+#include <opencv2/opencv.hpp>
+
+#include "IFrameReader.hpp"
+#include "CvVideoFrameReader.hpp"
+
+SCENARIO("Read frames from video file using CV frame reader", "[framereader]") {
+
+ GIVEN("a valid video file") {
+
+ std::string testResources = TEST_RESOURCE_DIR;
+ REQUIRE(testResources != "");
+ std::string file = testResources + "/" + "Megamind.avi";
+ WHEN("Frame reader is initialised") {
+
+ od::CvVideoFrameReader reader;
+ THEN("no exception is thrown") {
+ reader.Init(file);
+
+ AND_WHEN("when source parameters are read") {
+
+ auto fps = reader.GetSourceFps();
+ auto height = reader.GetSourceHeight();
+ auto width = reader.GetSourceWidth();
+ auto encoding = reader.GetSourceEncoding();
+ auto framesCount = reader.GetFrameCount();
+
+ THEN("they are aligned with video file") {
+
+ REQUIRE(height == 528);
+ REQUIRE(width == 720);
+ REQUIRE(encoding == "XVID");
+ REQUIRE(fps == 23.976);
+ REQUIRE(framesCount == 270);
+ }
+
+ }
+
+ AND_WHEN("frame is read") {
+ auto framePtr = reader.ReadFrame();
+
+ THEN("it is not a NULL pointer") {
+ REQUIRE(framePtr != nullptr);
+ }
+
+ AND_THEN("it is not empty") {
+ REQUIRE(!framePtr->empty());
+ REQUIRE(!reader.IsExhausted(framePtr));
+ }
+ }
+
+ AND_WHEN("all frames were read from the file") {
+
+ for (int i = 0; i < 270; i++) {
+ auto framePtr = reader.ReadFrame();
+ }
+
+ THEN("last + 1 frame is empty") {
+ auto framePtr = reader.ReadFrame();
+
+ REQUIRE(framePtr->empty());
+ REQUIRE(reader.IsExhausted(framePtr));
+ }
+
+ }
+
+ AND_WHEN("frames are read from the file, pointers point to the different objects") {
+
+ auto framePtr = reader.ReadFrame();
+
+ cv::Mat *frame = framePtr.get();
+
+ for (int i = 0; i < 30; i++) {
+ REQUIRE(frame != reader.ReadFrame().get());
+ }
+
+ }
+ }
+ }
+ }
+
+ GIVEN("an invalid video file") {
+
+ std::string file = "nosuchfile.avi";
+
+ WHEN("Frame reader is initialised") {
+
+ od::CvVideoFrameReader reader;
+
+ THEN("exception is thrown") {
+ REQUIRE_THROWS(reader.Init(file));
+ }
+ }
+
+ }
+} \ No newline at end of file
diff --git a/samples/ObjectDetection/test/ImageUtilsTest.cpp b/samples/ObjectDetection/test/ImageUtilsTest.cpp
new file mode 100644
index 0000000000..e486ae192b
--- /dev/null
+++ b/samples/ObjectDetection/test/ImageUtilsTest.cpp
@@ -0,0 +1,128 @@
+//
+// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+#include <catch.hpp>
+#include <opencv2/opencv.hpp>
+#include "ImageUtils.hpp"
+#include "Types.hpp"
+
+std::vector<std::tuple<int, int>> GetBoundingBoxPoints(std::vector<od::DetectedObject>& decodedResults,
+ cv::Mat imageMat)
+{
+ std::vector<std::tuple<int, int>> bboxes;
+ for(const od::DetectedObject& object : decodedResults)
+ {
+ const od::BoundingBox& bbox = object.GetBoundingBox();
+
+ if (bbox.GetX() + bbox.GetWidth() > imageMat.cols)
+ {
+ for (int y = bbox.GetY(); y < bbox.GetY() + bbox.GetHeight(); ++y)
+ {
+ bboxes.emplace_back(std::tuple<int, int>{bbox.GetX(), y});
+ }
+
+ for (int x = bbox.GetX(); x < imageMat.cols; ++x)
+ {
+ bboxes.emplace_back(std::tuple<int, int>{x, bbox.GetY() + bbox.GetHeight() - 1});
+ }
+
+ for (int y = bbox.GetY(); y < bbox.GetY() + bbox.GetHeight(); ++y)
+ {
+ bboxes.emplace_back(std::tuple<int, int>{imageMat.cols - 1, y});
+ }
+ }
+ else if (bbox.GetY() + bbox.GetHeight() > imageMat.rows)
+ {
+ for (int y = bbox.GetY(); y < imageMat.rows; ++y)
+ {
+ bboxes.emplace_back(std::tuple<int, int>{bbox.GetX(), y});
+ }
+
+ for (int x = bbox.GetX(); x < bbox.GetX() + bbox.GetWidth(); ++x)
+ {
+ bboxes.emplace_back(std::tuple<int, int>{x, imageMat.rows - 1});
+ }
+
+ for (int y = bbox.GetY(); y < imageMat.rows; ++y)
+ {
+ bboxes.emplace_back(std::tuple<int, int>{bbox.GetX() + bbox.GetWidth() - 1, y});
+ }
+ }
+ else
+ {
+ for (int y = bbox.GetY(); y < bbox.GetY() + bbox.GetHeight(); ++y)
+ {
+ bboxes.emplace_back(std::tuple<int, int>{bbox.GetX(), y});
+ }
+
+ for (int x = bbox.GetX(); x < bbox.GetX() + bbox.GetWidth(); ++x)
+ {
+ bboxes.emplace_back(std::tuple<int, int>{x, bbox.GetY() + bbox.GetHeight() - 1});
+ }
+
+ for (int y = bbox.GetY(); y < bbox.GetY() + bbox.GetHeight(); ++y)
+ {
+ bboxes.emplace_back(std::tuple<int, int>{bbox.GetX() + bbox.GetWidth() - 1, y});
+ }
+ }
+ }
+ return bboxes;
+}
+
+static std::string GetResourceFilePath(std::string filename)
+{
+ std::string testResources = TEST_RESOURCE_DIR;
+ if (0 == testResources.size())
+ {
+ throw "Invalid test resources directory provided";
+ }
+ else
+ {
+ if(testResources.back() != '/')
+ {
+ return testResources + "/" + filename;
+ }
+ else
+ {
+ return testResources + filename;
+ }
+ }
+}
+
+TEST_CASE("Test Adding Inference output to frame")
+{
+ //todo: re-write test to use static detections
+
+ std::string testResources = TEST_RESOURCE_DIR;
+ REQUIRE(testResources != "");
+ std::vector<std::tuple<std::string, od::BBoxColor>> labels;
+
+ od::BBoxColor c
+ {
+ .colorCode = std::make_tuple (0, 0, 255)
+ };
+
+ auto bboxInfo = std::make_tuple ("person", c);
+ od::BoundingBox bbox(10, 10, 50, 50);
+ od::DetectedObject detection(0, "person", bbox, 0.75);
+
+ labels.push_back(bboxInfo);
+
+ od::DetectedObjects detections;
+ cv::Mat frame = cv::imread(GetResourceFilePath("basketball1.png"), cv::IMREAD_COLOR);
+ detections.push_back(detection);
+
+ AddInferenceOutputToFrame(detections, frame, labels);
+
+ std::vector<std::tuple<int, int>> bboxes = GetBoundingBoxPoints(detections, frame);
+
+ // Check that every point is the expected color
+ for(std::tuple<int, int> tuple : bboxes)
+ {
+ cv::Point p(std::get<0>(tuple), std::get<1>(tuple));
+ CHECK(static_cast<int>(frame.at<cv::Vec3b>(p)[0]) == 0);
+ CHECK(static_cast<int>(frame.at<cv::Vec3b>(p)[1]) == 0);
+ CHECK(static_cast<int>(frame.at<cv::Vec3b>(p)[2]) == 255);
+ }
+}
diff --git a/samples/ObjectDetection/test/NMSTests.cpp b/samples/ObjectDetection/test/NMSTests.cpp
new file mode 100644
index 0000000000..d8b7c11ae1
--- /dev/null
+++ b/samples/ObjectDetection/test/NMSTests.cpp
@@ -0,0 +1,90 @@
+//
+// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include <catch.hpp>
+
+#include "NonMaxSuppression.hpp"
+
+TEST_CASE("Non_Max_Suppression_1")
+{
+ // Box with iou exactly 0.5.
+ od::DetectedObject detectedObject1;
+ detectedObject1.SetLabel("2");
+ detectedObject1.SetScore(171);
+ detectedObject1.SetBoundingBox({0, 0, 150, 150});
+
+ // Strongest detection.
+ od::DetectedObject detectedObject2;
+ detectedObject2.SetLabel("2");
+ detectedObject2.SetScore(230);
+ detectedObject2.SetBoundingBox({0, 75, 150, 75});
+
+ // Weaker detection with same coordinates of strongest.
+ od::DetectedObject detectedObject3;
+ detectedObject3.SetLabel("2");
+ detectedObject3.SetScore(20);
+ detectedObject3.SetBoundingBox({0, 75, 150, 75});
+
+ // Detection not overlapping strongest.
+ od::DetectedObject detectedObject4;
+ detectedObject4.SetLabel("2");
+ detectedObject4.SetScore(222);
+ detectedObject4.SetBoundingBox({0, 0, 50, 50});
+
+ // Small detection inside strongest.
+ od::DetectedObject detectedObject5;
+ detectedObject5.SetLabel("2");
+ detectedObject5.SetScore(201);
+ detectedObject5.SetBoundingBox({100, 100, 20, 20});
+
+ // Box with iou exactly 0.5 but different label.
+ od::DetectedObject detectedObject6;
+ detectedObject6.SetLabel("1");
+ detectedObject6.SetScore(75);
+ detectedObject6.SetBoundingBox({0, 0, 150, 150});
+
+ od::DetectedObjects expectedResults {detectedObject1,
+ detectedObject2,
+ detectedObject3,
+ detectedObject4,
+ detectedObject5,
+ detectedObject6};
+
+ auto sorted = od::NonMaxSuppression(expectedResults, 0.49);
+
+ // 1st and 3rd detection should be suppressed.
+ REQUIRE(sorted.size() == 4);
+
+ // Final detects should be ordered strongest to weakest.
+ REQUIRE(sorted[0] == 1);
+ REQUIRE(sorted[1] == 3);
+ REQUIRE(sorted[2] == 4);
+ REQUIRE(sorted[3] == 5);
+}
+
+TEST_CASE("Non_Max_Suppression_2")
+{
+ // Real box examples.
+ od::DetectedObject detectedObject1;
+ detectedObject1.SetLabel("2");
+ detectedObject1.SetScore(220);
+ detectedObject1.SetBoundingBox({430, 158, 68, 68});
+
+ od::DetectedObject detectedObject2;
+ detectedObject2.SetLabel("2");
+ detectedObject2.SetScore(171);
+ detectedObject2.SetBoundingBox({438, 158, 68, 68});
+
+ od::DetectedObjects expectedResults {detectedObject1,
+ detectedObject2};
+
+ auto sorted = od::NonMaxSuppression(expectedResults, 0.5);
+
+ // 2nd detect should be suppressed.
+ REQUIRE(sorted.size() == 1);
+
+ // First detect should be strongest and kept.
+ REQUIRE(sorted[0] == 0);
+}
diff --git a/samples/ObjectDetection/test/PipelineTest.cpp b/samples/ObjectDetection/test/PipelineTest.cpp
new file mode 100644
index 0000000000..289f44f5e9
--- /dev/null
+++ b/samples/ObjectDetection/test/PipelineTest.cpp
@@ -0,0 +1,60 @@
+//
+// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+#include <catch.hpp>
+#include <opencv2/opencv.hpp>
+#include <NetworkPipeline.hpp>
+#include "Types.hpp"
+
+static std::string GetResourceFilePath(const std::string& filename)
+{
+ std::string testResources = TEST_RESOURCE_DIR;
+ if (0 == testResources.size())
+ {
+ throw "Invalid test resources directory provided";
+ }
+ else
+ {
+ if(testResources.back() != '/')
+ {
+ return testResources + "/" + filename;
+ }
+ else
+ {
+ return testResources + filename;
+ }
+ }
+}
+
+TEST_CASE("Test Network Execution SSD_MOBILE")
+{
+ std::string testResources = TEST_RESOURCE_DIR;
+ REQUIRE(testResources != "");
+ // Create the network options
+ od::ODPipelineOptions options;
+ options.m_ModelFilePath = GetResourceFilePath("detect.tflite");
+ options.m_ModelName = "SSD_MOBILE";
+ options.m_backends = {"CpuAcc", "CpuRef"};
+
+ od::IPipelinePtr objectDetectionPipeline = od::CreatePipeline(options);
+
+ od::InferenceResults results;
+ cv::Mat processed;
+ cv::Mat inputFrame = cv::imread(GetResourceFilePath("basketball1.png"), cv::IMREAD_COLOR);
+ cv::cvtColor(inputFrame, inputFrame, cv::COLOR_BGR2RGB);
+
+ objectDetectionPipeline->PreProcessing(inputFrame, processed);
+
+ CHECK(processed.type() == CV_8UC3);
+ CHECK(processed.cols == 300);
+ CHECK(processed.rows == 300);
+
+ objectDetectionPipeline->Inference(processed, results);
+ objectDetectionPipeline->PostProcessing(results,
+ [](od::DetectedObjects detects) -> void {
+ CHECK(detects.size() == 2);
+ CHECK(detects[0].GetLabel() == "0");
+ });
+
+}