From 222c753ba83bec5dc915f01d999ff76826ec45d0 Mon Sep 17 00:00:00 2001 From: keidav01 Date: Thu, 14 Mar 2019 17:12:10 +0000 Subject: IVGCVSW-2429 Add Detection PostProcess Parser to TensorFlow Lite Parser * Add helper function to generate custom data for detectPostProcess * Test helper function within current test suite Change-Id: I9e66d0a28d69b1376da67723f03b112d17e97281 Signed-off-by: keidav01 Signed-off-by: Aron Virginas-Tar --- .../test/DetectionPostProcess.cpp | 50 +++++++++++++++------- .../test/ParserFlatbuffersFixture.hpp | 37 ++++++++++++++-- 2 files changed, 68 insertions(+), 19 deletions(-) (limited to 'src') diff --git a/src/armnnTfLiteParser/test/DetectionPostProcess.cpp b/src/armnnTfLiteParser/test/DetectionPostProcess.cpp index 3002885016..638238db57 100644 --- a/src/armnnTfLiteParser/test/DetectionPostProcess.cpp +++ b/src/armnnTfLiteParser/test/DetectionPostProcess.cpp @@ -10,17 +10,19 @@ #include "ParserFlatbuffersFixture.hpp" #include "ParserPrototxtFixture.hpp" +#include "ParserHelper.hpp" BOOST_AUTO_TEST_SUITE(TensorflowLiteParser) struct DetectionPostProcessFixture : ParserFlatbuffersFixture { - explicit DetectionPostProcessFixture() + explicit DetectionPostProcessFixture(const std::string& custom_options) { /* The following values were used for the custom_options: use_regular_nms = true max_classes_per_detection = 1 + detections_per_class = 1 nms_score_threshold = 0.0 nms_iou_threshold = 0.5 max_detections = 3 @@ -107,19 +109,7 @@ struct DetectionPostProcessFixture : ParserFlatbuffersFixture "inputs": [0, 1, 2], "outputs": [3, 4, 5, 6], "builtin_options_type": 0, - "custom_options": [ - 109, 97, 120, 95, 100, 101, 116, 101, 99, 116, 105, 111, 110, 115, 0, 109, 97, 120, - 95, 99, 108, 97, 115, 115, 101, 115, 95, 112, 101, 114, 95, 100, 101, 116, 101, 99, - 116, 105, 111, 110, 0, 110, 109, 115, 95, 115, 99, 111, 114, 101, 95, 116, 104, 114, - 101, 115, 104, 111, 108, 100, 0, 110, 109, 115, 95, 105, 111, 117, 95, 116, 104, 114, - 101, 115, 104, 111, 108, 100, 0, 110, 117, 109, 95, 99, 108, 97, 115, 115, 101, 115, - 0, 104, 95, 115, 99, 97, 108, 101, 0, 119, 95, 115, 99, 97, 108, 101, 0, 120, 95, 115, - 99, 97, 108, 101, 0, 121, 95, 115, 99, 97, 108, 101, 0, 117, 115, 101, 95, 114, 101, - 103, 117, 108, 97, 114, 95, 110, 109, 115, 0, 10, 49, 126, 142, 82, 103, 66, 23, 48, - 41, 34, 0, 0, 12, 0, 0, 0, 1, 0, 0, 0, 10, 0, 0, 0, 0, 0, 160, 64, 1, 0, 0, 0, 3, 0, - 0, 0, 0, 0, 0, 63, 0, 0, 0, 0, 2, 0, 0, 0, 1, 0, 0, 0, 0, 0, 160, 64, 0, 0, 32, 65, 0, - 0, 32, 65, 14, 6, 6, 14, 14, 6, 106, 14, 14, 14, 50, 38, 1 - ], + "custom_options": [)" + custom_options + R"(], "custom_options_format": "FLEXBUFFERS" }] }], @@ -141,7 +131,35 @@ struct DetectionPostProcessFixture : ParserFlatbuffersFixture } }; -BOOST_FIXTURE_TEST_CASE( ParseDetectionPostProcess, DetectionPostProcessFixture ) +struct ParseDetectionPostProcessCustomOptions : DetectionPostProcessFixture +{ +private: + static armnn::DetectionPostProcessDescriptor GenerateDescriptor() + { + static armnn::DetectionPostProcessDescriptor descriptor; + descriptor.m_UseRegularNms = true; + descriptor.m_MaxDetections = 3u; + descriptor.m_MaxClassesPerDetection = 1u; + descriptor.m_DetectionsPerClass = 1u; + descriptor.m_NumClasses = 2u; + descriptor.m_NmsScoreThreshold = 0.0f; + descriptor.m_NmsIouThreshold = 0.5f; + descriptor.m_ScaleH = 5.0f; + descriptor.m_ScaleW = 5.0f; + descriptor.m_ScaleX = 10.0f; + descriptor.m_ScaleY = 10.0f; + + return descriptor; + } + +public: + ParseDetectionPostProcessCustomOptions() + : DetectionPostProcessFixture( + GenerateDetectionPostProcessJsonString(GenerateDescriptor())) + {} +}; + +BOOST_FIXTURE_TEST_CASE( ParseDetectionPostProcess, ParseDetectionPostProcessCustomOptions ) { Setup(); @@ -202,7 +220,7 @@ BOOST_FIXTURE_TEST_CASE( ParseDetectionPostProcess, DetectionPostProcessFixture RunTest(0, input, output); } -BOOST_FIXTURE_TEST_CASE(DetectionPostProcessGraphStructureTest, DetectionPostProcessFixture) +BOOST_FIXTURE_TEST_CASE(DetectionPostProcessGraphStructureTest, ParseDetectionPostProcessCustomOptions) { /* Inputs: box_encodings scores diff --git a/src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp b/src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp index 50e674ef2c..9eb2f2b93d 100644 --- a/src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp +++ b/src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp @@ -5,13 +5,16 @@ #pragma once +#include +#include +#include + #include "Schema.hpp" #include #include #include #include -#include -#include + #include "test/TensorHelpers.hpp" #include "TypeUtils.hpp" @@ -21,6 +24,7 @@ #include "flatbuffers/idl.h" #include "flatbuffers/util.h" +#include "flatbuffers/flexbuffers.h" #include #include @@ -159,6 +163,33 @@ struct ParserFlatbuffersFixture const std::map>& inputData, const std::map>& expectedOutputData); + static inline std::string GenerateDetectionPostProcessJsonString( + const armnn::DetectionPostProcessDescriptor& descriptor) + { + flexbuffers::Builder detectPostProcess; + detectPostProcess.Map([&]() { + detectPostProcess.Bool("use_regular_nms", descriptor.m_UseRegularNms); + detectPostProcess.Int("max_detections", descriptor.m_MaxDetections); + detectPostProcess.Int("max_classes_per_detection", descriptor.m_MaxClassesPerDetection); + detectPostProcess.Int("detections_per_class", descriptor.m_DetectionsPerClass); + detectPostProcess.Int("num_classes", descriptor.m_NumClasses); + detectPostProcess.Float("nms_score_threshold", descriptor.m_NmsScoreThreshold); + detectPostProcess.Float("nms_iou_threshold", descriptor.m_NmsIouThreshold); + detectPostProcess.Float("h_scale", descriptor.m_ScaleH); + detectPostProcess.Float("w_scale", descriptor.m_ScaleW); + detectPostProcess.Float("x_scale", descriptor.m_ScaleX); + detectPostProcess.Float("y_scale", descriptor.m_ScaleY); + }); + detectPostProcess.Finish(); + + // Create JSON string + std::stringstream strStream; + std::vector buffer = detectPostProcess.GetBuffer(); + std::copy(buffer.begin(), buffer.end(),std::ostream_iterator(strStream,",")); + + return strStream.str(); + } + void CheckTensors(const TensorRawPtr& tensors, size_t shapeSize, const std::vector& shape, tflite::TensorType tensorType, uint32_t buffer, const std::string& name, const std::vector& min, const std::vector& max, @@ -310,4 +341,4 @@ void ParserFlatbuffersFixture::RunTest(std::size_t subgraphId, } } } -} +} \ No newline at end of file -- cgit v1.2.1