diff options
-rw-r--r-- | tests/CMakeLists.txt | 5 | ||||
-rw-r--r-- | tests/ExecuteNetwork/ExecuteNetwork.cpp | 32 | ||||
-rw-r--r-- | tests/InferenceModel.hpp | 50 |
3 files changed, 77 insertions, 10 deletions
diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 1fc89da016..9913321295 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -241,7 +241,7 @@ if (BUILD_ONNX_PARSER) OnnxParserTest(OnnxMobileNet-Armnn "${OnnxMobileNet-Armnn_sources}") endif() -if (BUILD_CAFFE_PARSER OR BUILD_TF_PARSER OR BUILD_TF_LITE_PARSER OR BUILD_ONNX_PARSER) +if (BUILD_ARMNN_SERIALIZER OR BUILD_CAFFE_PARSER OR BUILD_TF_PARSER OR BUILD_TF_LITE_PARSER OR BUILD_ONNX_PARSER) set(ExecuteNetwork_sources ExecuteNetwork/ExecuteNetwork.cpp) @@ -250,6 +250,9 @@ if (BUILD_CAFFE_PARSER OR BUILD_TF_PARSER OR BUILD_TF_LITE_PARSER OR BUILD_ONNX_ target_include_directories(ExecuteNetwork PRIVATE ../src/armnnUtils) target_include_directories(ExecuteNetwork PRIVATE ../src/backends) + if (BUILD_ARMNN_SERIALIZER) + target_link_libraries(ExecuteNetwork armnnSerializer) + endif() if (BUILD_CAFFE_PARSER) target_link_libraries(ExecuteNetwork armnnCaffeParser) endif() diff --git a/tests/ExecuteNetwork/ExecuteNetwork.cpp b/tests/ExecuteNetwork/ExecuteNetwork.cpp index bbab70b39a..a97d6da3d5 100644 --- a/tests/ExecuteNetwork/ExecuteNetwork.cpp +++ b/tests/ExecuteNetwork/ExecuteNetwork.cpp @@ -5,6 +5,9 @@ #include <armnn/ArmNN.hpp> #include <armnn/TypesUtils.hpp> +#if defined(ARMNN_SERIALIZER) +#include "armnnDeserializeParser/IDeserializeParser.hpp" +#endif #if defined(ARMNN_CAFFE_PARSER) #include "armnnCaffeParser/ICaffeParser.hpp" #endif @@ -361,7 +364,20 @@ int RunTest(const std::string& format, } // Forward to implementation based on the parser type - if (modelFormat.find("caffe") != std::string::npos) + if (modelFormat.find("armnn") != std::string::npos) + { +#if defined(ARMNN_SERIALIZER) + return MainImpl<armnnDeserializeParser::IDeserializeParser, float>( + modelPath.c_str(), isModelBinary, computeDevice, + inputNamesVector, inputTensorShapes, + inputTensorDataFilePathsVector, inputTypesVector, + outputNamesVector, enableProfiling, subgraphId, runtime); +#else + BOOST_LOG_TRIVIAL(fatal) << "Not built with serialization support."; + return EXIT_FAILURE; +#endif + } + else if (modelFormat.find("caffe") != std::string::npos) { #if defined(ARMNN_CAFFE_PARSER) return MainImpl<armnnCaffeParser::ICaffeParser, float>(modelPath.c_str(), isModelBinary, computeDevice, @@ -447,9 +463,10 @@ int RunCsvTest(const armnnUtils::CsvRow &csvRow, { desc.add_options() ("model-format,f", po::value(&modelFormat), - "caffe-binary, caffe-text, tflite-binary, onnx-binary, onnx-text, tensorflow-binary or tensorflow-text.") - ("model-path,m", po::value(&modelPath), "Path to model file, e.g. .caffemodel, .prototxt, .tflite," - " .onnx") + "armnn-binary, caffe-binary, caffe-text, tflite-binary, onnx-binary, onnx-text, tensorflow-binary or " + "tensorflow-text.") + ("model-path,m", po::value(&modelPath), "Path to model file, e.g. .armnn, .caffemodel, .prototxt, " + ".tflite, .onnx") ("compute,c", po::value<std::vector<armnn::BackendId>>()->multitoken(), backendsMessage.c_str()) ("input-name,i", po::value(&inputNames), "Identifier of the input tensors in the network separated by comma.") @@ -557,9 +574,10 @@ int main(int argc, const char* argv[]) ("concurrent,n", po::bool_switch()->default_value(false), "Whether or not the test cases should be executed in parallel") ("model-format,f", po::value(&modelFormat)->required(), - "caffe-binary, caffe-text, onnx-binary, onnx-text, tflite-binary, tensorflow-binary or tensorflow-text.") - ("model-path,m", po::value(&modelPath)->required(), "Path to model file, e.g. .caffemodel, .prototxt," - " .tflite, .onnx") + "armnn-binary, caffe-binary, caffe-text, onnx-binary, onnx-text, tflite-binary, tensorflow-binary or " + "tensorflow-text.") + ("model-path,m", po::value(&modelPath)->required(), "Path to model file, e.g. .armnn, .caffemodel, " + ".prototxt, .tflite, .onnx") ("compute,c", po::value<std::vector<std::string>>()->multitoken(), backendsMessage.c_str()) ("input-name,i", po::value(&inputNames), diff --git a/tests/InferenceModel.hpp b/tests/InferenceModel.hpp index eb5f708c81..eb3b2ccd42 100644 --- a/tests/InferenceModel.hpp +++ b/tests/InferenceModel.hpp @@ -5,15 +5,18 @@ #pragma once #include <armnn/ArmNN.hpp> +#if defined(ARMNN_SERIALIZER) +#include "armnnDeserializeParser/IDeserializeParser.hpp" +#endif #if defined(ARMNN_TF_LITE_PARSER) #include <armnnTfLiteParser/ITfLiteParser.hpp> #endif - -#include <HeapProfiling.hpp> #if defined(ARMNN_ONNX_PARSER) #include <armnnOnnxParser/IOnnxParser.hpp> #endif +#include <HeapProfiling.hpp> + #include <backendsCommon/BackendRegistry.hpp> #include <boost/algorithm/string/join.hpp> @@ -160,6 +163,49 @@ public: } }; +#if defined(ARMNN_SERIALIZER) +template <> +struct CreateNetworkImpl<armnnDeserializeParser::IDeserializeParser> +{ +public: + using IParser = armnnDeserializeParser::IDeserializeParser; + using Params = InferenceModelInternal::Params; + using BindingPointInfo = InferenceModelInternal::BindingPointInfo; + + static armnn::INetworkPtr Create(const Params& params, + std::vector<BindingPointInfo>& inputBindings, + std::vector<BindingPointInfo>& outputBindings) + { + auto parser(IParser::Create()); + BOOST_ASSERT(parser); + + armnn::INetworkPtr network{nullptr, [](armnn::INetwork *){}}; + + { + ARMNN_SCOPED_HEAP_PROFILING("Parsing"); + const std::string& modelPath = params.m_ModelPath; + network = parser->CreateNetworkFromBinaryFile(modelPath.c_str()); + } + + unsigned int subGraphId = boost::numeric_cast<unsigned int>(params.m_SubgraphId); + + for (const std::string& inputLayerName : params.m_InputBindings) + { + BindingPointInfo inputBinding = parser->GetNetworkInputBindingInfo(subGraphId, inputLayerName); + inputBindings.push_back(inputBinding); + } + + for (const std::string& outputLayerName : params.m_OutputBindings) + { + BindingPointInfo outputBinding = parser->GetNetworkOutputBindingInfo(subGraphId, outputLayerName); + outputBindings.push_back(outputBinding); + } + + return network; + } +}; +#endif + #if defined(ARMNN_TF_LITE_PARSER) template <> struct CreateNetworkImpl<armnnTfLiteParser::ITfLiteParser> |