From 6331f91a4a1cb1ad16c569d98bb9ddf704788464 Mon Sep 17 00:00:00 2001 From: Aron Virginas-Tar Date: Mon, 3 Jun 2019 17:10:02 +0100 Subject: IVGCVSW-2971 Support QSymm16 for DetectionPostProcess workloads Signed-off-by: Aron Virginas-Tar Change-Id: I8af45afe851a9ccbf8bce54727147fcd52ac9a1f --- .../test/RefDetectionPostProcessTests.cpp | 68 ++++++++++++++++------ src/backends/reference/test/RefLayerTests.cpp | 16 ++++- 2 files changed, 64 insertions(+), 20 deletions(-) (limited to 'src/backends/reference/test') diff --git a/src/backends/reference/test/RefDetectionPostProcessTests.cpp b/src/backends/reference/test/RefDetectionPostProcessTests.cpp index a9faff70b1..fab6e00bad 100644 --- a/src/backends/reference/test/RefDetectionPostProcessTests.cpp +++ b/src/backends/reference/test/RefDetectionPostProcessTests.cpp @@ -3,7 +3,7 @@ // SPDX-License-Identifier: MIT // -#include "reference/workloads/DetectionPostProcess.cpp" +#include #include #include @@ -12,13 +12,12 @@ BOOST_AUTO_TEST_SUITE(RefDetectionPostProcess) - BOOST_AUTO_TEST_CASE(TopKSortTest) { unsigned int k = 3; unsigned int indices[8] = { 0, 1, 2, 3, 4, 5, 6, 7 }; float values[8] = { 0, 7, 6, 5, 4, 3, 2, 500 }; - TopKSort(k, indices, values, 8); + armnn::TopKSort(k, indices, values, 8); BOOST_TEST(indices[0] == 7); BOOST_TEST(indices[1] == 1); BOOST_TEST(indices[2] == 2); @@ -29,7 +28,7 @@ BOOST_AUTO_TEST_CASE(FullTopKSortTest) unsigned int k = 8; unsigned int indices[8] = { 0, 1, 2, 3, 4, 5, 6, 7 }; float values[8] = { 0, 7, 6, 5, 4, 3, 2, 500 }; - TopKSort(k, indices, values, 8); + armnn::TopKSort(k, indices, values, 8); BOOST_TEST(indices[0] == 7); BOOST_TEST(indices[1] == 1); BOOST_TEST(indices[2] == 2); @@ -44,7 +43,7 @@ BOOST_AUTO_TEST_CASE(IouTest) { float boxI[4] = { 0.0f, 0.0f, 10.0f, 10.0f }; float boxJ[4] = { 1.0f, 1.0f, 11.0f, 11.0f }; - float iou = IntersectionOverUnion(boxI, boxJ); + float iou = armnn::IntersectionOverUnion(boxI, boxJ); BOOST_TEST(iou == 0.68, boost::test_tools::tolerance(0.001)); } @@ -61,14 +60,17 @@ BOOST_AUTO_TEST_CASE(NmsFunction) std::vector scores({ 0.9f, 0.75f, 0.6f, 0.93f, 0.5f, 0.3f }); - std::vector result = NonMaxSuppression(6, boxCorners, scores, 0.0, 3, 0.5); + std::vector result = + armnn::NonMaxSuppression(6, boxCorners, scores, 0.0, 3, 0.5); + BOOST_TEST(result.size() == 3); BOOST_TEST(result[0] == 3); BOOST_TEST(result[1] == 0); BOOST_TEST(result[2] == 5); } -void DetectionPostProcessTestImpl(bool useRegularNms, const std::vector& expectedDetectionBoxes, +void DetectionPostProcessTestImpl(bool useRegularNms, + const std::vector& expectedDetectionBoxes, const std::vector& expectedDetectionClasses, const std::vector& expectedDetectionScores, const std::vector& expectedNumDetections) @@ -103,6 +105,7 @@ void DetectionPostProcessTestImpl(bool useRegularNms, const std::vector& 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f }); + std::vector scores({ 0.0f, 0.9f, 0.8f, 0.0f, 0.75f, 0.72f, @@ -111,6 +114,7 @@ void DetectionPostProcessTestImpl(bool useRegularNms, const std::vector& 0.0f, 0.5f, 0.4f, 0.0f, 0.3f, 0.2f }); + std::vector anchors({ 0.5f, 0.5f, 1.0f, 1.0f, 0.5f, 0.5f, 1.0f, 1.0f, @@ -120,22 +124,50 @@ void DetectionPostProcessTestImpl(bool useRegularNms, const std::vector& 0.5f, 100.5f, 1.0f, 1.0f }); + auto boxEncodingsDecoder = armnn::MakeDecoder(boxEncodingsInfo, boxEncodings.data()); + auto scoresDecoder = armnn::MakeDecoder(scoresInfo, scores.data()); + auto anchorsDecoder = armnn::MakeDecoder(anchorsInfo, anchors.data()); + std::vector detectionBoxes(detectionBoxesInfo.GetNumElements()); std::vector detectionScores(detectionScoresInfo.GetNumElements()); std::vector detectionClasses(detectionClassesInfo.GetNumElements()); std::vector numDetections(1); - armnn::DetectionPostProcess(boxEncodingsInfo, scoresInfo, anchorsInfo, - detectionBoxesInfo, detectionClassesInfo, - detectionScoresInfo, numDetectionInfo, desc, - boxEncodings.data(), scores.data(), anchors.data(), - detectionBoxes.data(), detectionClasses.data(), - detectionScores.data(), numDetections.data()); - - BOOST_TEST(detectionBoxes == expectedDetectionBoxes); - BOOST_TEST(detectionScores == expectedDetectionScores); - BOOST_TEST(detectionClasses == expectedDetectionClasses); - BOOST_TEST(numDetections == expectedNumDetections); + armnn::DetectionPostProcess(boxEncodingsInfo, + scoresInfo, + anchorsInfo, + detectionBoxesInfo, + detectionClassesInfo, + detectionScoresInfo, + numDetectionInfo, + desc, + *boxEncodingsDecoder, + *scoresDecoder, + *anchorsDecoder, + detectionBoxes.data(), + detectionClasses.data(), + detectionScores.data(), + numDetections.data()); + + BOOST_CHECK_EQUAL_COLLECTIONS(detectionBoxes.begin(), + detectionBoxes.end(), + expectedDetectionBoxes.begin(), + expectedDetectionBoxes.end()); + + BOOST_CHECK_EQUAL_COLLECTIONS(detectionScores.begin(), + detectionScores.end(), + expectedDetectionScores.begin(), + expectedDetectionScores.end()); + + BOOST_CHECK_EQUAL_COLLECTIONS(detectionClasses.begin(), + detectionClasses.end(), + expectedDetectionClasses.begin(), + expectedDetectionClasses.end()); + + BOOST_CHECK_EQUAL_COLLECTIONS(numDetections.begin(), + numDetections.end(), + expectedNumDetections.begin(), + expectedNumDetections.end()); } BOOST_AUTO_TEST_CASE(RegularNmsDetectionPostProcess) diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp index b2f71a8920..f54a8d067c 100644 --- a/src/backends/reference/test/RefLayerTests.cpp +++ b/src/backends/reference/test/RefLayerTests.cpp @@ -624,11 +624,23 @@ BOOST_AUTO_TEST_CASE(DetectionPostProcessFastNmsFloat) } BOOST_AUTO_TEST_CASE(DetectionPostProcessRegularNmsUint8) { - DetectionPostProcessRegularNmsUint8Test(); + DetectionPostProcessRegularNmsQuantizedTest< + armnn::RefWorkloadFactory, armnn::DataType::QuantisedAsymm8>(); } BOOST_AUTO_TEST_CASE(DetectionPostProcessFastNmsUint8) { - DetectionPostProcessFastNmsUint8Test(); + DetectionPostProcessRegularNmsQuantizedTest< + armnn::RefWorkloadFactory, armnn::DataType::QuantisedAsymm8>(); +} +BOOST_AUTO_TEST_CASE(DetectionPostProcessRegularNmsInt16) +{ + DetectionPostProcessRegularNmsQuantizedTest< + armnn::RefWorkloadFactory, armnn::DataType::QuantisedSymm16>(); +} +BOOST_AUTO_TEST_CASE(DetectionPostProcessFastNmsInt16) +{ + DetectionPostProcessFastNmsQuantizedTest< + armnn::RefWorkloadFactory, armnn::DataType::QuantisedSymm16>(); } // Dequantize -- cgit v1.2.1