ArmNN
 20.08
OnnxParser Class Reference

#include <OnnxParser.hpp>

Inheritance diagram for OnnxParser:
IOnnxParser

Public Types

using GraphPtr = std::unique_ptr< onnx::GraphProto >
 

Public Member Functions

virtual armnn::INetworkPtr CreateNetworkFromBinaryFile (const char *graphFile) override
 Create the network from a protobuf binary file on disk. More...
 
virtual armnn::INetworkPtr CreateNetworkFromTextFile (const char *graphFile) override
 Create the network from a protobuf text file on disk. More...
 
virtual armnn::INetworkPtr CreateNetworkFromString (const std::string &protoText) override
 Create the network directly from protobuf text in a string. Useful for debugging/testing. More...
 
virtual BindingPointInfo GetNetworkInputBindingInfo (const std::string &name) const override
 Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name. More...
 
virtual BindingPointInfo GetNetworkOutputBindingInfo (const std::string &name) const override
 Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name. More...
 
 OnnxParser ()
 
template<typename TypePair , typename Location >
void ValidateInputs (const onnx::NodeProto &node, TypePair validInputs, const Location &location)
 

Static Public Member Functions

static ModelPtr LoadModelFromBinaryFile (const char *fileName)
 
static ModelPtr LoadModelFromTextFile (const char *fileName)
 
static ModelPtr LoadModelFromString (const std::string &inputString)
 
static std::vector< std::string > GetInputs (ModelPtr &model)
 Retrieve inputs names. More...
 
static std::vector< std::string > GetOutputs (ModelPtr &model)
 Retrieve outputs names. More...
 
- Static Public Member Functions inherited from IOnnxParser
static IOnnxParserCreateRaw ()
 
static IOnnxParserPtr Create ()
 
static void Destroy (IOnnxParser *parser)
 

Additional Inherited Members

- Protected Member Functions inherited from IOnnxParser
virtual ~IOnnxParser ()
 

Detailed Description

Definition at line 25 of file OnnxParser.hpp.

Member Typedef Documentation

◆ GraphPtr

using GraphPtr = std::unique_ptr<onnx::GraphProto>

Definition at line 32 of file OnnxParser.hpp.

Constructor & Destructor Documentation

◆ OnnxParser()

Definition at line 449 of file OnnxParser.cpp.

References CHECK_LOCATION, TensorInfo::GetNumBytes(), and TensorInfo::GetNumElements().

450  : m_Network(nullptr, nullptr)
451 {
452 }

Member Function Documentation

◆ CreateNetworkFromBinaryFile()

INetworkPtr CreateNetworkFromBinaryFile ( const char *  graphFile)
overridevirtual

Create the network from a protobuf binary file on disk.

Implements IOnnxParser.

Definition at line 572 of file OnnxParser.cpp.

References OnnxParser::LoadModelFromBinaryFile().

573 {
574  ResetParser();
575  ModelPtr modelProto = LoadModelFromBinaryFile(graphFile);
576  return CreateNetworkFromModel(*modelProto);
577 }
std::unique_ptr< onnx::ModelProto > ModelPtr
static ModelPtr LoadModelFromBinaryFile(const char *fileName)
Definition: OnnxParser.cpp:542

◆ CreateNetworkFromString()

INetworkPtr CreateNetworkFromString ( const std::string &  protoText)
overridevirtual

Create the network directly from protobuf text in a string. Useful for debugging/testing.

Implements IOnnxParser.

Definition at line 599 of file OnnxParser.cpp.

References ARMNN_ASSERT, armnnTfParser::CalcPadding(), CHECK_LOCATION, CHECK_VALID_DATATYPE, CHECK_VALID_SIZE, CHECKED_INT32, IConnectableLayer::GetInputSlot(), TensorShape::GetNumDimensions(), IConnectableLayer::GetNumInputSlots(), IConnectableLayer::GetNumOutputSlots(), IConnectableLayer::GetOutputSlot(), TensorInfo::GetShape(), OnnxParser::LoadModelFromString(), ActivationDescriptor::m_A, ActivationDescriptor::m_B, FullyConnectedDescriptor::m_BiasEnabled, Convolution2dDescriptor::m_BiasEnabled, BatchNormalizationDescriptor::m_Eps, ActivationDescriptor::m_Function, Pooling2dDescriptor::m_OutputShapeRounding, Pooling2dDescriptor::m_PadBottom, Convolution2dDescriptor::m_PadBottom, Pooling2dDescriptor::m_PaddingMethod, Pooling2dDescriptor::m_PadLeft, Convolution2dDescriptor::m_PadLeft, DepthwiseConvolution2dDescriptor::m_PadLeft, Pooling2dDescriptor::m_PadRight, Convolution2dDescriptor::m_PadRight, Pooling2dDescriptor::m_PadTop, Convolution2dDescriptor::m_PadTop, Pooling2dDescriptor::m_PoolHeight, Pooling2dDescriptor::m_PoolType, Pooling2dDescriptor::m_PoolWidth, Pooling2dDescriptor::m_StrideX, Convolution2dDescriptor::m_StrideX, Pooling2dDescriptor::m_StrideY, Convolution2dDescriptor::m_StrideY, ReshapeDescriptor::m_TargetShape, TensorInfo::SetShape(), IOutputSlot::SetTensorInfo(), STR_LIST, armnnDeserializer::ToTensorInfo(), and VALID_INPUTS.

600 {
601  ResetParser();
602  ModelPtr modelProto = LoadModelFromString(protoText);
603  return CreateNetworkFromModel(*modelProto);
604 }
std::unique_ptr< onnx::ModelProto > ModelPtr
static ModelPtr LoadModelFromString(const std::string &inputString)
Definition: OnnxParser.cpp:579

◆ CreateNetworkFromTextFile()

INetworkPtr CreateNetworkFromTextFile ( const char *  graphFile)
overridevirtual

Create the network from a protobuf text file on disk.

Implements IOnnxParser.

Definition at line 534 of file OnnxParser.cpp.

References OnnxParser::LoadModelFromTextFile().

535 {
536  ResetParser();
537  ModelPtr modelProto = LoadModelFromTextFile(graphFile);
538  return CreateNetworkFromModel(*modelProto);
539 }
std::unique_ptr< onnx::ModelProto > ModelPtr
static ModelPtr LoadModelFromTextFile(const char *fileName)
Definition: OnnxParser.cpp:507

◆ GetInputs()

std::vector< std::string > GetInputs ( ModelPtr model)
static

Retrieve inputs names.

Definition at line 1791 of file OnnxParser.cpp.

References CHECK_LOCATION.

Referenced by BOOST_FIXTURE_TEST_CASE().

1792 {
1793  if(model == nullptr) {
1794  throw InvalidArgumentException(boost::str(
1795  boost::format("The given model cannot be null %1%")
1796  % CHECK_LOCATION().AsString()));
1797  }
1798 
1799  std::vector<std::string> inputNames;
1800  std::map<std::string, bool> isConstant;
1801  for(auto tensor : model->graph().initializer())
1802  {
1803  isConstant[tensor.name()] = true;
1804  }
1805  for(auto input : model->graph().input())
1806  {
1807  auto it = isConstant.find(input.name());
1808  if(it == isConstant.end())
1809  {
1810  inputNames.push_back(input.name());
1811  }
1812  }
1813  return inputNames;
1814 }
#define CHECK_LOCATION()
Definition: Exceptions.hpp:197

◆ GetNetworkInputBindingInfo()

BindingPointInfo GetNetworkInputBindingInfo ( const std::string &  name) const
overridevirtual

Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name.

Implements IOnnxParser.

Definition at line 1763 of file OnnxParser.cpp.

References CHECK_LOCATION, and armnnDeserializer::ToTensorInfo().

1764 {
1765  for(int i = 0; i < m_Graph->input_size(); ++i)
1766  {
1767  auto input = m_Graph->input(i);
1768  if(input.name() == name)
1769  {
1770  return std::make_pair(static_cast<armnn::LayerBindingId>(i), ToTensorInfo(input));
1771  }
1772  }
1773  throw InvalidArgumentException(boost::str(boost::format("The input layer '%1%' does not exist %2%")
1774  % name % CHECK_LOCATION().AsString()));
1775 }
armnn::TensorInfo ToTensorInfo(Deserializer::TensorRawPtr tensorPtr)
#define CHECK_LOCATION()
Definition: Exceptions.hpp:197

◆ GetNetworkOutputBindingInfo()

BindingPointInfo GetNetworkOutputBindingInfo ( const std::string &  name) const
overridevirtual

Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name.

Implements IOnnxParser.

Definition at line 1777 of file OnnxParser.cpp.

References CHECK_LOCATION, and armnnDeserializer::ToTensorInfo().

1778 {
1779  for(int i = 0; i < m_Graph->output_size(); ++i)
1780  {
1781  auto output = m_Graph->output(i);
1782  if(output.name() == name)
1783  {
1784  return std::make_pair(static_cast<armnn::LayerBindingId>(i), ToTensorInfo(output));
1785  }
1786  }
1787  throw InvalidArgumentException(boost::str(boost::format("The output layer '%1%' does not exist %2%")
1788  % name % CHECK_LOCATION().AsString()));
1789 }
armnn::TensorInfo ToTensorInfo(Deserializer::TensorRawPtr tensorPtr)
#define CHECK_LOCATION()
Definition: Exceptions.hpp:197

◆ GetOutputs()

std::vector< std::string > GetOutputs ( ModelPtr model)
static

Retrieve outputs names.

Definition at line 1816 of file OnnxParser.cpp.

References CHECK_LOCATION.

Referenced by BOOST_FIXTURE_TEST_CASE().

1817 {
1818  if(model == nullptr) {
1819  throw InvalidArgumentException(boost::str(
1820  boost::format("The given model cannot be null %1%")
1821  % CHECK_LOCATION().AsString()));
1822  }
1823 
1824  std::vector<std::string> outputNames;
1825  for(auto output : model->graph().output())
1826  {
1827  outputNames.push_back(output.name());
1828  }
1829  return outputNames;
1830 }
#define CHECK_LOCATION()
Definition: Exceptions.hpp:197

◆ LoadModelFromBinaryFile()

ModelPtr LoadModelFromBinaryFile ( const char *  fileName)
static

Definition at line 542 of file OnnxParser.cpp.

References CHECK_LOCATION, and armnn::error.

Referenced by OnnxParser::CreateNetworkFromBinaryFile().

543 {
544  FILE* fd = fopen(graphFile, "rb");
545 
546  if (fd == nullptr)
547  {
548  throw FileNotFoundException(boost::str(
549  boost::format("Invalid (null) filename %1%") % CHECK_LOCATION().AsString()));
550  }
551 
552  // Parse the file into a message
553  ModelPtr modelProto = std::make_unique<onnx::ModelProto>();
554 
555  google::protobuf::io::FileInputStream inStream(fileno(fd));
556  google::protobuf::io::CodedInputStream codedStream(&inStream);
557  codedStream.SetTotalBytesLimit(INT_MAX, INT_MAX);
558  bool success = modelProto.get()->ParseFromCodedStream(&codedStream);
559  fclose(fd);
560 
561  if (!success)
562  {
563  std::stringstream error;
564  error << "Failed to parse graph file";
565  throw ParseException(boost::str(
566  boost::format("%1% %2%") % error.str() % CHECK_LOCATION().AsString()));
567  }
568  return modelProto;
569 
570 }
std::unique_ptr< onnx::ModelProto > ModelPtr
#define CHECK_LOCATION()
Definition: Exceptions.hpp:197

◆ LoadModelFromString()

ModelPtr LoadModelFromString ( const std::string &  inputString)
static

Definition at line 579 of file OnnxParser.cpp.

References CHECK_LOCATION, and armnn::error.

Referenced by BOOST_AUTO_TEST_CASE(), BOOST_FIXTURE_TEST_CASE(), and OnnxParser::CreateNetworkFromString().

580 {
581  if (protoText == "")
582  {
583  throw InvalidArgumentException(boost::str(
584  boost::format("Invalid (empty) string for model parameter %1%") % CHECK_LOCATION().AsString()));
585  }
586  // Parse the string into a message
587  ModelPtr modelProto = std::make_unique<onnx::ModelProto>();
588  bool success = google::protobuf::TextFormat::ParseFromString(protoText, modelProto.get());
589  if (!success)
590  {
591  std::stringstream error;
592  error << "Failed to parse graph file";
593  throw ParseException(boost::str(
594  boost::format("%1% %2%") % error.str() % CHECK_LOCATION().AsString()));
595  }
596  return modelProto;
597 }
std::unique_ptr< onnx::ModelProto > ModelPtr
#define CHECK_LOCATION()
Definition: Exceptions.hpp:197

◆ LoadModelFromTextFile()

ModelPtr LoadModelFromTextFile ( const char *  fileName)
static

Definition at line 507 of file OnnxParser.cpp.

References CHECK_LOCATION, and armnn::error.

Referenced by OnnxParser::CreateNetworkFromTextFile().

508 {
509  FILE* fd = fopen(graphFile, "r");
510 
511  if (fd == nullptr)
512  {
513  throw FileNotFoundException(boost::str(
514  boost::format("Invalid (null) filename %1%") % CHECK_LOCATION().AsString()));
515  }
516 
517  // Parse the file into a message
518  ModelPtr modelProto = std::make_unique<onnx::ModelProto>();
519  using google::protobuf::io::FileInputStream;
520  std::unique_ptr<FileInputStream> input = std::make_unique<FileInputStream>(fileno(fd));
521  bool success = google::protobuf::TextFormat::Parse(input.get(), modelProto.get());
522  fclose(fd);
523 
524  if (!success)
525  {
526  std::stringstream error;
527  error << "Failed to parse graph file";
528  throw ParseException(boost::str(
529  boost::format("%1% %2%") % error.str() % CHECK_LOCATION().AsString()));
530  }
531  return modelProto;
532 }
std::unique_ptr< onnx::ModelProto > ModelPtr
#define CHECK_LOCATION()
Definition: Exceptions.hpp:197

◆ ValidateInputs()

void ValidateInputs ( const onnx::NodeProto &  node,
TypePair  validInputs,
const Location &  location 
)

Definition at line 382 of file OnnxParser.cpp.

385 {
386  for(auto input : node.input())
387  {
388  CheckValidDataType(validInputs.second,
389  m_TensorsInfo[input].m_dtype,
390  validInputs.first,
391  node.name(),
392  input,
393  location);
394  }
395 }

The documentation for this class was generated from the following files: