ArmNN
 20.02
LoadModel.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include <boost/test/unit_test.hpp>
7 #include "../TfLiteParser.hpp"
8 
13 
14 BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
15 
16 struct LoadModelFixture : public ParserFlatbuffersFixture
17 {
18  explicit LoadModelFixture()
19  {
20  m_JsonString = R"(
21  {
22  "version": 3,
23  "operator_codes": [ { "builtin_code": "AVERAGE_POOL_2D" }, { "builtin_code": "CONV_2D" } ],
24  "subgraphs": [
25  {
26  "tensors": [
27  {
28  "shape": [ 1, 1, 1, 1 ] ,
29  "type": "UINT8",
30  "buffer": 0,
31  "name": "OutputTensor",
32  "quantization": {
33  "min": [ 0.0 ],
34  "max": [ 255.0 ],
35  "scale": [ 1.0 ],
36  "zero_point": [ 0 ]
37  }
38  },
39  {
40  "shape": [ 1, 2, 2, 1 ] ,
41  "type": "UINT8",
42  "buffer": 1,
43  "name": "InputTensor",
44  "quantization": {
45  "min": [ 0.0 ],
46  "max": [ 255.0 ],
47  "scale": [ 1.0 ],
48  "zero_point": [ 0 ]
49  }
50  }
51  ],
52  "inputs": [ 1 ],
53  "outputs": [ 0 ],
54  "operators": [ {
55  "opcode_index": 0,
56  "inputs": [ 1 ],
57  "outputs": [ 0 ],
58  "builtin_options_type": "Pool2DOptions",
59  "builtin_options":
60  {
61  "padding": "VALID",
62  "stride_w": 2,
63  "stride_h": 2,
64  "filter_width": 2,
65  "filter_height": 2,
66  "fused_activation_function": "NONE"
67  },
68  "custom_options_format": "FLEXBUFFERS"
69  } ]
70  },
71  {
72  "tensors": [
73  {
74  "shape": [ 1, 3, 3, 1 ],
75  "type": "UINT8",
76  "buffer": 0,
77  "name": "ConvInputTensor",
78  "quantization": {
79  "scale": [ 1.0 ],
80  "zero_point": [ 0 ],
81  }
82  },
83  {
84  "shape": [ 1, 1, 1, 1 ],
85  "type": "UINT8",
86  "buffer": 1,
87  "name": "ConvOutputTensor",
88  "quantization": {
89  "min": [ 0.0 ],
90  "max": [ 511.0 ],
91  "scale": [ 2.0 ],
92  "zero_point": [ 0 ],
93  }
94  },
95  {
96  "shape": [ 1, 3, 3, 1 ],
97  "type": "UINT8",
98  "buffer": 2,
99  "name": "filterTensor",
100  "quantization": {
101  "min": [ 0.0 ],
102  "max": [ 255.0 ],
103  "scale": [ 1.0 ],
104  "zero_point": [ 0 ],
105  }
106  }
107  ],
108  "inputs": [ 0 ],
109  "outputs": [ 1 ],
110  "operators": [
111  {
112  "opcode_index": 1,
113  "inputs": [ 0, 2 ],
114  "outputs": [ 1 ],
115  "builtin_options_type": "Conv2DOptions",
116  "builtin_options": {
117  "padding": "VALID",
118  "stride_w": 1,
119  "stride_h": 1,
120  "fused_activation_function": "NONE"
121  },
122  "custom_options_format": "FLEXBUFFERS"
123  }
124  ],
125  }
126  ],
127  "description": "Test loading a model",
128  "buffers" : [ {}, {} ]
129  })";
130 
132  }
133 
134  void CheckModel(const ModelPtr& model, uint32_t version, size_t opcodeSize,
135  const std::vector<tflite::BuiltinOperator>& opcodes,
136  size_t subgraphs, const std::string desc, size_t buffers)
137  {
138  BOOST_CHECK(model);
139  BOOST_CHECK_EQUAL(version, model->version);
140  BOOST_CHECK_EQUAL(opcodeSize, model->operator_codes.size());
141  CheckBuiltinOperators(opcodes, model->operator_codes);
142  BOOST_CHECK_EQUAL(subgraphs, model->subgraphs.size());
143  BOOST_CHECK_EQUAL(desc, model->description);
144  BOOST_CHECK_EQUAL(buffers, model->buffers.size());
145  }
146 
147  void CheckBuiltinOperators(const std::vector<tflite::BuiltinOperator>& expectedOperators,
148  const std::vector<std::unique_ptr<tflite::OperatorCodeT>>& result)
149  {
150  BOOST_CHECK_EQUAL(expectedOperators.size(), result.size());
151  for (size_t i = 0; i < expectedOperators.size(); i++)
152  {
153  BOOST_CHECK_EQUAL(expectedOperators[i], result[i]->builtin_code);
154  }
155  }
156 
157  void CheckSubgraph(const SubgraphPtr& subgraph, size_t tensors, const std::vector<int32_t>& inputs,
158  const std::vector<int32_t>& outputs, size_t operators, const std::string& name)
159  {
160  BOOST_CHECK(subgraph);
161  BOOST_CHECK_EQUAL(tensors, subgraph->tensors.size());
162  BOOST_CHECK_EQUAL_COLLECTIONS(inputs.begin(), inputs.end(), subgraph->inputs.begin(), subgraph->inputs.end());
163  BOOST_CHECK_EQUAL_COLLECTIONS(outputs.begin(), outputs.end(),
164  subgraph->outputs.begin(), subgraph->outputs.end());
165  BOOST_CHECK_EQUAL(operators, subgraph->operators.size());
166  BOOST_CHECK_EQUAL(name, subgraph->name);
167  }
168 
169  void CheckOperator(const OperatorPtr& operatorPtr, uint32_t opcode, const std::vector<int32_t>& inputs,
170  const std::vector<int32_t>& outputs, tflite::BuiltinOptions optionType,
171  tflite::CustomOptionsFormat custom_options_format)
172  {
173  BOOST_CHECK(operatorPtr);
174  BOOST_CHECK_EQUAL(opcode, operatorPtr->opcode_index);
175  BOOST_CHECK_EQUAL_COLLECTIONS(inputs.begin(), inputs.end(),
176  operatorPtr->inputs.begin(), operatorPtr->inputs.end());
177  BOOST_CHECK_EQUAL_COLLECTIONS(outputs.begin(), outputs.end(),
178  operatorPtr->outputs.begin(), operatorPtr->outputs.end());
179  BOOST_CHECK_EQUAL(optionType, operatorPtr->builtin_options.type);
180  BOOST_CHECK_EQUAL(custom_options_format, operatorPtr->custom_options_format);
181  }
182 };
183 
184 BOOST_FIXTURE_TEST_CASE(LoadModelFromBinary, LoadModelFixture)
185 {
186  TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromBinary(m_GraphBinary.data(), m_GraphBinary.size());
187  CheckModel(model, 3, 2, { tflite::BuiltinOperator_AVERAGE_POOL_2D, tflite::BuiltinOperator_CONV_2D },
188  2, "Test loading a model", 2);
189  CheckSubgraph(model->subgraphs[0], 2, { 1 }, { 0 }, 1, "");
190  CheckSubgraph(model->subgraphs[1], 3, { 0 }, { 1 }, 1, "");
191  CheckOperator(model->subgraphs[0]->operators[0], 0, { 1 }, { 0 }, tflite::BuiltinOptions_Pool2DOptions,
192  tflite::CustomOptionsFormat_FLEXBUFFERS);
193  CheckOperator(model->subgraphs[1]->operators[0], 1, { 0, 2 }, { 1 }, tflite::BuiltinOptions_Conv2DOptions,
194  tflite::CustomOptionsFormat_FLEXBUFFERS);
195 }
196 
197 BOOST_FIXTURE_TEST_CASE(LoadModelFromFile, LoadModelFixture)
198 {
199  using namespace boost::filesystem;
200  std::string fname = unique_path(temp_directory_path() / "%%%%-%%%%-%%%%.tflite").string();
201  bool saved = flatbuffers::SaveFile(fname.c_str(),
202  reinterpret_cast<char *>(m_GraphBinary.data()),
203  m_GraphBinary.size(), true);
204  BOOST_CHECK_MESSAGE(saved, "Cannot save test file");
205 
206  TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromFile(fname.c_str());
207  CheckModel(model, 3, 2, { tflite::BuiltinOperator_AVERAGE_POOL_2D, tflite::BuiltinOperator_CONV_2D },
208  2, "Test loading a model", 2);
209  CheckSubgraph(model->subgraphs[0], 2, { 1 }, { 0 }, 1, "");
210  CheckSubgraph(model->subgraphs[1], 3, { 0 }, { 1 }, 1, "");
211  CheckOperator(model->subgraphs[0]->operators[0], 0, { 1 }, { 0 }, tflite::BuiltinOptions_Pool2DOptions,
212  tflite::CustomOptionsFormat_FLEXBUFFERS);
213  CheckOperator(model->subgraphs[1]->operators[0], 1, { 0, 2 }, { 1 }, tflite::BuiltinOptions_Conv2DOptions,
214  tflite::CustomOptionsFormat_FLEXBUFFERS);
215  remove(fname);
216 }
217 
218 BOOST_AUTO_TEST_CASE(LoadNullBinary)
219 {
220  BOOST_CHECK_THROW(TfLiteParser::LoadModelFromBinary(nullptr, 0), armnn::InvalidArgumentException);
221 }
222 
223 BOOST_AUTO_TEST_CASE(LoadInvalidBinary)
224 {
225  std::string testData = "invalid data";
226  BOOST_CHECK_THROW(TfLiteParser::LoadModelFromBinary(reinterpret_cast<const uint8_t*>(&testData),
227  testData.length()), armnn::ParseException);
228 }
229 
230 BOOST_AUTO_TEST_CASE(LoadFileNotFound)
231 {
232  BOOST_CHECK_THROW(TfLiteParser::LoadModelFromFile("invalidfile.tflite"), armnn::FileNotFoundException);
233 }
234 
235 BOOST_AUTO_TEST_CASE(LoadNullPtrFile)
236 {
237  BOOST_CHECK_THROW(TfLiteParser::LoadModelFromFile(nullptr), armnn::InvalidArgumentException);
238 }
239 
BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
TfLiteParser::ModelPtr ModelPtr
Definition: LoadModel.cpp:10
TfLiteParser::SubgraphPtr SubgraphPtr
Definition: LoadModel.cpp:11
std::unique_ptr< onnx::ModelProto > ModelPtr
BOOST_CHECK(profilingService.GetCurrentState()==ProfilingState::WaitingForAck)
BOOST_AUTO_TEST_CASE(LoadNullBinary)
Definition: LoadModel.cpp:218
TfLiteParser::OperatorPtr OperatorPtr
Definition: LoadModel.cpp:12
BOOST_AUTO_TEST_SUITE_END()
BOOST_FIXTURE_TEST_CASE(LoadModelFromBinary, LoadModelFixture)
Definition: LoadModel.cpp:184