aboutsummaryrefslogtreecommitdiff
path: root/tests/TfLiteYoloV3Big-Armnn/NMS.cpp
blob: d067f1a004b394ec208137bc09ee6bb738c072f0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
//
// Copyright © 2020 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//


#include "NMS.hpp"

#include <cmath>
#include <algorithm>
#include <cstddef>
#include <numeric>
#include <ostream>

namespace yolov3 {
namespace {
/** Number of elements needed to represent a box */
constexpr int box_elements = 4;
/** Number of elements needed to represent a confidence factor */
constexpr int confidence_elements = 1;

/** Calculate Intersection Over Union of two boxes
 *
 * @param[in] box1 First box
 * @param[in] box2 Second box
 *
 * @return The IoU of the two boxes
 */
float iou(const Box& box1, const Box& box2)
{
    const float area1 = (box1.xmax - box1.xmin) * (box1.ymax - box1.ymin);
    const float area2 = (box2.xmax - box2.xmin) * (box2.ymax - box2.ymin);
    float overlap;
    if (area1 <= 0 || area2 <= 0)
    {
        overlap = 0.0f;
    }
    else
    {
        const auto y_min_intersection = std::max<float>(box1.ymin, box2.ymin);
        const auto x_min_intersection = std::max<float>(box1.xmin, box2.xmin);
        const auto y_max_intersection = std::min<float>(box1.ymax, box2.ymax);
        const auto x_max_intersection = std::min<float>(box1.xmax, box2.xmax);
        const auto area_intersection =
            std::max<float>(y_max_intersection - y_min_intersection, 0.0f) *
            std::max<float>(x_max_intersection - x_min_intersection, 0.0f);
        overlap = area_intersection / (area1 + area2 - area_intersection);
    }
    return overlap;
}

std::vector<Detection> convert_to_detections(const NMSConfig& config,
                                             const std::vector<float>& detected_boxes)
{
    const size_t element_step = static_cast<size_t>(
        box_elements + confidence_elements + config.num_classes);
    std::vector<Detection> detections;

    for (unsigned int i = 0; i < config.num_boxes; ++i)
    {
        const float* cur_box = &detected_boxes[i * element_step];
        if (cur_box[4] > config.confidence_threshold)
        {
            Detection det;
            det.box = {cur_box[0], cur_box[0] + cur_box[2], cur_box[1],
                       cur_box[1] + cur_box[3]};
            det.confidence = cur_box[4];
            det.classes.resize(static_cast<size_t>(config.num_classes), 0);
            for (unsigned int c = 0; c < config.num_classes; ++c)
            {
                const float class_prob = det.confidence * cur_box[5 + c];
                if (class_prob > config.confidence_threshold)
                {
                    det.classes[c] = class_prob;
                }
            }
            detections.emplace_back(std::move(det));
        }
    }
    return detections;
}
} // namespace

bool compare_detection(const yolov3::Detection& detection,
                       const std::vector<float>& expected)
{
    float tolerance = 0.001f;
    return (std::fabs(detection.classes[0] - expected[0]) < tolerance  &&
            std::fabs(detection.box.xmin   - expected[1]) < tolerance  &&
            std::fabs(detection.box.ymin   - expected[2]) < tolerance  &&
            std::fabs(detection.box.xmax   - expected[3]) < tolerance  &&
            std::fabs(detection.box.ymax   - expected[4]) < tolerance  &&
            std::fabs(detection.confidence - expected[5]) < tolerance  );
}

void print_detection(std::ostream& os,
                     const std::vector<Detection>& detections)
{
    for (const auto& detection : detections)
    {
        for (unsigned int c = 0; c < detection.classes.size(); ++c)
        {
            if (detection.classes[c] != 0.0f)
            {
                os << c << " " << detection.classes[c] << " " << detection.box.xmin
                   << " " << detection.box.ymin << " " << detection.box.xmax << " "
                   << detection.box.ymax << std::endl;
            }
        }
    }
}

std::vector<Detection> nms(const NMSConfig& config,
                           const std::vector<float>& detected_boxes) {
    // Get detections that comply with the expected confidence threshold
    std::vector<Detection> detections =
        convert_to_detections(config, detected_boxes);

    const unsigned int num_detections = static_cast<unsigned int>(detections.size());
    for (unsigned int c = 0; c < config.num_classes; ++c)
    {
        // Sort classes
        std::sort(detections.begin(), detections.begin() + static_cast<std::ptrdiff_t>(num_detections),
                  [c](Detection& detection1, Detection& detection2)
                    {
                        return (detection1.classes[c] - detection2.classes[c]) > 0;
                    });
        // Clear detections with high IoU
        for (unsigned int d = 0; d < num_detections; ++d)
        {
            // Check if class is already cleared/invalidated
            if (detections[d].classes[c] == 0.f)
            {
                continue;
            }

            // Filter out boxes on IoU threshold
            const Box& box1 = detections[d].box;
            for (unsigned int b = d + 1; b < num_detections; ++b)
            {
                const Box& box2 = detections[b].box;
                if (iou(box1, box2) > config.iou_threshold)
                {
                    detections[b].classes[c] = 0.f;
                }
            }
        }
    }
    return detections;
}
} // namespace yolov3