ArmNN
 22.05.01
OnnxParserImpl Class Reference

#include <OnnxParser.hpp>

Public Types

using GraphPtr = std::unique_ptr< onnx::GraphProto >
 

Public Member Functions

armnn::INetworkPtr CreateNetworkFromBinaryFile (const char *graphFile)
 Create the network from a protobuf binary file on disk. More...
 
armnn::INetworkPtr CreateNetworkFromBinaryFile (const char *graphFile, const std::map< std::string, armnn::TensorShape > &inputShapes)
 Create the network from a protobuf binary file on disk, with inputShapes specified. More...
 
armnn::INetworkPtr CreateNetworkFromTextFile (const char *graphFile)
 Create the network from a protobuf text file on disk. More...
 
armnn::INetworkPtr CreateNetworkFromTextFile (const char *graphFile, const std::map< std::string, armnn::TensorShape > &inputShapes)
 Create the network from a protobuf text file on disk, with inputShapes specified. More...
 
armnn::INetworkPtr CreateNetworkFromString (const std::string &protoText)
 Create the network directly from protobuf text in a string. Useful for debugging/testing. More...
 
armnn::INetworkPtr CreateNetworkFromString (const std::string &protoText, const std::map< std::string, armnn::TensorShape > &inputShapes)
 Create the network directly from protobuf text in a string, with inputShapes specified. More...
 
BindingPointInfo GetNetworkInputBindingInfo (const std::string &name) const
 Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name. More...
 
BindingPointInfo GetNetworkOutputBindingInfo (const std::string &name) const
 Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name. More...
 
 OnnxParserImpl ()
 
 ~OnnxParserImpl ()=default
 
template<typename TypePair , typename Location >
void ValidateInputs (const onnx::NodeProto &node, TypePair validInputs, const Location &location)
 

Static Public Member Functions

static ModelPtr LoadModelFromBinaryFile (const char *fileName)
 
static ModelPtr LoadModelFromTextFile (const char *fileName)
 
static ModelPtr LoadModelFromString (const std::string &inputString)
 
static std::vector< std::string > GetInputs (ModelPtr &model)
 Retrieve inputs names. More...
 
static std::vector< std::string > GetOutputs (ModelPtr &model)
 Retrieve outputs names. More...
 
static const std::string GetVersion ()
 Retrieve version in X.Y.Z form. More...
 

Detailed Description

Definition at line 25 of file OnnxParser.hpp.

Member Typedef Documentation

◆ GraphPtr

using GraphPtr = std::unique_ptr<onnx::GraphProto>

Definition at line 32 of file OnnxParser.hpp.

Constructor & Destructor Documentation

◆ OnnxParserImpl()

Definition at line 542 of file OnnxParser.cpp.

543  : m_Network(nullptr, nullptr)
544 {
545 }

◆ ~OnnxParserImpl()

~OnnxParserImpl ( )
default

Member Function Documentation

◆ CreateNetworkFromBinaryFile() [1/2]

INetworkPtr CreateNetworkFromBinaryFile ( const char *  graphFile)

Create the network from a protobuf binary file on disk.

Definition at line 762 of file OnnxParser.cpp.

References OnnxParserImpl::LoadModelFromBinaryFile().

763 {
764  ResetParser();
765  ModelPtr modelProto = LoadModelFromBinaryFile(graphFile);
766  return CreateNetworkFromModel(*modelProto);
767 }
std::unique_ptr< onnx::ModelProto > ModelPtr
static ModelPtr LoadModelFromBinaryFile(const char *fileName)
Definition: OnnxParser.cpp:734

◆ CreateNetworkFromBinaryFile() [2/2]

armnn::INetworkPtr CreateNetworkFromBinaryFile ( const char *  graphFile,
const std::map< std::string, armnn::TensorShape > &  inputShapes 
)

Create the network from a protobuf binary file on disk, with inputShapes specified.

◆ CreateNetworkFromString() [1/2]

INetworkPtr CreateNetworkFromString ( const std::string &  protoText)

Create the network directly from protobuf text in a string. Useful for debugging/testing.

Definition at line 797 of file OnnxParser.cpp.

References ARMNN_ASSERT, ARMNN_NO_DEPRECATE_WARN_BEGIN, ARMNN_NO_DEPRECATE_WARN_END, CHECK_LOCATION, CHECK_VALID_DATATYPE, CHECK_VALID_SIZE, CHECKED_INT32, CHECKED_NON_NEGATIVE, IOutputSlot::Connect(), TensorShape::GetDimensionality(), IConnectableLayer::GetInputSlot(), TensorShape::GetNumDimensions(), IConnectableLayer::GetNumInputSlots(), IConnectableLayer::GetNumOutputSlots(), IConnectableLayer::GetOutputSlot(), TensorInfo::GetShape(), OnnxParserImpl::LoadModelFromString(), ActivationDescriptor::m_A, GatherDescriptor::m_Axis, ActivationDescriptor::m_B, FullyConnectedDescriptor::m_BiasEnabled, Convolution2dDescriptor::m_BiasEnabled, Convolution2dDescriptor::m_DilationX, Convolution2dDescriptor::m_DilationY, TransposeDescriptor::m_DimMappings, BatchNormalizationDescriptor::m_Eps, ActivationDescriptor::m_Function, Pooling2dDescriptor::m_OutputShapeRounding, Pooling2dDescriptor::m_PadBottom, Convolution2dDescriptor::m_PadBottom, Pooling2dDescriptor::m_PaddingMethod, Pooling2dDescriptor::m_PadLeft, Convolution2dDescriptor::m_PadLeft, DepthwiseConvolution2dDescriptor::m_PadLeft, Pooling2dDescriptor::m_PadRight, Convolution2dDescriptor::m_PadRight, Pooling2dDescriptor::m_PadTop, Convolution2dDescriptor::m_PadTop, Pooling2dDescriptor::m_PoolHeight, Pooling2dDescriptor::m_PoolType, Pooling2dDescriptor::m_PoolWidth, Pooling2dDescriptor::m_StrideX, Convolution2dDescriptor::m_StrideX, Pooling2dDescriptor::m_StrideY, Convolution2dDescriptor::m_StrideY, ReshapeDescriptor::m_TargetShape, FullyConnectedDescriptor::m_TransposeWeightMatrix, armnn::numeric_cast(), armnnUtils::Permuted(), armnnUtils::ProcessConcatInputTensorInfo(), OriginsDescriptor::SetConcatAxis(), TensorInfo::SetConstant(), TensorInfo::SetShape(), IOutputSlot::SetTensorInfo(), STR_LIST, armnnDeserializer::ToTensorInfo(), and VALID_INPUTS.

798 {
799  ResetParser();
800  ModelPtr modelProto = LoadModelFromString(protoText);
801  return CreateNetworkFromModel(*modelProto);
802 }
std::unique_ptr< onnx::ModelProto > ModelPtr
static ModelPtr LoadModelFromString(const std::string &inputString)
Definition: OnnxParser.cpp:778

◆ CreateNetworkFromString() [2/2]

armnn::INetworkPtr CreateNetworkFromString ( const std::string &  protoText,
const std::map< std::string, armnn::TensorShape > &  inputShapes 
)

Create the network directly from protobuf text in a string, with inputShapes specified.

Useful for debugging/testing

◆ CreateNetworkFromTextFile() [1/2]

INetworkPtr CreateNetworkFromTextFile ( const char *  graphFile)

Create the network from a protobuf text file on disk.

Definition at line 718 of file OnnxParser.cpp.

References OnnxParserImpl::LoadModelFromTextFile().

719 {
720  ResetParser();
721  ModelPtr modelProto = LoadModelFromTextFile(graphFile);
722  return CreateNetworkFromModel(*modelProto);
723 }
std::unique_ptr< onnx::ModelProto > ModelPtr
static ModelPtr LoadModelFromTextFile(const char *fileName)
Definition: OnnxParser.cpp:693

◆ CreateNetworkFromTextFile() [2/2]

armnn::INetworkPtr CreateNetworkFromTextFile ( const char *  graphFile,
const std::map< std::string, armnn::TensorShape > &  inputShapes 
)

Create the network from a protobuf text file on disk, with inputShapes specified.

◆ GetInputs()

std::vector< std::string > GetInputs ( ModelPtr model)
static

Retrieve inputs names.

Definition at line 2443 of file OnnxParser.cpp.

References CHECK_LOCATION.

Referenced by TEST_SUITE().

2444 {
2445  if(model == nullptr) {
2446  throw InvalidArgumentException(fmt::format("The given model cannot be null {}",
2447  CHECK_LOCATION().AsString()));
2448  }
2449 
2450  std::vector<std::string> inputNames;
2451  std::map<std::string, bool> isConstant;
2452  for(auto tensor : model->graph().initializer())
2453  {
2454  isConstant[tensor.name()] = true;
2455  }
2456  for(auto input : model->graph().input())
2457  {
2458  auto it = isConstant.find(input.name());
2459  if(it == isConstant.end())
2460  {
2461  inputNames.push_back(input.name());
2462  }
2463  }
2464  return inputNames;
2465 }
#define CHECK_LOCATION()
Definition: Exceptions.hpp:203

◆ GetNetworkInputBindingInfo()

BindingPointInfo GetNetworkInputBindingInfo ( const std::string &  name) const

Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name.

Definition at line 2405 of file OnnxParser.cpp.

References CHECK_LOCATION.

2406 {
2407  for(int i = 0; i < m_Graph->input_size(); ++i)
2408  {
2409  auto input = m_Graph->input(i);
2410  if(input.name() == name)
2411  {
2412  auto it = m_InputInfos.find(name);
2413 
2414  if (it != m_InputInfos.end())
2415  {
2416  return std::make_pair(static_cast<armnn::LayerBindingId>(i), it->second);
2417  }
2418  }
2419  }
2420  throw InvalidArgumentException(fmt::format("The input layer '{}' does not exist {}",
2421  name, CHECK_LOCATION().AsString()));
2422 }
#define CHECK_LOCATION()
Definition: Exceptions.hpp:203

◆ GetNetworkOutputBindingInfo()

BindingPointInfo GetNetworkOutputBindingInfo ( const std::string &  name) const

Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name.

Definition at line 2424 of file OnnxParser.cpp.

References CHECK_LOCATION.

2425 {
2426  for(int i = 0; i < m_Graph->output_size(); ++i)
2427  {
2428  auto output = m_Graph->output(i);
2429  if(output.name() == name)
2430  {
2431  auto it = m_OutputInfos.find(name);
2432 
2433  if (it != m_OutputInfos.end())
2434  {
2435  return std::make_pair(static_cast<armnn::LayerBindingId>(i), it->second);
2436  }
2437  }
2438  }
2439  throw InvalidArgumentException(fmt::format("The output layer '{}' does not exist {}",
2440  name, CHECK_LOCATION().AsString()));
2441 }
#define CHECK_LOCATION()
Definition: Exceptions.hpp:203

◆ GetOutputs()

std::vector< std::string > GetOutputs ( ModelPtr model)
static

Retrieve outputs names.

Definition at line 2467 of file OnnxParser.cpp.

References CHECK_LOCATION.

Referenced by TEST_SUITE().

2468 {
2469  if(model == nullptr) {
2470  throw InvalidArgumentException(fmt::format("The given model cannot be null {}",
2471  CHECK_LOCATION().AsString()));
2472  }
2473 
2474  std::vector<std::string> outputNames;
2475  for(auto output : model->graph().output())
2476  {
2477  outputNames.push_back(output.name());
2478  }
2479  return outputNames;
2480 }
#define CHECK_LOCATION()
Definition: Exceptions.hpp:203

◆ GetVersion()

const std::string GetVersion ( )
static

Retrieve version in X.Y.Z form.

Definition at line 2482 of file OnnxParser.cpp.

References ONNX_PARSER_VERSION.

2483 {
2484  return ONNX_PARSER_VERSION;
2485 }
#define ONNX_PARSER_VERSION
ONNX_PARSER_VERSION: "X.Y.Z" where: X = Major version number Y = Minor version number Z = Patch versi...
Definition: Version.hpp:25

◆ LoadModelFromBinaryFile()

ModelPtr LoadModelFromBinaryFile ( const char *  fileName)
static

Definition at line 734 of file OnnxParser.cpp.

References CHECK_LOCATION, and armnn::error.

Referenced by OnnxParserImpl::CreateNetworkFromBinaryFile().

735 {
736  FILE* fd = fopen(graphFile, "rb");
737 
738  if (fd == nullptr)
739  {
740  throw FileNotFoundException(fmt::format("Invalid (null) filename {}", CHECK_LOCATION().AsString()));
741  }
742 
743  // Parse the file into a message
744  ModelPtr modelProto = std::make_unique<onnx::ModelProto>();
745 
746  google::protobuf::io::FileInputStream inStream(fileno(fd));
747  google::protobuf::io::CodedInputStream codedStream(&inStream);
748  codedStream.SetTotalBytesLimit(INT_MAX);
749  bool success = modelProto.get()->ParseFromCodedStream(&codedStream);
750  fclose(fd);
751 
752  if (!success)
753  {
754  std::stringstream error;
755  error << "Failed to parse graph file";
756  throw ParseException(fmt::format("{} {}", error.str(), CHECK_LOCATION().AsString()));
757  }
758  return modelProto;
759 
760 }
std::unique_ptr< onnx::ModelProto > ModelPtr
#define CHECK_LOCATION()
Definition: Exceptions.hpp:203

◆ LoadModelFromString()

ModelPtr LoadModelFromString ( const std::string &  inputString)
static

Definition at line 778 of file OnnxParser.cpp.

References CHECK_LOCATION, and armnn::error.

Referenced by OnnxParserImpl::CreateNetworkFromString(), and TEST_SUITE().

779 {
780  if (protoText == "")
781  {
782  throw InvalidArgumentException(fmt::format("Invalid (empty) string for model parameter {}",
783  CHECK_LOCATION().AsString()));
784  }
785  // Parse the string into a message
786  ModelPtr modelProto = std::make_unique<onnx::ModelProto>();
787  bool success = google::protobuf::TextFormat::ParseFromString(protoText, modelProto.get());
788  if (!success)
789  {
790  std::stringstream error;
791  error << "Failed to parse graph file";
792  throw ParseException(fmt::format("{} {}", error.str(), CHECK_LOCATION().AsString()));
793  }
794  return modelProto;
795 }
std::unique_ptr< onnx::ModelProto > ModelPtr
#define CHECK_LOCATION()
Definition: Exceptions.hpp:203

◆ LoadModelFromTextFile()

ModelPtr LoadModelFromTextFile ( const char *  fileName)
static

Definition at line 693 of file OnnxParser.cpp.

References CHECK_LOCATION, and armnn::error.

Referenced by OnnxParserImpl::CreateNetworkFromTextFile().

694 {
695  FILE* fd = fopen(graphFile, "r");
696 
697  if (fd == nullptr)
698  {
699  throw FileNotFoundException(fmt::format("Invalid (null) filename {}", CHECK_LOCATION().AsString()));
700  }
701 
702  // Parse the file into a message
703  ModelPtr modelProto = std::make_unique<onnx::ModelProto>();
704  using google::protobuf::io::FileInputStream;
705  std::unique_ptr<FileInputStream> input = std::make_unique<FileInputStream>(fileno(fd));
706  bool success = google::protobuf::TextFormat::Parse(input.get(), modelProto.get());
707  fclose(fd);
708 
709  if (!success)
710  {
711  std::stringstream error;
712  error << "Failed to parse graph file";
713  throw ParseException(fmt::format("{} {}", error.str(), CHECK_LOCATION().AsString()));
714  }
715  return modelProto;
716 }
std::unique_ptr< onnx::ModelProto > ModelPtr
#define CHECK_LOCATION()
Definition: Exceptions.hpp:203

◆ ValidateInputs()

void ValidateInputs ( const onnx::NodeProto &  node,
TypePair  validInputs,
const Location &  location 
)

Definition at line 467 of file OnnxParser.cpp.

470 {
471  for(auto input : node.input())
472  {
473  CheckValidDataType(validInputs.second,
474  m_TensorsInfo[input].m_dtype,
475  validInputs.first,
476  node.name(),
477  input,
478  location);
479  }
480 }

The documentation for this class was generated from the following files: