aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAron Virginas-Tar <Aron.Virginas-Tar@arm.com>2019-02-12 11:27:53 +0000
committerAron Virginas-Tar <Aron.Virginas-Tar@arm.com>2019-02-12 11:40:24 +0000
commit64e4ccb4546473e922b4ddd699ff6b77a5c2527d (patch)
tree3c5fbd6be6706d7450919030e7c91355d0b3507a
parent424951560f2948b49506f178352e788cbe680fd8 (diff)
downloadarmnn-64e4ccb4546473e922b4ddd699ff6b77a5c2527d.tar.gz
IVGCVSW-2663 Enable ExecuteNetwork to load ArmNN files
Change-Id: I1a61a1da2258bd07b39da6063d22a5bd22c1884d Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com>
-rw-r--r--tests/CMakeLists.txt5
-rw-r--r--tests/ExecuteNetwork/ExecuteNetwork.cpp32
-rw-r--r--tests/InferenceModel.hpp50
3 files changed, 77 insertions, 10 deletions
diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt
index 1fc89da016..9913321295 100644
--- a/tests/CMakeLists.txt
+++ b/tests/CMakeLists.txt
@@ -241,7 +241,7 @@ if (BUILD_ONNX_PARSER)
OnnxParserTest(OnnxMobileNet-Armnn "${OnnxMobileNet-Armnn_sources}")
endif()
-if (BUILD_CAFFE_PARSER OR BUILD_TF_PARSER OR BUILD_TF_LITE_PARSER OR BUILD_ONNX_PARSER)
+if (BUILD_ARMNN_SERIALIZER OR BUILD_CAFFE_PARSER OR BUILD_TF_PARSER OR BUILD_TF_LITE_PARSER OR BUILD_ONNX_PARSER)
set(ExecuteNetwork_sources
ExecuteNetwork/ExecuteNetwork.cpp)
@@ -250,6 +250,9 @@ if (BUILD_CAFFE_PARSER OR BUILD_TF_PARSER OR BUILD_TF_LITE_PARSER OR BUILD_ONNX_
target_include_directories(ExecuteNetwork PRIVATE ../src/armnnUtils)
target_include_directories(ExecuteNetwork PRIVATE ../src/backends)
+ if (BUILD_ARMNN_SERIALIZER)
+ target_link_libraries(ExecuteNetwork armnnSerializer)
+ endif()
if (BUILD_CAFFE_PARSER)
target_link_libraries(ExecuteNetwork armnnCaffeParser)
endif()
diff --git a/tests/ExecuteNetwork/ExecuteNetwork.cpp b/tests/ExecuteNetwork/ExecuteNetwork.cpp
index bbab70b39a..a97d6da3d5 100644
--- a/tests/ExecuteNetwork/ExecuteNetwork.cpp
+++ b/tests/ExecuteNetwork/ExecuteNetwork.cpp
@@ -5,6 +5,9 @@
#include <armnn/ArmNN.hpp>
#include <armnn/TypesUtils.hpp>
+#if defined(ARMNN_SERIALIZER)
+#include "armnnDeserializeParser/IDeserializeParser.hpp"
+#endif
#if defined(ARMNN_CAFFE_PARSER)
#include "armnnCaffeParser/ICaffeParser.hpp"
#endif
@@ -361,7 +364,20 @@ int RunTest(const std::string& format,
}
// Forward to implementation based on the parser type
- if (modelFormat.find("caffe") != std::string::npos)
+ if (modelFormat.find("armnn") != std::string::npos)
+ {
+#if defined(ARMNN_SERIALIZER)
+ return MainImpl<armnnDeserializeParser::IDeserializeParser, float>(
+ modelPath.c_str(), isModelBinary, computeDevice,
+ inputNamesVector, inputTensorShapes,
+ inputTensorDataFilePathsVector, inputTypesVector,
+ outputNamesVector, enableProfiling, subgraphId, runtime);
+#else
+ BOOST_LOG_TRIVIAL(fatal) << "Not built with serialization support.";
+ return EXIT_FAILURE;
+#endif
+ }
+ else if (modelFormat.find("caffe") != std::string::npos)
{
#if defined(ARMNN_CAFFE_PARSER)
return MainImpl<armnnCaffeParser::ICaffeParser, float>(modelPath.c_str(), isModelBinary, computeDevice,
@@ -447,9 +463,10 @@ int RunCsvTest(const armnnUtils::CsvRow &csvRow,
{
desc.add_options()
("model-format,f", po::value(&modelFormat),
- "caffe-binary, caffe-text, tflite-binary, onnx-binary, onnx-text, tensorflow-binary or tensorflow-text.")
- ("model-path,m", po::value(&modelPath), "Path to model file, e.g. .caffemodel, .prototxt, .tflite,"
- " .onnx")
+ "armnn-binary, caffe-binary, caffe-text, tflite-binary, onnx-binary, onnx-text, tensorflow-binary or "
+ "tensorflow-text.")
+ ("model-path,m", po::value(&modelPath), "Path to model file, e.g. .armnn, .caffemodel, .prototxt, "
+ ".tflite, .onnx")
("compute,c", po::value<std::vector<armnn::BackendId>>()->multitoken(),
backendsMessage.c_str())
("input-name,i", po::value(&inputNames), "Identifier of the input tensors in the network separated by comma.")
@@ -557,9 +574,10 @@ int main(int argc, const char* argv[])
("concurrent,n", po::bool_switch()->default_value(false),
"Whether or not the test cases should be executed in parallel")
("model-format,f", po::value(&modelFormat)->required(),
- "caffe-binary, caffe-text, onnx-binary, onnx-text, tflite-binary, tensorflow-binary or tensorflow-text.")
- ("model-path,m", po::value(&modelPath)->required(), "Path to model file, e.g. .caffemodel, .prototxt,"
- " .tflite, .onnx")
+ "armnn-binary, caffe-binary, caffe-text, onnx-binary, onnx-text, tflite-binary, tensorflow-binary or "
+ "tensorflow-text.")
+ ("model-path,m", po::value(&modelPath)->required(), "Path to model file, e.g. .armnn, .caffemodel, "
+ ".prototxt, .tflite, .onnx")
("compute,c", po::value<std::vector<std::string>>()->multitoken(),
backendsMessage.c_str())
("input-name,i", po::value(&inputNames),
diff --git a/tests/InferenceModel.hpp b/tests/InferenceModel.hpp
index eb5f708c81..eb3b2ccd42 100644
--- a/tests/InferenceModel.hpp
+++ b/tests/InferenceModel.hpp
@@ -5,15 +5,18 @@
#pragma once
#include <armnn/ArmNN.hpp>
+#if defined(ARMNN_SERIALIZER)
+#include "armnnDeserializeParser/IDeserializeParser.hpp"
+#endif
#if defined(ARMNN_TF_LITE_PARSER)
#include <armnnTfLiteParser/ITfLiteParser.hpp>
#endif
-
-#include <HeapProfiling.hpp>
#if defined(ARMNN_ONNX_PARSER)
#include <armnnOnnxParser/IOnnxParser.hpp>
#endif
+#include <HeapProfiling.hpp>
+
#include <backendsCommon/BackendRegistry.hpp>
#include <boost/algorithm/string/join.hpp>
@@ -160,6 +163,49 @@ public:
}
};
+#if defined(ARMNN_SERIALIZER)
+template <>
+struct CreateNetworkImpl<armnnDeserializeParser::IDeserializeParser>
+{
+public:
+ using IParser = armnnDeserializeParser::IDeserializeParser;
+ using Params = InferenceModelInternal::Params;
+ using BindingPointInfo = InferenceModelInternal::BindingPointInfo;
+
+ static armnn::INetworkPtr Create(const Params& params,
+ std::vector<BindingPointInfo>& inputBindings,
+ std::vector<BindingPointInfo>& outputBindings)
+ {
+ auto parser(IParser::Create());
+ BOOST_ASSERT(parser);
+
+ armnn::INetworkPtr network{nullptr, [](armnn::INetwork *){}};
+
+ {
+ ARMNN_SCOPED_HEAP_PROFILING("Parsing");
+ const std::string& modelPath = params.m_ModelPath;
+ network = parser->CreateNetworkFromBinaryFile(modelPath.c_str());
+ }
+
+ unsigned int subGraphId = boost::numeric_cast<unsigned int>(params.m_SubgraphId);
+
+ for (const std::string& inputLayerName : params.m_InputBindings)
+ {
+ BindingPointInfo inputBinding = parser->GetNetworkInputBindingInfo(subGraphId, inputLayerName);
+ inputBindings.push_back(inputBinding);
+ }
+
+ for (const std::string& outputLayerName : params.m_OutputBindings)
+ {
+ BindingPointInfo outputBinding = parser->GetNetworkOutputBindingInfo(subGraphId, outputLayerName);
+ outputBindings.push_back(outputBinding);
+ }
+
+ return network;
+ }
+};
+#endif
+
#if defined(ARMNN_TF_LITE_PARSER)
template <>
struct CreateNetworkImpl<armnnTfLiteParser::ITfLiteParser>