16 #include <unordered_map> 34 class ParsedTfOperation;
52 : m_IndexedValue{value}
56 : m_IndexedValue{value}
69 const char* graphFile,
70 const std::map<std::string, armnn::TensorShape>& inputShapes,
71 const std::vector<std::string>& requestedOutputs)
override;
75 const char* graphFile,
76 const std::map<std::string, armnn::TensorShape>& inputShapes,
77 const std::vector<std::string>& requestedOutputs)
override;
81 const char* protoText,
82 const std::map<std::string, armnn::TensorShape>& inputShapes,
83 const std::vector<std::string>& requestedOutputs)
override;
86 virtual BindingPointInfo GetNetworkInputBindingInfo(
const std::string& name)
const override;
89 virtual BindingPointInfo GetNetworkOutputBindingInfo(
const std::string& name)
const override;
96 friend class ParsedConstTfOperation;
97 friend class ParsedMatMulTfOperation;
98 friend class ParsedMulTfOperation;
102 const std::map<std::string, armnn::TensorShape>& inputShapes,
103 const std::vector<std::string>& requestedOutputs);
106 void LoadGraphDef(
const tensorflow::GraphDef& graphDef);
109 void LoadNodeDef(
const tensorflow::NodeDef& nodeDef,
const tensorflow::GraphDef& graphDef);
112 const tensorflow::NodeDef* ResolveIdentityNode(
const tensorflow::NodeDef* nodeDef);
114 std::vector<OutputOfConstNodeDef> GetTfInputNodes(
const tensorflow::NodeDef& nodeDef)
const;
119 std::vector<OutputOfParsedTfOperation> GetInputParsedTfOperationsChecked(
const tensorflow::NodeDef& nodeDef,
120 std::size_t expectedNumInputs);
122 ParsedTfOperationPtr ParseConst(
const tensorflow::NodeDef& nodeDef,
const tensorflow::GraphDef& graphDef);
125 template<
typename Type>
126 bool HasParsedConstTensor(
const std::string & nodeName)
const;
127 template<
typename Type>
128 bool HasParsedConstTensor(ParsedTfOperation* parsedTfOpPtr)
const;
130 unsigned int GetConstInputIndex(
const std::vector<OutputOfParsedTfOperation>& inputs);
132 ParsedTfOperationPtr ParseAdd(
const tensorflow::NodeDef& nodeDef,
const tensorflow::GraphDef& graphDef);
133 ParsedTfOperationPtr ParseAddN(
const tensorflow::NodeDef& nodeDef,
const tensorflow::GraphDef& graphDef);
134 ParsedTfOperationPtr ParseBiasAdd(
const tensorflow::NodeDef& nodeDef,
const tensorflow::GraphDef& graphDef);
135 ParsedTfOperationPtr ParseConv2D(
const tensorflow::NodeDef& nodeDef,
const tensorflow::GraphDef& graphDef);
137 const tensorflow::GraphDef& graphDef);
138 ParsedTfOperationPtr ParseExpandDims(
const tensorflow::NodeDef& nodeDef,
const tensorflow::GraphDef& graphDef);
139 ParsedTfOperationPtr ParseFusedBatchNorm(
const tensorflow::NodeDef& nodeDef,
const tensorflow::GraphDef& graphDef);
140 ParsedTfOperationPtr ParseConcat(
const tensorflow::NodeDef& nodeDef,
const tensorflow::GraphDef& graphDef);
141 ParsedTfOperationPtr ParseIdentity(
const tensorflow::NodeDef& nodeDef,
const tensorflow::GraphDef& graphDef);
142 ParsedTfOperationPtr ParseLrn(
const tensorflow::NodeDef& nodeDef,
const tensorflow::GraphDef& graphDef);
143 ParsedTfOperationPtr ParseMatMul(
const tensorflow::NodeDef& nodeDef,
const tensorflow::GraphDef& graphDef);
144 ParsedTfOperationPtr ParseMean(
const tensorflow::NodeDef& nodeDef,
const tensorflow::GraphDef& graphDef);
145 ParsedTfOperationPtr ParseMul(
const tensorflow::NodeDef& nodeDef,
const tensorflow::GraphDef& graphDef);
146 ParsedTfOperationPtr ParsePlaceholder(
const tensorflow::NodeDef& nodeDef,
const tensorflow::GraphDef& graphDef);
147 ParsedTfOperationPtr ParseRealDiv(
const tensorflow::NodeDef& nodeDef,
const tensorflow::GraphDef& graphDef);
148 ParsedTfOperationPtr ParseRelu(
const tensorflow::NodeDef& nodeDef,
const tensorflow::GraphDef& graphDef);
149 ParsedTfOperationPtr ParseRelu6(
const tensorflow::NodeDef& nodeDef,
const tensorflow::GraphDef& graphDef);
150 ParsedTfOperationPtr ParseReshape(
const tensorflow::NodeDef& nodeDef,
const tensorflow::GraphDef& graphDef);
151 ParsedTfOperationPtr ParseResizeBilinear(
const tensorflow::NodeDef& nodeDef,
const tensorflow::GraphDef& graphDef);
152 ParsedTfOperationPtr ParseRsqrt(
const tensorflow::NodeDef& nodeDef,
const tensorflow::GraphDef& graphDef);
153 ParsedTfOperationPtr ParseShape(
const tensorflow::NodeDef& nodeDef,
const tensorflow::GraphDef& graphDef);
154 ParsedTfOperationPtr ParseSqueeze(
const tensorflow::NodeDef& nodeDef,
const tensorflow::GraphDef& graphDef);
155 ParsedTfOperationPtr ParseSigmoid(
const tensorflow::NodeDef& nodeDef,
const tensorflow::GraphDef& graphDef);
156 ParsedTfOperationPtr ParseSoftmax(
const tensorflow::NodeDef& nodeDef,
const tensorflow::GraphDef& graphDef);
157 ParsedTfOperationPtr ParseSoftplus(
const tensorflow::NodeDef& nodeDef,
const tensorflow::GraphDef& graphDef);
158 ParsedTfOperationPtr ParseSplit(
const tensorflow::NodeDef& nodeDef,
const tensorflow::GraphDef& graphDef);
159 ParsedTfOperationPtr ParseStridedSlice(
const tensorflow::NodeDef& nodeDef,
const tensorflow::GraphDef& graphDef);
160 ParsedTfOperationPtr ParseTanh(
const tensorflow::NodeDef& nodeDef,
const tensorflow::GraphDef& graphDef);
161 ParsedTfOperationPtr ParseMaxPool(
const tensorflow::NodeDef& nodeDef,
const tensorflow::GraphDef& graphDef);
162 ParsedTfOperationPtr ParseAvgPool(
const tensorflow::NodeDef& nodeDef,
const tensorflow::GraphDef& graphDef);
164 const tensorflow::GraphDef& graphDef,
166 ParsedTfOperationPtr ParseEqual(
const tensorflow::NodeDef& nodeDef,
const tensorflow::GraphDef& graphDef);
167 ParsedTfOperationPtr ParseMaximum(
const tensorflow::NodeDef& nodeDef,
const tensorflow::GraphDef& graphDef);
168 ParsedTfOperationPtr ParseMinimum(
const tensorflow::NodeDef& nodeDef,
const tensorflow::GraphDef& graphDef);
169 ParsedTfOperationPtr ParseGather(
const tensorflow::NodeDef& nodeDef,
const tensorflow::GraphDef& graphDef);
170 ParsedTfOperationPtr ParseGreater(
const tensorflow::NodeDef& nodeDef,
const tensorflow::GraphDef& graphDef);
171 ParsedTfOperationPtr ParsePad(
const tensorflow::NodeDef& nodeDef,
const tensorflow::GraphDef& graphDef);
172 ParsedTfOperationPtr ParseSub(
const tensorflow::NodeDef& nodeDef,
const tensorflow::GraphDef& graphDef);
173 ParsedTfOperationPtr ParseStack(
const tensorflow::NodeDef& nodeDef,
const tensorflow::GraphDef& graphDef);
175 ParsedTfOperationPtr AddAdditionLayer(
const tensorflow::NodeDef& nodeDef,
bool isBiasAdd =
false);
183 const tensorflow::NodeDef* addNodeDef,
const char* armnnLayerName);
185 bool IsSupportedLeakyReluPattern(
const tensorflow::NodeDef& mulNodeDef,
186 size_t alphaLayerIndex,
191 std::pair<armnn::IOutputSlot*, armnn::IOutputSlot*> ProcessElementwiseInputSlots(
192 const tensorflow::NodeDef& nodeDef,
const std::string& layerName);
198 const tensorflow::NodeDef& nodeDef);
204 const tensorflow::NodeDef& nodeDef);
207 const tensorflow::NodeDef& nodeDef,
210 const std::string& layerName);
213 const tensorflow::NodeDef& nodeDef,
216 unsigned int numberOfAddition);
219 const tensorflow::NodeDef& nodeDef,
222 unsigned int numberOfAddition,
223 unsigned long numberOfLayersToConnect,
227 const tensorflow::NodeDef& nodeDef,
231 static std::pair<armnn::LayerBindingId, armnn::TensorInfo> GetBindingInfo(
const std::string& layerName,
232 const char* bindingPointDesc,
233 const std::unordered_map<std::string, BindingPointInfo>& nameToBindingInfo);
245 const char* bindingPointDesc,
246 std::unordered_map<std::string, BindingPointInfo>& nameToBindingInfo);
254 const tensorflow::GraphDef& graphDef);
257 static const std::map<std::string, OperationParsingFunction> ms_OperationNameToParsingFunctions;
259 static const std::list<std::string> m_ControlInputs;
261 std::map<std::string, armnn::TensorShape> m_InputShapes;
262 std::vector<std::string> m_RequestedOutputs;
265 std::unordered_map<std::string, const tensorflow::NodeDef*> m_NodesByName;
267 std::unordered_map<std::string, ParsedTfOperationPtr> m_ParsedTfOperations;
270 std::unordered_map<std::string, BindingPointInfo> m_NetworkInputsBindingInfo;
273 std::unordered_map<std::string, BindingPointInfo> m_NetworkOutputsBindingInfo;
std::unique_ptr< ParsedTfOperation > ParsedTfOperationPtr
WithOutputTensorIndex(T &&value, unsigned int index)
An ActivationDescriptor for the ActivationLayer.
WithOutputTensorIndex(const T &value, unsigned int index)
armnn::BindingPointInfo BindingPointInfo
std::unique_ptr< INetwork, void(*)(INetwork *network)> INetworkPtr
An output connection slot for a layer. The output slot may be connected to 1 or more input slots of s...
Interface for a layer that is connectable to other layers via InputSlots and OutputSlots.
Parses a directed acyclic graph from a tensorflow protobuf file.
int LayerBindingId
Type of identifiers for bindable layers (inputs, outputs).