diff options
author | keidav01 <keith.davis@arm.com> | 2019-03-14 17:12:10 +0000 |
---|---|---|
committer | keidav01 <keith.davis@arm.com> | 2019-03-14 17:12:10 +0000 |
commit | 222c753ba83bec5dc915f01d999ff76826ec45d0 (patch) | |
tree | 0e70783e23442d1ee5cb79afcb240fc839eaba9d /src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp | |
parent | 232cfc2a544c874b35cc2a949e9fa71f41690565 (diff) | |
download | armnn-222c753ba83bec5dc915f01d999ff76826ec45d0.tar.gz |
IVGCVSW-2429 Add Detection PostProcess Parser to TensorFlow Lite Parser
* Add helper function to generate custom data for detectPostProcess
* Test helper function within current test suite
Change-Id: I9e66d0a28d69b1376da67723f03b112d17e97281
Signed-off-by: keidav01 <keith.davis@arm.com>
Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com>
Diffstat (limited to 'src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp')
-rw-r--r-- | src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp | 37 |
1 files changed, 34 insertions, 3 deletions
diff --git a/src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp b/src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp index 50e674ef2c..9eb2f2b93d 100644 --- a/src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp +++ b/src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp @@ -5,13 +5,16 @@ #pragma once +#include <armnn/Descriptors.hpp> +#include <armnn/IRuntime.hpp> +#include <armnn/TypesUtils.hpp> + #include "Schema.hpp" #include <boost/filesystem.hpp> #include <boost/assert.hpp> #include <boost/format.hpp> #include <experimental/filesystem> -#include <armnn/IRuntime.hpp> -#include <armnn/TypesUtils.hpp> + #include "test/TensorHelpers.hpp" #include "TypeUtils.hpp" @@ -21,6 +24,7 @@ #include "flatbuffers/idl.h" #include "flatbuffers/util.h" +#include "flatbuffers/flexbuffers.h" #include <schema_generated.h> #include <iostream> @@ -159,6 +163,33 @@ struct ParserFlatbuffersFixture const std::map<std::string, std::vector<DataType1>>& inputData, const std::map<std::string, std::vector<DataType2>>& expectedOutputData); + static inline std::string GenerateDetectionPostProcessJsonString( + const armnn::DetectionPostProcessDescriptor& descriptor) + { + flexbuffers::Builder detectPostProcess; + detectPostProcess.Map([&]() { + detectPostProcess.Bool("use_regular_nms", descriptor.m_UseRegularNms); + detectPostProcess.Int("max_detections", descriptor.m_MaxDetections); + detectPostProcess.Int("max_classes_per_detection", descriptor.m_MaxClassesPerDetection); + detectPostProcess.Int("detections_per_class", descriptor.m_DetectionsPerClass); + detectPostProcess.Int("num_classes", descriptor.m_NumClasses); + detectPostProcess.Float("nms_score_threshold", descriptor.m_NmsScoreThreshold); + detectPostProcess.Float("nms_iou_threshold", descriptor.m_NmsIouThreshold); + detectPostProcess.Float("h_scale", descriptor.m_ScaleH); + detectPostProcess.Float("w_scale", descriptor.m_ScaleW); + detectPostProcess.Float("x_scale", descriptor.m_ScaleX); + detectPostProcess.Float("y_scale", descriptor.m_ScaleY); + }); + detectPostProcess.Finish(); + + // Create JSON string + std::stringstream strStream; + std::vector<uint8_t> buffer = detectPostProcess.GetBuffer(); + std::copy(buffer.begin(), buffer.end(),std::ostream_iterator<int>(strStream,",")); + + return strStream.str(); + } + void CheckTensors(const TensorRawPtr& tensors, size_t shapeSize, const std::vector<int32_t>& shape, tflite::TensorType tensorType, uint32_t buffer, const std::string& name, const std::vector<float>& min, const std::vector<float>& max, @@ -310,4 +341,4 @@ void ParserFlatbuffersFixture::RunTest(std::size_t subgraphId, } } } -} +}
\ No newline at end of file |