diff options
Diffstat (limited to 'src/armnnConverter/ArmnnConverter.cpp')
-rw-r--r-- | src/armnnConverter/ArmnnConverter.cpp | 176 |
1 files changed, 162 insertions, 14 deletions
diff --git a/src/armnnConverter/ArmnnConverter.cpp b/src/armnnConverter/ArmnnConverter.cpp index fbec1449a8..9bc6cc841f 100644 --- a/src/armnnConverter/ArmnnConverter.cpp +++ b/src/armnnConverter/ArmnnConverter.cpp @@ -4,8 +4,21 @@ // #include <armnn/ArmNN.hpp> +#if defined(ARMNN_CAFFE_PARSER) +#include <armnnCaffeParser/ICaffeParser.hpp> +#endif +#if defined(ARMNN_ONNX_PARSER) +#include <armnnOnnxParser/IOnnxParser.hpp> +#endif +#if defined(ARMNN_SERIALIZER) #include <armnnSerializer/ISerializer.hpp> +#endif +#if defined(ARMNN_TF_PARSER) #include <armnnTfParser/ITfParser.hpp> +#endif +#if defined(ARMNN_TF_LITE_PARSER) +#include <armnnTfLiteParser/ITfLiteParser.hpp> +#endif #include <Logging.hpp> #include <HeapProfiling.hpp> @@ -111,7 +124,20 @@ int ParseCommandLineArgs(int argc, const char* argv[], desc.add_options() ("help", "Display usage information") - ("model-format,f", po::value(&modelFormat)->required(),"tensorflow-binary or tensorflow-text.") + ("model-format,f", po::value(&modelFormat)->required(),"Format of the model file" +#if defined(ARMNN_CAFFE_PARSER) + ", caffe-binary, caffe-text" +#endif +#if defined(ARMNN_ONNX_PARSER) + ", onnx-binary, onnx-text" +#endif +#if defined(ARMNN_TENSORFLOW_PARSER) + ", tensorflow-binary, tensorflow-text" +#endif +#if defined(ARMNN_TF_LITE_PARSER) + ", tflite-binary" +#endif + ".") ("model-path,m", po::value(&modelPath)->required(), "Path to model file") ("input-name,i", po::value<std::vector<std::string>>()->multitoken(), "Identifier of the input tensors in the network separated by whitespace") @@ -135,7 +161,6 @@ int ParseCommandLineArgs(int argc, const char* argv[], std::cout << desc << std::endl; return EXIT_SUCCESS; } - po::notify(vm); } catch (const po::error& e) @@ -160,7 +185,7 @@ int ParseCommandLineArgs(int argc, const char* argv[], { isModelBinary = true; } - else if (modelFormat.find("txt") != std::string::npos || modelFormat.find("text") != std::string::npos) + else if (modelFormat.find("text") != std::string::npos) { isModelBinary = false; } @@ -177,6 +202,12 @@ int ParseCommandLineArgs(int argc, const char* argv[], return EXIT_SUCCESS; } +template<typename T> +struct ParserType +{ + typedef T parserType; +}; + class ArmnnConverter { public: @@ -215,6 +246,21 @@ public: template <typename IParser> bool CreateNetwork () { + return CreateNetwork (ParserType<IParser>()); + } + +private: + armnn::INetworkPtr m_NetworkPtr; + std::string m_ModelPath; + std::vector<std::string> m_InputNames; + std::vector<armnn::TensorShape> m_InputShapes; + std::vector<std::string> m_OutputNames; + std::string m_OutputPath; + bool m_IsModelBinary; + + template <typename IParser> + bool CreateNetwork (ParserType<IParser>) + { // Create a network from a file on disk auto parser(IParser::Create()); @@ -246,14 +292,62 @@ public: return m_NetworkPtr.get() != nullptr; } -private: - armnn::INetworkPtr m_NetworkPtr; - std::string m_ModelPath; - std::vector<std::string> m_InputNames; - std::vector<armnn::TensorShape> m_InputShapes; - std::vector<std::string> m_OutputNames; - std::string m_OutputPath; - bool m_IsModelBinary; +#if defined(ARMNN_TF_LITE_PARSER) + bool CreateNetwork (ParserType<armnnTfLiteParser::ITfLiteParser>) + { + // Create a network from a file on disk + auto parser(armnnTfLiteParser::ITfLiteParser::Create()); + + if (!m_InputShapes.empty()) + { + const size_t numInputShapes = m_InputShapes.size(); + const size_t numInputBindings = m_InputNames.size(); + if (numInputShapes < numInputBindings) + { + throw armnn::Exception(boost::str(boost::format( + "Not every input has its tensor shape specified: expected=%1%, got=%2%") + % numInputBindings % numInputShapes)); + } + } + + { + ARMNN_SCOPED_HEAP_PROFILING("Parsing"); + m_NetworkPtr = parser->CreateNetworkFromBinaryFile(m_ModelPath.c_str()); + } + + return m_NetworkPtr.get() != nullptr; + } +#endif + +#if defined(ARMNN_ONNX_PARSER) + bool CreateNetwork (ParserType<armnnOnnxParser::IOnnxParser>) + { + // Create a network from a file on disk + auto parser(armnnOnnxParser::IOnnxParser::Create()); + + if (!m_InputShapes.empty()) + { + const size_t numInputShapes = m_InputShapes.size(); + const size_t numInputBindings = m_InputNames.size(); + if (numInputShapes < numInputBindings) + { + throw armnn::Exception(boost::str(boost::format( + "Not every input has its tensor shape specified: expected=%1%, got=%2%") + % numInputBindings % numInputShapes)); + } + } + + { + ARMNN_SCOPED_HEAP_PROFILING("Parsing"); + m_NetworkPtr = (m_IsModelBinary ? + parser->CreateNetworkFromBinaryFile(m_ModelPath.c_str()) : + parser->CreateNetworkFromTextFile(m_ModelPath.c_str())); + } + + return m_NetworkPtr.get() != nullptr; + } +#endif + }; } // anonymous namespace @@ -261,8 +355,11 @@ private: int main(int argc, const char* argv[]) { -#if !defined(ARMNN_TF_PARSER) - BOOST_LOG_TRIVIAL(fatal) << "Not built with Tensorflow parser support."; +#if (!defined(ARMNN_CAFFE_PARSER) \ + && !defined(ARMNN_ONNX_PARSER) \ + && !defined(ARMNN_TF_PARSER) \ + && !defined(ARMNN_TF_LITE_PARSER)) + BOOST_LOG_TRIVIAL(fatal) << "Not built with any of the supported parsers, Caffe, Onnx, Tensorflow, or TfLite."; return EXIT_FAILURE; #endif @@ -320,13 +417,64 @@ int main(int argc, const char* argv[]) ArmnnConverter converter(modelPath, inputNames, inputTensorShapes, outputNames, outputPath, isModelBinary); - if (modelFormat.find("tensorflow") != std::string::npos) + if (modelFormat.find("caffe") != std::string::npos) { +#if defined(ARMNN_CAFFE_PARSER) + if (!converter.CreateNetwork<armnnCaffeParser::ICaffeParser>()) + { + BOOST_LOG_TRIVIAL(fatal) << "Failed to load model from file"; + return EXIT_FAILURE; + } +#else + BOOST_LOG_TRIVIAL(fatal) << "Not built with Caffe parser support."; + return EXIT_FAILURE; +#endif + } + else if (modelFormat.find("onnx") != std::string::npos) + { +#if defined(ARMNN_ONNX_PARSER) + if (!converter.CreateNetwork<armnnOnnxParser::IOnnxParser>()) + { + BOOST_LOG_TRIVIAL(fatal) << "Failed to load model from file"; + return EXIT_FAILURE; + } +#else + BOOST_LOG_TRIVIAL(fatal) << "Not built with Onnx parser support."; + return EXIT_FAILURE; +#endif + } + else if (modelFormat.find("tensorflow") != std::string::npos) + { +#if defined(ARMNN_TF_PARSER) if (!converter.CreateNetwork<armnnTfParser::ITfParser>()) { BOOST_LOG_TRIVIAL(fatal) << "Failed to load model from file"; return EXIT_FAILURE; } +#else + BOOST_LOG_TRIVIAL(fatal) << "Not built with Tensorflow parser support."; + return EXIT_FAILURE; +#endif + } + else if (modelFormat.find("tflite") != std::string::npos) + { +#if defined(ARMNN_TF_LITE_PARSER) + if (!isModelBinary) + { + BOOST_LOG_TRIVIAL(fatal) << "Unknown model format: '" << modelFormat << "'. Only 'binary' format supported \ + for tflite files"; + return EXIT_FAILURE; + } + + if (!converter.CreateNetwork<armnnTfLiteParser::ITfLiteParser>()) + { + BOOST_LOG_TRIVIAL(fatal) << "Failed to load model from file"; + return EXIT_FAILURE; + } +#else + BOOST_LOG_TRIVIAL(fatal) << "Not built with TfLite parser support."; + return EXIT_FAILURE; +#endif } else { |