From 64e4ccb4546473e922b4ddd699ff6b77a5c2527d Mon Sep 17 00:00:00 2001 From: Aron Virginas-Tar Date: Tue, 12 Feb 2019 11:27:53 +0000 Subject: IVGCVSW-2663 Enable ExecuteNetwork to load ArmNN files Change-Id: I1a61a1da2258bd07b39da6063d22a5bd22c1884d Signed-off-by: Aron Virginas-Tar --- tests/CMakeLists.txt | 5 +++- tests/ExecuteNetwork/ExecuteNetwork.cpp | 32 ++++++++++++++++----- tests/InferenceModel.hpp | 50 +++++++++++++++++++++++++++++++-- 3 files changed, 77 insertions(+), 10 deletions(-) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 1fc89da016..9913321295 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -241,7 +241,7 @@ if (BUILD_ONNX_PARSER) OnnxParserTest(OnnxMobileNet-Armnn "${OnnxMobileNet-Armnn_sources}") endif() -if (BUILD_CAFFE_PARSER OR BUILD_TF_PARSER OR BUILD_TF_LITE_PARSER OR BUILD_ONNX_PARSER) +if (BUILD_ARMNN_SERIALIZER OR BUILD_CAFFE_PARSER OR BUILD_TF_PARSER OR BUILD_TF_LITE_PARSER OR BUILD_ONNX_PARSER) set(ExecuteNetwork_sources ExecuteNetwork/ExecuteNetwork.cpp) @@ -250,6 +250,9 @@ if (BUILD_CAFFE_PARSER OR BUILD_TF_PARSER OR BUILD_TF_LITE_PARSER OR BUILD_ONNX_ target_include_directories(ExecuteNetwork PRIVATE ../src/armnnUtils) target_include_directories(ExecuteNetwork PRIVATE ../src/backends) + if (BUILD_ARMNN_SERIALIZER) + target_link_libraries(ExecuteNetwork armnnSerializer) + endif() if (BUILD_CAFFE_PARSER) target_link_libraries(ExecuteNetwork armnnCaffeParser) endif() diff --git a/tests/ExecuteNetwork/ExecuteNetwork.cpp b/tests/ExecuteNetwork/ExecuteNetwork.cpp index bbab70b39a..a97d6da3d5 100644 --- a/tests/ExecuteNetwork/ExecuteNetwork.cpp +++ b/tests/ExecuteNetwork/ExecuteNetwork.cpp @@ -5,6 +5,9 @@ #include #include +#if defined(ARMNN_SERIALIZER) +#include "armnnDeserializeParser/IDeserializeParser.hpp" +#endif #if defined(ARMNN_CAFFE_PARSER) #include "armnnCaffeParser/ICaffeParser.hpp" #endif @@ -361,7 +364,20 @@ int RunTest(const std::string& format, } // Forward to implementation based on the parser type - if (modelFormat.find("caffe") != std::string::npos) + if (modelFormat.find("armnn") != std::string::npos) + { +#if defined(ARMNN_SERIALIZER) + return MainImpl( + modelPath.c_str(), isModelBinary, computeDevice, + inputNamesVector, inputTensorShapes, + inputTensorDataFilePathsVector, inputTypesVector, + outputNamesVector, enableProfiling, subgraphId, runtime); +#else + BOOST_LOG_TRIVIAL(fatal) << "Not built with serialization support."; + return EXIT_FAILURE; +#endif + } + else if (modelFormat.find("caffe") != std::string::npos) { #if defined(ARMNN_CAFFE_PARSER) return MainImpl(modelPath.c_str(), isModelBinary, computeDevice, @@ -447,9 +463,10 @@ int RunCsvTest(const armnnUtils::CsvRow &csvRow, { desc.add_options() ("model-format,f", po::value(&modelFormat), - "caffe-binary, caffe-text, tflite-binary, onnx-binary, onnx-text, tensorflow-binary or tensorflow-text.") - ("model-path,m", po::value(&modelPath), "Path to model file, e.g. .caffemodel, .prototxt, .tflite," - " .onnx") + "armnn-binary, caffe-binary, caffe-text, tflite-binary, onnx-binary, onnx-text, tensorflow-binary or " + "tensorflow-text.") + ("model-path,m", po::value(&modelPath), "Path to model file, e.g. .armnn, .caffemodel, .prototxt, " + ".tflite, .onnx") ("compute,c", po::value>()->multitoken(), backendsMessage.c_str()) ("input-name,i", po::value(&inputNames), "Identifier of the input tensors in the network separated by comma.") @@ -557,9 +574,10 @@ int main(int argc, const char* argv[]) ("concurrent,n", po::bool_switch()->default_value(false), "Whether or not the test cases should be executed in parallel") ("model-format,f", po::value(&modelFormat)->required(), - "caffe-binary, caffe-text, onnx-binary, onnx-text, tflite-binary, tensorflow-binary or tensorflow-text.") - ("model-path,m", po::value(&modelPath)->required(), "Path to model file, e.g. .caffemodel, .prototxt," - " .tflite, .onnx") + "armnn-binary, caffe-binary, caffe-text, onnx-binary, onnx-text, tflite-binary, tensorflow-binary or " + "tensorflow-text.") + ("model-path,m", po::value(&modelPath)->required(), "Path to model file, e.g. .armnn, .caffemodel, " + ".prototxt, .tflite, .onnx") ("compute,c", po::value>()->multitoken(), backendsMessage.c_str()) ("input-name,i", po::value(&inputNames), diff --git a/tests/InferenceModel.hpp b/tests/InferenceModel.hpp index eb5f708c81..eb3b2ccd42 100644 --- a/tests/InferenceModel.hpp +++ b/tests/InferenceModel.hpp @@ -5,15 +5,18 @@ #pragma once #include +#if defined(ARMNN_SERIALIZER) +#include "armnnDeserializeParser/IDeserializeParser.hpp" +#endif #if defined(ARMNN_TF_LITE_PARSER) #include #endif - -#include #if defined(ARMNN_ONNX_PARSER) #include #endif +#include + #include #include @@ -160,6 +163,49 @@ public: } }; +#if defined(ARMNN_SERIALIZER) +template <> +struct CreateNetworkImpl +{ +public: + using IParser = armnnDeserializeParser::IDeserializeParser; + using Params = InferenceModelInternal::Params; + using BindingPointInfo = InferenceModelInternal::BindingPointInfo; + + static armnn::INetworkPtr Create(const Params& params, + std::vector& inputBindings, + std::vector& outputBindings) + { + auto parser(IParser::Create()); + BOOST_ASSERT(parser); + + armnn::INetworkPtr network{nullptr, [](armnn::INetwork *){}}; + + { + ARMNN_SCOPED_HEAP_PROFILING("Parsing"); + const std::string& modelPath = params.m_ModelPath; + network = parser->CreateNetworkFromBinaryFile(modelPath.c_str()); + } + + unsigned int subGraphId = boost::numeric_cast(params.m_SubgraphId); + + for (const std::string& inputLayerName : params.m_InputBindings) + { + BindingPointInfo inputBinding = parser->GetNetworkInputBindingInfo(subGraphId, inputLayerName); + inputBindings.push_back(inputBinding); + } + + for (const std::string& outputLayerName : params.m_OutputBindings) + { + BindingPointInfo outputBinding = parser->GetNetworkOutputBindingInfo(subGraphId, outputLayerName); + outputBindings.push_back(outputBinding); + } + + return network; + } +}; +#endif + #if defined(ARMNN_TF_LITE_PARSER) template <> struct CreateNetworkImpl -- cgit v1.2.1