ArmNN
 20.02
ArmnnConverter.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include <armnn/Logging.hpp>
6 
7 #if defined(ARMNN_CAFFE_PARSER)
9 #endif
10 #if defined(ARMNN_ONNX_PARSER)
12 #endif
13 #if defined(ARMNN_SERIALIZER)
15 #endif
16 #if defined(ARMNN_TF_PARSER)
18 #endif
19 #if defined(ARMNN_TF_LITE_PARSER)
21 #endif
22 
23 #include <HeapProfiling.hpp>
24 
25 #include <boost/format.hpp>
26 #include <boost/algorithm/string/split.hpp>
27 #include <boost/algorithm/string/classification.hpp>
28 #include <boost/program_options.hpp>
29 
30 #include <cstdlib>
31 #include <fstream>
32 #include <iostream>
33 
34 namespace
35 {
36 
37 namespace po = boost::program_options;
38 
39 armnn::TensorShape ParseTensorShape(std::istream& stream)
40 {
41  std::vector<unsigned int> result;
42  std::string line;
43 
44  while (std::getline(stream, line))
45  {
46  std::vector<std::string> tokens;
47  try
48  {
49  // Coverity fix: boost::split() may throw an exception of type boost::bad_function_call.
50  boost::split(tokens, line, boost::algorithm::is_any_of(","), boost::token_compress_on);
51  }
52  catch (const std::exception& e)
53  {
54  ARMNN_LOG(error) << "An error occurred when splitting tokens: " << e.what();
55  continue;
56  }
57  for (const std::string& token : tokens)
58  {
59  if (!token.empty())
60  {
61  try
62  {
63  result.push_back(boost::numeric_cast<unsigned int>(std::stoi((token))));
64  }
65  catch (const std::exception&)
66  {
67  ARMNN_LOG(error) << "'" << token << "' is not a valid number. It has been ignored.";
68  }
69  }
70  }
71  }
72 
73  return armnn::TensorShape(boost::numeric_cast<unsigned int>(result.size()), result.data());
74 }
75 
76 bool CheckOption(const po::variables_map& vm,
77  const char* option)
78 {
79  if (option == nullptr)
80  {
81  return false;
82  }
83 
84  // Check whether 'option' is provided.
85  return vm.find(option) != vm.end();
86 }
87 
88 void CheckOptionDependency(const po::variables_map& vm,
89  const char* option,
90  const char* required)
91 {
92  if (option == nullptr || required == nullptr)
93  {
94  throw po::error("Invalid option to check dependency for");
95  }
96 
97  // Check that if 'option' is provided, 'required' is also provided.
98  if (CheckOption(vm, option) && !vm[option].defaulted())
99  {
100  if (CheckOption(vm, required) == 0 || vm[required].defaulted())
101  {
102  throw po::error(std::string("Option '") + option + "' requires option '" + required + "'.");
103  }
104  }
105 }
106 
107 void CheckOptionDependencies(const po::variables_map& vm)
108 {
109  CheckOptionDependency(vm, "model-path", "model-format");
110  CheckOptionDependency(vm, "model-path", "input-name");
111  CheckOptionDependency(vm, "model-path", "output-name");
112  CheckOptionDependency(vm, "input-tensor-shape", "model-path");
113 }
114 
115 int ParseCommandLineArgs(int argc, const char* argv[],
116  std::string& modelFormat,
117  std::string& modelPath,
118  std::vector<std::string>& inputNames,
119  std::vector<std::string>& inputTensorShapeStrs,
120  std::vector<std::string>& outputNames,
121  std::string& outputPath, bool& isModelBinary)
122 {
123  po::options_description desc("Options");
124 
125  desc.add_options()
126  ("help", "Display usage information")
127  ("model-format,f", po::value(&modelFormat)->required(),"Format of the model file"
128 #if defined(ARMNN_CAFFE_PARSER)
129  ", caffe-binary, caffe-text"
130 #endif
131 #if defined(ARMNN_ONNX_PARSER)
132  ", onnx-binary, onnx-text"
133 #endif
134 #if defined(ARMNN_TF_PARSER)
135  ", tensorflow-binary, tensorflow-text"
136 #endif
137 #if defined(ARMNN_TF_LITE_PARSER)
138  ", tflite-binary"
139 #endif
140  ".")
141  ("model-path,m", po::value(&modelPath)->required(), "Path to model file.")
142  ("input-name,i", po::value<std::vector<std::string>>()->multitoken(),
143  "Identifier of the input tensors in the network, separated by whitespace.")
144  ("input-tensor-shape,s", po::value<std::vector<std::string>>()->multitoken(),
145  "The shape of the input tensor in the network as a flat array of integers, separated by comma."
146  " Multiple shapes are separated by whitespace."
147  " This parameter is optional, depending on the network.")
148  ("output-name,o", po::value<std::vector<std::string>>()->multitoken(),
149  "Identifier of the output tensor in the network.")
150  ("output-path,p", po::value(&outputPath)->required(), "Path to serialize the network to.");
151 
152  po::variables_map vm;
153  try
154  {
155  po::store(po::parse_command_line(argc, argv, desc), vm);
156 
157  if (CheckOption(vm, "help") || argc <= 1)
158  {
159  std::cout << "Convert a neural network model from provided file to ArmNN format." << std::endl;
160  std::cout << std::endl;
161  std::cout << desc << std::endl;
162  exit(EXIT_SUCCESS);
163  }
164  po::notify(vm);
165  }
166  catch (const po::error& e)
167  {
168  std::cerr << e.what() << std::endl << std::endl;
169  std::cerr << desc << std::endl;
170  return EXIT_FAILURE;
171  }
172 
173  try
174  {
175  CheckOptionDependencies(vm);
176  }
177  catch (const po::error& e)
178  {
179  std::cerr << e.what() << std::endl << std::endl;
180  std::cerr << desc << std::endl;
181  return EXIT_FAILURE;
182  }
183 
184  if (modelFormat.find("bin") != std::string::npos)
185  {
186  isModelBinary = true;
187  }
188  else if (modelFormat.find("text") != std::string::npos)
189  {
190  isModelBinary = false;
191  }
192  else
193  {
194  ARMNN_LOG(fatal) << "Unknown model format: '" << modelFormat << "'. Please include 'binary' or 'text'";
195  return EXIT_FAILURE;
196  }
197 
198  if (!vm["input-tensor-shape"].empty())
199  {
200  inputTensorShapeStrs = vm["input-tensor-shape"].as<std::vector<std::string>>();
201  }
202 
203  inputNames = vm["input-name"].as<std::vector<std::string>>();
204  outputNames = vm["output-name"].as<std::vector<std::string>>();
205 
206  return EXIT_SUCCESS;
207 }
208 
209 template<typename T>
210 struct ParserType
211 {
212  typedef T parserType;
213 };
214 
215 class ArmnnConverter
216 {
217 public:
218  ArmnnConverter(const std::string& modelPath,
219  const std::vector<std::string>& inputNames,
220  const std::vector<armnn::TensorShape>& inputShapes,
221  const std::vector<std::string>& outputNames,
222  const std::string& outputPath,
223  bool isModelBinary)
224  : m_NetworkPtr(armnn::INetworkPtr(nullptr, [](armnn::INetwork *){})),
225  m_ModelPath(modelPath),
226  m_InputNames(inputNames),
227  m_InputShapes(inputShapes),
228  m_OutputNames(outputNames),
229  m_OutputPath(outputPath),
230  m_IsModelBinary(isModelBinary) {}
231 
232  bool Serialize()
233  {
234  if (m_NetworkPtr.get() == nullptr)
235  {
236  return false;
237  }
238 
240 
241  serializer->Serialize(*m_NetworkPtr);
242 
243  std::ofstream file(m_OutputPath, std::ios::out | std::ios::binary);
244 
245  bool retVal = serializer->SaveSerializedToStream(file);
246 
247  return retVal;
248  }
249 
250  template <typename IParser>
251  bool CreateNetwork ()
252  {
253  return CreateNetwork (ParserType<IParser>());
254  }
255 
256 private:
257  armnn::INetworkPtr m_NetworkPtr;
258  std::string m_ModelPath;
259  std::vector<std::string> m_InputNames;
260  std::vector<armnn::TensorShape> m_InputShapes;
261  std::vector<std::string> m_OutputNames;
262  std::string m_OutputPath;
263  bool m_IsModelBinary;
264 
265  template <typename IParser>
266  bool CreateNetwork (ParserType<IParser>)
267  {
268  // Create a network from a file on disk
269  auto parser(IParser::Create());
270 
271  std::map<std::string, armnn::TensorShape> inputShapes;
272  if (!m_InputShapes.empty())
273  {
274  const size_t numInputShapes = m_InputShapes.size();
275  const size_t numInputBindings = m_InputNames.size();
276  if (numInputShapes < numInputBindings)
277  {
278  throw armnn::Exception(boost::str(boost::format(
279  "Not every input has its tensor shape specified: expected=%1%, got=%2%")
280  % numInputBindings % numInputShapes));
281  }
282 
283  for (size_t i = 0; i < numInputShapes; i++)
284  {
285  inputShapes[m_InputNames[i]] = m_InputShapes[i];
286  }
287  }
288 
289  {
290  ARMNN_SCOPED_HEAP_PROFILING("Parsing");
291  m_NetworkPtr = (m_IsModelBinary ?
292  parser->CreateNetworkFromBinaryFile(m_ModelPath.c_str(), inputShapes, m_OutputNames) :
293  parser->CreateNetworkFromTextFile(m_ModelPath.c_str(), inputShapes, m_OutputNames));
294  }
295 
296  return m_NetworkPtr.get() != nullptr;
297  }
298 
299 #if defined(ARMNN_TF_LITE_PARSER)
300  bool CreateNetwork (ParserType<armnnTfLiteParser::ITfLiteParser>)
301  {
302  // Create a network from a file on disk
304 
305  if (!m_InputShapes.empty())
306  {
307  const size_t numInputShapes = m_InputShapes.size();
308  const size_t numInputBindings = m_InputNames.size();
309  if (numInputShapes < numInputBindings)
310  {
311  throw armnn::Exception(boost::str(boost::format(
312  "Not every input has its tensor shape specified: expected=%1%, got=%2%")
313  % numInputBindings % numInputShapes));
314  }
315  }
316 
317  {
318  ARMNN_SCOPED_HEAP_PROFILING("Parsing");
319  m_NetworkPtr = parser->CreateNetworkFromBinaryFile(m_ModelPath.c_str());
320  }
321 
322  return m_NetworkPtr.get() != nullptr;
323  }
324 #endif
325 
326 #if defined(ARMNN_ONNX_PARSER)
327  bool CreateNetwork (ParserType<armnnOnnxParser::IOnnxParser>)
328  {
329  // Create a network from a file on disk
331 
332  if (!m_InputShapes.empty())
333  {
334  const size_t numInputShapes = m_InputShapes.size();
335  const size_t numInputBindings = m_InputNames.size();
336  if (numInputShapes < numInputBindings)
337  {
338  throw armnn::Exception(boost::str(boost::format(
339  "Not every input has its tensor shape specified: expected=%1%, got=%2%")
340  % numInputBindings % numInputShapes));
341  }
342  }
343 
344  {
345  ARMNN_SCOPED_HEAP_PROFILING("Parsing");
346  m_NetworkPtr = (m_IsModelBinary ?
347  parser->CreateNetworkFromBinaryFile(m_ModelPath.c_str()) :
348  parser->CreateNetworkFromTextFile(m_ModelPath.c_str()));
349  }
350 
351  return m_NetworkPtr.get() != nullptr;
352  }
353 #endif
354 
355 };
356 
357 } // anonymous namespace
358 
359 int main(int argc, const char* argv[])
360 {
361 
362 #if (!defined(ARMNN_CAFFE_PARSER) \
363  && !defined(ARMNN_ONNX_PARSER) \
364  && !defined(ARMNN_TF_PARSER) \
365  && !defined(ARMNN_TF_LITE_PARSER))
366  ARMNN_LOG(fatal) << "Not built with any of the supported parsers, Caffe, Onnx, Tensorflow, or TfLite.";
367  return EXIT_FAILURE;
368 #endif
369 
370 #if !defined(ARMNN_SERIALIZER)
371  ARMNN_LOG(fatal) << "Not built with Serializer support.";
372  return EXIT_FAILURE;
373 #endif
374 
375 #ifdef NDEBUG
377 #else
379 #endif
380 
381  armnn::ConfigureLogging(true, true, level);
382 
383  std::string modelFormat;
384  std::string modelPath;
385 
386  std::vector<std::string> inputNames;
387  std::vector<std::string> inputTensorShapeStrs;
388  std::vector<armnn::TensorShape> inputTensorShapes;
389 
390  std::vector<std::string> outputNames;
391  std::string outputPath;
392 
393  bool isModelBinary = true;
394 
395  if (ParseCommandLineArgs(
396  argc, argv, modelFormat, modelPath, inputNames, inputTensorShapeStrs, outputNames, outputPath, isModelBinary)
397  != EXIT_SUCCESS)
398  {
399  return EXIT_FAILURE;
400  }
401 
402  for (const std::string& shapeStr : inputTensorShapeStrs)
403  {
404  if (!shapeStr.empty())
405  {
406  std::stringstream ss(shapeStr);
407 
408  try
409  {
410  armnn::TensorShape shape = ParseTensorShape(ss);
411  inputTensorShapes.push_back(shape);
412  }
413  catch (const armnn::InvalidArgumentException& e)
414  {
415  ARMNN_LOG(fatal) << "Cannot create tensor shape: " << e.what();
416  return EXIT_FAILURE;
417  }
418  }
419  }
420 
421  ArmnnConverter converter(modelPath, inputNames, inputTensorShapes, outputNames, outputPath, isModelBinary);
422 
423  try
424  {
425  if (modelFormat.find("caffe") != std::string::npos)
426  {
427 #if defined(ARMNN_CAFFE_PARSER)
428  if (!converter.CreateNetwork<armnnCaffeParser::ICaffeParser>())
429  {
430  ARMNN_LOG(fatal) << "Failed to load model from file";
431  return EXIT_FAILURE;
432  }
433 #else
434  ARMNN_LOG(fatal) << "Not built with Caffe parser support.";
435  return EXIT_FAILURE;
436 #endif
437  }
438  else if (modelFormat.find("onnx") != std::string::npos)
439  {
440 #if defined(ARMNN_ONNX_PARSER)
441  if (!converter.CreateNetwork<armnnOnnxParser::IOnnxParser>())
442  {
443  ARMNN_LOG(fatal) << "Failed to load model from file";
444  return EXIT_FAILURE;
445  }
446 #else
447  ARMNN_LOG(fatal) << "Not built with Onnx parser support.";
448  return EXIT_FAILURE;
449 #endif
450  }
451  else if (modelFormat.find("tensorflow") != std::string::npos)
452  {
453 #if defined(ARMNN_TF_PARSER)
454  if (!converter.CreateNetwork<armnnTfParser::ITfParser>())
455  {
456  ARMNN_LOG(fatal) << "Failed to load model from file";
457  return EXIT_FAILURE;
458  }
459 #else
460  ARMNN_LOG(fatal) << "Not built with Tensorflow parser support.";
461  return EXIT_FAILURE;
462 #endif
463  }
464  else if (modelFormat.find("tflite") != std::string::npos)
465  {
466 #if defined(ARMNN_TF_LITE_PARSER)
467  if (!isModelBinary)
468  {
469  ARMNN_LOG(fatal) << "Unknown model format: '" << modelFormat << "'. Only 'binary' format supported \
470  for tflite files";
471  return EXIT_FAILURE;
472  }
473 
474  if (!converter.CreateNetwork<armnnTfLiteParser::ITfLiteParser>())
475  {
476  ARMNN_LOG(fatal) << "Failed to load model from file";
477  return EXIT_FAILURE;
478  }
479 #else
480  ARMNN_LOG(fatal) << "Not built with TfLite parser support.";
481  return EXIT_FAILURE;
482 #endif
483  }
484  else
485  {
486  ARMNN_LOG(fatal) << "Unknown model format: '" << modelFormat << "'";
487  return EXIT_FAILURE;
488  }
489  }
490  catch(armnn::Exception& e)
491  {
492  ARMNN_LOG(fatal) << "Failed to load model from file: " << e.what();
493  return EXIT_FAILURE;
494  }
495 
496  if (!converter.Serialize())
497  {
498  ARMNN_LOG(fatal) << "Failed to serialize model";
499  return EXIT_FAILURE;
500  }
501 
502  return EXIT_SUCCESS;
503 }
void ConfigureLogging(bool printToStandardOutput, bool printToDebugOutput, LogSeverity severity)
Configures the logging behaviour of the ARMNN library.
Definition: Utils.cpp:10
virtual const char * what() const noexcept override
Definition: Exceptions.cpp:32
#define ARMNN_LOG(severity)
Definition: Logging.hpp:163
Copyright (c) 2020 ARM Limited.
static ITfLiteParserPtr Create(const armnn::Optional< TfLiteParserOptions > &options=armnn::EmptyOptional())
#define ARMNN_SCOPED_HEAP_PROFILING(TAG)
static IOnnxParserPtr Create()
Definition: OnnxParser.cpp:424
Parses a directed acyclic graph from a tensorflow protobuf file.
Definition: ITfParser.hpp:25
Base class for all ArmNN exceptions so that users can filter to just those.
Definition: Exceptions.hpp:46
static ISerializerPtr Create()
std::unique_ptr< INetwork, void(*)(INetwork *network)> INetworkPtr
Definition: INetwork.hpp:101
LogSeverity
Definition: Utils.hpp:12
int main(int argc, const char *argv[])