ArmNN
 21.08
OnnxParserImpl Class Reference

#include <OnnxParser.hpp>

Public Types

using GraphPtr = std::unique_ptr< onnx::GraphProto >
 

Public Member Functions

armnn::INetworkPtr CreateNetworkFromBinaryFile (const char *graphFile)
 Create the network from a protobuf binary file on disk. More...
 
armnn::INetworkPtr CreateNetworkFromTextFile (const char *graphFile)
 Create the network from a protobuf text file on disk. More...
 
armnn::INetworkPtr CreateNetworkFromString (const std::string &protoText)
 Create the network directly from protobuf text in a string. Useful for debugging/testing. More...
 
BindingPointInfo GetNetworkInputBindingInfo (const std::string &name) const
 Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name. More...
 
BindingPointInfo GetNetworkOutputBindingInfo (const std::string &name) const
 Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name. More...
 
 OnnxParserImpl ()
 
 ~OnnxParserImpl ()=default
 
template<typename TypePair , typename Location >
void ValidateInputs (const onnx::NodeProto &node, TypePair validInputs, const Location &location)
 

Static Public Member Functions

static ModelPtr LoadModelFromBinaryFile (const char *fileName)
 
static ModelPtr LoadModelFromTextFile (const char *fileName)
 
static ModelPtr LoadModelFromString (const std::string &inputString)
 
static std::vector< std::string > GetInputs (ModelPtr &model)
 Retrieve inputs names. More...
 
static std::vector< std::string > GetOutputs (ModelPtr &model)
 Retrieve outputs names. More...
 
static const std::string GetVersion ()
 Retrieve version in X.Y.Z form. More...
 

Detailed Description

Definition at line 25 of file OnnxParser.hpp.

Member Typedef Documentation

◆ GraphPtr

using GraphPtr = std::unique_ptr<onnx::GraphProto>

Definition at line 32 of file OnnxParser.hpp.

Constructor & Destructor Documentation

◆ OnnxParserImpl()

Definition at line 485 of file OnnxParser.cpp.

486  : m_Network(nullptr, nullptr)
487 {
488 }

◆ ~OnnxParserImpl()

~OnnxParserImpl ( )
default

Member Function Documentation

◆ CreateNetworkFromBinaryFile()

INetworkPtr CreateNetworkFromBinaryFile ( const char *  graphFile)

Create the network from a protobuf binary file on disk.

Definition at line 631 of file OnnxParser.cpp.

References OnnxParserImpl::LoadModelFromBinaryFile().

632 {
633  ResetParser();
634  ModelPtr modelProto = LoadModelFromBinaryFile(graphFile);
635  return CreateNetworkFromModel(*modelProto);
636 }
std::unique_ptr< onnx::ModelProto > ModelPtr
static ModelPtr LoadModelFromBinaryFile(const char *fileName)
Definition: OnnxParser.cpp:603

◆ CreateNetworkFromString()

INetworkPtr CreateNetworkFromString ( const std::string &  protoText)

Create the network directly from protobuf text in a string. Useful for debugging/testing.

Definition at line 657 of file OnnxParser.cpp.

References ARMNN_ASSERT, CHECK_LOCATION, CHECK_VALID_DATATYPE, CHECK_VALID_SIZE, CHECKED_INT32, IOutputSlot::Connect(), IConnectableLayer::GetInputSlot(), TensorShape::GetNumDimensions(), IConnectableLayer::GetNumInputSlots(), IConnectableLayer::GetNumOutputSlots(), IConnectableLayer::GetOutputSlot(), TensorInfo::GetShape(), OnnxParserImpl::LoadModelFromString(), ActivationDescriptor::m_A, ActivationDescriptor::m_B, FullyConnectedDescriptor::m_BiasEnabled, Convolution2dDescriptor::m_BiasEnabled, Convolution2dDescriptor::m_DilationX, Convolution2dDescriptor::m_DilationY, BatchNormalizationDescriptor::m_Eps, ActivationDescriptor::m_Function, Pooling2dDescriptor::m_OutputShapeRounding, Pooling2dDescriptor::m_PadBottom, Convolution2dDescriptor::m_PadBottom, Pooling2dDescriptor::m_PaddingMethod, Pooling2dDescriptor::m_PadLeft, Convolution2dDescriptor::m_PadLeft, DepthwiseConvolution2dDescriptor::m_PadLeft, Pooling2dDescriptor::m_PadRight, Convolution2dDescriptor::m_PadRight, Pooling2dDescriptor::m_PadTop, Convolution2dDescriptor::m_PadTop, Pooling2dDescriptor::m_PoolHeight, Pooling2dDescriptor::m_PoolType, Pooling2dDescriptor::m_PoolWidth, Pooling2dDescriptor::m_StrideX, Convolution2dDescriptor::m_StrideX, Pooling2dDescriptor::m_StrideY, Convolution2dDescriptor::m_StrideY, ReshapeDescriptor::m_TargetShape, TensorInfo::SetConstant(), TensorInfo::SetShape(), IOutputSlot::SetTensorInfo(), STR_LIST, armnnDeserializer::ToTensorInfo(), and VALID_INPUTS.

658 {
659  ResetParser();
660  ModelPtr modelProto = LoadModelFromString(protoText);
661  return CreateNetworkFromModel(*modelProto);
662 }
std::unique_ptr< onnx::ModelProto > ModelPtr
static ModelPtr LoadModelFromString(const std::string &inputString)
Definition: OnnxParser.cpp:638

◆ CreateNetworkFromTextFile()

INetworkPtr CreateNetworkFromTextFile ( const char *  graphFile)

Create the network from a protobuf text file on disk.

Definition at line 595 of file OnnxParser.cpp.

References OnnxParserImpl::LoadModelFromTextFile().

596 {
597  ResetParser();
598  ModelPtr modelProto = LoadModelFromTextFile(graphFile);
599  return CreateNetworkFromModel(*modelProto);
600 }
std::unique_ptr< onnx::ModelProto > ModelPtr
static ModelPtr LoadModelFromTextFile(const char *fileName)
Definition: OnnxParser.cpp:570

◆ GetInputs()

std::vector< std::string > GetInputs ( ModelPtr model)
static

Retrieve inputs names.

Definition at line 1870 of file OnnxParser.cpp.

References CHECK_LOCATION.

Referenced by TEST_SUITE().

1871 {
1872  if(model == nullptr) {
1873  throw InvalidArgumentException(fmt::format("The given model cannot be null {}",
1874  CHECK_LOCATION().AsString()));
1875  }
1876 
1877  std::vector<std::string> inputNames;
1878  std::map<std::string, bool> isConstant;
1879  for(auto tensor : model->graph().initializer())
1880  {
1881  isConstant[tensor.name()] = true;
1882  }
1883  for(auto input : model->graph().input())
1884  {
1885  auto it = isConstant.find(input.name());
1886  if(it == isConstant.end())
1887  {
1888  inputNames.push_back(input.name());
1889  }
1890  }
1891  return inputNames;
1892 }
#define CHECK_LOCATION()
Definition: Exceptions.hpp:197

◆ GetNetworkInputBindingInfo()

BindingPointInfo GetNetworkInputBindingInfo ( const std::string &  name) const

Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name.

Definition at line 1842 of file OnnxParser.cpp.

References CHECK_LOCATION, and armnnDeserializer::ToTensorInfo().

1843 {
1844  for(int i = 0; i < m_Graph->input_size(); ++i)
1845  {
1846  auto input = m_Graph->input(i);
1847  if(input.name() == name)
1848  {
1849  return std::make_pair(static_cast<armnn::LayerBindingId>(i), ToTensorInfo(input));
1850  }
1851  }
1852  throw InvalidArgumentException(fmt::format("The input layer '{}' does not exist {}",
1853  name, CHECK_LOCATION().AsString()));
1854 }
#define CHECK_LOCATION()
Definition: Exceptions.hpp:197
armnn::TensorInfo ToTensorInfo(TensorRawPtr tensorPtr)

◆ GetNetworkOutputBindingInfo()

BindingPointInfo GetNetworkOutputBindingInfo ( const std::string &  name) const

Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name.

Definition at line 1856 of file OnnxParser.cpp.

References CHECK_LOCATION, and armnnDeserializer::ToTensorInfo().

1857 {
1858  for(int i = 0; i < m_Graph->output_size(); ++i)
1859  {
1860  auto output = m_Graph->output(i);
1861  if(output.name() == name)
1862  {
1863  return std::make_pair(static_cast<armnn::LayerBindingId>(i), ToTensorInfo(output));
1864  }
1865  }
1866  throw InvalidArgumentException(fmt::format("The output layer '{}' does not exist {}",
1867  name, CHECK_LOCATION().AsString()));
1868 }
#define CHECK_LOCATION()
Definition: Exceptions.hpp:197
armnn::TensorInfo ToTensorInfo(TensorRawPtr tensorPtr)

◆ GetOutputs()

std::vector< std::string > GetOutputs ( ModelPtr model)
static

Retrieve outputs names.

Definition at line 1894 of file OnnxParser.cpp.

References CHECK_LOCATION.

Referenced by TEST_SUITE().

1895 {
1896  if(model == nullptr) {
1897  throw InvalidArgumentException(fmt::format("The given model cannot be null {}",
1898  CHECK_LOCATION().AsString()));
1899  }
1900 
1901  std::vector<std::string> outputNames;
1902  for(auto output : model->graph().output())
1903  {
1904  outputNames.push_back(output.name());
1905  }
1906  return outputNames;
1907 }
#define CHECK_LOCATION()
Definition: Exceptions.hpp:197

◆ GetVersion()

const std::string GetVersion ( )
static

Retrieve version in X.Y.Z form.

Definition at line 1909 of file OnnxParser.cpp.

References ONNX_PARSER_VERSION.

1910 {
1911  return ONNX_PARSER_VERSION;
1912 }
#define ONNX_PARSER_VERSION
ONNX_PARSER_VERSION: "X.Y.Z" where: X = Major version number Y = Minor version number Z = Patch versi...
Definition: Version.hpp:25

◆ LoadModelFromBinaryFile()

ModelPtr LoadModelFromBinaryFile ( const char *  fileName)
static

Definition at line 603 of file OnnxParser.cpp.

References CHECK_LOCATION, and armnn::error.

Referenced by OnnxParserImpl::CreateNetworkFromBinaryFile().

604 {
605  FILE* fd = fopen(graphFile, "rb");
606 
607  if (fd == nullptr)
608  {
609  throw FileNotFoundException(fmt::format("Invalid (null) filename {}", CHECK_LOCATION().AsString()));
610  }
611 
612  // Parse the file into a message
613  ModelPtr modelProto = std::make_unique<onnx::ModelProto>();
614 
615  google::protobuf::io::FileInputStream inStream(fileno(fd));
616  google::protobuf::io::CodedInputStream codedStream(&inStream);
617  codedStream.SetTotalBytesLimit(INT_MAX);
618  bool success = modelProto.get()->ParseFromCodedStream(&codedStream);
619  fclose(fd);
620 
621  if (!success)
622  {
623  std::stringstream error;
624  error << "Failed to parse graph file";
625  throw ParseException(fmt::format("{} {}", error.str(), CHECK_LOCATION().AsString()));
626  }
627  return modelProto;
628 
629 }
std::unique_ptr< onnx::ModelProto > ModelPtr
#define CHECK_LOCATION()
Definition: Exceptions.hpp:197

◆ LoadModelFromString()

ModelPtr LoadModelFromString ( const std::string &  inputString)
static

Definition at line 638 of file OnnxParser.cpp.

References CHECK_LOCATION, and armnn::error.

Referenced by OnnxParserImpl::CreateNetworkFromString(), and TEST_SUITE().

639 {
640  if (protoText == "")
641  {
642  throw InvalidArgumentException(fmt::format("Invalid (empty) string for model parameter {}",
643  CHECK_LOCATION().AsString()));
644  }
645  // Parse the string into a message
646  ModelPtr modelProto = std::make_unique<onnx::ModelProto>();
647  bool success = google::protobuf::TextFormat::ParseFromString(protoText, modelProto.get());
648  if (!success)
649  {
650  std::stringstream error;
651  error << "Failed to parse graph file";
652  throw ParseException(fmt::format("{} {}", error.str(), CHECK_LOCATION().AsString()));
653  }
654  return modelProto;
655 }
std::unique_ptr< onnx::ModelProto > ModelPtr
#define CHECK_LOCATION()
Definition: Exceptions.hpp:197

◆ LoadModelFromTextFile()

ModelPtr LoadModelFromTextFile ( const char *  fileName)
static

Definition at line 570 of file OnnxParser.cpp.

References CHECK_LOCATION, and armnn::error.

Referenced by OnnxParserImpl::CreateNetworkFromTextFile().

571 {
572  FILE* fd = fopen(graphFile, "r");
573 
574  if (fd == nullptr)
575  {
576  throw FileNotFoundException(fmt::format("Invalid (null) filename {}", CHECK_LOCATION().AsString()));
577  }
578 
579  // Parse the file into a message
580  ModelPtr modelProto = std::make_unique<onnx::ModelProto>();
581  using google::protobuf::io::FileInputStream;
582  std::unique_ptr<FileInputStream> input = std::make_unique<FileInputStream>(fileno(fd));
583  bool success = google::protobuf::TextFormat::Parse(input.get(), modelProto.get());
584  fclose(fd);
585 
586  if (!success)
587  {
588  std::stringstream error;
589  error << "Failed to parse graph file";
590  throw ParseException(fmt::format("{} {}", error.str(), CHECK_LOCATION().AsString()));
591  }
592  return modelProto;
593 }
std::unique_ptr< onnx::ModelProto > ModelPtr
#define CHECK_LOCATION()
Definition: Exceptions.hpp:197

◆ ValidateInputs()

void ValidateInputs ( const onnx::NodeProto &  node,
TypePair  validInputs,
const Location &  location 
)

Definition at line 433 of file OnnxParser.cpp.

436 {
437  for(auto input : node.input())
438  {
439  CheckValidDataType(validInputs.second,
440  m_TensorsInfo[input].m_dtype,
441  validInputs.first,
442  node.name(),
443  input,
444  location);
445  }
446 }

The documentation for this class was generated from the following files: