ArmNN  NotReleased
OnnxParser.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
8 #include "google/protobuf/repeated_field.h"
9 #include <unordered_map>
10 
11 #include <onnx/onnx.pb.h>
12 
13 
14 namespace armnn
15 {
16 class TensorInfo;
17 enum class ActivationFunction;
18 }
19 
20 namespace armnnOnnxParser
21 {
22 
23 using ModelPtr = std::unique_ptr<onnx::ModelProto>;
24 
25 class OnnxParser : public IOnnxParser
26 {
27 
28 using OperationParsingFunction = void(OnnxParser::*)(const onnx::NodeProto& NodeProto);
29 
30 public:
31 
32  using GraphPtr = std::unique_ptr<onnx::GraphProto>;
33 
35  virtual armnn::INetworkPtr CreateNetworkFromBinaryFile(const char* graphFile) override;
36 
38  virtual armnn::INetworkPtr CreateNetworkFromTextFile(const char* graphFile) override;
39 
41  virtual armnn::INetworkPtr CreateNetworkFromString(const std::string& protoText) override;
42 
44  virtual BindingPointInfo GetNetworkInputBindingInfo(const std::string& name) const override;
45 
47  virtual BindingPointInfo GetNetworkOutputBindingInfo(const std::string& name) const override;
48 
49 public:
50 
51  OnnxParser();
52 
53  static ModelPtr LoadModelFromBinaryFile(const char * fileName);
54  static ModelPtr LoadModelFromTextFile(const char * fileName);
55  static ModelPtr LoadModelFromString(const std::string& inputString);
56 
58  static std::vector<std::string> GetInputs(ModelPtr& model);
59 
61  static std::vector<std::string> GetOutputs(ModelPtr& model);
62 
63 private:
64 
66  armnn::INetworkPtr CreateNetworkFromModel(onnx::ModelProto& model);
67 
69  void LoadGraph();
70 
71  void SetupInfo(const google::protobuf::RepeatedPtrField<onnx::ValueInfoProto >* list);
72 
73  std::vector<armnn::TensorInfo> ComputeOutputInfo(std::vector<std::string> outNames,
74  const armnn::IConnectableLayer* layer,
75  std::vector<armnn::TensorShape> inputShapes);
76 
77  void DetectFullyConnected();
78 
79  template <typename Location>
80  void GetInputAndParam(const onnx::NodeProto& node,
81  std::string* inputName,
82  std::string* constName,
83  const Location& location);
84 
85  template <typename Location>
86  void To1DTensor(const std::string &name, const Location& location);
87 
88  //Broadcast Preparation functions
89  std::pair<std::string, std::string> AddPrepareBroadcast(const std::string& input0, const std::string& input1);
90  void PrependForBroadcast(const std::string& outputName, const std::string& input0, const std::string& input1);
91 
92  void CreateConstantLayer(const std::string& tensorName, const std::string& layerName);
93  void CreateReshapeLayer(const std::string& inputName,
94  const std::string& outputName,
95  const std::string& layerName);
96 
97  void ParseBatchNormalization(const onnx::NodeProto& node);
98  void ParseConstant(const onnx::NodeProto& nodeProto);
99 
100  void ParseMaxPool(const onnx::NodeProto& nodeProto);
101  void ParseAveragePool(const onnx::NodeProto& nodeProto);
102  void ParseGlobalAveragePool(const onnx::NodeProto& node);
103 
104  void AddPoolingLayer(const onnx::NodeProto& nodeProto, armnn::Pooling2dDescriptor& desc);
105 
106  void ParseReshape(const onnx::NodeProto& nodeProto);
107 
108  void ParseActivation(const onnx::NodeProto& nodeProto, const armnn::ActivationFunction func);
109  void ParseSigmoid(const onnx::NodeProto& nodeProto);
110  void ParseTanh(const onnx::NodeProto& nodeProto);
111  void ParseRelu(const onnx::NodeProto& nodeProto);
112  void ParseLeakyRelu(const onnx::NodeProto& nodeProto);
113 
114  void AddConvLayerWithDepthwiseConv(const onnx::NodeProto& node, const armnn::Convolution2dDescriptor& convDesc);
115  void ParseConv(const onnx::NodeProto& nodeProto);
116 
117  void ParseAdd(const onnx::NodeProto& nodeProto);
118  void AddFullyConnected(const onnx::NodeProto& matmulNode, const onnx::NodeProto* addNode = nullptr);
119 
120  void RegisterInputSlots(armnn::IConnectableLayer* layer, const std::vector<std::string>& tensorIndexes);
121  void RegisterOutputSlots(armnn::IConnectableLayer* layer, const std::vector<std::string>& tensorIndexes);
122 
123  void SetupInputLayers();
124  void SetupOutputLayers();
125 
126  void ResetParser();
127  void Cleanup();
128 
129  std::pair<armnn::ConstTensor, std::unique_ptr<float[]>> CreateConstTensor(const std::string name);
130 
131  template <typename TypeList, typename Location>
132  void ValidateInputs(const onnx::NodeProto& node,
133  TypeList validInputs,
134  const Location& location);
135 
137  armnn::INetworkPtr m_Network;
138 
140  GraphPtr m_Graph;
141 
143  struct OnnxTensor
144  {
145  std::unique_ptr<armnn::TensorInfo> m_info;
146  std::unique_ptr<const onnx::TensorProto> m_tensor;
148 
149  OnnxTensor() : m_info(nullptr), m_tensor(nullptr), m_dtype(onnx::TensorProto::FLOAT) { }
150  bool isConstant() { return m_tensor != nullptr; }
151  };
152 
153  std::unordered_map<std::string, OnnxTensor> m_TensorsInfo;
154 
156  static const std::map<std::string, OperationParsingFunction> m_ParserFunctions;
157 
161  struct TensorSlots
162  {
163  armnn::IOutputSlot* outputSlot;
164  std::vector<armnn::IInputSlot*> inputSlots;
165 
166  TensorSlots() : outputSlot(nullptr) { }
167  };
169  std::unordered_map<std::string, TensorSlots> m_TensorConnections;
170 
172  std::unordered_map<std::string, std::pair<const onnx::NodeProto*, int>> m_OutputsMap;
173 
176  struct UsageSummary
177  {
178  std::vector<size_t> fusedWithNodes;
179  size_t inputForNodes;
180 
181  UsageSummary() : fusedWithNodes({}), inputForNodes(0) { }
182 
183  };
184 
185  std::vector<UsageSummary> m_OutputsFusedAndUsed;
186 };
187 }
std::unique_ptr< onnx::ModelProto > ModelPtr
Definition: OnnxParser.hpp:23
ActivationFunction
Definition: Types.hpp:54
std::unique_ptr< onnx::GraphProto > GraphPtr
Definition: OnnxParser.hpp:32
std::unique_ptr< INetwork, void(*)(INetwork *network)> INetworkPtr
Definition: INetwork.hpp:85
armnn::BindingPointInfo BindingPointInfo
Definition: IOnnxParser.hpp:17
An output connection slot for a layer. The output slot may be connected to 1 or more input slots of s...
Definition: INetwork.hpp:37
DataType
Definition: Types.hpp:32
Interface for a layer that is connectable to other layers via InputSlots and OutputSlots.
Definition: INetwork.hpp:61
A Pooling2dDescriptor for the Pooling2dLayer.
A Convolution2dDescriptor for the Convolution2dLayer.