7 #if defined(ARMNN_ONNX_PARSER) 10 #if defined(ARMNN_SERIALIZER) 13 #if defined(ARMNN_TF_LITE_PARSER) 27 #define CXXOPTS_VECTOR_DELIMITER '.' 28 #include <cxxopts/cxxopts.hpp> 30 #include <fmt/format.h> 41 std::vector<unsigned int> result;
44 while (std::getline(stream, line))
47 for (
const std::string& token : tokens)
53 result.push_back(armnn::numeric_cast<unsigned int>(std::stoi((token))));
55 catch (
const std::exception&)
57 ARMNN_LOG(error) <<
"'" << token <<
"' is not a valid number. It has been ignored.";
63 return armnn::TensorShape(armnn::numeric_cast<unsigned int>(result.size()), result.data());
66 int ParseCommandLineArgs(
int argc,
char* argv[],
67 std::string& modelFormat,
68 std::string& modelPath,
69 std::vector<std::string>& inputNames,
70 std::vector<std::string>& inputTensorShapeStrs,
71 std::vector<std::string>& outputNames,
72 std::string& outputPath,
bool& isModelBinary)
74 cxxopts::Options options(
"ArmNNConverter",
"Convert a neural network model from provided file to ArmNN format.");
77 std::string modelFormatDescription(
"Format of the model file");
78 #if defined(ARMNN_ONNX_PARSER) 79 modelFormatDescription +=
", onnx-binary, onnx-text";
81 #if defined(ARMNN_TF_PARSER) 82 modelFormatDescription +=
", tensorflow-binary, tensorflow-text";
84 #if defined(ARMNN_TF_LITE_PARSER) 85 modelFormatDescription +=
", tflite-binary";
87 modelFormatDescription +=
".";
89 (
"help",
"Display usage information")
90 (
"f,model-format", modelFormatDescription, cxxopts::value<std::string>(modelFormat))
91 (
"m,model-path",
"Path to model file.", cxxopts::value<std::string>(modelPath))
93 (
"i,input-name",
"Identifier of the input tensors in the network. " 94 "Each input must be specified separately.",
95 cxxopts::value<std::vector<std::string>>(inputNames))
96 (
"s,input-tensor-shape",
97 "The shape of the input tensor in the network as a flat array of integers, " 98 "separated by comma. Each input shape must be specified separately after the input name. " 99 "This parameter is optional, depending on the network.",
100 cxxopts::value<std::vector<std::string>>(inputTensorShapeStrs))
102 (
"o,output-name",
"Identifier of the output tensor in the network.",
103 cxxopts::value<std::vector<std::string>>(outputNames))
105 "Path to serialize the network to.", cxxopts::value<std::string>(outputPath));
107 catch (
const std::exception& e)
109 std::cerr << e.what() << std::endl << options.help() << std::endl;
114 cxxopts::ParseResult result = options.parse(argc, argv);
115 if (result.count(
"help"))
117 std::cerr << options.help() << std::endl;
121 std::string mandatorySingleParameters[] = {
"model-format",
"model-path",
"output-name",
"output-path" };
122 bool somethingsMissing =
false;
123 for (
auto param : mandatorySingleParameters)
125 if (result.count(param) != 1)
127 std::cerr <<
"Parameter \'--" << param <<
"\' is required but missing." << std::endl;
128 somethingsMissing =
true;
132 if (result.count(
"input-name") == 0)
134 std::cerr <<
"Parameter \'--" <<
"input-name" <<
"\' must be specified at least once." << std::endl;
135 somethingsMissing =
true;
138 if (result.count(
"input-tensor-shape") > 0)
140 if (result.count(
"input-tensor-shape") != result.count(
"input-name"))
142 std::cerr <<
"When specifying \'input-tensor-shape\' a matching number of \'input-name\' parameters " 143 "must be specified." << std::endl;
144 somethingsMissing =
true;
148 if (somethingsMissing)
150 std::cerr << options.help() << std::endl;
154 catch (
const cxxopts::OptionException& e)
156 std::cerr << e.what() << std::endl << std::endl;
160 if (modelFormat.find(
"bin") != std::string::npos)
162 isModelBinary =
true;
164 else if (modelFormat.find(
"text") != std::string::npos)
166 isModelBinary =
false;
170 ARMNN_LOG(fatal) <<
"Unknown model format: '" << modelFormat <<
"'. Please include 'binary' or 'text'";
180 typedef T parserType;
186 ArmnnConverter(
const std::string& modelPath,
187 const std::vector<std::string>& inputNames,
188 const std::vector<armnn::TensorShape>& inputShapes,
189 const std::vector<std::string>& outputNames,
190 const std::string& outputPath,
193 m_ModelPath(modelPath),
194 m_InputNames(inputNames),
195 m_InputShapes(inputShapes),
196 m_OutputNames(outputNames),
197 m_OutputPath(outputPath),
198 m_IsModelBinary(isModelBinary) {}
202 if (m_NetworkPtr.get() ==
nullptr)
211 std::ofstream file(m_OutputPath, std::ios::out | std::ios::binary);
213 bool retVal =
serializer->SaveSerializedToStream(file);
218 template <
typename IParser>
219 bool CreateNetwork ()
221 return CreateNetwork (ParserType<IParser>());
226 std::string m_ModelPath;
227 std::vector<std::string> m_InputNames;
228 std::vector<armnn::TensorShape> m_InputShapes;
229 std::vector<std::string> m_OutputNames;
230 std::string m_OutputPath;
231 bool m_IsModelBinary;
233 template <
typename IParser>
234 bool CreateNetwork (ParserType<IParser>)
237 auto parser(IParser::Create());
239 std::map<std::string, armnn::TensorShape> inputShapes;
240 if (!m_InputShapes.empty())
242 const size_t numInputShapes = m_InputShapes.size();
243 const size_t numInputBindings = m_InputNames.size();
244 if (numInputShapes < numInputBindings)
247 "Not every input has its tensor shape specified: expected={0}, got={1}",
248 numInputBindings, numInputShapes));
251 for (
size_t i = 0; i < numInputShapes; i++)
253 inputShapes[m_InputNames[i]] = m_InputShapes[i];
259 m_NetworkPtr = (m_IsModelBinary ?
260 parser->CreateNetworkFromBinaryFile(m_ModelPath.c_str(), inputShapes, m_OutputNames) :
261 parser->CreateNetworkFromTextFile(m_ModelPath.c_str(), inputShapes, m_OutputNames));
264 return m_NetworkPtr.get() !=
nullptr;
267 #if defined(ARMNN_TF_LITE_PARSER) 268 bool CreateNetwork (ParserType<armnnTfLiteParser::ITfLiteParser>)
273 if (!m_InputShapes.empty())
275 const size_t numInputShapes = m_InputShapes.size();
276 const size_t numInputBindings = m_InputNames.size();
277 if (numInputShapes < numInputBindings)
280 "Not every input has its tensor shape specified: expected={0}, got={1}",
281 numInputBindings, numInputShapes));
287 m_NetworkPtr = parser->CreateNetworkFromBinaryFile(m_ModelPath.c_str());
290 return m_NetworkPtr.get() !=
nullptr;
294 #if defined(ARMNN_ONNX_PARSER) 295 bool CreateNetwork (ParserType<armnnOnnxParser::IOnnxParser>)
300 if (!m_InputShapes.empty())
302 const size_t numInputShapes = m_InputShapes.size();
303 const size_t numInputBindings = m_InputNames.size();
304 if (numInputShapes < numInputBindings)
307 "Not every input has its tensor shape specified: expected={0}, got={1}",
308 numInputBindings, numInputShapes));
314 m_NetworkPtr = (m_IsModelBinary ?
315 parser->CreateNetworkFromBinaryFile(m_ModelPath.c_str()) :
316 parser->CreateNetworkFromTextFile(m_ModelPath.c_str()));
319 return m_NetworkPtr.get() !=
nullptr;
327 int main(
int argc,
char* argv[])
330 #if (!defined(ARMNN_ONNX_PARSER) \ 331 && !defined(ARMNN_TF_PARSER) \ 332 && !defined(ARMNN_TF_LITE_PARSER)) 333 ARMNN_LOG(fatal) <<
"Not built with any of the supported parsers Onnx, Tensorflow, or TfLite.";
337 #if !defined(ARMNN_SERIALIZER) 338 ARMNN_LOG(fatal) <<
"Not built with Serializer support.";
350 std::string modelFormat;
351 std::string modelPath;
353 std::vector<std::string> inputNames;
354 std::vector<std::string> inputTensorShapeStrs;
355 std::vector<armnn::TensorShape> inputTensorShapes;
357 std::vector<std::string> outputNames;
358 std::string outputPath;
360 bool isModelBinary =
true;
362 if (ParseCommandLineArgs(
363 argc, argv, modelFormat, modelPath, inputNames, inputTensorShapeStrs, outputNames, outputPath, isModelBinary)
369 for (
const std::string& shapeStr : inputTensorShapeStrs)
371 if (!shapeStr.empty())
373 std::stringstream ss(shapeStr);
378 inputTensorShapes.push_back(shape);
382 ARMNN_LOG(fatal) <<
"Cannot create tensor shape: " << e.
what();
388 ArmnnConverter converter(modelPath, inputNames, inputTensorShapes, outputNames, outputPath, isModelBinary);
392 if (modelFormat.find(
"onnx") != std::string::npos)
394 #if defined(ARMNN_ONNX_PARSER) 397 ARMNN_LOG(fatal) <<
"Failed to load model from file";
401 ARMNN_LOG(fatal) <<
"Not built with Onnx parser support.";
405 else if (modelFormat.find(
"tflite") != std::string::npos)
407 #if defined(ARMNN_TF_LITE_PARSER) 410 ARMNN_LOG(fatal) <<
"Unknown model format: '" << modelFormat <<
"'. Only 'binary' format supported \ 417 ARMNN_LOG(fatal) <<
"Failed to load model from file";
421 ARMNN_LOG(fatal) <<
"Not built with TfLite parser support.";
427 ARMNN_LOG(fatal) <<
"Unknown model format: '" << modelFormat <<
"'";
433 ARMNN_LOG(fatal) <<
"Failed to load model from file: " << e.
what();
437 if (!converter.Serialize())
439 ARMNN_LOG(fatal) <<
"Failed to serialize model";
std::vector< std::string > StringTokenizer(const std::string &str, const char *delimiters, bool tokenCompression=true)
Function to take a string and a list of delimiters and split the string into tokens based on those de...
void ConfigureLogging(bool printToStandardOutput, bool printToDebugOutput, LogSeverity severity)
Configures the logging behaviour of the ARMNN library.
virtual const char * what() const noexcept override
#define ARMNN_LOG(severity)
Main network class which provides the interface for building up a neural network. ...
static ITfLiteParserPtr Create(const armnn::Optional< TfLiteParserOptions > &options=armnn::EmptyOptional())
#define ARMNN_SCOPED_HEAP_PROFILING(TAG)
static IOnnxParserPtr Create()
Base class for all ArmNN exceptions so that users can filter to just those.
static ISerializerPtr Create()
int main(int argc, char *argv[])
std::unique_ptr< INetwork, void(*)(INetwork *network)> INetworkPtr