7 #if defined(ARMNN_CAFFE_PARSER) 10 #if defined(ARMNN_ONNX_PARSER) 13 #if defined(ARMNN_SERIALIZER) 16 #if defined(ARMNN_TF_PARSER) 19 #if defined(ARMNN_TF_LITE_PARSER) 33 #define CXXOPTS_VECTOR_DELIMITER '.' 34 #include <cxxopts/cxxopts.hpp> 36 #include <fmt/format.h> 47 std::vector<unsigned int> result;
50 while (std::getline(stream, line))
53 for (
const std::string& token : tokens)
59 result.push_back(armnn::numeric_cast<unsigned int>(std::stoi((token))));
61 catch (
const std::exception&)
63 ARMNN_LOG(error) <<
"'" << token <<
"' is not a valid number. It has been ignored.";
69 return armnn::TensorShape(armnn::numeric_cast<unsigned int>(result.size()), result.data());
72 int ParseCommandLineArgs(
int argc,
char* argv[],
73 std::string& modelFormat,
74 std::string& modelPath,
75 std::vector<std::string>& inputNames,
76 std::vector<std::string>& inputTensorShapeStrs,
77 std::vector<std::string>& outputNames,
78 std::string& outputPath,
bool& isModelBinary)
80 cxxopts::Options options(
"ArmNNConverter",
"Convert a neural network model from provided file to ArmNN format.");
83 std::string modelFormatDescription(
"Format of the model file");
84 #if defined(ARMNN_CAFFE_PARSER) 85 modelFormatDescription +=
", caffe-binary, caffe-text";
87 #if defined(ARMNN_ONNX_PARSER) 88 modelFormatDescription +=
", onnx-binary, onnx-text";
90 #if defined(ARMNN_TF_PARSER) 91 modelFormatDescription +=
", tensorflow-binary, tensorflow-text";
93 #if defined(ARMNN_TF_LITE_PARSER) 94 modelFormatDescription +=
", tflite-binary";
96 modelFormatDescription +=
".";
98 (
"help",
"Display usage information")
99 (
"f,model-format", modelFormatDescription, cxxopts::value<std::string>(modelFormat))
100 (
"m,model-path",
"Path to model file.", cxxopts::value<std::string>(modelPath))
102 (
"i,input-name",
"Identifier of the input tensors in the network. " 103 "Each input must be specified separately.",
104 cxxopts::value<std::vector<std::string>>(inputNames))
105 (
"s,input-tensor-shape",
106 "The shape of the input tensor in the network as a flat array of integers, " 107 "separated by comma. Each input shape must be specified separately after the input name. " 108 "This parameter is optional, depending on the network.",
109 cxxopts::value<std::vector<std::string>>(inputTensorShapeStrs))
111 (
"o,output-name",
"Identifier of the output tensor in the network.",
112 cxxopts::value<std::vector<std::string>>(outputNames))
114 "Path to serialize the network to.", cxxopts::value<std::string>(outputPath));
116 catch (
const std::exception& e)
118 std::cerr << e.what() << std::endl << options.help() << std::endl;
123 cxxopts::ParseResult result = options.parse(argc, argv);
124 if (result.count(
"help"))
126 std::cerr << options.help() << std::endl;
130 std::string mandatorySingleParameters[] = {
"model-format",
"model-path",
"output-name",
"output-path" };
131 bool somethingsMissing =
false;
132 for (
auto param : mandatorySingleParameters)
134 if (result.count(param) != 1)
136 std::cerr <<
"Parameter \'--" << param <<
"\' is required but missing." << std::endl;
137 somethingsMissing =
true;
141 if (result.count(
"input-name") == 0)
143 std::cerr <<
"Parameter \'--" <<
"input-name" <<
"\' must be specified at least once." << std::endl;
144 somethingsMissing =
true;
147 if (result.count(
"input-tensor-shape") > 0)
149 if (result.count(
"input-tensor-shape") != result.count(
"input-name"))
151 std::cerr <<
"When specifying \'input-tensor-shape\' a matching number of \'input-name\' parameters " 152 "must be specified." << std::endl;
153 somethingsMissing =
true;
157 if (somethingsMissing)
159 std::cerr << options.help() << std::endl;
163 catch (
const cxxopts::OptionException& e)
165 std::cerr << e.what() << std::endl << std::endl;
169 if (modelFormat.find(
"bin") != std::string::npos)
171 isModelBinary =
true;
173 else if (modelFormat.find(
"text") != std::string::npos)
175 isModelBinary =
false;
179 ARMNN_LOG(fatal) <<
"Unknown model format: '" << modelFormat <<
"'. Please include 'binary' or 'text'";
189 typedef T parserType;
195 ArmnnConverter(
const std::string& modelPath,
196 const std::vector<std::string>& inputNames,
197 const std::vector<armnn::TensorShape>& inputShapes,
198 const std::vector<std::string>& outputNames,
199 const std::string& outputPath,
202 m_ModelPath(modelPath),
203 m_InputNames(inputNames),
204 m_InputShapes(inputShapes),
205 m_OutputNames(outputNames),
206 m_OutputPath(outputPath),
207 m_IsModelBinary(isModelBinary) {}
211 if (m_NetworkPtr.get() ==
nullptr)
220 std::ofstream file(m_OutputPath, std::ios::out | std::ios::binary);
222 bool retVal =
serializer->SaveSerializedToStream(file);
227 template <
typename IParser>
228 bool CreateNetwork ()
230 return CreateNetwork (ParserType<IParser>());
235 std::string m_ModelPath;
236 std::vector<std::string> m_InputNames;
237 std::vector<armnn::TensorShape> m_InputShapes;
238 std::vector<std::string> m_OutputNames;
239 std::string m_OutputPath;
240 bool m_IsModelBinary;
242 template <
typename IParser>
243 bool CreateNetwork (ParserType<IParser>)
246 auto parser(IParser::Create());
248 std::map<std::string, armnn::TensorShape> inputShapes;
249 if (!m_InputShapes.empty())
251 const size_t numInputShapes = m_InputShapes.size();
252 const size_t numInputBindings = m_InputNames.size();
253 if (numInputShapes < numInputBindings)
256 "Not every input has its tensor shape specified: expected={0}, got={1}",
257 numInputBindings, numInputShapes));
260 for (
size_t i = 0; i < numInputShapes; i++)
262 inputShapes[m_InputNames[i]] = m_InputShapes[i];
268 m_NetworkPtr = (m_IsModelBinary ?
269 parser->CreateNetworkFromBinaryFile(m_ModelPath.c_str(), inputShapes, m_OutputNames) :
270 parser->CreateNetworkFromTextFile(m_ModelPath.c_str(), inputShapes, m_OutputNames));
273 return m_NetworkPtr.get() !=
nullptr;
276 #if defined(ARMNN_TF_LITE_PARSER) 277 bool CreateNetwork (ParserType<armnnTfLiteParser::ITfLiteParser>)
282 if (!m_InputShapes.empty())
284 const size_t numInputShapes = m_InputShapes.size();
285 const size_t numInputBindings = m_InputNames.size();
286 if (numInputShapes < numInputBindings)
289 "Not every input has its tensor shape specified: expected={0}, got={1}",
290 numInputBindings, numInputShapes));
296 m_NetworkPtr = parser->CreateNetworkFromBinaryFile(m_ModelPath.c_str());
299 return m_NetworkPtr.get() !=
nullptr;
303 #if defined(ARMNN_ONNX_PARSER) 304 bool CreateNetwork (ParserType<armnnOnnxParser::IOnnxParser>)
309 if (!m_InputShapes.empty())
311 const size_t numInputShapes = m_InputShapes.size();
312 const size_t numInputBindings = m_InputNames.size();
313 if (numInputShapes < numInputBindings)
316 "Not every input has its tensor shape specified: expected={0}, got={1}",
317 numInputBindings, numInputShapes));
323 m_NetworkPtr = (m_IsModelBinary ?
324 parser->CreateNetworkFromBinaryFile(m_ModelPath.c_str()) :
325 parser->CreateNetworkFromTextFile(m_ModelPath.c_str()));
328 return m_NetworkPtr.get() !=
nullptr;
336 int main(
int argc,
char* argv[])
339 #if (!defined(ARMNN_CAFFE_PARSER) \ 340 && !defined(ARMNN_ONNX_PARSER) \ 341 && !defined(ARMNN_TF_PARSER) \ 342 && !defined(ARMNN_TF_LITE_PARSER)) 343 ARMNN_LOG(fatal) <<
"Not built with any of the supported parsers, Caffe, Onnx, Tensorflow, or TfLite.";
347 #if !defined(ARMNN_SERIALIZER) 348 ARMNN_LOG(fatal) <<
"Not built with Serializer support.";
360 std::string modelFormat;
361 std::string modelPath;
363 std::vector<std::string> inputNames;
364 std::vector<std::string> inputTensorShapeStrs;
365 std::vector<armnn::TensorShape> inputTensorShapes;
367 std::vector<std::string> outputNames;
368 std::string outputPath;
370 bool isModelBinary =
true;
372 if (ParseCommandLineArgs(
373 argc, argv, modelFormat, modelPath, inputNames, inputTensorShapeStrs, outputNames, outputPath, isModelBinary)
379 for (
const std::string& shapeStr : inputTensorShapeStrs)
381 if (!shapeStr.empty())
383 std::stringstream ss(shapeStr);
388 inputTensorShapes.push_back(shape);
392 ARMNN_LOG(fatal) <<
"Cannot create tensor shape: " << e.
what();
398 ArmnnConverter converter(modelPath, inputNames, inputTensorShapes, outputNames, outputPath, isModelBinary);
402 if (modelFormat.find(
"caffe") != std::string::npos)
404 #if defined(ARMNN_CAFFE_PARSER) 407 ARMNN_LOG(fatal) <<
"Failed to load model from file";
411 ARMNN_LOG(fatal) <<
"Not built with Caffe parser support.";
415 else if (modelFormat.find(
"onnx") != std::string::npos)
417 #if defined(ARMNN_ONNX_PARSER) 420 ARMNN_LOG(fatal) <<
"Failed to load model from file";
424 ARMNN_LOG(fatal) <<
"Not built with Onnx parser support.";
428 else if (modelFormat.find(
"tensorflow") != std::string::npos)
430 #if defined(ARMNN_TF_PARSER) 433 ARMNN_LOG(fatal) <<
"Failed to load model from file";
437 ARMNN_LOG(fatal) <<
"Not built with Tensorflow parser support.";
441 else if (modelFormat.find(
"tflite") != std::string::npos)
443 #if defined(ARMNN_TF_LITE_PARSER) 446 ARMNN_LOG(fatal) <<
"Unknown model format: '" << modelFormat <<
"'. Only 'binary' format supported \ 453 ARMNN_LOG(fatal) <<
"Failed to load model from file";
457 ARMNN_LOG(fatal) <<
"Not built with TfLite parser support.";
463 ARMNN_LOG(fatal) <<
"Unknown model format: '" << modelFormat <<
"'";
469 ARMNN_LOG(fatal) <<
"Failed to load model from file: " << e.
what();
473 if (!converter.Serialize())
475 ARMNN_LOG(fatal) <<
"Failed to serialize model";
std::vector< std::string > StringTokenizer(const std::string &str, const char *delimiters, bool tokenCompression=true)
Function to take a string and a list of delimiters and split the string into tokens based on those de...
void ConfigureLogging(bool printToStandardOutput, bool printToDebugOutput, LogSeverity severity)
Configures the logging behaviour of the ARMNN library.
virtual const char * what() const noexcept override
#define ARMNN_LOG(severity)
Main network class which provides the interface for building up a neural network. ...
static ITfLiteParserPtr Create(const armnn::Optional< TfLiteParserOptions > &options=armnn::EmptyOptional())
#define ARMNN_SCOPED_HEAP_PROFILING(TAG)
static IOnnxParserPtr Create()
Parses a directed acyclic graph from a tensorflow protobuf file.
Base class for all ArmNN exceptions so that users can filter to just those.
static ISerializerPtr Create()
int main(int argc, char *argv[])
std::unique_ptr< INetwork, void(*)(INetwork *network)> INetworkPtr