12 #include <schema_generated.h>
14 #include <unordered_map>
17 #include <tensorflow/lite/version.h>
19 #if TF_MAJOR_VERSION > 2 || (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION > 3)
20 #define ARMNN_POST_TFLITE_2_3
30 using ModelPtr = std::unique_ptr<tflite::ModelT>;
53 const std::string& name)
const;
58 const std::string& name)
const;
91 const std::vector<int32_t>& targetDimsIn);
106 using OperatorParsingFunction = void(
TfLiteParserImpl::*)(
size_t subgraphIndex,
size_t operatorIndex);
108 void ParseCustomOperator(
size_t subgraphIndex,
size_t operatorIndex);
109 void ParseUnsupportedOperator(
size_t subgraphIndex,
size_t operatorIndex);
111 void ParseAbs(
size_t subgraphIndex,
size_t operatorIndex);
113 void ParseAdd(
size_t subgraphIndex,
size_t operatorIndex);
115 void ParseArgMin(
size_t subgraphIndex,
size_t operatorIndex);
116 void ParseArgMax(
size_t subgraphIndex,
size_t operatorIndex);
117 void ParseAveragePool2D(
size_t subgraphIndex,
size_t operatorIndex);
118 void ParseBatchMatMul(
size_t subgraphIndex,
size_t operatorIndex);
119 void ParseBatchToSpaceND(
size_t subgraphIndex,
size_t operatorIndex);
120 void ParseCast(
size_t subgraphIndex,
size_t operatorIndex);
121 void ParseCeil(
size_t subgraphIndex,
size_t operatorIndex);
123 void ParseConcatenation(
size_t subgraphIndex,
size_t operatorIndex);
124 void ParseConv2D(
size_t subgraphIndex,
size_t operatorIndex);
126 #if defined(ARMNN_POST_TFLITE_2_4)
127 void ParseConv3D(
size_t subgraphIndex,
size_t operatorIndex);
129 void ParseDepthToSpace(
size_t subgraphIndex,
size_t operatorIndex);
130 void ParseDepthwiseConv2D(
size_t subgraphIndex,
size_t operatorIndex);
131 void ParseDequantize(
size_t subgraphIndex,
size_t operatorIndex);
132 void ParseDetectionPostProcess(
size_t subgraphIndex,
size_t operatorIndex);
133 void ParseDiv(
size_t subgraphIndex,
size_t operatorIndex);
134 void ParseElementwiseUnary(
size_t subgraphIndex,
size_t operatorIndex,
armnn::UnaryOperation unaryOperation);
135 void ParseElu(
size_t subgraphIndex,
size_t operatorIndex);
136 void ParseEqual(
size_t subgraphIndex,
size_t operatorIndex);
137 void ParseExp(
size_t subgraphIndex,
size_t operatorIndex);
138 void ParseExpandDims(
size_t subgraphIndex,
size_t operatorIndex);
139 void ParseFloorDiv(
size_t subgraphIndex,
size_t operatorIndex);
140 void ParseFullyConnected(
size_t subgraphIndex,
size_t operatorIndex);
141 void ParseGather(
size_t subgraphIndex,
size_t operatorIndex);
142 void ParseGatherNd(
size_t subgraphIndex,
size_t operatorIndex);
143 void ParseGreater(
size_t subgraphIndex,
size_t operatorIndex);
144 void ParseGreaterOrEqual(
size_t subgraphIndex,
size_t operatorIndex);
145 void ParseHardSwish(
size_t subgraphIndex,
size_t operatorIndex);
146 void ParseLeakyRelu(
size_t subgraphIndex,
size_t operatorIndex);
147 void ParseLess(
size_t subgraphIndex,
size_t operatorIndex);
148 void ParseLessOrEqual(
size_t subgraphIndex,
size_t operatorIndex);
149 void ParseLog(
size_t subgraphIndex,
size_t operatorIndex);
150 void ParseLocalResponseNormalization(
size_t subgraphIndex,
size_t operatorIndex);
151 void ParseLogicalNot(
size_t subgraphIndex,
size_t operatorIndex);
152 void ParseLogistic(
size_t subgraphIndex,
size_t operatorIndex);
153 void ParseLogSoftmax(
size_t subgraphIndex,
size_t operatorIndex);
154 void ParseL2Normalization(
size_t subgraphIndex,
size_t operatorIndex);
155 void ParseMaxPool2D(
size_t subgraphIndex,
size_t operatorIndex);
156 void ParseMaximum(
size_t subgraphIndex,
size_t operatorIndex);
157 void ParseMean(
size_t subgraphIndex,
size_t operatorIndex);
158 void ParseMinimum(
size_t subgraphIndex,
size_t operatorIndex);
159 void ParseMirrorPad(
size_t subgraphIndex,
size_t operatorIndex);
160 void ParseMul(
size_t subgraphIndex,
size_t operatorIndex);
161 void ParseNeg(
size_t subgraphIndex,
size_t operatorIndex);
162 void ParseNotEqual(
size_t subgraphIndex,
size_t operatorIndex);
163 void ParsePack(
size_t subgraphIndex,
size_t operatorIndex);
164 void ParsePad(
size_t subgraphIndex,
size_t operatorIndex);
166 void ParsePrelu(
size_t subgraphIndex,
size_t operatorIndex);
167 void ParseQuantize(
size_t subgraphIndex,
size_t operatorIndex);
169 void ParseReduceMax(
size_t subgraphIndex,
size_t operatorIndex);
170 void ParseReduceMin(
size_t subgraphIndex,
size_t operatorIndex);
171 void ParseReduceProd(
size_t subgraphIndex,
size_t operatorIndex);
172 void ParseRelu(
size_t subgraphIndex,
size_t operatorIndex);
173 void ParseRelu6(
size_t subgraphIndex,
size_t operatorIndex);
174 void ParseReshape(
size_t subgraphIndex,
size_t operatorIndex);
175 void ParseResize(
size_t subgraphIndex,
size_t operatorIndex,
armnn::ResizeMethod resizeMethod);
176 void ParseResizeBilinear(
size_t subgraphIndex,
size_t operatorIndex);
177 void ParseResizeNearestNeighbor(
size_t subgraphIndex,
size_t operatorIndex);
178 void ParseRsqrt(
size_t subgraphIndex,
size_t operatorIndex);
179 void ParseShape(
size_t subgraphIndex,
size_t operatorIndex);
180 void ParseSin(
size_t subgraphIndex,
size_t operatorIndex);
181 void ParseSlice(
size_t subgraphIndex,
size_t operatorIndex);
182 void ParseSoftmax(
size_t subgraphIndex,
size_t operatorIndex);
183 void ParseSqrt(
size_t subgraphIndex,
size_t operatorIndex);
184 void ParseSpaceToBatchND(
size_t subgraphIndex,
size_t operatorIndex);
185 void ParseSpaceToDepth(
size_t subgraphIndex,
size_t operatorIndex);
186 void ParseSplit(
size_t subgraphIndex,
size_t operatorIndex);
187 void ParseSplitV(
size_t subgraphIndex,
size_t operatorIndex);
188 void ParseSqueeze(
size_t subgraphIndex,
size_t operatorIndex);
189 void ParseStridedSlice(
size_t subgraphIndex,
size_t operatorIndex);
190 void ParseSub(
size_t subgraphIndex,
size_t operatorIndex);
191 void ParseSum(
size_t subgraphIndex,
size_t operatorIndex);
192 void ParseTanH(
size_t subgraphIndex,
size_t operatorIndex);
193 void ParseTranspose(
size_t subgraphIndex,
size_t operatorIndex);
194 void ParseTransposeConv(
size_t subgraphIndex,
size_t operatorIndex);
195 void ParseUnidirectionalSequenceLSTM(
size_t subgraphIndex,
size_t operatorIndex);
196 void ParseUnpack(
size_t subgraphIndex,
size_t operatorIndex);
198 void RegisterProducerOfTensor(
size_t subgraphIndex,
size_t tensorIndex,
armnn::IOutputSlot* slot);
199 void RegisterConsumerOfTensor(
size_t subgraphIndex,
size_t tensorIndex,
armnn::IInputSlot* slot);
200 void RegisterInputSlots(
size_t subgraphIndex,
201 size_t operatorIndex,
203 const std::vector<unsigned int>& tensorIndexes,
204 unsigned int startingSlotIndex = 0);
205 void RegisterOutputSlots(
size_t subgraphIndex,
206 size_t operatorIndex,
208 const std::vector<unsigned int>& tensorIndexes);
210 void SetupInputLayerTensorInfos(
size_t subgraphIndex);
211 void SetupConstantLayerTensorInfos(
size_t subgraphIndex);
213 void SetupInputLayers(
size_t subgraphIndex);
214 void SetupOutputLayers(
size_t subgraphIndex);
215 void SetupConstantLayers(
size_t subgraphIndex);
219 void AddBroadcastReshapeLayer(
size_t subgraphIndex,
220 size_t operatorIndex,
225 unsigned int outputSlot,
226 std::string reshapeLayerName,
231 unsigned int outputSlot,
232 tflite::ActivationFunctionType activationType);
239 struct SupportedDataStorage
243 SupportedDataStorage(std::unique_ptr<
float[]>&& data);
244 SupportedDataStorage(std::unique_ptr<uint8_t[]>&& data);
245 SupportedDataStorage(std::unique_ptr<int8_t[]>&& data);
246 SupportedDataStorage(std::unique_ptr<int32_t[]>&& data);
250 std::unique_ptr<float[]> m_FloatData;
251 std::unique_ptr<uint8_t[]> m_Uint8Data;
252 std::unique_ptr<int8_t[]> m_Int8Data;
253 std::unique_ptr<int32_t[]> m_Int32Data;
256 bool ShouldConstantTensorBeCreated(
unsigned int tensorIndex);
267 std::pair<armnn::ConstTensor, SupportedDataStorage>
272 std::pair<armnn::ConstTensor, std::unique_ptr<float[]>>
278 std::pair<armnn::ConstTensor, TfLiteParserImpl::SupportedDataStorage>
284 std::pair<armnn::ConstTensor*, std::unique_ptr<float[]>>
289 size_t operatorIndex,
293 size_t operatorIndex,
296 std::vector<int> inputs);
299 size_t operatorIndex,
302 std::vector<armnn::TensorShape> inputShapes = {});
311 std::vector<OperatorParsingFunction> m_ParserFunctions;
312 std::unordered_map<std::string, OperatorParsingFunction> m_CustomParserFunctions;
320 std::vector<armnn::IInputSlot*> inputSlots;
322 TensorSlots() : outputSlot(nullptr) { }
324 typedef std::vector<TensorSlots> TensorConnections;
327 std::vector<TensorConnections> m_SubgraphConnections;
331 std::vector<std::vector<unsigned int>> m_OverriddenOutputShapes;
333 std::vector<unsigned int> m_ConstantsToDequantize;
334 std::vector<unsigned int> m_ConstantsToBeCreated;
335 std::map<size_t, armnn::TensorInfo> m_TensorInfos;