7 #if defined(ARMNN_CAFFE_PARSER) 10 #if defined(ARMNN_ONNX_PARSER) 13 #if defined(ARMNN_SERIALIZER) 16 #if defined(ARMNN_TF_PARSER) 19 #if defined(ARMNN_TF_LITE_PARSER) 26 #include <boost/format.hpp> 27 #include <boost/program_options.hpp> 36 namespace po = boost::program_options;
40 std::vector<unsigned int> result;
43 while (std::getline(stream, line))
46 for (
const std::string& token : tokens)
52 result.push_back(boost::numeric_cast<unsigned int>(std::stoi((token))));
54 catch (
const std::exception&)
56 ARMNN_LOG(error) <<
"'" << token <<
"' is not a valid number. It has been ignored.";
62 return armnn::TensorShape(boost::numeric_cast<unsigned int>(result.size()), result.data());
65 bool CheckOption(
const po::variables_map& vm,
68 if (option ==
nullptr)
74 return vm.find(option) != vm.end();
77 void CheckOptionDependency(
const po::variables_map& vm,
81 if (option ==
nullptr || required ==
nullptr)
83 throw po::error(
"Invalid option to check dependency for");
87 if (CheckOption(vm, option) && !vm[option].defaulted())
89 if (CheckOption(vm, required) == 0 || vm[required].defaulted())
91 throw po::error(std::string(
"Option '") + option +
"' requires option '" + required +
"'.");
96 void CheckOptionDependencies(
const po::variables_map& vm)
98 CheckOptionDependency(vm,
"model-path",
"model-format");
99 CheckOptionDependency(vm,
"model-path",
"input-name");
100 CheckOptionDependency(vm,
"model-path",
"output-name");
101 CheckOptionDependency(vm,
"input-tensor-shape",
"model-path");
104 int ParseCommandLineArgs(
int argc,
const char* argv[],
105 std::string& modelFormat,
106 std::string& modelPath,
107 std::vector<std::string>& inputNames,
108 std::vector<std::string>& inputTensorShapeStrs,
109 std::vector<std::string>& outputNames,
110 std::string& outputPath,
bool& isModelBinary)
112 po::options_description desc(
"Options");
115 (
"help",
"Display usage information")
116 (
"model-format,f", po::value(&modelFormat)->required(),
"Format of the model file" 117 #if defined(ARMNN_CAFFE_PARSER) 118 ", caffe-binary, caffe-text" 120 #if defined(ARMNN_ONNX_PARSER) 121 ", onnx-binary, onnx-text" 123 #if defined(ARMNN_TF_PARSER) 124 ", tensorflow-binary, tensorflow-text" 126 #if defined(ARMNN_TF_LITE_PARSER) 130 (
"model-path,m", po::value(&modelPath)->required(),
"Path to model file.")
131 (
"input-name,i", po::value<std::vector<std::string>>()->multitoken(),
132 "Identifier of the input tensors in the network, separated by whitespace.")
133 (
"input-tensor-shape,s", po::value<std::vector<std::string>>()->multitoken(),
134 "The shape of the input tensor in the network as a flat array of integers, separated by comma." 135 " Multiple shapes are separated by whitespace." 136 " This parameter is optional, depending on the network.")
137 (
"output-name,o", po::value<std::vector<std::string>>()->multitoken(),
138 "Identifier of the output tensor in the network.")
139 (
"output-path,p", po::value(&outputPath)->required(),
"Path to serialize the network to.");
141 po::variables_map vm;
144 po::store(po::parse_command_line(argc, argv, desc), vm);
146 if (CheckOption(vm,
"help") || argc <= 1)
148 std::cout <<
"Convert a neural network model from provided file to ArmNN format." << std::endl;
149 std::cout << std::endl;
150 std::cout << desc << std::endl;
155 catch (
const po::error& e)
157 std::cerr << e.what() << std::endl << std::endl;
158 std::cerr << desc << std::endl;
164 CheckOptionDependencies(vm);
166 catch (
const po::error& e)
168 std::cerr << e.what() << std::endl << std::endl;
169 std::cerr << desc << std::endl;
173 if (modelFormat.find(
"bin") != std::string::npos)
175 isModelBinary =
true;
177 else if (modelFormat.find(
"text") != std::string::npos)
179 isModelBinary =
false;
183 ARMNN_LOG(fatal) <<
"Unknown model format: '" << modelFormat <<
"'. Please include 'binary' or 'text'";
187 if (!vm[
"input-tensor-shape"].empty())
189 inputTensorShapeStrs = vm[
"input-tensor-shape"].as<std::vector<std::string>>();
192 inputNames = vm[
"input-name"].as<std::vector<std::string>>();
193 outputNames = vm[
"output-name"].as<std::vector<std::string>>();
201 typedef T parserType;
207 ArmnnConverter(
const std::string& modelPath,
208 const std::vector<std::string>& inputNames,
209 const std::vector<armnn::TensorShape>& inputShapes,
210 const std::vector<std::string>& outputNames,
211 const std::string& outputPath,
214 m_ModelPath(modelPath),
215 m_InputNames(inputNames),
216 m_InputShapes(inputShapes),
217 m_OutputNames(outputNames),
218 m_OutputPath(outputPath),
219 m_IsModelBinary(isModelBinary) {}
223 if (m_NetworkPtr.get() ==
nullptr)
232 std::ofstream file(m_OutputPath, std::ios::out | std::ios::binary);
234 bool retVal =
serializer->SaveSerializedToStream(file);
239 template <
typename IParser>
240 bool CreateNetwork ()
242 return CreateNetwork (ParserType<IParser>());
247 std::string m_ModelPath;
248 std::vector<std::string> m_InputNames;
249 std::vector<armnn::TensorShape> m_InputShapes;
250 std::vector<std::string> m_OutputNames;
251 std::string m_OutputPath;
252 bool m_IsModelBinary;
254 template <
typename IParser>
255 bool CreateNetwork (ParserType<IParser>)
258 auto parser(IParser::Create());
260 std::map<std::string, armnn::TensorShape> inputShapes;
261 if (!m_InputShapes.empty())
263 const size_t numInputShapes = m_InputShapes.size();
264 const size_t numInputBindings = m_InputNames.size();
265 if (numInputShapes < numInputBindings)
268 "Not every input has its tensor shape specified: expected=%1%, got=%2%")
269 % numInputBindings % numInputShapes));
272 for (
size_t i = 0; i < numInputShapes; i++)
274 inputShapes[m_InputNames[i]] = m_InputShapes[i];
280 m_NetworkPtr = (m_IsModelBinary ?
281 parser->CreateNetworkFromBinaryFile(m_ModelPath.c_str(), inputShapes, m_OutputNames) :
282 parser->CreateNetworkFromTextFile(m_ModelPath.c_str(), inputShapes, m_OutputNames));
285 return m_NetworkPtr.get() !=
nullptr;
288 #if defined(ARMNN_TF_LITE_PARSER) 289 bool CreateNetwork (ParserType<armnnTfLiteParser::ITfLiteParser>)
294 if (!m_InputShapes.empty())
296 const size_t numInputShapes = m_InputShapes.size();
297 const size_t numInputBindings = m_InputNames.size();
298 if (numInputShapes < numInputBindings)
301 "Not every input has its tensor shape specified: expected=%1%, got=%2%")
302 % numInputBindings % numInputShapes));
308 m_NetworkPtr = parser->CreateNetworkFromBinaryFile(m_ModelPath.c_str());
311 return m_NetworkPtr.get() !=
nullptr;
315 #if defined(ARMNN_ONNX_PARSER) 316 bool CreateNetwork (ParserType<armnnOnnxParser::IOnnxParser>)
321 if (!m_InputShapes.empty())
323 const size_t numInputShapes = m_InputShapes.size();
324 const size_t numInputBindings = m_InputNames.size();
325 if (numInputShapes < numInputBindings)
328 "Not every input has its tensor shape specified: expected=%1%, got=%2%")
329 % numInputBindings % numInputShapes));
335 m_NetworkPtr = (m_IsModelBinary ?
336 parser->CreateNetworkFromBinaryFile(m_ModelPath.c_str()) :
337 parser->CreateNetworkFromTextFile(m_ModelPath.c_str()));
340 return m_NetworkPtr.get() !=
nullptr;
348 int main(
int argc,
const char* argv[])
351 #if (!defined(ARMNN_CAFFE_PARSER) \ 352 && !defined(ARMNN_ONNX_PARSER) \ 353 && !defined(ARMNN_TF_PARSER) \ 354 && !defined(ARMNN_TF_LITE_PARSER)) 355 ARMNN_LOG(fatal) <<
"Not built with any of the supported parsers, Caffe, Onnx, Tensorflow, or TfLite.";
359 #if !defined(ARMNN_SERIALIZER) 360 ARMNN_LOG(fatal) <<
"Not built with Serializer support.";
372 std::string modelFormat;
373 std::string modelPath;
375 std::vector<std::string> inputNames;
376 std::vector<std::string> inputTensorShapeStrs;
377 std::vector<armnn::TensorShape> inputTensorShapes;
379 std::vector<std::string> outputNames;
380 std::string outputPath;
382 bool isModelBinary =
true;
384 if (ParseCommandLineArgs(
385 argc, argv, modelFormat, modelPath, inputNames, inputTensorShapeStrs, outputNames, outputPath, isModelBinary)
391 for (
const std::string& shapeStr : inputTensorShapeStrs)
393 if (!shapeStr.empty())
395 std::stringstream ss(shapeStr);
400 inputTensorShapes.push_back(shape);
404 ARMNN_LOG(fatal) <<
"Cannot create tensor shape: " << e.
what();
410 ArmnnConverter converter(modelPath, inputNames, inputTensorShapes, outputNames, outputPath, isModelBinary);
414 if (modelFormat.find(
"caffe") != std::string::npos)
416 #if defined(ARMNN_CAFFE_PARSER) 419 ARMNN_LOG(fatal) <<
"Failed to load model from file";
423 ARMNN_LOG(fatal) <<
"Not built with Caffe parser support.";
427 else if (modelFormat.find(
"onnx") != std::string::npos)
429 #if defined(ARMNN_ONNX_PARSER) 432 ARMNN_LOG(fatal) <<
"Failed to load model from file";
436 ARMNN_LOG(fatal) <<
"Not built with Onnx parser support.";
440 else if (modelFormat.find(
"tensorflow") != std::string::npos)
442 #if defined(ARMNN_TF_PARSER) 445 ARMNN_LOG(fatal) <<
"Failed to load model from file";
449 ARMNN_LOG(fatal) <<
"Not built with Tensorflow parser support.";
453 else if (modelFormat.find(
"tflite") != std::string::npos)
455 #if defined(ARMNN_TF_LITE_PARSER) 458 ARMNN_LOG(fatal) <<
"Unknown model format: '" << modelFormat <<
"'. Only 'binary' format supported \ 465 ARMNN_LOG(fatal) <<
"Failed to load model from file";
469 ARMNN_LOG(fatal) <<
"Not built with TfLite parser support.";
475 ARMNN_LOG(fatal) <<
"Unknown model format: '" << modelFormat <<
"'";
481 ARMNN_LOG(fatal) <<
"Failed to load model from file: " << e.
what();
485 if (!converter.Serialize())
487 ARMNN_LOG(fatal) <<
"Failed to serialize model";
std::vector< std::string > StringTokenizer(const std::string &str, const char *delimiters, bool tokenCompression=true)
Function to take a string and a list of delimiters and split the string into tokens based on those de...
void ConfigureLogging(bool printToStandardOutput, bool printToDebugOutput, LogSeverity severity)
Configures the logging behaviour of the ARMNN library.
virtual const char * what() const noexcept override
#define ARMNN_LOG(severity)
Copyright (c) 2020 ARM Limited.
static ITfLiteParserPtr Create(const armnn::Optional< TfLiteParserOptions > &options=armnn::EmptyOptional())
#define ARMNN_SCOPED_HEAP_PROFILING(TAG)
static IOnnxParserPtr Create()
Parses a directed acyclic graph from a tensorflow protobuf file.
Base class for all ArmNN exceptions so that users can filter to just those.
static ISerializerPtr Create()
std::unique_ptr< INetwork, void(*)(INetwork *network)> INetworkPtr
int main(int argc, const char *argv[])