diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/validation/reference/NonMaxSuppression.cpp | 26 |
1 files changed, 16 insertions, 10 deletions
diff --git a/tests/validation/reference/NonMaxSuppression.cpp b/tests/validation/reference/NonMaxSuppression.cpp index 75929085b3..5b7980d2f0 100644 --- a/tests/validation/reference/NonMaxSuppression.cpp +++ b/tests/validation/reference/NonMaxSuppression.cpp @@ -53,13 +53,14 @@ inline Box get_box(const SimpleTensor<float> &boxes, size_t id) get_elem_by_coordinate(boxes, Coordinates(3, id))); } +// returns a pair (minX, minY) inline std::pair<float, float> get_min_yx(Box b) { return std::make_pair( std::min<float>(std::get<0>(b), std::get<2>(b)), std::min<float>(std::get<1>(b), std::get<3>(b))); } - +// returns a pair (maxX, maxY) inline std::pair<float, float> get_max_yx(Box b) { return std::make_pair( @@ -96,7 +97,8 @@ inline bool reject_box(Box b0, Box b1, float threshold) } else { - return compute_intersection(b0_min, b0_max, b1_min, b1_max, b0_size, b1_size) > threshold; + const float box_weight = compute_intersection(b0_min, b0_max, b1_min, b1_max, b0_size, b1_size); + return box_weight > threshold; } } @@ -111,7 +113,7 @@ inline std::vector<CandidateBox> get_candidates(const SimpleTensor<float> &score candidates_vector.push_back(cb); } } - std::sort(candidates_vector.begin(), candidates_vector.end(), [](const CandidateBox bb0, const CandidateBox bb1) + std::stable_sort(candidates_vector.begin(), candidates_vector.end(), [](const CandidateBox bb0, const CandidateBox bb1) { return bb0.second >= bb1.second; }); @@ -122,7 +124,10 @@ inline bool is_box_selected(const CandidateBox &cb, const SimpleTensor<float> &b { for(int j = selected_boxes.size() - 1; j >= 0; --j) { - if(reject_box(get_box(bboxes, cb.first), get_box(bboxes, selected_boxes[j]), threshold)) + const auto selected_box_jth = get_box(bboxes, selected_boxes[j]); + const auto candidate_box = get_box(bboxes, cb.first); + const bool candidate_rejected = reject_box(candidate_box, selected_box_jth, threshold); + if(candidate_rejected) { return false; } @@ -138,19 +143,20 @@ SimpleTensor<int> non_max_suppression(const SimpleTensor<float> &bboxes, const S const size_t output_size = std::min(static_cast<size_t>(max_output_size), num_boxes); const std::vector<CandidateBox> candidates_vector = get_candidates(scores, score_threshold); std::vector<int> selected; - size_t p(0); - while(selected.size() < output_size && p < candidates_vector.size() && selected.size() < candidates_vector.size()) + for(const auto c : candidates_vector) { - const auto nc = candidates_vector[p++]; - if(is_box_selected(nc, bboxes, selected, nms_threshold)) + if(selected.size() == output_size) + { + break; + } + if(is_box_selected(c, bboxes, selected, nms_threshold)) { - selected.push_back(nc.first); + selected.push_back(c.first); } } std::copy_n(selected.begin(), selected.size(), indices.data()); return indices; } - } // namespace reference } // namespace validation } // namespace test |