diff options
Diffstat (limited to 'tests/validation/reference/NonMaxSuppression.cpp')
-rw-r--r-- | tests/validation/reference/NonMaxSuppression.cpp | 14 |
1 files changed, 10 insertions, 4 deletions
diff --git a/tests/validation/reference/NonMaxSuppression.cpp b/tests/validation/reference/NonMaxSuppression.cpp index 5b7980d2f0..8fc370b7af 100644 --- a/tests/validation/reference/NonMaxSuppression.cpp +++ b/tests/validation/reference/NonMaxSuppression.cpp @@ -76,10 +76,10 @@ inline float compute_size(const std::pair<float, float> &min, const std::pair<fl inline float compute_intersection(const std::pair<float, float> &b0_min, const std::pair<float, float> &b0_max, const std::pair<float, float> &b1_min, const std::pair<float, float> &b1_max, float b0_size, float b1_size) { - const float inter = std::max<float>(std::min<float>(b0_max.first, b1_max.first) - std::max<float>(b0_min.first, b1_min.first), 0.0) * std::max<float>(std::min<float>(b0_max.second, + const float inter = std::max<float>(std::min<float>(b0_max.first, b1_max.first) - std::max<float>(b0_min.first, b1_min.first), 0.0f) * std::max<float>(std::min<float>(b0_max.second, b1_max.second) - std::max<float>(b0_min.second, b1_min.second), - 0.0); + 0.0f); return inter / (b0_size + b1_size - inter); } @@ -107,7 +107,7 @@ inline std::vector<CandidateBox> get_candidates(const SimpleTensor<float> &score std::vector<CandidateBox> candidates_vector; for(int i = 0; i < scores.num_elements(); ++i) { - if(scores[i] > threshold) + if(scores[i] >= threshold) { const auto cb = CandidateBox({ i, scores[i] }); candidates_vector.push_back(cb); @@ -115,7 +115,7 @@ inline std::vector<CandidateBox> get_candidates(const SimpleTensor<float> &score } std::stable_sort(candidates_vector.begin(), candidates_vector.end(), [](const CandidateBox bb0, const CandidateBox bb1) { - return bb0.second >= bb1.second; + return bb0.second > bb1.second; }); return candidates_vector; } @@ -155,6 +155,12 @@ SimpleTensor<int> non_max_suppression(const SimpleTensor<float> &bboxes, const S } } std::copy_n(selected.begin(), selected.size(), indices.data()); + + for(unsigned int i = selected.size(); i < max_output_size; ++i) + { + indices[i] = -1; + } + return indices; } } // namespace reference |