diff options
Diffstat (limited to 'samples/ObjectDetection/test')
-rw-r--r-- | samples/ObjectDetection/test/BoundingBoxTests.cpp | 177 | ||||
-rw-r--r-- | samples/ObjectDetection/test/FrameReaderTest.cpp | 103 | ||||
-rw-r--r-- | samples/ObjectDetection/test/ImageUtilsTest.cpp | 128 | ||||
-rw-r--r-- | samples/ObjectDetection/test/NMSTests.cpp | 90 | ||||
-rw-r--r-- | samples/ObjectDetection/test/PipelineTest.cpp | 60 |
5 files changed, 558 insertions, 0 deletions
diff --git a/samples/ObjectDetection/test/BoundingBoxTests.cpp b/samples/ObjectDetection/test/BoundingBoxTests.cpp new file mode 100644 index 0000000000..a8ed29a977 --- /dev/null +++ b/samples/ObjectDetection/test/BoundingBoxTests.cpp @@ -0,0 +1,177 @@ +// +// Copyright © 2020 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include <catch.hpp> +#include "BoundingBox.hpp" + +namespace +{ + static constexpr unsigned int s_X = 100u; + static constexpr unsigned int s_Y = 200u; + static constexpr unsigned int s_W = 300u; + static constexpr unsigned int s_H = 400u; +} // anonymous namespace + +TEST_CASE("BoundingBoxTest_Default") +{ + od::BoundingBox boundingBox; + + REQUIRE(boundingBox.GetX() == 0u); + REQUIRE(boundingBox.GetY() == 0u); + REQUIRE(boundingBox.GetWidth() == 0u); + REQUIRE(boundingBox.GetHeight() == 0u); +} + +TEST_CASE("BoundingBoxTest_Custom") +{ + od::BoundingBox boundingBox(s_X, s_Y, s_W, s_H); + + REQUIRE(boundingBox.GetX() == s_X); + REQUIRE(boundingBox.GetY() == s_Y); + REQUIRE(boundingBox.GetWidth() == s_W); + REQUIRE(boundingBox.GetHeight() == s_H); +} + +TEST_CASE("BoundingBoxTest_Setters") +{ + od::BoundingBox boundingBox; + + boundingBox.SetX(s_X); + boundingBox.SetY(s_Y); + boundingBox.SetWidth(s_W); + boundingBox.SetHeight(s_H); + + REQUIRE(boundingBox.GetX() == s_X); + REQUIRE(boundingBox.GetY() == s_Y); + REQUIRE(boundingBox.GetWidth() == s_W); + REQUIRE(boundingBox.GetHeight() == s_H); +} + +static inline bool AreBoxesEqual(od::BoundingBox& b1, od::BoundingBox& b2) +{ + return (b1.GetX() == b2.GetX() && b1.GetY() == b2.GetY() && + b1.GetWidth() == b2.GetWidth() && b1.GetHeight() == b2.GetHeight()); +} + +TEST_CASE("BoundingBoxTest_GetValidBoundingBox") +{ + od::BoundingBox boxIn { 0, 0, 10, 20 }; + od::BoundingBox boxOut; + + WHEN("Limiting box is completely within the input box") + { + od::BoundingBox boxLmt{ 1, 1, 9, 18 }; + GetValidBoundingBox(boxIn, boxOut, boxLmt); + REQUIRE(AreBoxesEqual(boxLmt,boxOut)); + } + + WHEN("Limiting box cuts off the top and left") + { + od::BoundingBox boxLmt{ 1, 1, 10, 20 }; + od::BoundingBox boxExp{ 1, 1, 9, 19 }; + GetValidBoundingBox(boxIn, boxOut, boxLmt); + REQUIRE(AreBoxesEqual(boxExp, boxOut)); + } + + WHEN("Limiting box cuts off the bottom") + { + od::BoundingBox boxLmt{ 0, 0, 10, 19 }; + GetValidBoundingBox(boxIn, boxOut, boxLmt); + REQUIRE(AreBoxesEqual(boxLmt, boxOut)); + } + + WHEN("Limiting box cuts off the right") + { + od::BoundingBox boxLmt{ 0, 0, 9, 20 }; + GetValidBoundingBox(boxIn, boxOut, boxLmt); + REQUIRE(AreBoxesEqual(boxLmt, boxOut)); + } + + WHEN("Limiting box cuts off the bottom and right") + { + od::BoundingBox boxLmt{ 0, 0, 9, 19 }; + GetValidBoundingBox(boxIn, boxOut, boxLmt); + REQUIRE(AreBoxesEqual(boxLmt, boxOut)); + } + + WHEN("Limiting box cuts off the bottom and left") + { + od::BoundingBox boxLmt{ 1, 0, 10, 19 }; + od::BoundingBox boxExp{ 1, 0, 9, 19 }; + GetValidBoundingBox(boxIn, boxOut, boxLmt); + REQUIRE(AreBoxesEqual(boxExp, boxOut)); + } + + WHEN("Limiting box does not impose any limit") + { + od::BoundingBox boxLmt{ 0, 0, 10, 20 }; + GetValidBoundingBox(boxIn, boxOut, boxLmt); + REQUIRE(AreBoxesEqual(boxIn, boxOut)); + } + + WHEN("Limiting box zeros out the width") + { + od::BoundingBox boxLmt{ 0, 0, 0, 20 }; + od::BoundingBox boxExp{ 0, 0, 0, 0 }; + GetValidBoundingBox(boxIn, boxOut, boxLmt); + REQUIRE(AreBoxesEqual(boxExp, boxOut)); + } + + WHEN("Limiting box zeros out the height") + { + od::BoundingBox boxLmt{ 0, 0, 10, 0 }; + od::BoundingBox boxExp{ 0, 0, 0, 0 }; + GetValidBoundingBox(boxIn, boxOut, boxLmt); + REQUIRE(AreBoxesEqual(boxExp, boxOut)); + } + + WHEN("Limiting box with negative starts - top and left with 1 sq pixel cut-off") + { + od::BoundingBox boxLmt{ -1, -1, 10, 20 }; + od::BoundingBox boxExp{ 0, 0, 9, 19 }; + GetValidBoundingBox(boxIn, boxOut, boxLmt); + REQUIRE(AreBoxesEqual(boxExp, boxOut)); + } + + WHEN("Limiting box with negative starts - top and left with full overlap") + { + od::BoundingBox boxLmt{ -1, -1, 11, 21 }; + GetValidBoundingBox(boxIn, boxOut, boxLmt); + REQUIRE(AreBoxesEqual(boxIn, boxOut)); + } + + WHEN("Limiting box with zero overlap") + { + od::BoundingBox boxLmt{-10,-20, 10, 20 }; + od::BoundingBox boxExp{ 0, 0, 0, 0 }; + GetValidBoundingBox(boxIn, boxOut, boxLmt); + REQUIRE(AreBoxesEqual(boxExp, boxOut)); + } + + WHEN("Limiting box with one square pixel overlap") + { + od::BoundingBox boxLmt{-9,-19, 10, 20 }; + od::BoundingBox boxExp{ 0, 0, 1, 1 }; + GetValidBoundingBox(boxIn, boxOut, boxLmt); + REQUIRE(AreBoxesEqual(boxExp, boxOut)); + } + + WHEN("Limiting box with unrealistically high values in positive quadrant") + { + od::BoundingBox boxLmt{INT32_MAX, INT32_MAX, UINT32_MAX, UINT32_MAX }; + od::BoundingBox boxExp{ 0, 0, 0, 0 }; + GetValidBoundingBox(boxIn, boxOut, boxLmt); + REQUIRE(AreBoxesEqual(boxExp, boxOut)); + } + + /* This should actually return a valid bounding box, currently not handled. */ + WHEN("Limiting box with unrealistic values spanning 32 bit space") + { + od::BoundingBox boxLmt{-(INT32_MAX), -(INT32_MAX), UINT32_MAX, UINT32_MAX}; + od::BoundingBox boxExp{ 0, 0, 0, 0 }; + GetValidBoundingBox(boxIn, boxOut, boxLmt); + REQUIRE(AreBoxesEqual(boxExp, boxOut)); + } +}
\ No newline at end of file diff --git a/samples/ObjectDetection/test/FrameReaderTest.cpp b/samples/ObjectDetection/test/FrameReaderTest.cpp new file mode 100644 index 0000000000..a4bda227b3 --- /dev/null +++ b/samples/ObjectDetection/test/FrameReaderTest.cpp @@ -0,0 +1,103 @@ +// +// Copyright © 2020 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#define CATCH_CONFIG_MAIN + +#include <catch.hpp> +#include <opencv2/opencv.hpp> + +#include "IFrameReader.hpp" +#include "CvVideoFrameReader.hpp" + +SCENARIO("Read frames from video file using CV frame reader", "[framereader]") { + + GIVEN("a valid video file") { + + std::string testResources = TEST_RESOURCE_DIR; + REQUIRE(testResources != ""); + std::string file = testResources + "/" + "Megamind.avi"; + WHEN("Frame reader is initialised") { + + od::CvVideoFrameReader reader; + THEN("no exception is thrown") { + reader.Init(file); + + AND_WHEN("when source parameters are read") { + + auto fps = reader.GetSourceFps(); + auto height = reader.GetSourceHeight(); + auto width = reader.GetSourceWidth(); + auto encoding = reader.GetSourceEncoding(); + auto framesCount = reader.GetFrameCount(); + + THEN("they are aligned with video file") { + + REQUIRE(height == 528); + REQUIRE(width == 720); + REQUIRE(encoding == "XVID"); + REQUIRE(fps == 23.976); + REQUIRE(framesCount == 270); + } + + } + + AND_WHEN("frame is read") { + auto framePtr = reader.ReadFrame(); + + THEN("it is not a NULL pointer") { + REQUIRE(framePtr != nullptr); + } + + AND_THEN("it is not empty") { + REQUIRE(!framePtr->empty()); + REQUIRE(!reader.IsExhausted(framePtr)); + } + } + + AND_WHEN("all frames were read from the file") { + + for (int i = 0; i < 270; i++) { + auto framePtr = reader.ReadFrame(); + } + + THEN("last + 1 frame is empty") { + auto framePtr = reader.ReadFrame(); + + REQUIRE(framePtr->empty()); + REQUIRE(reader.IsExhausted(framePtr)); + } + + } + + AND_WHEN("frames are read from the file, pointers point to the different objects") { + + auto framePtr = reader.ReadFrame(); + + cv::Mat *frame = framePtr.get(); + + for (int i = 0; i < 30; i++) { + REQUIRE(frame != reader.ReadFrame().get()); + } + + } + } + } + } + + GIVEN("an invalid video file") { + + std::string file = "nosuchfile.avi"; + + WHEN("Frame reader is initialised") { + + od::CvVideoFrameReader reader; + + THEN("exception is thrown") { + REQUIRE_THROWS(reader.Init(file)); + } + } + + } +}
\ No newline at end of file diff --git a/samples/ObjectDetection/test/ImageUtilsTest.cpp b/samples/ObjectDetection/test/ImageUtilsTest.cpp new file mode 100644 index 0000000000..e486ae192b --- /dev/null +++ b/samples/ObjectDetection/test/ImageUtilsTest.cpp @@ -0,0 +1,128 @@ +// +// Copyright © 2020 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// +#include <catch.hpp> +#include <opencv2/opencv.hpp> +#include "ImageUtils.hpp" +#include "Types.hpp" + +std::vector<std::tuple<int, int>> GetBoundingBoxPoints(std::vector<od::DetectedObject>& decodedResults, + cv::Mat imageMat) +{ + std::vector<std::tuple<int, int>> bboxes; + for(const od::DetectedObject& object : decodedResults) + { + const od::BoundingBox& bbox = object.GetBoundingBox(); + + if (bbox.GetX() + bbox.GetWidth() > imageMat.cols) + { + for (int y = bbox.GetY(); y < bbox.GetY() + bbox.GetHeight(); ++y) + { + bboxes.emplace_back(std::tuple<int, int>{bbox.GetX(), y}); + } + + for (int x = bbox.GetX(); x < imageMat.cols; ++x) + { + bboxes.emplace_back(std::tuple<int, int>{x, bbox.GetY() + bbox.GetHeight() - 1}); + } + + for (int y = bbox.GetY(); y < bbox.GetY() + bbox.GetHeight(); ++y) + { + bboxes.emplace_back(std::tuple<int, int>{imageMat.cols - 1, y}); + } + } + else if (bbox.GetY() + bbox.GetHeight() > imageMat.rows) + { + for (int y = bbox.GetY(); y < imageMat.rows; ++y) + { + bboxes.emplace_back(std::tuple<int, int>{bbox.GetX(), y}); + } + + for (int x = bbox.GetX(); x < bbox.GetX() + bbox.GetWidth(); ++x) + { + bboxes.emplace_back(std::tuple<int, int>{x, imageMat.rows - 1}); + } + + for (int y = bbox.GetY(); y < imageMat.rows; ++y) + { + bboxes.emplace_back(std::tuple<int, int>{bbox.GetX() + bbox.GetWidth() - 1, y}); + } + } + else + { + for (int y = bbox.GetY(); y < bbox.GetY() + bbox.GetHeight(); ++y) + { + bboxes.emplace_back(std::tuple<int, int>{bbox.GetX(), y}); + } + + for (int x = bbox.GetX(); x < bbox.GetX() + bbox.GetWidth(); ++x) + { + bboxes.emplace_back(std::tuple<int, int>{x, bbox.GetY() + bbox.GetHeight() - 1}); + } + + for (int y = bbox.GetY(); y < bbox.GetY() + bbox.GetHeight(); ++y) + { + bboxes.emplace_back(std::tuple<int, int>{bbox.GetX() + bbox.GetWidth() - 1, y}); + } + } + } + return bboxes; +} + +static std::string GetResourceFilePath(std::string filename) +{ + std::string testResources = TEST_RESOURCE_DIR; + if (0 == testResources.size()) + { + throw "Invalid test resources directory provided"; + } + else + { + if(testResources.back() != '/') + { + return testResources + "/" + filename; + } + else + { + return testResources + filename; + } + } +} + +TEST_CASE("Test Adding Inference output to frame") +{ + //todo: re-write test to use static detections + + std::string testResources = TEST_RESOURCE_DIR; + REQUIRE(testResources != ""); + std::vector<std::tuple<std::string, od::BBoxColor>> labels; + + od::BBoxColor c + { + .colorCode = std::make_tuple (0, 0, 255) + }; + + auto bboxInfo = std::make_tuple ("person", c); + od::BoundingBox bbox(10, 10, 50, 50); + od::DetectedObject detection(0, "person", bbox, 0.75); + + labels.push_back(bboxInfo); + + od::DetectedObjects detections; + cv::Mat frame = cv::imread(GetResourceFilePath("basketball1.png"), cv::IMREAD_COLOR); + detections.push_back(detection); + + AddInferenceOutputToFrame(detections, frame, labels); + + std::vector<std::tuple<int, int>> bboxes = GetBoundingBoxPoints(detections, frame); + + // Check that every point is the expected color + for(std::tuple<int, int> tuple : bboxes) + { + cv::Point p(std::get<0>(tuple), std::get<1>(tuple)); + CHECK(static_cast<int>(frame.at<cv::Vec3b>(p)[0]) == 0); + CHECK(static_cast<int>(frame.at<cv::Vec3b>(p)[1]) == 0); + CHECK(static_cast<int>(frame.at<cv::Vec3b>(p)[2]) == 255); + } +} diff --git a/samples/ObjectDetection/test/NMSTests.cpp b/samples/ObjectDetection/test/NMSTests.cpp new file mode 100644 index 0000000000..d8b7c11ae1 --- /dev/null +++ b/samples/ObjectDetection/test/NMSTests.cpp @@ -0,0 +1,90 @@ +// +// Copyright © 2020 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include <catch.hpp> + +#include "NonMaxSuppression.hpp" + +TEST_CASE("Non_Max_Suppression_1") +{ + // Box with iou exactly 0.5. + od::DetectedObject detectedObject1; + detectedObject1.SetLabel("2"); + detectedObject1.SetScore(171); + detectedObject1.SetBoundingBox({0, 0, 150, 150}); + + // Strongest detection. + od::DetectedObject detectedObject2; + detectedObject2.SetLabel("2"); + detectedObject2.SetScore(230); + detectedObject2.SetBoundingBox({0, 75, 150, 75}); + + // Weaker detection with same coordinates of strongest. + od::DetectedObject detectedObject3; + detectedObject3.SetLabel("2"); + detectedObject3.SetScore(20); + detectedObject3.SetBoundingBox({0, 75, 150, 75}); + + // Detection not overlapping strongest. + od::DetectedObject detectedObject4; + detectedObject4.SetLabel("2"); + detectedObject4.SetScore(222); + detectedObject4.SetBoundingBox({0, 0, 50, 50}); + + // Small detection inside strongest. + od::DetectedObject detectedObject5; + detectedObject5.SetLabel("2"); + detectedObject5.SetScore(201); + detectedObject5.SetBoundingBox({100, 100, 20, 20}); + + // Box with iou exactly 0.5 but different label. + od::DetectedObject detectedObject6; + detectedObject6.SetLabel("1"); + detectedObject6.SetScore(75); + detectedObject6.SetBoundingBox({0, 0, 150, 150}); + + od::DetectedObjects expectedResults {detectedObject1, + detectedObject2, + detectedObject3, + detectedObject4, + detectedObject5, + detectedObject6}; + + auto sorted = od::NonMaxSuppression(expectedResults, 0.49); + + // 1st and 3rd detection should be suppressed. + REQUIRE(sorted.size() == 4); + + // Final detects should be ordered strongest to weakest. + REQUIRE(sorted[0] == 1); + REQUIRE(sorted[1] == 3); + REQUIRE(sorted[2] == 4); + REQUIRE(sorted[3] == 5); +} + +TEST_CASE("Non_Max_Suppression_2") +{ + // Real box examples. + od::DetectedObject detectedObject1; + detectedObject1.SetLabel("2"); + detectedObject1.SetScore(220); + detectedObject1.SetBoundingBox({430, 158, 68, 68}); + + od::DetectedObject detectedObject2; + detectedObject2.SetLabel("2"); + detectedObject2.SetScore(171); + detectedObject2.SetBoundingBox({438, 158, 68, 68}); + + od::DetectedObjects expectedResults {detectedObject1, + detectedObject2}; + + auto sorted = od::NonMaxSuppression(expectedResults, 0.5); + + // 2nd detect should be suppressed. + REQUIRE(sorted.size() == 1); + + // First detect should be strongest and kept. + REQUIRE(sorted[0] == 0); +} diff --git a/samples/ObjectDetection/test/PipelineTest.cpp b/samples/ObjectDetection/test/PipelineTest.cpp new file mode 100644 index 0000000000..289f44f5e9 --- /dev/null +++ b/samples/ObjectDetection/test/PipelineTest.cpp @@ -0,0 +1,60 @@ +// +// Copyright © 2020 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// +#include <catch.hpp> +#include <opencv2/opencv.hpp> +#include <NetworkPipeline.hpp> +#include "Types.hpp" + +static std::string GetResourceFilePath(const std::string& filename) +{ + std::string testResources = TEST_RESOURCE_DIR; + if (0 == testResources.size()) + { + throw "Invalid test resources directory provided"; + } + else + { + if(testResources.back() != '/') + { + return testResources + "/" + filename; + } + else + { + return testResources + filename; + } + } +} + +TEST_CASE("Test Network Execution SSD_MOBILE") +{ + std::string testResources = TEST_RESOURCE_DIR; + REQUIRE(testResources != ""); + // Create the network options + od::ODPipelineOptions options; + options.m_ModelFilePath = GetResourceFilePath("detect.tflite"); + options.m_ModelName = "SSD_MOBILE"; + options.m_backends = {"CpuAcc", "CpuRef"}; + + od::IPipelinePtr objectDetectionPipeline = od::CreatePipeline(options); + + od::InferenceResults results; + cv::Mat processed; + cv::Mat inputFrame = cv::imread(GetResourceFilePath("basketball1.png"), cv::IMREAD_COLOR); + cv::cvtColor(inputFrame, inputFrame, cv::COLOR_BGR2RGB); + + objectDetectionPipeline->PreProcessing(inputFrame, processed); + + CHECK(processed.type() == CV_8UC3); + CHECK(processed.cols == 300); + CHECK(processed.rows == 300); + + objectDetectionPipeline->Inference(processed, results); + objectDetectionPipeline->PostProcessing(results, + [](od::DetectedObjects detects) -> void { + CHECK(detects.size() == 2); + CHECK(detects[0].GetLabel() == "0"); + }); + +} |