diff options
Diffstat (limited to 'tests/TfLiteYoloV3Big-Armnn/NMS.cpp')
-rw-r--r-- | tests/TfLiteYoloV3Big-Armnn/NMS.cpp | 13 |
1 files changed, 13 insertions, 0 deletions
diff --git a/tests/TfLiteYoloV3Big-Armnn/NMS.cpp b/tests/TfLiteYoloV3Big-Armnn/NMS.cpp index 3ef840f875..d067f1a004 100644 --- a/tests/TfLiteYoloV3Big-Armnn/NMS.cpp +++ b/tests/TfLiteYoloV3Big-Armnn/NMS.cpp @@ -6,6 +6,7 @@ #include "NMS.hpp" +#include <cmath> #include <algorithm> #include <cstddef> #include <numeric> @@ -80,6 +81,18 @@ std::vector<Detection> convert_to_detections(const NMSConfig& config, } } // namespace +bool compare_detection(const yolov3::Detection& detection, + const std::vector<float>& expected) +{ + float tolerance = 0.001f; + return (std::fabs(detection.classes[0] - expected[0]) < tolerance && + std::fabs(detection.box.xmin - expected[1]) < tolerance && + std::fabs(detection.box.ymin - expected[2]) < tolerance && + std::fabs(detection.box.xmax - expected[3]) < tolerance && + std::fabs(detection.box.ymax - expected[4]) < tolerance && + std::fabs(detection.confidence - expected[5]) < tolerance ); +} + void print_detection(std::ostream& os, const std::vector<Detection>& detections) { |