aboutsummaryrefslogtreecommitdiff
path: root/samples/ObjectDetection/src/NonMaxSuppression.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'samples/ObjectDetection/src/NonMaxSuppression.cpp')
-rw-r--r--samples/ObjectDetection/src/NonMaxSuppression.cpp92
1 files changed, 92 insertions, 0 deletions
diff --git a/samples/ObjectDetection/src/NonMaxSuppression.cpp b/samples/ObjectDetection/src/NonMaxSuppression.cpp
new file mode 100644
index 0000000000..7bcd9045a5
--- /dev/null
+++ b/samples/ObjectDetection/src/NonMaxSuppression.cpp
@@ -0,0 +1,92 @@
+//
+// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+#include "NonMaxSuppression.hpp"
+
+#include <algorithm>
+
+namespace od
+{
+
+static std::vector<unsigned int> GenerateRangeK(unsigned int k)
+{
+ std::vector<unsigned int> range(k);
+ std::iota(range.begin(), range.end(), 0);
+ return range;
+}
+
+
+/**
+* @brief Returns the intersection over union for two bounding boxes
+*
+* @param[in] First detect containing bounding box.
+* @param[in] Second detect containing bounding box.
+* @return Calculated intersection over union.
+*
+*/
+static double IntersectionOverUnion(DetectedObject& detect1, DetectedObject& detect2)
+{
+ uint32_t area1 = (detect1.GetBoundingBox().GetHeight() * detect1.GetBoundingBox().GetWidth());
+ uint32_t area2 = (detect2.GetBoundingBox().GetHeight() * detect2.GetBoundingBox().GetWidth());
+
+ float yMinIntersection = std::max(detect1.GetBoundingBox().GetY(), detect2.GetBoundingBox().GetY());
+ float xMinIntersection = std::max(detect1.GetBoundingBox().GetX(), detect2.GetBoundingBox().GetX());
+
+ float yMaxIntersection = std::min(detect1.GetBoundingBox().GetY() + detect1.GetBoundingBox().GetHeight(),
+ detect2.GetBoundingBox().GetY() + detect2.GetBoundingBox().GetHeight());
+ float xMaxIntersection = std::min(detect1.GetBoundingBox().GetX() + detect1.GetBoundingBox().GetWidth(),
+ detect2.GetBoundingBox().GetX() + detect2.GetBoundingBox().GetWidth());
+
+ double areaIntersection = std::max(yMaxIntersection - yMinIntersection, 0.0f) *
+ std::max(xMaxIntersection - xMinIntersection, 0.0f);
+ double areaUnion = area1 + area2 - areaIntersection;
+
+ return areaIntersection / areaUnion;
+}
+
+std::vector<int> NonMaxSuppression(DetectedObjects& inputDetections, float iouThresh)
+{
+ // Sort indicies of detections by highest score to lowest.
+ std::vector<unsigned int> sortedIndicies = GenerateRangeK(inputDetections.size());
+ std::sort(sortedIndicies.begin(), sortedIndicies.end(),
+ [&inputDetections](int idx1, int idx2)
+ {
+ return inputDetections[idx1].GetScore() > inputDetections[idx2].GetScore();
+ });
+
+ std::vector<bool> visited(inputDetections.size(), false);
+ std::vector<int> outputIndiciesAfterNMS;
+
+ for (int i=0; i < inputDetections.size(); ++i)
+ {
+ // Each new unvisited detect should be kept.
+ if (!visited[sortedIndicies[i]])
+ {
+ outputIndiciesAfterNMS.emplace_back(sortedIndicies[i]);
+ visited[sortedIndicies[i]] = true;
+ }
+
+ // Look for detections to suppress.
+ for (int j=i+1; j<inputDetections.size(); ++j)
+ {
+ // Skip if already kept or suppressed.
+ if (!visited[sortedIndicies[j]])
+ {
+ // Detects must have the same label to be suppressed.
+ if (inputDetections[sortedIndicies[j]].GetLabel() == inputDetections[sortedIndicies[i]].GetLabel())
+ {
+ auto iou = IntersectionOverUnion(inputDetections[sortedIndicies[i]],
+ inputDetections[sortedIndicies[j]]);
+ if (iou > iouThresh)
+ {
+ visited[sortedIndicies[j]] = true;
+ }
+ }
+ }
+ }
+ }
+ return outputIndiciesAfterNMS;
+}
+
+} // namespace od