ArmNN
 21.02
OnnxParserImpl Class Reference

#include <OnnxParser.hpp>

Public Types

using GraphPtr = std::unique_ptr< onnx::GraphProto >
 

Public Member Functions

armnn::INetworkPtr CreateNetworkFromBinaryFile (const char *graphFile)
 Create the network from a protobuf binary file on disk. More...
 
armnn::INetworkPtr CreateNetworkFromTextFile (const char *graphFile)
 Create the network from a protobuf text file on disk. More...
 
armnn::INetworkPtr CreateNetworkFromString (const std::string &protoText)
 Create the network directly from protobuf text in a string. Useful for debugging/testing. More...
 
BindingPointInfo GetNetworkInputBindingInfo (const std::string &name) const
 Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name. More...
 
BindingPointInfo GetNetworkOutputBindingInfo (const std::string &name) const
 Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name. More...
 
 OnnxParserImpl ()
 
 ~OnnxParserImpl ()=default
 
template<typename TypePair , typename Location >
void ValidateInputs (const onnx::NodeProto &node, TypePair validInputs, const Location &location)
 

Static Public Member Functions

static ModelPtr LoadModelFromBinaryFile (const char *fileName)
 
static ModelPtr LoadModelFromTextFile (const char *fileName)
 
static ModelPtr LoadModelFromString (const std::string &inputString)
 
static std::vector< std::string > GetInputs (ModelPtr &model)
 Retrieve inputs names. More...
 
static std::vector< std::string > GetOutputs (ModelPtr &model)
 Retrieve outputs names. More...
 
static const std::string GetVersion ()
 Retrieve version in X.Y.Z form. More...
 

Detailed Description

Definition at line 25 of file OnnxParser.hpp.

Member Typedef Documentation

◆ GraphPtr

using GraphPtr = std::unique_ptr<onnx::GraphProto>

Definition at line 32 of file OnnxParser.hpp.

Constructor & Destructor Documentation

◆ OnnxParserImpl()

Definition at line 484 of file OnnxParser.cpp.

References CHECK_LOCATION, TensorInfo::GetNumBytes(), and TensorInfo::GetNumElements().

485  : m_Network(nullptr, nullptr)
486 {
487 }

◆ ~OnnxParserImpl()

~OnnxParserImpl ( )
default

Member Function Documentation

◆ CreateNetworkFromBinaryFile()

INetworkPtr CreateNetworkFromBinaryFile ( const char *  graphFile)

Create the network from a protobuf binary file on disk.

Definition at line 602 of file OnnxParser.cpp.

References OnnxParserImpl::LoadModelFromBinaryFile().

603 {
604  ResetParser();
605  ModelPtr modelProto = LoadModelFromBinaryFile(graphFile);
606  return CreateNetworkFromModel(*modelProto);
607 }
std::unique_ptr< onnx::ModelProto > ModelPtr
static ModelPtr LoadModelFromBinaryFile(const char *fileName)
Definition: OnnxParser.cpp:574

◆ CreateNetworkFromString()

INetworkPtr CreateNetworkFromString ( const std::string &  protoText)

Create the network directly from protobuf text in a string. Useful for debugging/testing.

Definition at line 628 of file OnnxParser.cpp.

References ARMNN_ASSERT, armnnTfParser::CalcPadding(), CHECK_LOCATION, CHECK_VALID_DATATYPE, CHECK_VALID_SIZE, CHECKED_INT32, IConnectableLayer::GetInputSlot(), TensorShape::GetNumDimensions(), IConnectableLayer::GetNumInputSlots(), IConnectableLayer::GetNumOutputSlots(), IConnectableLayer::GetOutputSlot(), TensorInfo::GetShape(), OnnxParserImpl::LoadModelFromString(), ActivationDescriptor::m_A, ActivationDescriptor::m_B, FullyConnectedDescriptor::m_BiasEnabled, Convolution2dDescriptor::m_BiasEnabled, Convolution2dDescriptor::m_DilationX, Convolution2dDescriptor::m_DilationY, BatchNormalizationDescriptor::m_Eps, ActivationDescriptor::m_Function, Pooling2dDescriptor::m_OutputShapeRounding, Pooling2dDescriptor::m_PadBottom, Convolution2dDescriptor::m_PadBottom, Pooling2dDescriptor::m_PaddingMethod, Pooling2dDescriptor::m_PadLeft, Convolution2dDescriptor::m_PadLeft, DepthwiseConvolution2dDescriptor::m_PadLeft, Pooling2dDescriptor::m_PadRight, Convolution2dDescriptor::m_PadRight, Pooling2dDescriptor::m_PadTop, Convolution2dDescriptor::m_PadTop, Pooling2dDescriptor::m_PoolHeight, Pooling2dDescriptor::m_PoolType, Pooling2dDescriptor::m_PoolWidth, Pooling2dDescriptor::m_StrideX, Convolution2dDescriptor::m_StrideX, Pooling2dDescriptor::m_StrideY, Convolution2dDescriptor::m_StrideY, ReshapeDescriptor::m_TargetShape, TensorInfo::SetShape(), IOutputSlot::SetTensorInfo(), STR_LIST, armnnDeserializer::ToTensorInfo(), and VALID_INPUTS.

629 {
630  ResetParser();
631  ModelPtr modelProto = LoadModelFromString(protoText);
632  return CreateNetworkFromModel(*modelProto);
633 }
std::unique_ptr< onnx::ModelProto > ModelPtr
static ModelPtr LoadModelFromString(const std::string &inputString)
Definition: OnnxParser.cpp:609

◆ CreateNetworkFromTextFile()

INetworkPtr CreateNetworkFromTextFile ( const char *  graphFile)

Create the network from a protobuf text file on disk.

Definition at line 566 of file OnnxParser.cpp.

References OnnxParserImpl::LoadModelFromTextFile().

567 {
568  ResetParser();
569  ModelPtr modelProto = LoadModelFromTextFile(graphFile);
570  return CreateNetworkFromModel(*modelProto);
571 }
std::unique_ptr< onnx::ModelProto > ModelPtr
static ModelPtr LoadModelFromTextFile(const char *fileName)
Definition: OnnxParser.cpp:541

◆ GetInputs()

std::vector< std::string > GetInputs ( ModelPtr model)
static

Retrieve inputs names.

Definition at line 1816 of file OnnxParser.cpp.

References CHECK_LOCATION.

Referenced by BOOST_FIXTURE_TEST_CASE().

1817 {
1818  if(model == nullptr) {
1819  throw InvalidArgumentException(fmt::format("The given model cannot be null {}",
1820  CHECK_LOCATION().AsString()));
1821  }
1822 
1823  std::vector<std::string> inputNames;
1824  std::map<std::string, bool> isConstant;
1825  for(auto tensor : model->graph().initializer())
1826  {
1827  isConstant[tensor.name()] = true;
1828  }
1829  for(auto input : model->graph().input())
1830  {
1831  auto it = isConstant.find(input.name());
1832  if(it == isConstant.end())
1833  {
1834  inputNames.push_back(input.name());
1835  }
1836  }
1837  return inputNames;
1838 }
#define CHECK_LOCATION()
Definition: Exceptions.hpp:197

◆ GetNetworkInputBindingInfo()

BindingPointInfo GetNetworkInputBindingInfo ( const std::string &  name) const

Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name.

Definition at line 1788 of file OnnxParser.cpp.

References CHECK_LOCATION, and armnnDeserializer::ToTensorInfo().

1789 {
1790  for(int i = 0; i < m_Graph->input_size(); ++i)
1791  {
1792  auto input = m_Graph->input(i);
1793  if(input.name() == name)
1794  {
1795  return std::make_pair(static_cast<armnn::LayerBindingId>(i), ToTensorInfo(input));
1796  }
1797  }
1798  throw InvalidArgumentException(fmt::format("The input layer '{}' does not exist {}",
1799  name, CHECK_LOCATION().AsString()));
1800 }
#define CHECK_LOCATION()
Definition: Exceptions.hpp:197
armnn::TensorInfo ToTensorInfo(TensorRawPtr tensorPtr)

◆ GetNetworkOutputBindingInfo()

BindingPointInfo GetNetworkOutputBindingInfo ( const std::string &  name) const

Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name.

Definition at line 1802 of file OnnxParser.cpp.

References CHECK_LOCATION, and armnnDeserializer::ToTensorInfo().

1803 {
1804  for(int i = 0; i < m_Graph->output_size(); ++i)
1805  {
1806  auto output = m_Graph->output(i);
1807  if(output.name() == name)
1808  {
1809  return std::make_pair(static_cast<armnn::LayerBindingId>(i), ToTensorInfo(output));
1810  }
1811  }
1812  throw InvalidArgumentException(fmt::format("The output layer '{}' does not exist {}",
1813  name, CHECK_LOCATION().AsString()));
1814 }
#define CHECK_LOCATION()
Definition: Exceptions.hpp:197
armnn::TensorInfo ToTensorInfo(TensorRawPtr tensorPtr)

◆ GetOutputs()

std::vector< std::string > GetOutputs ( ModelPtr model)
static

Retrieve outputs names.

Definition at line 1840 of file OnnxParser.cpp.

References CHECK_LOCATION.

Referenced by BOOST_FIXTURE_TEST_CASE().

1841 {
1842  if(model == nullptr) {
1843  throw InvalidArgumentException(fmt::format("The given model cannot be null {}",
1844  CHECK_LOCATION().AsString()));
1845  }
1846 
1847  std::vector<std::string> outputNames;
1848  for(auto output : model->graph().output())
1849  {
1850  outputNames.push_back(output.name());
1851  }
1852  return outputNames;
1853 }
#define CHECK_LOCATION()
Definition: Exceptions.hpp:197

◆ GetVersion()

const std::string GetVersion ( )
static

Retrieve version in X.Y.Z form.

Definition at line 1855 of file OnnxParser.cpp.

References ONNX_PARSER_VERSION.

1856 {
1857  return ONNX_PARSER_VERSION;
1858 }
#define ONNX_PARSER_VERSION
ONNX_PARSER_VERSION: "X.Y.Z" where: X = Major version number Y = Minor version number Z = Patch versi...
Definition: Version.hpp:25

◆ LoadModelFromBinaryFile()

ModelPtr LoadModelFromBinaryFile ( const char *  fileName)
static

Definition at line 574 of file OnnxParser.cpp.

References CHECK_LOCATION, and armnn::error.

Referenced by OnnxParserImpl::CreateNetworkFromBinaryFile().

575 {
576  FILE* fd = fopen(graphFile, "rb");
577 
578  if (fd == nullptr)
579  {
580  throw FileNotFoundException(fmt::format("Invalid (null) filename {}", CHECK_LOCATION().AsString()));
581  }
582 
583  // Parse the file into a message
584  ModelPtr modelProto = std::make_unique<onnx::ModelProto>();
585 
586  google::protobuf::io::FileInputStream inStream(fileno(fd));
587  google::protobuf::io::CodedInputStream codedStream(&inStream);
588  codedStream.SetTotalBytesLimit(INT_MAX);
589  bool success = modelProto.get()->ParseFromCodedStream(&codedStream);
590  fclose(fd);
591 
592  if (!success)
593  {
594  std::stringstream error;
595  error << "Failed to parse graph file";
596  throw ParseException(fmt::format("{} {}", error.str(), CHECK_LOCATION().AsString()));
597  }
598  return modelProto;
599 
600 }
std::unique_ptr< onnx::ModelProto > ModelPtr
#define CHECK_LOCATION()
Definition: Exceptions.hpp:197

◆ LoadModelFromString()

ModelPtr LoadModelFromString ( const std::string &  inputString)
static

Definition at line 609 of file OnnxParser.cpp.

References CHECK_LOCATION, and armnn::error.

Referenced by BOOST_AUTO_TEST_CASE(), BOOST_FIXTURE_TEST_CASE(), and OnnxParserImpl::CreateNetworkFromString().

610 {
611  if (protoText == "")
612  {
613  throw InvalidArgumentException(fmt::format("Invalid (empty) string for model parameter {}",
614  CHECK_LOCATION().AsString()));
615  }
616  // Parse the string into a message
617  ModelPtr modelProto = std::make_unique<onnx::ModelProto>();
618  bool success = google::protobuf::TextFormat::ParseFromString(protoText, modelProto.get());
619  if (!success)
620  {
621  std::stringstream error;
622  error << "Failed to parse graph file";
623  throw ParseException(fmt::format("{} {}", error.str(), CHECK_LOCATION().AsString()));
624  }
625  return modelProto;
626 }
std::unique_ptr< onnx::ModelProto > ModelPtr
#define CHECK_LOCATION()
Definition: Exceptions.hpp:197

◆ LoadModelFromTextFile()

ModelPtr LoadModelFromTextFile ( const char *  fileName)
static

Definition at line 541 of file OnnxParser.cpp.

References CHECK_LOCATION, and armnn::error.

Referenced by OnnxParserImpl::CreateNetworkFromTextFile().

542 {
543  FILE* fd = fopen(graphFile, "r");
544 
545  if (fd == nullptr)
546  {
547  throw FileNotFoundException(fmt::format("Invalid (null) filename {}", CHECK_LOCATION().AsString()));
548  }
549 
550  // Parse the file into a message
551  ModelPtr modelProto = std::make_unique<onnx::ModelProto>();
552  using google::protobuf::io::FileInputStream;
553  std::unique_ptr<FileInputStream> input = std::make_unique<FileInputStream>(fileno(fd));
554  bool success = google::protobuf::TextFormat::Parse(input.get(), modelProto.get());
555  fclose(fd);
556 
557  if (!success)
558  {
559  std::stringstream error;
560  error << "Failed to parse graph file";
561  throw ParseException(fmt::format("{} {}", error.str(), CHECK_LOCATION().AsString()));
562  }
563  return modelProto;
564 }
std::unique_ptr< onnx::ModelProto > ModelPtr
#define CHECK_LOCATION()
Definition: Exceptions.hpp:197

◆ ValidateInputs()

void ValidateInputs ( const onnx::NodeProto &  node,
TypePair  validInputs,
const Location &  location 
)

Definition at line 432 of file OnnxParser.cpp.

435 {
436  for(auto input : node.input())
437  {
438  CheckValidDataType(validInputs.second,
439  m_TensorsInfo[input].m_dtype,
440  validInputs.first,
441  node.name(),
442  input,
443  location);
444  }
445 }

The documentation for this class was generated from the following files: