ArmNN
 20.02
OnnxParser Class Reference

#include <OnnxParser.hpp>

Inheritance diagram for OnnxParser:
IOnnxParser

Public Types

using GraphPtr = std::unique_ptr< onnx::GraphProto >
 

Public Member Functions

virtual armnn::INetworkPtr CreateNetworkFromBinaryFile (const char *graphFile) override
 Create the network from a protobuf binary file on disk. More...
 
virtual armnn::INetworkPtr CreateNetworkFromTextFile (const char *graphFile) override
 Create the network from a protobuf text file on disk. More...
 
virtual armnn::INetworkPtr CreateNetworkFromString (const std::string &protoText) override
 Create the network directly from protobuf text in a string. Useful for debugging/testing. More...
 
virtual BindingPointInfo GetNetworkInputBindingInfo (const std::string &name) const override
 Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name. More...
 
virtual BindingPointInfo GetNetworkOutputBindingInfo (const std::string &name) const override
 Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name. More...
 
 OnnxParser ()
 
template<typename TypePair , typename Location >
void ValidateInputs (const onnx::NodeProto &node, TypePair validInputs, const Location &location)
 

Static Public Member Functions

static ModelPtr LoadModelFromBinaryFile (const char *fileName)
 
static ModelPtr LoadModelFromTextFile (const char *fileName)
 
static ModelPtr LoadModelFromString (const std::string &inputString)
 
static std::vector< std::string > GetInputs (ModelPtr &model)
 Retrieve inputs names. More...
 
static std::vector< std::string > GetOutputs (ModelPtr &model)
 Retrieve outputs names. More...
 
- Static Public Member Functions inherited from IOnnxParser
static IOnnxParserCreateRaw ()
 
static IOnnxParserPtr Create ()
 
static void Destroy (IOnnxParser *parser)
 

Additional Inherited Members

- Protected Member Functions inherited from IOnnxParser
virtual ~IOnnxParser ()
 

Detailed Description

Definition at line 25 of file OnnxParser.hpp.

Member Typedef Documentation

◆ GraphPtr

using GraphPtr = std::unique_ptr<onnx::GraphProto>

Definition at line 32 of file OnnxParser.hpp.

Constructor & Destructor Documentation

◆ OnnxParser()

Definition at line 434 of file OnnxParser.cpp.

References CHECK_LOCATION, TensorInfo::GetNumBytes(), and TensorInfo::GetNumElements().

435  : m_Network(nullptr, nullptr)
436 {
437 }

Member Function Documentation

◆ CreateNetworkFromBinaryFile()

INetworkPtr CreateNetworkFromBinaryFile ( const char *  graphFile)
overridevirtual

Create the network from a protobuf binary file on disk.

Implements IOnnxParser.

Definition at line 557 of file OnnxParser.cpp.

References OnnxParser::LoadModelFromBinaryFile().

558 {
559  ResetParser();
560  ModelPtr modelProto = LoadModelFromBinaryFile(graphFile);
561  return CreateNetworkFromModel(*modelProto);
562 }
std::unique_ptr< onnx::ModelProto > ModelPtr
static ModelPtr LoadModelFromBinaryFile(const char *fileName)
Definition: OnnxParser.cpp:527

◆ CreateNetworkFromString()

INetworkPtr CreateNetworkFromString ( const std::string &  protoText)
overridevirtual

Create the network directly from protobuf text in a string. Useful for debugging/testing.

Implements IOnnxParser.

Definition at line 584 of file OnnxParser.cpp.

References armnnTfParser::CalcPadding(), CHECK_LOCATION, CHECK_VALID_DATATYPE, CHECK_VALID_SIZE, IConnectableLayer::GetInputSlot(), TensorShape::GetNumDimensions(), IConnectableLayer::GetNumInputSlots(), IConnectableLayer::GetNumOutputSlots(), IConnectableLayer::GetOutputSlot(), TensorInfo::GetShape(), OnnxParser::LoadModelFromString(), FullyConnectedDescriptor::m_BiasEnabled, Convolution2dDescriptor::m_BiasEnabled, BatchNormalizationDescriptor::m_Eps, ActivationDescriptor::m_Function, Pooling2dDescriptor::m_OutputShapeRounding, Pooling2dDescriptor::m_PadBottom, Convolution2dDescriptor::m_PadBottom, Pooling2dDescriptor::m_PaddingMethod, Pooling2dDescriptor::m_PadLeft, Convolution2dDescriptor::m_PadLeft, Pooling2dDescriptor::m_PadRight, Convolution2dDescriptor::m_PadRight, Pooling2dDescriptor::m_PadTop, Convolution2dDescriptor::m_PadTop, Pooling2dDescriptor::m_PoolHeight, Pooling2dDescriptor::m_PoolType, Pooling2dDescriptor::m_PoolWidth, Pooling2dDescriptor::m_StrideX, Convolution2dDescriptor::m_StrideX, Pooling2dDescriptor::m_StrideY, Convolution2dDescriptor::m_StrideY, ReshapeDescriptor::m_TargetShape, TensorInfo::SetShape(), IOutputSlot::SetTensorInfo(), STR_LIST, armnnDeserializer::ToTensorInfo(), and VALID_INPUTS.

585 {
586  ResetParser();
587  ModelPtr modelProto = LoadModelFromString(protoText);
588  return CreateNetworkFromModel(*modelProto);
589 }
std::unique_ptr< onnx::ModelProto > ModelPtr
static ModelPtr LoadModelFromString(const std::string &inputString)
Definition: OnnxParser.cpp:564

◆ CreateNetworkFromTextFile()

INetworkPtr CreateNetworkFromTextFile ( const char *  graphFile)
overridevirtual

Create the network from a protobuf text file on disk.

Implements IOnnxParser.

Definition at line 519 of file OnnxParser.cpp.

References OnnxParser::LoadModelFromTextFile().

520 {
521  ResetParser();
522  ModelPtr modelProto = LoadModelFromTextFile(graphFile);
523  return CreateNetworkFromModel(*modelProto);
524 }
std::unique_ptr< onnx::ModelProto > ModelPtr
static ModelPtr LoadModelFromTextFile(const char *fileName)
Definition: OnnxParser.cpp:492

◆ GetInputs()

std::vector< std::string > GetInputs ( ModelPtr model)
static

Retrieve inputs names.

Definition at line 1708 of file OnnxParser.cpp.

References CHECK_LOCATION.

Referenced by BOOST_FIXTURE_TEST_CASE().

1709 {
1710  if(model == nullptr) {
1711  throw InvalidArgumentException(boost::str(
1712  boost::format("The given model cannot be null %1%")
1713  % CHECK_LOCATION().AsString()));
1714  }
1715 
1716  std::vector<std::string> inputNames;
1717  std::map<std::string, bool> isConstant;
1718  for(auto tensor : model->graph().initializer())
1719  {
1720  isConstant[tensor.name()] = true;
1721  }
1722  for(auto input : model->graph().input())
1723  {
1724  auto it = isConstant.find(input.name());
1725  if(it == isConstant.end())
1726  {
1727  inputNames.push_back(input.name());
1728  }
1729  }
1730  return inputNames;
1731 }
#define CHECK_LOCATION()
Definition: Exceptions.hpp:192

◆ GetNetworkInputBindingInfo()

BindingPointInfo GetNetworkInputBindingInfo ( const std::string &  name) const
overridevirtual

Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name.

Implements IOnnxParser.

Definition at line 1680 of file OnnxParser.cpp.

References CHECK_LOCATION, and armnnDeserializer::ToTensorInfo().

1681 {
1682  for(int i = 0; i < m_Graph->input_size(); ++i)
1683  {
1684  auto input = m_Graph->input(i);
1685  if(input.name() == name)
1686  {
1687  return std::make_pair(static_cast<armnn::LayerBindingId>(i), ToTensorInfo(input));
1688  }
1689  }
1690  throw InvalidArgumentException(boost::str(boost::format("The input layer '%1%' does not exist %2%")
1691  % name % CHECK_LOCATION().AsString()));
1692 }
armnn::TensorInfo ToTensorInfo(Deserializer::TensorRawPtr tensorPtr)
#define CHECK_LOCATION()
Definition: Exceptions.hpp:192

◆ GetNetworkOutputBindingInfo()

BindingPointInfo GetNetworkOutputBindingInfo ( const std::string &  name) const
overridevirtual

Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name.

Implements IOnnxParser.

Definition at line 1694 of file OnnxParser.cpp.

References CHECK_LOCATION, and armnnDeserializer::ToTensorInfo().

1695 {
1696  for(int i = 0; i < m_Graph->output_size(); ++i)
1697  {
1698  auto output = m_Graph->output(i);
1699  if(output.name() == name)
1700  {
1701  return std::make_pair(static_cast<armnn::LayerBindingId>(i), ToTensorInfo(output));
1702  }
1703  }
1704  throw InvalidArgumentException(boost::str(boost::format("The output layer '%1%' does not exist %2%")
1705  % name % CHECK_LOCATION().AsString()));
1706 }
armnn::TensorInfo ToTensorInfo(Deserializer::TensorRawPtr tensorPtr)
#define CHECK_LOCATION()
Definition: Exceptions.hpp:192

◆ GetOutputs()

std::vector< std::string > GetOutputs ( ModelPtr model)
static

Retrieve outputs names.

Definition at line 1733 of file OnnxParser.cpp.

References CHECK_LOCATION.

Referenced by BOOST_FIXTURE_TEST_CASE().

1734 {
1735  if(model == nullptr) {
1736  throw InvalidArgumentException(boost::str(
1737  boost::format("The given model cannot be null %1%")
1738  % CHECK_LOCATION().AsString()));
1739  }
1740 
1741  std::vector<std::string> outputNames;
1742  for(auto output : model->graph().output())
1743  {
1744  outputNames.push_back(output.name());
1745  }
1746  return outputNames;
1747 }
#define CHECK_LOCATION()
Definition: Exceptions.hpp:192

◆ LoadModelFromBinaryFile()

ModelPtr LoadModelFromBinaryFile ( const char *  fileName)
static

Definition at line 527 of file OnnxParser.cpp.

References CHECK_LOCATION, and armnn::error.

Referenced by OnnxParser::CreateNetworkFromBinaryFile().

528 {
529  FILE* fd = fopen(graphFile, "rb");
530 
531  if (fd == nullptr)
532  {
533  throw FileNotFoundException(boost::str(
534  boost::format("Invalid (null) filename %1%") % CHECK_LOCATION().AsString()));
535  }
536 
537  // Parse the file into a message
538  ModelPtr modelProto = std::make_unique<onnx::ModelProto>();
539 
540  google::protobuf::io::FileInputStream inStream(fileno(fd));
541  google::protobuf::io::CodedInputStream codedStream(&inStream);
542  codedStream.SetTotalBytesLimit(INT_MAX, INT_MAX);
543  bool success = modelProto.get()->ParseFromCodedStream(&codedStream);
544  fclose(fd);
545 
546  if (!success)
547  {
548  std::stringstream error;
549  error << "Failed to parse graph file";
550  throw ParseException(boost::str(
551  boost::format("%1% %2%") % error.str() % CHECK_LOCATION().AsString()));
552  }
553  return modelProto;
554 
555 }
std::unique_ptr< onnx::ModelProto > ModelPtr
#define CHECK_LOCATION()
Definition: Exceptions.hpp:192

◆ LoadModelFromString()

ModelPtr LoadModelFromString ( const std::string &  inputString)
static

Definition at line 564 of file OnnxParser.cpp.

References CHECK_LOCATION, and armnn::error.

Referenced by BOOST_AUTO_TEST_CASE(), BOOST_FIXTURE_TEST_CASE(), and OnnxParser::CreateNetworkFromString().

565 {
566  if (protoText == "")
567  {
568  throw InvalidArgumentException(boost::str(
569  boost::format("Invalid (empty) string for model parameter %1%") % CHECK_LOCATION().AsString()));
570  }
571  // Parse the string into a message
572  ModelPtr modelProto = std::make_unique<onnx::ModelProto>();
573  bool success = google::protobuf::TextFormat::ParseFromString(protoText, modelProto.get());
574  if (!success)
575  {
576  std::stringstream error;
577  error << "Failed to parse graph file";
578  throw ParseException(boost::str(
579  boost::format("%1% %2%") % error.str() % CHECK_LOCATION().AsString()));
580  }
581  return modelProto;
582 }
std::unique_ptr< onnx::ModelProto > ModelPtr
#define CHECK_LOCATION()
Definition: Exceptions.hpp:192

◆ LoadModelFromTextFile()

ModelPtr LoadModelFromTextFile ( const char *  fileName)
static

Definition at line 492 of file OnnxParser.cpp.

References CHECK_LOCATION, and armnn::error.

Referenced by OnnxParser::CreateNetworkFromTextFile().

493 {
494  FILE* fd = fopen(graphFile, "r");
495 
496  if (fd == nullptr)
497  {
498  throw FileNotFoundException(boost::str(
499  boost::format("Invalid (null) filename %1%") % CHECK_LOCATION().AsString()));
500  }
501 
502  // Parse the file into a message
503  ModelPtr modelProto = std::make_unique<onnx::ModelProto>();
504  using google::protobuf::io::FileInputStream;
505  std::unique_ptr<FileInputStream> input = std::make_unique<FileInputStream>(fileno(fd));
506  bool success = google::protobuf::TextFormat::Parse(input.get(), modelProto.get());
507  fclose(fd);
508 
509  if (!success)
510  {
511  std::stringstream error;
512  error << "Failed to parse graph file";
513  throw ParseException(boost::str(
514  boost::format("%1% %2%") % error.str() % CHECK_LOCATION().AsString()));
515  }
516  return modelProto;
517 }
std::unique_ptr< onnx::ModelProto > ModelPtr
#define CHECK_LOCATION()
Definition: Exceptions.hpp:192

◆ ValidateInputs()

void ValidateInputs ( const onnx::NodeProto &  node,
TypePair  validInputs,
const Location &  location 
)

Definition at line 367 of file OnnxParser.cpp.

370 {
371  for(auto input : node.input())
372  {
373  CheckValidDataType(validInputs.second,
374  m_TensorsInfo[input].m_dtype,
375  validInputs.first,
376  node.name(),
377  input,
378  location);
379  }
380 }

The documentation for this class was generated from the following files: