aboutsummaryrefslogtreecommitdiff
path: root/tests/TfLiteYoloV3Big-Armnn/NMS.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/TfLiteYoloV3Big-Armnn/NMS.cpp')
-rw-r--r--tests/TfLiteYoloV3Big-Armnn/NMS.cpp13
1 files changed, 13 insertions, 0 deletions
diff --git a/tests/TfLiteYoloV3Big-Armnn/NMS.cpp b/tests/TfLiteYoloV3Big-Armnn/NMS.cpp
index 3ef840f875..d067f1a004 100644
--- a/tests/TfLiteYoloV3Big-Armnn/NMS.cpp
+++ b/tests/TfLiteYoloV3Big-Armnn/NMS.cpp
@@ -6,6 +6,7 @@
#include "NMS.hpp"
+#include <cmath>
#include <algorithm>
#include <cstddef>
#include <numeric>
@@ -80,6 +81,18 @@ std::vector<Detection> convert_to_detections(const NMSConfig& config,
}
} // namespace
+bool compare_detection(const yolov3::Detection& detection,
+ const std::vector<float>& expected)
+{
+ float tolerance = 0.001f;
+ return (std::fabs(detection.classes[0] - expected[0]) < tolerance &&
+ std::fabs(detection.box.xmin - expected[1]) < tolerance &&
+ std::fabs(detection.box.ymin - expected[2]) < tolerance &&
+ std::fabs(detection.box.xmax - expected[3]) < tolerance &&
+ std::fabs(detection.box.ymax - expected[4]) < tolerance &&
+ std::fabs(detection.confidence - expected[5]) < tolerance );
+}
+
void print_detection(std::ostream& os,
const std::vector<Detection>& detections)
{