aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/reference/NonMaxSuppression.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/reference/NonMaxSuppression.cpp')
-rw-r--r--tests/validation/reference/NonMaxSuppression.cpp26
1 files changed, 16 insertions, 10 deletions
diff --git a/tests/validation/reference/NonMaxSuppression.cpp b/tests/validation/reference/NonMaxSuppression.cpp
index 75929085b3..5b7980d2f0 100644
--- a/tests/validation/reference/NonMaxSuppression.cpp
+++ b/tests/validation/reference/NonMaxSuppression.cpp
@@ -53,13 +53,14 @@ inline Box get_box(const SimpleTensor<float> &boxes, size_t id)
get_elem_by_coordinate(boxes, Coordinates(3, id)));
}
+// returns a pair (minX, minY)
inline std::pair<float, float> get_min_yx(Box b)
{
return std::make_pair(
std::min<float>(std::get<0>(b), std::get<2>(b)),
std::min<float>(std::get<1>(b), std::get<3>(b)));
}
-
+// returns a pair (maxX, maxY)
inline std::pair<float, float> get_max_yx(Box b)
{
return std::make_pair(
@@ -96,7 +97,8 @@ inline bool reject_box(Box b0, Box b1, float threshold)
}
else
{
- return compute_intersection(b0_min, b0_max, b1_min, b1_max, b0_size, b1_size) > threshold;
+ const float box_weight = compute_intersection(b0_min, b0_max, b1_min, b1_max, b0_size, b1_size);
+ return box_weight > threshold;
}
}
@@ -111,7 +113,7 @@ inline std::vector<CandidateBox> get_candidates(const SimpleTensor<float> &score
candidates_vector.push_back(cb);
}
}
- std::sort(candidates_vector.begin(), candidates_vector.end(), [](const CandidateBox bb0, const CandidateBox bb1)
+ std::stable_sort(candidates_vector.begin(), candidates_vector.end(), [](const CandidateBox bb0, const CandidateBox bb1)
{
return bb0.second >= bb1.second;
});
@@ -122,7 +124,10 @@ inline bool is_box_selected(const CandidateBox &cb, const SimpleTensor<float> &b
{
for(int j = selected_boxes.size() - 1; j >= 0; --j)
{
- if(reject_box(get_box(bboxes, cb.first), get_box(bboxes, selected_boxes[j]), threshold))
+ const auto selected_box_jth = get_box(bboxes, selected_boxes[j]);
+ const auto candidate_box = get_box(bboxes, cb.first);
+ const bool candidate_rejected = reject_box(candidate_box, selected_box_jth, threshold);
+ if(candidate_rejected)
{
return false;
}
@@ -138,19 +143,20 @@ SimpleTensor<int> non_max_suppression(const SimpleTensor<float> &bboxes, const S
const size_t output_size = std::min(static_cast<size_t>(max_output_size), num_boxes);
const std::vector<CandidateBox> candidates_vector = get_candidates(scores, score_threshold);
std::vector<int> selected;
- size_t p(0);
- while(selected.size() < output_size && p < candidates_vector.size() && selected.size() < candidates_vector.size())
+ for(const auto c : candidates_vector)
{
- const auto nc = candidates_vector[p++];
- if(is_box_selected(nc, bboxes, selected, nms_threshold))
+ if(selected.size() == output_size)
+ {
+ break;
+ }
+ if(is_box_selected(c, bboxes, selected, nms_threshold))
{
- selected.push_back(nc.first);
+ selected.push_back(c.first);
}
}
std::copy_n(selected.begin(), selected.size(), indices.data());
return indices;
}
-
} // namespace reference
} // namespace validation
} // namespace test