aboutsummaryrefslogtreecommitdiff
path: root/src/armnnConverter/ArmnnConverter.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnConverter/ArmnnConverter.cpp')
-rw-r--r--src/armnnConverter/ArmnnConverter.cpp176
1 files changed, 162 insertions, 14 deletions
diff --git a/src/armnnConverter/ArmnnConverter.cpp b/src/armnnConverter/ArmnnConverter.cpp
index fbec1449a8..9bc6cc841f 100644
--- a/src/armnnConverter/ArmnnConverter.cpp
+++ b/src/armnnConverter/ArmnnConverter.cpp
@@ -4,8 +4,21 @@
//
#include <armnn/ArmNN.hpp>
+#if defined(ARMNN_CAFFE_PARSER)
+#include <armnnCaffeParser/ICaffeParser.hpp>
+#endif
+#if defined(ARMNN_ONNX_PARSER)
+#include <armnnOnnxParser/IOnnxParser.hpp>
+#endif
+#if defined(ARMNN_SERIALIZER)
#include <armnnSerializer/ISerializer.hpp>
+#endif
+#if defined(ARMNN_TF_PARSER)
#include <armnnTfParser/ITfParser.hpp>
+#endif
+#if defined(ARMNN_TF_LITE_PARSER)
+#include <armnnTfLiteParser/ITfLiteParser.hpp>
+#endif
#include <Logging.hpp>
#include <HeapProfiling.hpp>
@@ -111,7 +124,20 @@ int ParseCommandLineArgs(int argc, const char* argv[],
desc.add_options()
("help", "Display usage information")
- ("model-format,f", po::value(&modelFormat)->required(),"tensorflow-binary or tensorflow-text.")
+ ("model-format,f", po::value(&modelFormat)->required(),"Format of the model file"
+#if defined(ARMNN_CAFFE_PARSER)
+ ", caffe-binary, caffe-text"
+#endif
+#if defined(ARMNN_ONNX_PARSER)
+ ", onnx-binary, onnx-text"
+#endif
+#if defined(ARMNN_TENSORFLOW_PARSER)
+ ", tensorflow-binary, tensorflow-text"
+#endif
+#if defined(ARMNN_TF_LITE_PARSER)
+ ", tflite-binary"
+#endif
+ ".")
("model-path,m", po::value(&modelPath)->required(), "Path to model file")
("input-name,i", po::value<std::vector<std::string>>()->multitoken(),
"Identifier of the input tensors in the network separated by whitespace")
@@ -135,7 +161,6 @@ int ParseCommandLineArgs(int argc, const char* argv[],
std::cout << desc << std::endl;
return EXIT_SUCCESS;
}
-
po::notify(vm);
}
catch (const po::error& e)
@@ -160,7 +185,7 @@ int ParseCommandLineArgs(int argc, const char* argv[],
{
isModelBinary = true;
}
- else if (modelFormat.find("txt") != std::string::npos || modelFormat.find("text") != std::string::npos)
+ else if (modelFormat.find("text") != std::string::npos)
{
isModelBinary = false;
}
@@ -177,6 +202,12 @@ int ParseCommandLineArgs(int argc, const char* argv[],
return EXIT_SUCCESS;
}
+template<typename T>
+struct ParserType
+{
+ typedef T parserType;
+};
+
class ArmnnConverter
{
public:
@@ -215,6 +246,21 @@ public:
template <typename IParser>
bool CreateNetwork ()
{
+ return CreateNetwork (ParserType<IParser>());
+ }
+
+private:
+ armnn::INetworkPtr m_NetworkPtr;
+ std::string m_ModelPath;
+ std::vector<std::string> m_InputNames;
+ std::vector<armnn::TensorShape> m_InputShapes;
+ std::vector<std::string> m_OutputNames;
+ std::string m_OutputPath;
+ bool m_IsModelBinary;
+
+ template <typename IParser>
+ bool CreateNetwork (ParserType<IParser>)
+ {
// Create a network from a file on disk
auto parser(IParser::Create());
@@ -246,14 +292,62 @@ public:
return m_NetworkPtr.get() != nullptr;
}
-private:
- armnn::INetworkPtr m_NetworkPtr;
- std::string m_ModelPath;
- std::vector<std::string> m_InputNames;
- std::vector<armnn::TensorShape> m_InputShapes;
- std::vector<std::string> m_OutputNames;
- std::string m_OutputPath;
- bool m_IsModelBinary;
+#if defined(ARMNN_TF_LITE_PARSER)
+ bool CreateNetwork (ParserType<armnnTfLiteParser::ITfLiteParser>)
+ {
+ // Create a network from a file on disk
+ auto parser(armnnTfLiteParser::ITfLiteParser::Create());
+
+ if (!m_InputShapes.empty())
+ {
+ const size_t numInputShapes = m_InputShapes.size();
+ const size_t numInputBindings = m_InputNames.size();
+ if (numInputShapes < numInputBindings)
+ {
+ throw armnn::Exception(boost::str(boost::format(
+ "Not every input has its tensor shape specified: expected=%1%, got=%2%")
+ % numInputBindings % numInputShapes));
+ }
+ }
+
+ {
+ ARMNN_SCOPED_HEAP_PROFILING("Parsing");
+ m_NetworkPtr = parser->CreateNetworkFromBinaryFile(m_ModelPath.c_str());
+ }
+
+ return m_NetworkPtr.get() != nullptr;
+ }
+#endif
+
+#if defined(ARMNN_ONNX_PARSER)
+ bool CreateNetwork (ParserType<armnnOnnxParser::IOnnxParser>)
+ {
+ // Create a network from a file on disk
+ auto parser(armnnOnnxParser::IOnnxParser::Create());
+
+ if (!m_InputShapes.empty())
+ {
+ const size_t numInputShapes = m_InputShapes.size();
+ const size_t numInputBindings = m_InputNames.size();
+ if (numInputShapes < numInputBindings)
+ {
+ throw armnn::Exception(boost::str(boost::format(
+ "Not every input has its tensor shape specified: expected=%1%, got=%2%")
+ % numInputBindings % numInputShapes));
+ }
+ }
+
+ {
+ ARMNN_SCOPED_HEAP_PROFILING("Parsing");
+ m_NetworkPtr = (m_IsModelBinary ?
+ parser->CreateNetworkFromBinaryFile(m_ModelPath.c_str()) :
+ parser->CreateNetworkFromTextFile(m_ModelPath.c_str()));
+ }
+
+ return m_NetworkPtr.get() != nullptr;
+ }
+#endif
+
};
} // anonymous namespace
@@ -261,8 +355,11 @@ private:
int main(int argc, const char* argv[])
{
-#if !defined(ARMNN_TF_PARSER)
- BOOST_LOG_TRIVIAL(fatal) << "Not built with Tensorflow parser support.";
+#if (!defined(ARMNN_CAFFE_PARSER) \
+ && !defined(ARMNN_ONNX_PARSER) \
+ && !defined(ARMNN_TF_PARSER) \
+ && !defined(ARMNN_TF_LITE_PARSER))
+ BOOST_LOG_TRIVIAL(fatal) << "Not built with any of the supported parsers, Caffe, Onnx, Tensorflow, or TfLite.";
return EXIT_FAILURE;
#endif
@@ -320,13 +417,64 @@ int main(int argc, const char* argv[])
ArmnnConverter converter(modelPath, inputNames, inputTensorShapes, outputNames, outputPath, isModelBinary);
- if (modelFormat.find("tensorflow") != std::string::npos)
+ if (modelFormat.find("caffe") != std::string::npos)
{
+#if defined(ARMNN_CAFFE_PARSER)
+ if (!converter.CreateNetwork<armnnCaffeParser::ICaffeParser>())
+ {
+ BOOST_LOG_TRIVIAL(fatal) << "Failed to load model from file";
+ return EXIT_FAILURE;
+ }
+#else
+ BOOST_LOG_TRIVIAL(fatal) << "Not built with Caffe parser support.";
+ return EXIT_FAILURE;
+#endif
+ }
+ else if (modelFormat.find("onnx") != std::string::npos)
+ {
+#if defined(ARMNN_ONNX_PARSER)
+ if (!converter.CreateNetwork<armnnOnnxParser::IOnnxParser>())
+ {
+ BOOST_LOG_TRIVIAL(fatal) << "Failed to load model from file";
+ return EXIT_FAILURE;
+ }
+#else
+ BOOST_LOG_TRIVIAL(fatal) << "Not built with Onnx parser support.";
+ return EXIT_FAILURE;
+#endif
+ }
+ else if (modelFormat.find("tensorflow") != std::string::npos)
+ {
+#if defined(ARMNN_TF_PARSER)
if (!converter.CreateNetwork<armnnTfParser::ITfParser>())
{
BOOST_LOG_TRIVIAL(fatal) << "Failed to load model from file";
return EXIT_FAILURE;
}
+#else
+ BOOST_LOG_TRIVIAL(fatal) << "Not built with Tensorflow parser support.";
+ return EXIT_FAILURE;
+#endif
+ }
+ else if (modelFormat.find("tflite") != std::string::npos)
+ {
+#if defined(ARMNN_TF_LITE_PARSER)
+ if (!isModelBinary)
+ {
+ BOOST_LOG_TRIVIAL(fatal) << "Unknown model format: '" << modelFormat << "'. Only 'binary' format supported \
+ for tflite files";
+ return EXIT_FAILURE;
+ }
+
+ if (!converter.CreateNetwork<armnnTfLiteParser::ITfLiteParser>())
+ {
+ BOOST_LOG_TRIVIAL(fatal) << "Failed to load model from file";
+ return EXIT_FAILURE;
+ }
+#else
+ BOOST_LOG_TRIVIAL(fatal) << "Not built with TfLite parser support.";
+ return EXIT_FAILURE;
+#endif
}
else
{