aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfLiteParser
diff options
context:
space:
mode:
authorkeidav01 <keith.davis@arm.com>2019-03-14 17:12:10 +0000
committerkeidav01 <keith.davis@arm.com>2019-03-14 17:12:10 +0000
commit222c753ba83bec5dc915f01d999ff76826ec45d0 (patch)
tree0e70783e23442d1ee5cb79afcb240fc839eaba9d /src/armnnTfLiteParser
parent232cfc2a544c874b35cc2a949e9fa71f41690565 (diff)
downloadarmnn-222c753ba83bec5dc915f01d999ff76826ec45d0.tar.gz
IVGCVSW-2429 Add Detection PostProcess Parser to TensorFlow Lite Parser
* Add helper function to generate custom data for detectPostProcess * Test helper function within current test suite Change-Id: I9e66d0a28d69b1376da67723f03b112d17e97281 Signed-off-by: keidav01 <keith.davis@arm.com> Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com>
Diffstat (limited to 'src/armnnTfLiteParser')
-rw-r--r--src/armnnTfLiteParser/test/DetectionPostProcess.cpp50
-rw-r--r--src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp37
2 files changed, 68 insertions, 19 deletions
diff --git a/src/armnnTfLiteParser/test/DetectionPostProcess.cpp b/src/armnnTfLiteParser/test/DetectionPostProcess.cpp
index 3002885016..638238db57 100644
--- a/src/armnnTfLiteParser/test/DetectionPostProcess.cpp
+++ b/src/armnnTfLiteParser/test/DetectionPostProcess.cpp
@@ -10,17 +10,19 @@
#include "ParserFlatbuffersFixture.hpp"
#include "ParserPrototxtFixture.hpp"
+#include "ParserHelper.hpp"
BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
struct DetectionPostProcessFixture : ParserFlatbuffersFixture
{
- explicit DetectionPostProcessFixture()
+ explicit DetectionPostProcessFixture(const std::string& custom_options)
{
/*
The following values were used for the custom_options:
use_regular_nms = true
max_classes_per_detection = 1
+ detections_per_class = 1
nms_score_threshold = 0.0
nms_iou_threshold = 0.5
max_detections = 3
@@ -107,19 +109,7 @@ struct DetectionPostProcessFixture : ParserFlatbuffersFixture
"inputs": [0, 1, 2],
"outputs": [3, 4, 5, 6],
"builtin_options_type": 0,
- "custom_options": [
- 109, 97, 120, 95, 100, 101, 116, 101, 99, 116, 105, 111, 110, 115, 0, 109, 97, 120,
- 95, 99, 108, 97, 115, 115, 101, 115, 95, 112, 101, 114, 95, 100, 101, 116, 101, 99,
- 116, 105, 111, 110, 0, 110, 109, 115, 95, 115, 99, 111, 114, 101, 95, 116, 104, 114,
- 101, 115, 104, 111, 108, 100, 0, 110, 109, 115, 95, 105, 111, 117, 95, 116, 104, 114,
- 101, 115, 104, 111, 108, 100, 0, 110, 117, 109, 95, 99, 108, 97, 115, 115, 101, 115,
- 0, 104, 95, 115, 99, 97, 108, 101, 0, 119, 95, 115, 99, 97, 108, 101, 0, 120, 95, 115,
- 99, 97, 108, 101, 0, 121, 95, 115, 99, 97, 108, 101, 0, 117, 115, 101, 95, 114, 101,
- 103, 117, 108, 97, 114, 95, 110, 109, 115, 0, 10, 49, 126, 142, 82, 103, 66, 23, 48,
- 41, 34, 0, 0, 12, 0, 0, 0, 1, 0, 0, 0, 10, 0, 0, 0, 0, 0, 160, 64, 1, 0, 0, 0, 3, 0,
- 0, 0, 0, 0, 0, 63, 0, 0, 0, 0, 2, 0, 0, 0, 1, 0, 0, 0, 0, 0, 160, 64, 0, 0, 32, 65, 0,
- 0, 32, 65, 14, 6, 6, 14, 14, 6, 106, 14, 14, 14, 50, 38, 1
- ],
+ "custom_options": [)" + custom_options + R"(],
"custom_options_format": "FLEXBUFFERS"
}]
}],
@@ -141,7 +131,35 @@ struct DetectionPostProcessFixture : ParserFlatbuffersFixture
}
};
-BOOST_FIXTURE_TEST_CASE( ParseDetectionPostProcess, DetectionPostProcessFixture )
+struct ParseDetectionPostProcessCustomOptions : DetectionPostProcessFixture
+{
+private:
+ static armnn::DetectionPostProcessDescriptor GenerateDescriptor()
+ {
+ static armnn::DetectionPostProcessDescriptor descriptor;
+ descriptor.m_UseRegularNms = true;
+ descriptor.m_MaxDetections = 3u;
+ descriptor.m_MaxClassesPerDetection = 1u;
+ descriptor.m_DetectionsPerClass = 1u;
+ descriptor.m_NumClasses = 2u;
+ descriptor.m_NmsScoreThreshold = 0.0f;
+ descriptor.m_NmsIouThreshold = 0.5f;
+ descriptor.m_ScaleH = 5.0f;
+ descriptor.m_ScaleW = 5.0f;
+ descriptor.m_ScaleX = 10.0f;
+ descriptor.m_ScaleY = 10.0f;
+
+ return descriptor;
+ }
+
+public:
+ ParseDetectionPostProcessCustomOptions()
+ : DetectionPostProcessFixture(
+ GenerateDetectionPostProcessJsonString(GenerateDescriptor()))
+ {}
+};
+
+BOOST_FIXTURE_TEST_CASE( ParseDetectionPostProcess, ParseDetectionPostProcessCustomOptions )
{
Setup();
@@ -202,7 +220,7 @@ BOOST_FIXTURE_TEST_CASE( ParseDetectionPostProcess, DetectionPostProcessFixture
RunTest<armnn::DataType::QuantisedAsymm8, armnn::DataType::Float32>(0, input, output);
}
-BOOST_FIXTURE_TEST_CASE(DetectionPostProcessGraphStructureTest, DetectionPostProcessFixture)
+BOOST_FIXTURE_TEST_CASE(DetectionPostProcessGraphStructureTest, ParseDetectionPostProcessCustomOptions)
{
/*
Inputs: box_encodings scores
diff --git a/src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp b/src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp
index 50e674ef2c..9eb2f2b93d 100644
--- a/src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp
+++ b/src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp
@@ -5,13 +5,16 @@
#pragma once
+#include <armnn/Descriptors.hpp>
+#include <armnn/IRuntime.hpp>
+#include <armnn/TypesUtils.hpp>
+
#include "Schema.hpp"
#include <boost/filesystem.hpp>
#include <boost/assert.hpp>
#include <boost/format.hpp>
#include <experimental/filesystem>
-#include <armnn/IRuntime.hpp>
-#include <armnn/TypesUtils.hpp>
+
#include "test/TensorHelpers.hpp"
#include "TypeUtils.hpp"
@@ -21,6 +24,7 @@
#include "flatbuffers/idl.h"
#include "flatbuffers/util.h"
+#include "flatbuffers/flexbuffers.h"
#include <schema_generated.h>
#include <iostream>
@@ -159,6 +163,33 @@ struct ParserFlatbuffersFixture
const std::map<std::string, std::vector<DataType1>>& inputData,
const std::map<std::string, std::vector<DataType2>>& expectedOutputData);
+ static inline std::string GenerateDetectionPostProcessJsonString(
+ const armnn::DetectionPostProcessDescriptor& descriptor)
+ {
+ flexbuffers::Builder detectPostProcess;
+ detectPostProcess.Map([&]() {
+ detectPostProcess.Bool("use_regular_nms", descriptor.m_UseRegularNms);
+ detectPostProcess.Int("max_detections", descriptor.m_MaxDetections);
+ detectPostProcess.Int("max_classes_per_detection", descriptor.m_MaxClassesPerDetection);
+ detectPostProcess.Int("detections_per_class", descriptor.m_DetectionsPerClass);
+ detectPostProcess.Int("num_classes", descriptor.m_NumClasses);
+ detectPostProcess.Float("nms_score_threshold", descriptor.m_NmsScoreThreshold);
+ detectPostProcess.Float("nms_iou_threshold", descriptor.m_NmsIouThreshold);
+ detectPostProcess.Float("h_scale", descriptor.m_ScaleH);
+ detectPostProcess.Float("w_scale", descriptor.m_ScaleW);
+ detectPostProcess.Float("x_scale", descriptor.m_ScaleX);
+ detectPostProcess.Float("y_scale", descriptor.m_ScaleY);
+ });
+ detectPostProcess.Finish();
+
+ // Create JSON string
+ std::stringstream strStream;
+ std::vector<uint8_t> buffer = detectPostProcess.GetBuffer();
+ std::copy(buffer.begin(), buffer.end(),std::ostream_iterator<int>(strStream,","));
+
+ return strStream.str();
+ }
+
void CheckTensors(const TensorRawPtr& tensors, size_t shapeSize, const std::vector<int32_t>& shape,
tflite::TensorType tensorType, uint32_t buffer, const std::string& name,
const std::vector<float>& min, const std::vector<float>& max,
@@ -310,4 +341,4 @@ void ParserFlatbuffersFixture::RunTest(std::size_t subgraphId,
}
}
}
-}
+} \ No newline at end of file