diff options
author | Derek Lamberti <derek.lamberti@arm.com> | 2020-04-28 13:31:29 +0100 |
---|---|---|
committer | TeresaARM <teresa.charlinreyes@arm.com> | 2020-05-12 16:45:12 +0000 |
commit | d6cb30e7052891996efa41f608fffe4fa62d2094 (patch) | |
tree | 802517e0c3963b047504af4fed7aa832cdf67c44 /tests/TfLiteYoloV3Big-Armnn/NMS.cpp | |
parent | 1e9892e0d8ad63e9014910509ad2704fdf141c6c (diff) | |
download | armnn-d6cb30e7052891996efa41f608fffe4fa62d2094.tar.gz |
Add test app to drive YoloV3 split model
Change-Id: I4877f170003d3a82f97311046c52571792ad735a
Signed-off-by: Derek Lamberti <derek.lamberti@arm.com>
Diffstat (limited to 'tests/TfLiteYoloV3Big-Armnn/NMS.cpp')
-rw-r--r-- | tests/TfLiteYoloV3Big-Armnn/NMS.cpp | 138 |
1 files changed, 138 insertions, 0 deletions
diff --git a/tests/TfLiteYoloV3Big-Armnn/NMS.cpp b/tests/TfLiteYoloV3Big-Armnn/NMS.cpp new file mode 100644 index 0000000000..3ef840f875 --- /dev/null +++ b/tests/TfLiteYoloV3Big-Armnn/NMS.cpp @@ -0,0 +1,138 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + + +#include "NMS.hpp" + +#include <algorithm> +#include <cstddef> +#include <numeric> +#include <ostream> + +namespace yolov3 { +namespace { +/** Number of elements needed to represent a box */ +constexpr int box_elements = 4; +/** Number of elements needed to represent a confidence factor */ +constexpr int confidence_elements = 1; + +/** Calculate Intersection Over Union of two boxes + * + * @param[in] box1 First box + * @param[in] box2 Second box + * + * @return The IoU of the two boxes + */ +float iou(const Box& box1, const Box& box2) +{ + const float area1 = (box1.xmax - box1.xmin) * (box1.ymax - box1.ymin); + const float area2 = (box2.xmax - box2.xmin) * (box2.ymax - box2.ymin); + float overlap; + if (area1 <= 0 || area2 <= 0) + { + overlap = 0.0f; + } + else + { + const auto y_min_intersection = std::max<float>(box1.ymin, box2.ymin); + const auto x_min_intersection = std::max<float>(box1.xmin, box2.xmin); + const auto y_max_intersection = std::min<float>(box1.ymax, box2.ymax); + const auto x_max_intersection = std::min<float>(box1.xmax, box2.xmax); + const auto area_intersection = + std::max<float>(y_max_intersection - y_min_intersection, 0.0f) * + std::max<float>(x_max_intersection - x_min_intersection, 0.0f); + overlap = area_intersection / (area1 + area2 - area_intersection); + } + return overlap; +} + +std::vector<Detection> convert_to_detections(const NMSConfig& config, + const std::vector<float>& detected_boxes) +{ + const size_t element_step = static_cast<size_t>( + box_elements + confidence_elements + config.num_classes); + std::vector<Detection> detections; + + for (unsigned int i = 0; i < config.num_boxes; ++i) + { + const float* cur_box = &detected_boxes[i * element_step]; + if (cur_box[4] > config.confidence_threshold) + { + Detection det; + det.box = {cur_box[0], cur_box[0] + cur_box[2], cur_box[1], + cur_box[1] + cur_box[3]}; + det.confidence = cur_box[4]; + det.classes.resize(static_cast<size_t>(config.num_classes), 0); + for (unsigned int c = 0; c < config.num_classes; ++c) + { + const float class_prob = det.confidence * cur_box[5 + c]; + if (class_prob > config.confidence_threshold) + { + det.classes[c] = class_prob; + } + } + detections.emplace_back(std::move(det)); + } + } + return detections; +} +} // namespace + +void print_detection(std::ostream& os, + const std::vector<Detection>& detections) +{ + for (const auto& detection : detections) + { + for (unsigned int c = 0; c < detection.classes.size(); ++c) + { + if (detection.classes[c] != 0.0f) + { + os << c << " " << detection.classes[c] << " " << detection.box.xmin + << " " << detection.box.ymin << " " << detection.box.xmax << " " + << detection.box.ymax << std::endl; + } + } + } +} + +std::vector<Detection> nms(const NMSConfig& config, + const std::vector<float>& detected_boxes) { + // Get detections that comply with the expected confidence threshold + std::vector<Detection> detections = + convert_to_detections(config, detected_boxes); + + const unsigned int num_detections = static_cast<unsigned int>(detections.size()); + for (unsigned int c = 0; c < config.num_classes; ++c) + { + // Sort classes + std::sort(detections.begin(), detections.begin() + static_cast<std::ptrdiff_t>(num_detections), + [c](Detection& detection1, Detection& detection2) + { + return (detection1.classes[c] - detection2.classes[c]) > 0; + }); + // Clear detections with high IoU + for (unsigned int d = 0; d < num_detections; ++d) + { + // Check if class is already cleared/invalidated + if (detections[d].classes[c] == 0.f) + { + continue; + } + + // Filter out boxes on IoU threshold + const Box& box1 = detections[d].box; + for (unsigned int b = d + 1; b < num_detections; ++b) + { + const Box& box2 = detections[b].box; + if (iou(box1, box2) > config.iou_threshold) + { + detections[b].classes[c] = 0.f; + } + } + } + } + return detections; +} +} // namespace yolov3 |