diff options
Diffstat (limited to 'tests/ExecuteNetwork')
-rw-r--r-- | tests/ExecuteNetwork/ExecuteNetwork.cpp | 32 |
1 files changed, 25 insertions, 7 deletions
diff --git a/tests/ExecuteNetwork/ExecuteNetwork.cpp b/tests/ExecuteNetwork/ExecuteNetwork.cpp index bbab70b39a..a97d6da3d5 100644 --- a/tests/ExecuteNetwork/ExecuteNetwork.cpp +++ b/tests/ExecuteNetwork/ExecuteNetwork.cpp @@ -5,6 +5,9 @@ #include <armnn/ArmNN.hpp> #include <armnn/TypesUtils.hpp> +#if defined(ARMNN_SERIALIZER) +#include "armnnDeserializeParser/IDeserializeParser.hpp" +#endif #if defined(ARMNN_CAFFE_PARSER) #include "armnnCaffeParser/ICaffeParser.hpp" #endif @@ -361,7 +364,20 @@ int RunTest(const std::string& format, } // Forward to implementation based on the parser type - if (modelFormat.find("caffe") != std::string::npos) + if (modelFormat.find("armnn") != std::string::npos) + { +#if defined(ARMNN_SERIALIZER) + return MainImpl<armnnDeserializeParser::IDeserializeParser, float>( + modelPath.c_str(), isModelBinary, computeDevice, + inputNamesVector, inputTensorShapes, + inputTensorDataFilePathsVector, inputTypesVector, + outputNamesVector, enableProfiling, subgraphId, runtime); +#else + BOOST_LOG_TRIVIAL(fatal) << "Not built with serialization support."; + return EXIT_FAILURE; +#endif + } + else if (modelFormat.find("caffe") != std::string::npos) { #if defined(ARMNN_CAFFE_PARSER) return MainImpl<armnnCaffeParser::ICaffeParser, float>(modelPath.c_str(), isModelBinary, computeDevice, @@ -447,9 +463,10 @@ int RunCsvTest(const armnnUtils::CsvRow &csvRow, { desc.add_options() ("model-format,f", po::value(&modelFormat), - "caffe-binary, caffe-text, tflite-binary, onnx-binary, onnx-text, tensorflow-binary or tensorflow-text.") - ("model-path,m", po::value(&modelPath), "Path to model file, e.g. .caffemodel, .prototxt, .tflite," - " .onnx") + "armnn-binary, caffe-binary, caffe-text, tflite-binary, onnx-binary, onnx-text, tensorflow-binary or " + "tensorflow-text.") + ("model-path,m", po::value(&modelPath), "Path to model file, e.g. .armnn, .caffemodel, .prototxt, " + ".tflite, .onnx") ("compute,c", po::value<std::vector<armnn::BackendId>>()->multitoken(), backendsMessage.c_str()) ("input-name,i", po::value(&inputNames), "Identifier of the input tensors in the network separated by comma.") @@ -557,9 +574,10 @@ int main(int argc, const char* argv[]) ("concurrent,n", po::bool_switch()->default_value(false), "Whether or not the test cases should be executed in parallel") ("model-format,f", po::value(&modelFormat)->required(), - "caffe-binary, caffe-text, onnx-binary, onnx-text, tflite-binary, tensorflow-binary or tensorflow-text.") - ("model-path,m", po::value(&modelPath)->required(), "Path to model file, e.g. .caffemodel, .prototxt," - " .tflite, .onnx") + "armnn-binary, caffe-binary, caffe-text, onnx-binary, onnx-text, tflite-binary, tensorflow-binary or " + "tensorflow-text.") + ("model-path,m", po::value(&modelPath)->required(), "Path to model file, e.g. .armnn, .caffemodel, " + ".prototxt, .tflite, .onnx") ("compute,c", po::value<std::vector<std::string>>()->multitoken(), backendsMessage.c_str()) ("input-name,i", po::value(&inputNames), |