aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBruno Goncalves <bruno.slackware@gmail.com>2018-12-28 10:08:26 -0200
committerMatteo Martincigh <matteo.martincigh@arm.com>2019-02-08 08:54:35 +0000
commit61980d472006abdf3778d23903fb3bec5916f1f2 (patch)
tree27e452ad331630c54c07ac4637259cad9d5cdc53
parenta00a4ec08d998d8ec0cfc3c0bf4788f0d6a99693 (diff)
downloadarmnn-61980d472006abdf3778d23903fb3bec5916f1f2.tar.gz
Added TfLiteParser test for MnasNet
Change-Id: Ie31eee48cc14ada37526130998da7a482d56b1ea
-rw-r--r--tests/CMakeLists.txt6
-rw-r--r--tests/TfLiteMnasNet-Armnn/TfLiteMnasNet-Armnn.cpp58
2 files changed, 64 insertions, 0 deletions
diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt
index de40137af7..1fc89da016 100644
--- a/tests/CMakeLists.txt
+++ b/tests/CMakeLists.txt
@@ -200,6 +200,12 @@ if (BUILD_TF_LITE_PARSER)
ImagePreprocessor.hpp
ImagePreprocessor.cpp)
TfLiteParserTest(TfLiteResNetV2-Armnn "${TfLiteResNetV2-Armnn_sources}")
+
+ set(TfLiteMnasNet-Armnn_sources
+ TfLiteMnasNet-Armnn/TfLiteMnasNet-Armnn.cpp
+ ImagePreprocessor.hpp
+ ImagePreprocessor.cpp)
+ TfLiteParserTest(TfLiteMnasNet-Armnn "${TfLiteMnasNet-Armnn_sources}")
endif()
if (BUILD_ONNX_PARSER)
diff --git a/tests/TfLiteMnasNet-Armnn/TfLiteMnasNet-Armnn.cpp b/tests/TfLiteMnasNet-Armnn/TfLiteMnasNet-Armnn.cpp
new file mode 100644
index 0000000000..c676cd7355
--- /dev/null
+++ b/tests/TfLiteMnasNet-Armnn/TfLiteMnasNet-Armnn.cpp
@@ -0,0 +1,58 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+#include "../InferenceTest.hpp"
+#include "../ImagePreprocessor.hpp"
+#include "armnnTfLiteParser/ITfLiteParser.hpp"
+
+using namespace armnnTfLiteParser;
+
+int main(int argc, char* argv[])
+{
+ int retVal = EXIT_FAILURE;
+ try
+ {
+ // Coverity fix: The following code may throw an exception of type std::length_error.
+ std::vector<ImageSet> imageSet =
+ {
+ {"Dog.jpg", 209},
+ {"Cat.jpg", 283},
+ {"shark.jpg", 3},
+ };
+
+ armnn::TensorShape inputTensorShape({ 2, 224, 224, 3 });
+
+ using DataType = float;
+ using DatabaseType = ImagePreprocessor<DataType>;
+ using ParserType = armnnTfLiteParser::ITfLiteParser;
+ using ModelType = InferenceModel<ParserType, DataType>;
+
+ // Coverity fix: ClassifierInferenceTestMain() may throw uncaught exceptions.
+ retVal = armnn::test::ClassifierInferenceTestMain<DatabaseType,
+ ParserType>(
+ argc, argv,
+ "mnasnet_1.3_224.tflite", // model name
+ true, // model is binary
+ "input", // input tensor name
+ "output", // output tensor name
+ { 0, 1, 2 }, // test images to test with as above
+ [&imageSet](const char* dataDir, const ModelType & model) {
+ return DatabaseType(
+ dataDir,
+ 224,
+ 224,
+ imageSet);
+ },
+ &inputTensorShape);
+ }
+ catch (const std::exception& e)
+ {
+ // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
+ // exception of type std::length_error.
+ // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
+ std::cerr << "WARNING: " << *argv << ": An error has occurred when running "
+ "the classifier inference tests: " << e.what() << std::endl;
+ }
+ return retVal;
+}