aboutsummaryrefslogtreecommitdiff
path: root/tests/ExecuteNetwork/ExecuteNetwork.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/ExecuteNetwork/ExecuteNetwork.cpp')
-rw-r--r--tests/ExecuteNetwork/ExecuteNetwork.cpp32
1 files changed, 25 insertions, 7 deletions
diff --git a/tests/ExecuteNetwork/ExecuteNetwork.cpp b/tests/ExecuteNetwork/ExecuteNetwork.cpp
index bbab70b39a..a97d6da3d5 100644
--- a/tests/ExecuteNetwork/ExecuteNetwork.cpp
+++ b/tests/ExecuteNetwork/ExecuteNetwork.cpp
@@ -5,6 +5,9 @@
#include <armnn/ArmNN.hpp>
#include <armnn/TypesUtils.hpp>
+#if defined(ARMNN_SERIALIZER)
+#include "armnnDeserializeParser/IDeserializeParser.hpp"
+#endif
#if defined(ARMNN_CAFFE_PARSER)
#include "armnnCaffeParser/ICaffeParser.hpp"
#endif
@@ -361,7 +364,20 @@ int RunTest(const std::string& format,
}
// Forward to implementation based on the parser type
- if (modelFormat.find("caffe") != std::string::npos)
+ if (modelFormat.find("armnn") != std::string::npos)
+ {
+#if defined(ARMNN_SERIALIZER)
+ return MainImpl<armnnDeserializeParser::IDeserializeParser, float>(
+ modelPath.c_str(), isModelBinary, computeDevice,
+ inputNamesVector, inputTensorShapes,
+ inputTensorDataFilePathsVector, inputTypesVector,
+ outputNamesVector, enableProfiling, subgraphId, runtime);
+#else
+ BOOST_LOG_TRIVIAL(fatal) << "Not built with serialization support.";
+ return EXIT_FAILURE;
+#endif
+ }
+ else if (modelFormat.find("caffe") != std::string::npos)
{
#if defined(ARMNN_CAFFE_PARSER)
return MainImpl<armnnCaffeParser::ICaffeParser, float>(modelPath.c_str(), isModelBinary, computeDevice,
@@ -447,9 +463,10 @@ int RunCsvTest(const armnnUtils::CsvRow &csvRow,
{
desc.add_options()
("model-format,f", po::value(&modelFormat),
- "caffe-binary, caffe-text, tflite-binary, onnx-binary, onnx-text, tensorflow-binary or tensorflow-text.")
- ("model-path,m", po::value(&modelPath), "Path to model file, e.g. .caffemodel, .prototxt, .tflite,"
- " .onnx")
+ "armnn-binary, caffe-binary, caffe-text, tflite-binary, onnx-binary, onnx-text, tensorflow-binary or "
+ "tensorflow-text.")
+ ("model-path,m", po::value(&modelPath), "Path to model file, e.g. .armnn, .caffemodel, .prototxt, "
+ ".tflite, .onnx")
("compute,c", po::value<std::vector<armnn::BackendId>>()->multitoken(),
backendsMessage.c_str())
("input-name,i", po::value(&inputNames), "Identifier of the input tensors in the network separated by comma.")
@@ -557,9 +574,10 @@ int main(int argc, const char* argv[])
("concurrent,n", po::bool_switch()->default_value(false),
"Whether or not the test cases should be executed in parallel")
("model-format,f", po::value(&modelFormat)->required(),
- "caffe-binary, caffe-text, onnx-binary, onnx-text, tflite-binary, tensorflow-binary or tensorflow-text.")
- ("model-path,m", po::value(&modelPath)->required(), "Path to model file, e.g. .caffemodel, .prototxt,"
- " .tflite, .onnx")
+ "armnn-binary, caffe-binary, caffe-text, onnx-binary, onnx-text, tflite-binary, tensorflow-binary or "
+ "tensorflow-text.")
+ ("model-path,m", po::value(&modelPath)->required(), "Path to model file, e.g. .armnn, .caffemodel, "
+ ".prototxt, .tflite, .onnx")
("compute,c", po::value<std::vector<std::string>>()->multitoken(),
backendsMessage.c_str())
("input-name,i", po::value(&inputNames),