ArmNN
 20.08
ArmnnConverter.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include <armnn/Logging.hpp>
6 
7 #if defined(ARMNN_CAFFE_PARSER)
9 #endif
10 #if defined(ARMNN_ONNX_PARSER)
12 #endif
13 #if defined(ARMNN_SERIALIZER)
15 #endif
16 #if defined(ARMNN_TF_PARSER)
18 #endif
19 #if defined(ARMNN_TF_LITE_PARSER)
21 #endif
22 
23 #include <HeapProfiling.hpp>
25 
26 #include <boost/format.hpp>
27 #include <boost/program_options.hpp>
28 
29 #include <cstdlib>
30 #include <fstream>
31 #include <iostream>
32 
33 namespace
34 {
35 
36 namespace po = boost::program_options;
37 
38 armnn::TensorShape ParseTensorShape(std::istream& stream)
39 {
40  std::vector<unsigned int> result;
41  std::string line;
42 
43  while (std::getline(stream, line))
44  {
45  std::vector<std::string> tokens = armnn::stringUtils::StringTokenizer(line, ",");
46  for (const std::string& token : tokens)
47  {
48  if (!token.empty())
49  {
50  try
51  {
52  result.push_back(boost::numeric_cast<unsigned int>(std::stoi((token))));
53  }
54  catch (const std::exception&)
55  {
56  ARMNN_LOG(error) << "'" << token << "' is not a valid number. It has been ignored.";
57  }
58  }
59  }
60  }
61 
62  return armnn::TensorShape(boost::numeric_cast<unsigned int>(result.size()), result.data());
63 }
64 
65 bool CheckOption(const po::variables_map& vm,
66  const char* option)
67 {
68  if (option == nullptr)
69  {
70  return false;
71  }
72 
73  // Check whether 'option' is provided.
74  return vm.find(option) != vm.end();
75 }
76 
77 void CheckOptionDependency(const po::variables_map& vm,
78  const char* option,
79  const char* required)
80 {
81  if (option == nullptr || required == nullptr)
82  {
83  throw po::error("Invalid option to check dependency for");
84  }
85 
86  // Check that if 'option' is provided, 'required' is also provided.
87  if (CheckOption(vm, option) && !vm[option].defaulted())
88  {
89  if (CheckOption(vm, required) == 0 || vm[required].defaulted())
90  {
91  throw po::error(std::string("Option '") + option + "' requires option '" + required + "'.");
92  }
93  }
94 }
95 
96 void CheckOptionDependencies(const po::variables_map& vm)
97 {
98  CheckOptionDependency(vm, "model-path", "model-format");
99  CheckOptionDependency(vm, "model-path", "input-name");
100  CheckOptionDependency(vm, "model-path", "output-name");
101  CheckOptionDependency(vm, "input-tensor-shape", "model-path");
102 }
103 
104 int ParseCommandLineArgs(int argc, const char* argv[],
105  std::string& modelFormat,
106  std::string& modelPath,
107  std::vector<std::string>& inputNames,
108  std::vector<std::string>& inputTensorShapeStrs,
109  std::vector<std::string>& outputNames,
110  std::string& outputPath, bool& isModelBinary)
111 {
112  po::options_description desc("Options");
113 
114  desc.add_options()
115  ("help", "Display usage information")
116  ("model-format,f", po::value(&modelFormat)->required(),"Format of the model file"
117 #if defined(ARMNN_CAFFE_PARSER)
118  ", caffe-binary, caffe-text"
119 #endif
120 #if defined(ARMNN_ONNX_PARSER)
121  ", onnx-binary, onnx-text"
122 #endif
123 #if defined(ARMNN_TF_PARSER)
124  ", tensorflow-binary, tensorflow-text"
125 #endif
126 #if defined(ARMNN_TF_LITE_PARSER)
127  ", tflite-binary"
128 #endif
129  ".")
130  ("model-path,m", po::value(&modelPath)->required(), "Path to model file.")
131  ("input-name,i", po::value<std::vector<std::string>>()->multitoken(),
132  "Identifier of the input tensors in the network, separated by whitespace.")
133  ("input-tensor-shape,s", po::value<std::vector<std::string>>()->multitoken(),
134  "The shape of the input tensor in the network as a flat array of integers, separated by comma."
135  " Multiple shapes are separated by whitespace."
136  " This parameter is optional, depending on the network.")
137  ("output-name,o", po::value<std::vector<std::string>>()->multitoken(),
138  "Identifier of the output tensor in the network.")
139  ("output-path,p", po::value(&outputPath)->required(), "Path to serialize the network to.");
140 
141  po::variables_map vm;
142  try
143  {
144  po::store(po::parse_command_line(argc, argv, desc), vm);
145 
146  if (CheckOption(vm, "help") || argc <= 1)
147  {
148  std::cout << "Convert a neural network model from provided file to ArmNN format." << std::endl;
149  std::cout << std::endl;
150  std::cout << desc << std::endl;
151  exit(EXIT_SUCCESS);
152  }
153  po::notify(vm);
154  }
155  catch (const po::error& e)
156  {
157  std::cerr << e.what() << std::endl << std::endl;
158  std::cerr << desc << std::endl;
159  return EXIT_FAILURE;
160  }
161 
162  try
163  {
164  CheckOptionDependencies(vm);
165  }
166  catch (const po::error& e)
167  {
168  std::cerr << e.what() << std::endl << std::endl;
169  std::cerr << desc << std::endl;
170  return EXIT_FAILURE;
171  }
172 
173  if (modelFormat.find("bin") != std::string::npos)
174  {
175  isModelBinary = true;
176  }
177  else if (modelFormat.find("text") != std::string::npos)
178  {
179  isModelBinary = false;
180  }
181  else
182  {
183  ARMNN_LOG(fatal) << "Unknown model format: '" << modelFormat << "'. Please include 'binary' or 'text'";
184  return EXIT_FAILURE;
185  }
186 
187  if (!vm["input-tensor-shape"].empty())
188  {
189  inputTensorShapeStrs = vm["input-tensor-shape"].as<std::vector<std::string>>();
190  }
191 
192  inputNames = vm["input-name"].as<std::vector<std::string>>();
193  outputNames = vm["output-name"].as<std::vector<std::string>>();
194 
195  return EXIT_SUCCESS;
196 }
197 
198 template<typename T>
199 struct ParserType
200 {
201  typedef T parserType;
202 };
203 
204 class ArmnnConverter
205 {
206 public:
207  ArmnnConverter(const std::string& modelPath,
208  const std::vector<std::string>& inputNames,
209  const std::vector<armnn::TensorShape>& inputShapes,
210  const std::vector<std::string>& outputNames,
211  const std::string& outputPath,
212  bool isModelBinary)
213  : m_NetworkPtr(armnn::INetworkPtr(nullptr, [](armnn::INetwork *){})),
214  m_ModelPath(modelPath),
215  m_InputNames(inputNames),
216  m_InputShapes(inputShapes),
217  m_OutputNames(outputNames),
218  m_OutputPath(outputPath),
219  m_IsModelBinary(isModelBinary) {}
220 
221  bool Serialize()
222  {
223  if (m_NetworkPtr.get() == nullptr)
224  {
225  return false;
226  }
227 
229 
230  serializer->Serialize(*m_NetworkPtr);
231 
232  std::ofstream file(m_OutputPath, std::ios::out | std::ios::binary);
233 
234  bool retVal = serializer->SaveSerializedToStream(file);
235 
236  return retVal;
237  }
238 
239  template <typename IParser>
240  bool CreateNetwork ()
241  {
242  return CreateNetwork (ParserType<IParser>());
243  }
244 
245 private:
246  armnn::INetworkPtr m_NetworkPtr;
247  std::string m_ModelPath;
248  std::vector<std::string> m_InputNames;
249  std::vector<armnn::TensorShape> m_InputShapes;
250  std::vector<std::string> m_OutputNames;
251  std::string m_OutputPath;
252  bool m_IsModelBinary;
253 
254  template <typename IParser>
255  bool CreateNetwork (ParserType<IParser>)
256  {
257  // Create a network from a file on disk
258  auto parser(IParser::Create());
259 
260  std::map<std::string, armnn::TensorShape> inputShapes;
261  if (!m_InputShapes.empty())
262  {
263  const size_t numInputShapes = m_InputShapes.size();
264  const size_t numInputBindings = m_InputNames.size();
265  if (numInputShapes < numInputBindings)
266  {
267  throw armnn::Exception(boost::str(boost::format(
268  "Not every input has its tensor shape specified: expected=%1%, got=%2%")
269  % numInputBindings % numInputShapes));
270  }
271 
272  for (size_t i = 0; i < numInputShapes; i++)
273  {
274  inputShapes[m_InputNames[i]] = m_InputShapes[i];
275  }
276  }
277 
278  {
279  ARMNN_SCOPED_HEAP_PROFILING("Parsing");
280  m_NetworkPtr = (m_IsModelBinary ?
281  parser->CreateNetworkFromBinaryFile(m_ModelPath.c_str(), inputShapes, m_OutputNames) :
282  parser->CreateNetworkFromTextFile(m_ModelPath.c_str(), inputShapes, m_OutputNames));
283  }
284 
285  return m_NetworkPtr.get() != nullptr;
286  }
287 
288 #if defined(ARMNN_TF_LITE_PARSER)
289  bool CreateNetwork (ParserType<armnnTfLiteParser::ITfLiteParser>)
290  {
291  // Create a network from a file on disk
293 
294  if (!m_InputShapes.empty())
295  {
296  const size_t numInputShapes = m_InputShapes.size();
297  const size_t numInputBindings = m_InputNames.size();
298  if (numInputShapes < numInputBindings)
299  {
300  throw armnn::Exception(boost::str(boost::format(
301  "Not every input has its tensor shape specified: expected=%1%, got=%2%")
302  % numInputBindings % numInputShapes));
303  }
304  }
305 
306  {
307  ARMNN_SCOPED_HEAP_PROFILING("Parsing");
308  m_NetworkPtr = parser->CreateNetworkFromBinaryFile(m_ModelPath.c_str());
309  }
310 
311  return m_NetworkPtr.get() != nullptr;
312  }
313 #endif
314 
315 #if defined(ARMNN_ONNX_PARSER)
316  bool CreateNetwork (ParserType<armnnOnnxParser::IOnnxParser>)
317  {
318  // Create a network from a file on disk
320 
321  if (!m_InputShapes.empty())
322  {
323  const size_t numInputShapes = m_InputShapes.size();
324  const size_t numInputBindings = m_InputNames.size();
325  if (numInputShapes < numInputBindings)
326  {
327  throw armnn::Exception(boost::str(boost::format(
328  "Not every input has its tensor shape specified: expected=%1%, got=%2%")
329  % numInputBindings % numInputShapes));
330  }
331  }
332 
333  {
334  ARMNN_SCOPED_HEAP_PROFILING("Parsing");
335  m_NetworkPtr = (m_IsModelBinary ?
336  parser->CreateNetworkFromBinaryFile(m_ModelPath.c_str()) :
337  parser->CreateNetworkFromTextFile(m_ModelPath.c_str()));
338  }
339 
340  return m_NetworkPtr.get() != nullptr;
341  }
342 #endif
343 
344 };
345 
346 } // anonymous namespace
347 
348 int main(int argc, const char* argv[])
349 {
350 
351 #if (!defined(ARMNN_CAFFE_PARSER) \
352  && !defined(ARMNN_ONNX_PARSER) \
353  && !defined(ARMNN_TF_PARSER) \
354  && !defined(ARMNN_TF_LITE_PARSER))
355  ARMNN_LOG(fatal) << "Not built with any of the supported parsers, Caffe, Onnx, Tensorflow, or TfLite.";
356  return EXIT_FAILURE;
357 #endif
358 
359 #if !defined(ARMNN_SERIALIZER)
360  ARMNN_LOG(fatal) << "Not built with Serializer support.";
361  return EXIT_FAILURE;
362 #endif
363 
364 #ifdef NDEBUG
366 #else
368 #endif
369 
370  armnn::ConfigureLogging(true, true, level);
371 
372  std::string modelFormat;
373  std::string modelPath;
374 
375  std::vector<std::string> inputNames;
376  std::vector<std::string> inputTensorShapeStrs;
377  std::vector<armnn::TensorShape> inputTensorShapes;
378 
379  std::vector<std::string> outputNames;
380  std::string outputPath;
381 
382  bool isModelBinary = true;
383 
384  if (ParseCommandLineArgs(
385  argc, argv, modelFormat, modelPath, inputNames, inputTensorShapeStrs, outputNames, outputPath, isModelBinary)
386  != EXIT_SUCCESS)
387  {
388  return EXIT_FAILURE;
389  }
390 
391  for (const std::string& shapeStr : inputTensorShapeStrs)
392  {
393  if (!shapeStr.empty())
394  {
395  std::stringstream ss(shapeStr);
396 
397  try
398  {
399  armnn::TensorShape shape = ParseTensorShape(ss);
400  inputTensorShapes.push_back(shape);
401  }
402  catch (const armnn::InvalidArgumentException& e)
403  {
404  ARMNN_LOG(fatal) << "Cannot create tensor shape: " << e.what();
405  return EXIT_FAILURE;
406  }
407  }
408  }
409 
410  ArmnnConverter converter(modelPath, inputNames, inputTensorShapes, outputNames, outputPath, isModelBinary);
411 
412  try
413  {
414  if (modelFormat.find("caffe") != std::string::npos)
415  {
416 #if defined(ARMNN_CAFFE_PARSER)
417  if (!converter.CreateNetwork<armnnCaffeParser::ICaffeParser>())
418  {
419  ARMNN_LOG(fatal) << "Failed to load model from file";
420  return EXIT_FAILURE;
421  }
422 #else
423  ARMNN_LOG(fatal) << "Not built with Caffe parser support.";
424  return EXIT_FAILURE;
425 #endif
426  }
427  else if (modelFormat.find("onnx") != std::string::npos)
428  {
429 #if defined(ARMNN_ONNX_PARSER)
430  if (!converter.CreateNetwork<armnnOnnxParser::IOnnxParser>())
431  {
432  ARMNN_LOG(fatal) << "Failed to load model from file";
433  return EXIT_FAILURE;
434  }
435 #else
436  ARMNN_LOG(fatal) << "Not built with Onnx parser support.";
437  return EXIT_FAILURE;
438 #endif
439  }
440  else if (modelFormat.find("tensorflow") != std::string::npos)
441  {
442 #if defined(ARMNN_TF_PARSER)
443  if (!converter.CreateNetwork<armnnTfParser::ITfParser>())
444  {
445  ARMNN_LOG(fatal) << "Failed to load model from file";
446  return EXIT_FAILURE;
447  }
448 #else
449  ARMNN_LOG(fatal) << "Not built with Tensorflow parser support.";
450  return EXIT_FAILURE;
451 #endif
452  }
453  else if (modelFormat.find("tflite") != std::string::npos)
454  {
455 #if defined(ARMNN_TF_LITE_PARSER)
456  if (!isModelBinary)
457  {
458  ARMNN_LOG(fatal) << "Unknown model format: '" << modelFormat << "'. Only 'binary' format supported \
459  for tflite files";
460  return EXIT_FAILURE;
461  }
462 
463  if (!converter.CreateNetwork<armnnTfLiteParser::ITfLiteParser>())
464  {
465  ARMNN_LOG(fatal) << "Failed to load model from file";
466  return EXIT_FAILURE;
467  }
468 #else
469  ARMNN_LOG(fatal) << "Not built with TfLite parser support.";
470  return EXIT_FAILURE;
471 #endif
472  }
473  else
474  {
475  ARMNN_LOG(fatal) << "Unknown model format: '" << modelFormat << "'";
476  return EXIT_FAILURE;
477  }
478  }
479  catch(armnn::Exception& e)
480  {
481  ARMNN_LOG(fatal) << "Failed to load model from file: " << e.what();
482  return EXIT_FAILURE;
483  }
484 
485  if (!converter.Serialize())
486  {
487  ARMNN_LOG(fatal) << "Failed to serialize model";
488  return EXIT_FAILURE;
489  }
490 
491  return EXIT_SUCCESS;
492 }
std::vector< std::string > StringTokenizer(const std::string &str, const char *delimiters, bool tokenCompression=true)
Function to take a string and a list of delimiters and split the string into tokens based on those de...
Definition: StringUtils.hpp:20
void ConfigureLogging(bool printToStandardOutput, bool printToDebugOutput, LogSeverity severity)
Configures the logging behaviour of the ARMNN library.
Definition: Utils.cpp:10
virtual const char * what() const noexcept override
Definition: Exceptions.cpp:32
#define ARMNN_LOG(severity)
Definition: Logging.hpp:163
Copyright (c) 2020 ARM Limited.
static ITfLiteParserPtr Create(const armnn::Optional< TfLiteParserOptions > &options=armnn::EmptyOptional())
#define ARMNN_SCOPED_HEAP_PROFILING(TAG)
static IOnnxParserPtr Create()
Definition: OnnxParser.cpp:439
Parses a directed acyclic graph from a tensorflow protobuf file.
Definition: ITfParser.hpp:25
Base class for all ArmNN exceptions so that users can filter to just those.
Definition: Exceptions.hpp:46
static ISerializerPtr Create()
std::unique_ptr< INetwork, void(*)(INetwork *network)> INetworkPtr
Definition: INetwork.hpp:101
LogSeverity
Definition: Utils.hpp:12
int main(int argc, const char *argv[])