aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/test/RefDetectionPostProcessTests.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/test/RefDetectionPostProcessTests.cpp')
-rw-r--r--src/backends/reference/test/RefDetectionPostProcessTests.cpp68
1 files changed, 50 insertions, 18 deletions
diff --git a/src/backends/reference/test/RefDetectionPostProcessTests.cpp b/src/backends/reference/test/RefDetectionPostProcessTests.cpp
index a9faff70b1..fab6e00bad 100644
--- a/src/backends/reference/test/RefDetectionPostProcessTests.cpp
+++ b/src/backends/reference/test/RefDetectionPostProcessTests.cpp
@@ -3,7 +3,7 @@
// SPDX-License-Identifier: MIT
//
-#include "reference/workloads/DetectionPostProcess.cpp"
+#include <reference/workloads/DetectionPostProcess.hpp>
#include <armnn/Descriptors.hpp>
#include <armnn/Types.hpp>
@@ -12,13 +12,12 @@
BOOST_AUTO_TEST_SUITE(RefDetectionPostProcess)
-
BOOST_AUTO_TEST_CASE(TopKSortTest)
{
unsigned int k = 3;
unsigned int indices[8] = { 0, 1, 2, 3, 4, 5, 6, 7 };
float values[8] = { 0, 7, 6, 5, 4, 3, 2, 500 };
- TopKSort(k, indices, values, 8);
+ armnn::TopKSort(k, indices, values, 8);
BOOST_TEST(indices[0] == 7);
BOOST_TEST(indices[1] == 1);
BOOST_TEST(indices[2] == 2);
@@ -29,7 +28,7 @@ BOOST_AUTO_TEST_CASE(FullTopKSortTest)
unsigned int k = 8;
unsigned int indices[8] = { 0, 1, 2, 3, 4, 5, 6, 7 };
float values[8] = { 0, 7, 6, 5, 4, 3, 2, 500 };
- TopKSort(k, indices, values, 8);
+ armnn::TopKSort(k, indices, values, 8);
BOOST_TEST(indices[0] == 7);
BOOST_TEST(indices[1] == 1);
BOOST_TEST(indices[2] == 2);
@@ -44,7 +43,7 @@ BOOST_AUTO_TEST_CASE(IouTest)
{
float boxI[4] = { 0.0f, 0.0f, 10.0f, 10.0f };
float boxJ[4] = { 1.0f, 1.0f, 11.0f, 11.0f };
- float iou = IntersectionOverUnion(boxI, boxJ);
+ float iou = armnn::IntersectionOverUnion(boxI, boxJ);
BOOST_TEST(iou == 0.68, boost::test_tools::tolerance(0.001));
}
@@ -61,14 +60,17 @@ BOOST_AUTO_TEST_CASE(NmsFunction)
std::vector<float> scores({ 0.9f, 0.75f, 0.6f, 0.93f, 0.5f, 0.3f });
- std::vector<unsigned int> result = NonMaxSuppression(6, boxCorners, scores, 0.0, 3, 0.5);
+ std::vector<unsigned int> result =
+ armnn::NonMaxSuppression(6, boxCorners, scores, 0.0, 3, 0.5);
+
BOOST_TEST(result.size() == 3);
BOOST_TEST(result[0] == 3);
BOOST_TEST(result[1] == 0);
BOOST_TEST(result[2] == 5);
}
-void DetectionPostProcessTestImpl(bool useRegularNms, const std::vector<float>& expectedDetectionBoxes,
+void DetectionPostProcessTestImpl(bool useRegularNms,
+ const std::vector<float>& expectedDetectionBoxes,
const std::vector<float>& expectedDetectionClasses,
const std::vector<float>& expectedDetectionScores,
const std::vector<float>& expectedNumDetections)
@@ -103,6 +105,7 @@ void DetectionPostProcessTestImpl(bool useRegularNms, const std::vector<float>&
0.0f, 1.0f, 0.0f, 0.0f,
0.0f, 0.0f, 0.0f, 0.0f
});
+
std::vector<float> scores({
0.0f, 0.9f, 0.8f,
0.0f, 0.75f, 0.72f,
@@ -111,6 +114,7 @@ void DetectionPostProcessTestImpl(bool useRegularNms, const std::vector<float>&
0.0f, 0.5f, 0.4f,
0.0f, 0.3f, 0.2f
});
+
std::vector<float> anchors({
0.5f, 0.5f, 1.0f, 1.0f,
0.5f, 0.5f, 1.0f, 1.0f,
@@ -120,22 +124,50 @@ void DetectionPostProcessTestImpl(bool useRegularNms, const std::vector<float>&
0.5f, 100.5f, 1.0f, 1.0f
});
+ auto boxEncodingsDecoder = armnn::MakeDecoder<float>(boxEncodingsInfo, boxEncodings.data());
+ auto scoresDecoder = armnn::MakeDecoder<float>(scoresInfo, scores.data());
+ auto anchorsDecoder = armnn::MakeDecoder<float>(anchorsInfo, anchors.data());
+
std::vector<float> detectionBoxes(detectionBoxesInfo.GetNumElements());
std::vector<float> detectionScores(detectionScoresInfo.GetNumElements());
std::vector<float> detectionClasses(detectionClassesInfo.GetNumElements());
std::vector<float> numDetections(1);
- armnn::DetectionPostProcess(boxEncodingsInfo, scoresInfo, anchorsInfo,
- detectionBoxesInfo, detectionClassesInfo,
- detectionScoresInfo, numDetectionInfo, desc,
- boxEncodings.data(), scores.data(), anchors.data(),
- detectionBoxes.data(), detectionClasses.data(),
- detectionScores.data(), numDetections.data());
-
- BOOST_TEST(detectionBoxes == expectedDetectionBoxes);
- BOOST_TEST(detectionScores == expectedDetectionScores);
- BOOST_TEST(detectionClasses == expectedDetectionClasses);
- BOOST_TEST(numDetections == expectedNumDetections);
+ armnn::DetectionPostProcess(boxEncodingsInfo,
+ scoresInfo,
+ anchorsInfo,
+ detectionBoxesInfo,
+ detectionClassesInfo,
+ detectionScoresInfo,
+ numDetectionInfo,
+ desc,
+ *boxEncodingsDecoder,
+ *scoresDecoder,
+ *anchorsDecoder,
+ detectionBoxes.data(),
+ detectionClasses.data(),
+ detectionScores.data(),
+ numDetections.data());
+
+ BOOST_CHECK_EQUAL_COLLECTIONS(detectionBoxes.begin(),
+ detectionBoxes.end(),
+ expectedDetectionBoxes.begin(),
+ expectedDetectionBoxes.end());
+
+ BOOST_CHECK_EQUAL_COLLECTIONS(detectionScores.begin(),
+ detectionScores.end(),
+ expectedDetectionScores.begin(),
+ expectedDetectionScores.end());
+
+ BOOST_CHECK_EQUAL_COLLECTIONS(detectionClasses.begin(),
+ detectionClasses.end(),
+ expectedDetectionClasses.begin(),
+ expectedDetectionClasses.end());
+
+ BOOST_CHECK_EQUAL_COLLECTIONS(numDetections.begin(),
+ numDetections.end(),
+ expectedNumDetections.begin(),
+ expectedNumDetections.end());
}
BOOST_AUTO_TEST_CASE(RegularNmsDetectionPostProcess)