aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/reference/NonMaxSuppression.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/reference/NonMaxSuppression.cpp')
-rw-r--r--tests/validation/reference/NonMaxSuppression.cpp14
1 files changed, 10 insertions, 4 deletions
diff --git a/tests/validation/reference/NonMaxSuppression.cpp b/tests/validation/reference/NonMaxSuppression.cpp
index 5b7980d2f0..8fc370b7af 100644
--- a/tests/validation/reference/NonMaxSuppression.cpp
+++ b/tests/validation/reference/NonMaxSuppression.cpp
@@ -76,10 +76,10 @@ inline float compute_size(const std::pair<float, float> &min, const std::pair<fl
inline float compute_intersection(const std::pair<float, float> &b0_min, const std::pair<float, float> &b0_max,
const std::pair<float, float> &b1_min, const std::pair<float, float> &b1_max, float b0_size, float b1_size)
{
- const float inter = std::max<float>(std::min<float>(b0_max.first, b1_max.first) - std::max<float>(b0_min.first, b1_min.first), 0.0) * std::max<float>(std::min<float>(b0_max.second,
+ const float inter = std::max<float>(std::min<float>(b0_max.first, b1_max.first) - std::max<float>(b0_min.first, b1_min.first), 0.0f) * std::max<float>(std::min<float>(b0_max.second,
b1_max.second)
- std::max<float>(b0_min.second, b1_min.second),
- 0.0);
+ 0.0f);
return inter / (b0_size + b1_size - inter);
}
@@ -107,7 +107,7 @@ inline std::vector<CandidateBox> get_candidates(const SimpleTensor<float> &score
std::vector<CandidateBox> candidates_vector;
for(int i = 0; i < scores.num_elements(); ++i)
{
- if(scores[i] > threshold)
+ if(scores[i] >= threshold)
{
const auto cb = CandidateBox({ i, scores[i] });
candidates_vector.push_back(cb);
@@ -115,7 +115,7 @@ inline std::vector<CandidateBox> get_candidates(const SimpleTensor<float> &score
}
std::stable_sort(candidates_vector.begin(), candidates_vector.end(), [](const CandidateBox bb0, const CandidateBox bb1)
{
- return bb0.second >= bb1.second;
+ return bb0.second > bb1.second;
});
return candidates_vector;
}
@@ -155,6 +155,12 @@ SimpleTensor<int> non_max_suppression(const SimpleTensor<float> &bboxes, const S
}
}
std::copy_n(selected.begin(), selected.size(), indices.data());
+
+ for(unsigned int i = selected.size(); i < max_output_size; ++i)
+ {
+ indices[i] = -1;
+ }
+
return indices;
}
} // namespace reference