ArmNN
 21.02
NMS.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2020 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 
7 #include "NMS.hpp"
8 
9 #include <cmath>
10 #include <algorithm>
11 #include <cstddef>
12 #include <numeric>
13 #include <ostream>
14 
15 namespace yolov3 {
16 namespace {
17 /** Number of elements needed to represent a box */
18 constexpr int box_elements = 4;
19 /** Number of elements needed to represent a confidence factor */
20 constexpr int confidence_elements = 1;
21 
22 /** Calculate Intersection Over Union of two boxes
23  *
24  * @param[in] box1 First box
25  * @param[in] box2 Second box
26  *
27  * @return The IoU of the two boxes
28  */
29 float iou(const Box& box1, const Box& box2)
30 {
31  const float area1 = (box1.xmax - box1.xmin) * (box1.ymax - box1.ymin);
32  const float area2 = (box2.xmax - box2.xmin) * (box2.ymax - box2.ymin);
33  float overlap;
34  if (area1 <= 0 || area2 <= 0)
35  {
36  overlap = 0.0f;
37  }
38  else
39  {
40  const auto y_min_intersection = std::max<float>(box1.ymin, box2.ymin);
41  const auto x_min_intersection = std::max<float>(box1.xmin, box2.xmin);
42  const auto y_max_intersection = std::min<float>(box1.ymax, box2.ymax);
43  const auto x_max_intersection = std::min<float>(box1.xmax, box2.xmax);
44  const auto area_intersection =
45  std::max<float>(y_max_intersection - y_min_intersection, 0.0f) *
46  std::max<float>(x_max_intersection - x_min_intersection, 0.0f);
47  overlap = area_intersection / (area1 + area2 - area_intersection);
48  }
49  return overlap;
50 }
51 
52 std::vector<Detection> convert_to_detections(const NMSConfig& config,
53  const std::vector<float>& detected_boxes)
54 {
55  const size_t element_step = static_cast<size_t>(
56  box_elements + confidence_elements + config.num_classes);
57  std::vector<Detection> detections;
58 
59  for (unsigned int i = 0; i < config.num_boxes; ++i)
60  {
61  const float* cur_box = &detected_boxes[i * element_step];
62  if (cur_box[4] > config.confidence_threshold)
63  {
64  Detection det;
65  det.box = {cur_box[0], cur_box[0] + cur_box[2], cur_box[1],
66  cur_box[1] + cur_box[3]};
67  det.confidence = cur_box[4];
68  det.classes.resize(static_cast<size_t>(config.num_classes), 0);
69  for (unsigned int c = 0; c < config.num_classes; ++c)
70  {
71  const float class_prob = det.confidence * cur_box[5 + c];
72  if (class_prob > config.confidence_threshold)
73  {
74  det.classes[c] = class_prob;
75  }
76  }
77  detections.emplace_back(std::move(det));
78  }
79  }
80  return detections;
81 }
82 } // namespace
83 
84 bool compare_detection(const yolov3::Detection& detection,
85  const std::vector<float>& expected)
86 {
87  float tolerance = 0.001f;
88  return (std::fabs(detection.classes[0] - expected[0]) < tolerance &&
89  std::fabs(detection.box.xmin - expected[1]) < tolerance &&
90  std::fabs(detection.box.ymin - expected[2]) < tolerance &&
91  std::fabs(detection.box.xmax - expected[3]) < tolerance &&
92  std::fabs(detection.box.ymax - expected[4]) < tolerance &&
93  std::fabs(detection.confidence - expected[5]) < tolerance );
94 }
95 
96 void print_detection(std::ostream& os,
97  const std::vector<Detection>& detections)
98 {
99  for (const auto& detection : detections)
100  {
101  for (unsigned int c = 0; c < detection.classes.size(); ++c)
102  {
103  if (detection.classes[c] != 0.0f)
104  {
105  os << c << " " << detection.classes[c] << " " << detection.box.xmin
106  << " " << detection.box.ymin << " " << detection.box.xmax << " "
107  << detection.box.ymax << std::endl;
108  }
109  }
110  }
111 }
112 
113 std::vector<Detection> nms(const NMSConfig& config,
114  const std::vector<float>& detected_boxes) {
115  // Get detections that comply with the expected confidence threshold
116  std::vector<Detection> detections =
117  convert_to_detections(config, detected_boxes);
118 
119  const unsigned int num_detections = static_cast<unsigned int>(detections.size());
120  for (unsigned int c = 0; c < config.num_classes; ++c)
121  {
122  // Sort classes
123  std::sort(detections.begin(), detections.begin() + static_cast<std::ptrdiff_t>(num_detections),
124  [c](Detection& detection1, Detection& detection2)
125  {
126  return (detection1.classes[c] - detection2.classes[c]) > 0;
127  });
128  // Clear detections with high IoU
129  for (unsigned int d = 0; d < num_detections; ++d)
130  {
131  // Check if class is already cleared/invalidated
132  if (detections[d].classes[c] == 0.f)
133  {
134  continue;
135  }
136 
137  // Filter out boxes on IoU threshold
138  const Box& box1 = detections[d].box;
139  for (unsigned int b = d + 1; b < num_detections; ++b)
140  {
141  const Box& box2 = detections[b].box;
142  if (iou(box1, box2) > config.iou_threshold)
143  {
144  detections[b].classes[c] = 0.f;
145  }
146  }
147  }
148  }
149  return detections;
150 }
151 } // namespace yolov3
float confidence
Confidence of detection.
Definition: NMS.hpp:31
float ymin
Y-pos position of the low left coordinate.
Definition: NMS.hpp:24
unsigned int num_boxes
Number of detected boxes.
Definition: NMS.hpp:15
Definition: NMS.cpp:15
float xmin
X-pos position of the low left coordinate.
Definition: NMS.hpp:22
float xmax
X-pos position of the top right coordinate.
Definition: NMS.hpp:23
void print_detection(std::ostream &os, const std::vector< Detection > &detections)
Print identified yolo detections.
Definition: NMS.cpp:96
Detection structure.
Definition: NMS.hpp:29
std::vector< float > classes
Probability of classes.
Definition: NMS.hpp:32
Box box
Detection box.
Definition: NMS.hpp:30
Box representation structure.
Definition: NMS.hpp:21
float ymax
Y-pos position of the top right coordinate.
Definition: NMS.hpp:25
float iou_threshold
Inclusion threshold for Intersection-Over-Union.
Definition: NMS.hpp:17
std::vector< Detection > nms(const NMSConfig &config, const std::vector< float > &detected_boxes)
Perform Non-Maxima Supression on a list of given detections.
Definition: NMS.cpp:113
Non Maxima Suprresion configuration meta-data.
Definition: NMS.hpp:13
float confidence_threshold
Inclusion confidence threshold for a box.
Definition: NMS.hpp:16
bool compare_detection(const yolov3::Detection &detection, const std::vector< float > &expected)
Compare a detection object with a vector of float values.
Definition: NMS.cpp:84
unsigned int num_classes
Number of classes in the detected boxes.
Definition: NMS.hpp:14