ArmNN
 20.11
OnnxParser Class Reference

#include <OnnxParser.hpp>

Inheritance diagram for OnnxParser:
IOnnxParser

Public Types

using GraphPtr = std::unique_ptr< onnx::GraphProto >
 

Public Member Functions

virtual armnn::INetworkPtr CreateNetworkFromBinaryFile (const char *graphFile) override
 Create the network from a protobuf binary file on disk. More...
 
virtual armnn::INetworkPtr CreateNetworkFromTextFile (const char *graphFile) override
 Create the network from a protobuf text file on disk. More...
 
virtual armnn::INetworkPtr CreateNetworkFromString (const std::string &protoText) override
 Create the network directly from protobuf text in a string. Useful for debugging/testing. More...
 
virtual BindingPointInfo GetNetworkInputBindingInfo (const std::string &name) const override
 Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name. More...
 
virtual BindingPointInfo GetNetworkOutputBindingInfo (const std::string &name) const override
 Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name. More...
 
 OnnxParser ()
 
template<typename TypePair , typename Location >
void ValidateInputs (const onnx::NodeProto &node, TypePair validInputs, const Location &location)
 

Static Public Member Functions

static ModelPtr LoadModelFromBinaryFile (const char *fileName)
 
static ModelPtr LoadModelFromTextFile (const char *fileName)
 
static ModelPtr LoadModelFromString (const std::string &inputString)
 
static std::vector< std::string > GetInputs (ModelPtr &model)
 Retrieve inputs names. More...
 
static std::vector< std::string > GetOutputs (ModelPtr &model)
 Retrieve outputs names. More...
 
- Static Public Member Functions inherited from IOnnxParser
static IOnnxParserCreateRaw ()
 
static IOnnxParserPtr Create ()
 
static void Destroy (IOnnxParser *parser)
 

Additional Inherited Members

- Protected Member Functions inherited from IOnnxParser
virtual ~IOnnxParser ()
 

Detailed Description

Definition at line 25 of file OnnxParser.hpp.

Member Typedef Documentation

◆ GraphPtr

using GraphPtr = std::unique_ptr<onnx::GraphProto>

Definition at line 32 of file OnnxParser.hpp.

Constructor & Destructor Documentation

◆ OnnxParser()

Definition at line 445 of file OnnxParser.cpp.

References CHECK_LOCATION, TensorInfo::GetNumBytes(), and TensorInfo::GetNumElements().

446  : m_Network(nullptr, nullptr)
447 {
448 }

Member Function Documentation

◆ CreateNetworkFromBinaryFile()

INetworkPtr CreateNetworkFromBinaryFile ( const char *  graphFile)
overridevirtual

Create the network from a protobuf binary file on disk.

Implements IOnnxParser.

Definition at line 563 of file OnnxParser.cpp.

References OnnxParser::LoadModelFromBinaryFile().

564 {
565  ResetParser();
566  ModelPtr modelProto = LoadModelFromBinaryFile(graphFile);
567  return CreateNetworkFromModel(*modelProto);
568 }
std::unique_ptr< onnx::ModelProto > ModelPtr
static ModelPtr LoadModelFromBinaryFile(const char *fileName)
Definition: OnnxParser.cpp:535

◆ CreateNetworkFromString()

INetworkPtr CreateNetworkFromString ( const std::string &  protoText)
overridevirtual

Create the network directly from protobuf text in a string. Useful for debugging/testing.

Implements IOnnxParser.

Definition at line 589 of file OnnxParser.cpp.

References ARMNN_ASSERT, armnnTfParser::CalcPadding(), CHECK_LOCATION, CHECK_VALID_DATATYPE, CHECK_VALID_SIZE, CHECKED_INT32, IConnectableLayer::GetInputSlot(), TensorShape::GetNumDimensions(), IConnectableLayer::GetNumInputSlots(), IConnectableLayer::GetNumOutputSlots(), IConnectableLayer::GetOutputSlot(), TensorInfo::GetShape(), OnnxParser::LoadModelFromString(), ActivationDescriptor::m_A, ActivationDescriptor::m_B, FullyConnectedDescriptor::m_BiasEnabled, Convolution2dDescriptor::m_BiasEnabled, BatchNormalizationDescriptor::m_Eps, ActivationDescriptor::m_Function, Pooling2dDescriptor::m_OutputShapeRounding, Pooling2dDescriptor::m_PadBottom, Convolution2dDescriptor::m_PadBottom, Pooling2dDescriptor::m_PaddingMethod, Pooling2dDescriptor::m_PadLeft, Convolution2dDescriptor::m_PadLeft, DepthwiseConvolution2dDescriptor::m_PadLeft, Pooling2dDescriptor::m_PadRight, Convolution2dDescriptor::m_PadRight, Pooling2dDescriptor::m_PadTop, Convolution2dDescriptor::m_PadTop, Pooling2dDescriptor::m_PoolHeight, Pooling2dDescriptor::m_PoolType, Pooling2dDescriptor::m_PoolWidth, Pooling2dDescriptor::m_StrideX, Convolution2dDescriptor::m_StrideX, Pooling2dDescriptor::m_StrideY, Convolution2dDescriptor::m_StrideY, ReshapeDescriptor::m_TargetShape, TensorInfo::SetShape(), IOutputSlot::SetTensorInfo(), STR_LIST, armnnDeserializer::ToTensorInfo(), and VALID_INPUTS.

590 {
591  ResetParser();
592  ModelPtr modelProto = LoadModelFromString(protoText);
593  return CreateNetworkFromModel(*modelProto);
594 }
std::unique_ptr< onnx::ModelProto > ModelPtr
static ModelPtr LoadModelFromString(const std::string &inputString)
Definition: OnnxParser.cpp:570

◆ CreateNetworkFromTextFile()

INetworkPtr CreateNetworkFromTextFile ( const char *  graphFile)
overridevirtual

Create the network from a protobuf text file on disk.

Implements IOnnxParser.

Definition at line 527 of file OnnxParser.cpp.

References OnnxParser::LoadModelFromTextFile().

528 {
529  ResetParser();
530  ModelPtr modelProto = LoadModelFromTextFile(graphFile);
531  return CreateNetworkFromModel(*modelProto);
532 }
std::unique_ptr< onnx::ModelProto > ModelPtr
static ModelPtr LoadModelFromTextFile(const char *fileName)
Definition: OnnxParser.cpp:502

◆ GetInputs()

std::vector< std::string > GetInputs ( ModelPtr model)
static

Retrieve inputs names.

Definition at line 1765 of file OnnxParser.cpp.

References CHECK_LOCATION.

Referenced by BOOST_FIXTURE_TEST_CASE().

1766 {
1767  if(model == nullptr) {
1768  throw InvalidArgumentException(fmt::format("The given model cannot be null {}",
1769  CHECK_LOCATION().AsString()));
1770  }
1771 
1772  std::vector<std::string> inputNames;
1773  std::map<std::string, bool> isConstant;
1774  for(auto tensor : model->graph().initializer())
1775  {
1776  isConstant[tensor.name()] = true;
1777  }
1778  for(auto input : model->graph().input())
1779  {
1780  auto it = isConstant.find(input.name());
1781  if(it == isConstant.end())
1782  {
1783  inputNames.push_back(input.name());
1784  }
1785  }
1786  return inputNames;
1787 }
#define CHECK_LOCATION()
Definition: Exceptions.hpp:197

◆ GetNetworkInputBindingInfo()

BindingPointInfo GetNetworkInputBindingInfo ( const std::string &  name) const
overridevirtual

Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name.

Implements IOnnxParser.

Definition at line 1737 of file OnnxParser.cpp.

References CHECK_LOCATION, and armnnDeserializer::ToTensorInfo().

1738 {
1739  for(int i = 0; i < m_Graph->input_size(); ++i)
1740  {
1741  auto input = m_Graph->input(i);
1742  if(input.name() == name)
1743  {
1744  return std::make_pair(static_cast<armnn::LayerBindingId>(i), ToTensorInfo(input));
1745  }
1746  }
1747  throw InvalidArgumentException(fmt::format("The input layer '{}' does not exist {}",
1748  name, CHECK_LOCATION().AsString()));
1749 }
armnn::TensorInfo ToTensorInfo(Deserializer::TensorRawPtr tensorPtr)
#define CHECK_LOCATION()
Definition: Exceptions.hpp:197

◆ GetNetworkOutputBindingInfo()

BindingPointInfo GetNetworkOutputBindingInfo ( const std::string &  name) const
overridevirtual

Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name.

Implements IOnnxParser.

Definition at line 1751 of file OnnxParser.cpp.

References CHECK_LOCATION, and armnnDeserializer::ToTensorInfo().

1752 {
1753  for(int i = 0; i < m_Graph->output_size(); ++i)
1754  {
1755  auto output = m_Graph->output(i);
1756  if(output.name() == name)
1757  {
1758  return std::make_pair(static_cast<armnn::LayerBindingId>(i), ToTensorInfo(output));
1759  }
1760  }
1761  throw InvalidArgumentException(fmt::format("The output layer '{}' does not exist {}",
1762  name, CHECK_LOCATION().AsString()));
1763 }
armnn::TensorInfo ToTensorInfo(Deserializer::TensorRawPtr tensorPtr)
#define CHECK_LOCATION()
Definition: Exceptions.hpp:197

◆ GetOutputs()

std::vector< std::string > GetOutputs ( ModelPtr model)
static

Retrieve outputs names.

Definition at line 1789 of file OnnxParser.cpp.

References CHECK_LOCATION.

Referenced by BOOST_FIXTURE_TEST_CASE().

1790 {
1791  if(model == nullptr) {
1792  throw InvalidArgumentException(fmt::format("The given model cannot be null {}",
1793  CHECK_LOCATION().AsString()));
1794  }
1795 
1796  std::vector<std::string> outputNames;
1797  for(auto output : model->graph().output())
1798  {
1799  outputNames.push_back(output.name());
1800  }
1801  return outputNames;
1802 }
#define CHECK_LOCATION()
Definition: Exceptions.hpp:197

◆ LoadModelFromBinaryFile()

ModelPtr LoadModelFromBinaryFile ( const char *  fileName)
static

Definition at line 535 of file OnnxParser.cpp.

References CHECK_LOCATION, and armnn::error.

Referenced by OnnxParser::CreateNetworkFromBinaryFile().

536 {
537  FILE* fd = fopen(graphFile, "rb");
538 
539  if (fd == nullptr)
540  {
541  throw FileNotFoundException(fmt::format("Invalid (null) filename {}", CHECK_LOCATION().AsString()));
542  }
543 
544  // Parse the file into a message
545  ModelPtr modelProto = std::make_unique<onnx::ModelProto>();
546 
547  google::protobuf::io::FileInputStream inStream(fileno(fd));
548  google::protobuf::io::CodedInputStream codedStream(&inStream);
549  codedStream.SetTotalBytesLimit(INT_MAX);
550  bool success = modelProto.get()->ParseFromCodedStream(&codedStream);
551  fclose(fd);
552 
553  if (!success)
554  {
555  std::stringstream error;
556  error << "Failed to parse graph file";
557  throw ParseException(fmt::format("{} {}", error.str(), CHECK_LOCATION().AsString()));
558  }
559  return modelProto;
560 
561 }
std::unique_ptr< onnx::ModelProto > ModelPtr
#define CHECK_LOCATION()
Definition: Exceptions.hpp:197

◆ LoadModelFromString()

ModelPtr LoadModelFromString ( const std::string &  inputString)
static

Definition at line 570 of file OnnxParser.cpp.

References CHECK_LOCATION, and armnn::error.

Referenced by BOOST_AUTO_TEST_CASE(), BOOST_FIXTURE_TEST_CASE(), and OnnxParser::CreateNetworkFromString().

571 {
572  if (protoText == "")
573  {
574  throw InvalidArgumentException(fmt::format("Invalid (empty) string for model parameter {}",
575  CHECK_LOCATION().AsString()));
576  }
577  // Parse the string into a message
578  ModelPtr modelProto = std::make_unique<onnx::ModelProto>();
579  bool success = google::protobuf::TextFormat::ParseFromString(protoText, modelProto.get());
580  if (!success)
581  {
582  std::stringstream error;
583  error << "Failed to parse graph file";
584  throw ParseException(fmt::format("{} {}", error.str(), CHECK_LOCATION().AsString()));
585  }
586  return modelProto;
587 }
std::unique_ptr< onnx::ModelProto > ModelPtr
#define CHECK_LOCATION()
Definition: Exceptions.hpp:197

◆ LoadModelFromTextFile()

ModelPtr LoadModelFromTextFile ( const char *  fileName)
static

Definition at line 502 of file OnnxParser.cpp.

References CHECK_LOCATION, and armnn::error.

Referenced by OnnxParser::CreateNetworkFromTextFile().

503 {
504  FILE* fd = fopen(graphFile, "r");
505 
506  if (fd == nullptr)
507  {
508  throw FileNotFoundException(fmt::format("Invalid (null) filename {}", CHECK_LOCATION().AsString()));
509  }
510 
511  // Parse the file into a message
512  ModelPtr modelProto = std::make_unique<onnx::ModelProto>();
513  using google::protobuf::io::FileInputStream;
514  std::unique_ptr<FileInputStream> input = std::make_unique<FileInputStream>(fileno(fd));
515  bool success = google::protobuf::TextFormat::Parse(input.get(), modelProto.get());
516  fclose(fd);
517 
518  if (!success)
519  {
520  std::stringstream error;
521  error << "Failed to parse graph file";
522  throw ParseException(fmt::format("{} {}", error.str(), CHECK_LOCATION().AsString()));
523  }
524  return modelProto;
525 }
std::unique_ptr< onnx::ModelProto > ModelPtr
#define CHECK_LOCATION()
Definition: Exceptions.hpp:197

◆ ValidateInputs()

void ValidateInputs ( const onnx::NodeProto &  node,
TypePair  validInputs,
const Location &  location 
)

Definition at line 378 of file OnnxParser.cpp.

381 {
382  for(auto input : node.input())
383  {
384  CheckValidDataType(validInputs.second,
385  m_TensorsInfo[input].m_dtype,
386  validInputs.first,
387  node.name(),
388  input,
389  location);
390  }
391 }

The documentation for this class was generated from the following files: