aboutsummaryrefslogtreecommitdiff
path: root/tests/TfLiteYoloV3Big-Armnn/NMS.hpp
blob: f5e3cf38af2455c86107e4f44009ec5a2db622f1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
//
// Copyright © 2020 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//

#pragma once

#include <ostream>
#include <vector>

namespace yolov3 {
/** Non Maxima Suprresion configuration meta-data */
struct NMSConfig {
    unsigned int num_classes{0};      /**< Number of classes in the detected boxes */
    unsigned int num_boxes{0};        /**< Number of detected boxes */
    float confidence_threshold{0.8f}; /**< Inclusion confidence threshold for a box */
    float iou_threshold{0.8f};        /**< Inclusion threshold for Intersection-Over-Union */
};

/** Box representation structure */
struct Box {
    float xmin;  /**< X-pos position of the low left coordinate */
    float xmax;  /**< X-pos position of the top right coordinate */
    float ymin;  /**< Y-pos position of the low left coordinate */
    float ymax;  /**< Y-pos position of the top right coordinate */
};

/** Detection structure */
struct Detection {
    Box box;                    /**< Detection box */
    float confidence;           /**< Confidence of detection */
    std::vector<float> classes; /**< Probability of classes */
};

/** Print identified yolo detections
 *
 * @param[in, out] os          Output stream to print to
 * @param[in]      detections  Detections to print
 */
void print_detection(std::ostream& os,
                     const std::vector<Detection>& detections);

/** Compare a detection object with a vector of float values
 *
 * @param detection [in] Detection object
 * @param expected  [in] Vector of expected float values
 * @return Boolean to represent if they match or not
 */
bool compare_detection(const yolov3::Detection& detection,
                       const std::vector<float>& expected);

/** Perform Non-Maxima Supression on a list of given detections
 *
 * @param[in] config         Configuration metadata for NMS
 * @param[in] detected_boxes Detected boxes
 *
 * @return A vector with the final detections
 */
std::vector<Detection> nms(const NMSConfig& config,
                           const std::vector<float>& detected_boxes);
} // namespace yolov3