From 032fb7e91dfd7a5b53a0ffb31890147148c5d598 Mon Sep 17 00:00:00 2001 From: Pablo Tello Date: Wed, 27 Feb 2019 13:32:51 +0000 Subject: COMPMID-1766: Fixed NonMaxSuppression reference. Change-Id: Id1cb964e35ee3e524ffa2db6b112a8cc37853124 Signed-off-by: Pablo Tello Reviewed-on: https://review.mlplatform.org/c/801 Reviewed-by: Georgios Pinitas Tested-by: Arm Jenkins --- tests/validation/reference/NonMaxSuppression.cpp | 26 +++++++++++++++--------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/tests/validation/reference/NonMaxSuppression.cpp b/tests/validation/reference/NonMaxSuppression.cpp index 75929085b3..5b7980d2f0 100644 --- a/tests/validation/reference/NonMaxSuppression.cpp +++ b/tests/validation/reference/NonMaxSuppression.cpp @@ -53,13 +53,14 @@ inline Box get_box(const SimpleTensor &boxes, size_t id) get_elem_by_coordinate(boxes, Coordinates(3, id))); } +// returns a pair (minX, minY) inline std::pair get_min_yx(Box b) { return std::make_pair( std::min(std::get<0>(b), std::get<2>(b)), std::min(std::get<1>(b), std::get<3>(b))); } - +// returns a pair (maxX, maxY) inline std::pair get_max_yx(Box b) { return std::make_pair( @@ -96,7 +97,8 @@ inline bool reject_box(Box b0, Box b1, float threshold) } else { - return compute_intersection(b0_min, b0_max, b1_min, b1_max, b0_size, b1_size) > threshold; + const float box_weight = compute_intersection(b0_min, b0_max, b1_min, b1_max, b0_size, b1_size); + return box_weight > threshold; } } @@ -111,7 +113,7 @@ inline std::vector get_candidates(const SimpleTensor &score candidates_vector.push_back(cb); } } - std::sort(candidates_vector.begin(), candidates_vector.end(), [](const CandidateBox bb0, const CandidateBox bb1) + std::stable_sort(candidates_vector.begin(), candidates_vector.end(), [](const CandidateBox bb0, const CandidateBox bb1) { return bb0.second >= bb1.second; }); @@ -122,7 +124,10 @@ inline bool is_box_selected(const CandidateBox &cb, const SimpleTensor &b { for(int j = selected_boxes.size() - 1; j >= 0; --j) { - if(reject_box(get_box(bboxes, cb.first), get_box(bboxes, selected_boxes[j]), threshold)) + const auto selected_box_jth = get_box(bboxes, selected_boxes[j]); + const auto candidate_box = get_box(bboxes, cb.first); + const bool candidate_rejected = reject_box(candidate_box, selected_box_jth, threshold); + if(candidate_rejected) { return false; } @@ -138,19 +143,20 @@ SimpleTensor non_max_suppression(const SimpleTensor &bboxes, const S const size_t output_size = std::min(static_cast(max_output_size), num_boxes); const std::vector candidates_vector = get_candidates(scores, score_threshold); std::vector selected; - size_t p(0); - while(selected.size() < output_size && p < candidates_vector.size() && selected.size() < candidates_vector.size()) + for(const auto c : candidates_vector) { - const auto nc = candidates_vector[p++]; - if(is_box_selected(nc, bboxes, selected, nms_threshold)) + if(selected.size() == output_size) + { + break; + } + if(is_box_selected(c, bboxes, selected, nms_threshold)) { - selected.push_back(nc.first); + selected.push_back(c.first); } } std::copy_n(selected.begin(), selected.size(), indices.data()); return indices; } - } // namespace reference } // namespace validation } // namespace test -- cgit v1.2.1