ArmNN
 21.02
LoadModel.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include <boost/test/unit_test.hpp>
7 #include "../TfLiteParser.hpp"
8 
9 #include <Filesystem.hpp>
10 
15 
16 BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
17 
18 struct LoadModelFixture : public ParserFlatbuffersFixture
19 {
20  explicit LoadModelFixture()
21  {
22  m_JsonString = R"(
23  {
24  "version": 3,
25  "operator_codes": [ { "builtin_code": "AVERAGE_POOL_2D" }, { "builtin_code": "CONV_2D" } ],
26  "subgraphs": [
27  {
28  "tensors": [
29  {
30  "shape": [ 1, 1, 1, 1 ] ,
31  "type": "UINT8",
32  "buffer": 0,
33  "name": "OutputTensor",
34  "quantization": {
35  "min": [ 0.0 ],
36  "max": [ 255.0 ],
37  "scale": [ 1.0 ],
38  "zero_point": [ 0 ]
39  }
40  },
41  {
42  "shape": [ 1, 2, 2, 1 ] ,
43  "type": "UINT8",
44  "buffer": 1,
45  "name": "InputTensor",
46  "quantization": {
47  "min": [ 0.0 ],
48  "max": [ 255.0 ],
49  "scale": [ 1.0 ],
50  "zero_point": [ 0 ]
51  }
52  }
53  ],
54  "inputs": [ 1 ],
55  "outputs": [ 0 ],
56  "operators": [ {
57  "opcode_index": 0,
58  "inputs": [ 1 ],
59  "outputs": [ 0 ],
60  "builtin_options_type": "Pool2DOptions",
61  "builtin_options":
62  {
63  "padding": "VALID",
64  "stride_w": 2,
65  "stride_h": 2,
66  "filter_width": 2,
67  "filter_height": 2,
68  "fused_activation_function": "NONE"
69  },
70  "custom_options_format": "FLEXBUFFERS"
71  } ]
72  },
73  {
74  "tensors": [
75  {
76  "shape": [ 1, 3, 3, 1 ],
77  "type": "UINT8",
78  "buffer": 0,
79  "name": "ConvInputTensor",
80  "quantization": {
81  "scale": [ 1.0 ],
82  "zero_point": [ 0 ],
83  }
84  },
85  {
86  "shape": [ 1, 1, 1, 1 ],
87  "type": "UINT8",
88  "buffer": 1,
89  "name": "ConvOutputTensor",
90  "quantization": {
91  "min": [ 0.0 ],
92  "max": [ 511.0 ],
93  "scale": [ 2.0 ],
94  "zero_point": [ 0 ],
95  }
96  },
97  {
98  "shape": [ 1, 3, 3, 1 ],
99  "type": "UINT8",
100  "buffer": 2,
101  "name": "filterTensor",
102  "quantization": {
103  "min": [ 0.0 ],
104  "max": [ 255.0 ],
105  "scale": [ 1.0 ],
106  "zero_point": [ 0 ],
107  }
108  }
109  ],
110  "inputs": [ 0 ],
111  "outputs": [ 1 ],
112  "operators": [
113  {
114  "opcode_index": 1,
115  "inputs": [ 0, 2 ],
116  "outputs": [ 1 ],
117  "builtin_options_type": "Conv2DOptions",
118  "builtin_options": {
119  "padding": "VALID",
120  "stride_w": 1,
121  "stride_h": 1,
122  "fused_activation_function": "NONE"
123  },
124  "custom_options_format": "FLEXBUFFERS"
125  }
126  ],
127  }
128  ],
129  "description": "Test loading a model",
130  "buffers" : [ {}, {} ]
131  })";
132 
134  }
135 
136  void CheckModel(const ModelPtr& model, uint32_t version, size_t opcodeSize,
137  const std::vector<tflite::BuiltinOperator>& opcodes,
138  size_t subgraphs, const std::string desc, size_t buffers)
139  {
140  BOOST_CHECK(model);
141  BOOST_CHECK_EQUAL(version, model->version);
142  BOOST_CHECK_EQUAL(opcodeSize, model->operator_codes.size());
143  CheckBuiltinOperators(opcodes, model->operator_codes);
144  BOOST_CHECK_EQUAL(subgraphs, model->subgraphs.size());
145  BOOST_CHECK_EQUAL(desc, model->description);
146  BOOST_CHECK_EQUAL(buffers, model->buffers.size());
147  }
148 
149  void CheckBuiltinOperators(const std::vector<tflite::BuiltinOperator>& expectedOperators,
150  const std::vector<std::unique_ptr<tflite::OperatorCodeT>>& result)
151  {
152  BOOST_CHECK_EQUAL(expectedOperators.size(), result.size());
153  for (size_t i = 0; i < expectedOperators.size(); i++)
154  {
155  BOOST_CHECK_EQUAL(expectedOperators[i], result[i]->builtin_code);
156  }
157  }
158 
159  void CheckSubgraph(const SubgraphPtr& subgraph, size_t tensors, const std::vector<int32_t>& inputs,
160  const std::vector<int32_t>& outputs, size_t operators, const std::string& name)
161  {
162  BOOST_CHECK(subgraph);
163  BOOST_CHECK_EQUAL(tensors, subgraph->tensors.size());
164  BOOST_CHECK_EQUAL_COLLECTIONS(inputs.begin(), inputs.end(), subgraph->inputs.begin(), subgraph->inputs.end());
165  BOOST_CHECK_EQUAL_COLLECTIONS(outputs.begin(), outputs.end(),
166  subgraph->outputs.begin(), subgraph->outputs.end());
167  BOOST_CHECK_EQUAL(operators, subgraph->operators.size());
168  BOOST_CHECK_EQUAL(name, subgraph->name);
169  }
170 
171  void CheckOperator(const OperatorPtr& operatorPtr, uint32_t opcode, const std::vector<int32_t>& inputs,
172  const std::vector<int32_t>& outputs, tflite::BuiltinOptions optionType,
173  tflite::CustomOptionsFormat custom_options_format)
174  {
175  BOOST_CHECK(operatorPtr);
176  BOOST_CHECK_EQUAL(opcode, operatorPtr->opcode_index);
177  BOOST_CHECK_EQUAL_COLLECTIONS(inputs.begin(), inputs.end(),
178  operatorPtr->inputs.begin(), operatorPtr->inputs.end());
179  BOOST_CHECK_EQUAL_COLLECTIONS(outputs.begin(), outputs.end(),
180  operatorPtr->outputs.begin(), operatorPtr->outputs.end());
181  BOOST_CHECK_EQUAL(optionType, operatorPtr->builtin_options.type);
182  BOOST_CHECK_EQUAL(custom_options_format, operatorPtr->custom_options_format);
183  }
184 };
185 
186 BOOST_FIXTURE_TEST_CASE(LoadModelFromBinary, LoadModelFixture)
187 {
188  TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromBinary(m_GraphBinary.data(),
189  m_GraphBinary.size());
190  CheckModel(model, 3, 2, { tflite::BuiltinOperator_AVERAGE_POOL_2D, tflite::BuiltinOperator_CONV_2D },
191  2, "Test loading a model", 2);
192  CheckSubgraph(model->subgraphs[0], 2, { 1 }, { 0 }, 1, "");
193  CheckSubgraph(model->subgraphs[1], 3, { 0 }, { 1 }, 1, "");
194  CheckOperator(model->subgraphs[0]->operators[0], 0, { 1 }, { 0 }, tflite::BuiltinOptions_Pool2DOptions,
195  tflite::CustomOptionsFormat_FLEXBUFFERS);
196  CheckOperator(model->subgraphs[1]->operators[0], 1, { 0, 2 }, { 1 }, tflite::BuiltinOptions_Conv2DOptions,
197  tflite::CustomOptionsFormat_FLEXBUFFERS);
198 }
199 
200 BOOST_FIXTURE_TEST_CASE(LoadModelFromFile, LoadModelFixture)
201 {
202  using namespace fs;
203  fs::path fname = armnnUtils::Filesystem::NamedTempFile("Armnn-tfLite-LoadModelFromFile-TempFile.csv");
204  bool saved = flatbuffers::SaveFile(fname.c_str(),
205  reinterpret_cast<char *>(m_GraphBinary.data()),
206  m_GraphBinary.size(), true);
207  BOOST_CHECK_MESSAGE(saved, "Cannot save test file");
208 
209  TfLiteParserImpl::ModelPtr model = TfLiteParserImpl::LoadModelFromFile(fname.c_str());
210  CheckModel(model, 3, 2, { tflite::BuiltinOperator_AVERAGE_POOL_2D, tflite::BuiltinOperator_CONV_2D },
211  2, "Test loading a model", 2);
212  CheckSubgraph(model->subgraphs[0], 2, { 1 }, { 0 }, 1, "");
213  CheckSubgraph(model->subgraphs[1], 3, { 0 }, { 1 }, 1, "");
214  CheckOperator(model->subgraphs[0]->operators[0], 0, { 1 }, { 0 }, tflite::BuiltinOptions_Pool2DOptions,
215  tflite::CustomOptionsFormat_FLEXBUFFERS);
216  CheckOperator(model->subgraphs[1]->operators[0], 1, { 0, 2 }, { 1 }, tflite::BuiltinOptions_Conv2DOptions,
217  tflite::CustomOptionsFormat_FLEXBUFFERS);
218  remove(fname);
219 }
220 
221 BOOST_AUTO_TEST_CASE(LoadNullBinary)
222 {
223  BOOST_CHECK_THROW(TfLiteParserImpl::LoadModelFromBinary(nullptr, 0), armnn::InvalidArgumentException);
224 }
225 
226 BOOST_AUTO_TEST_CASE(LoadInvalidBinary)
227 {
228  std::string testData = "invalid data";
229  BOOST_CHECK_THROW(TfLiteParserImpl::LoadModelFromBinary(reinterpret_cast<const uint8_t*>(&testData),
230  testData.length()), armnn::ParseException);
231 }
232 
233 BOOST_AUTO_TEST_CASE(LoadFileNotFound)
234 {
235  BOOST_CHECK_THROW(TfLiteParserImpl::LoadModelFromFile("invalidfile.tflite"), armnn::FileNotFoundException);
236 }
237 
238 BOOST_AUTO_TEST_CASE(LoadNullPtrFile)
239 {
240  BOOST_CHECK_THROW(TfLiteParserImpl::LoadModelFromFile(nullptr), armnn::InvalidArgumentException);
241 }
242 
BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
TfLiteParserImpl::OperatorPtr OperatorPtr
Definition: LoadModel.cpp:14
std::unique_ptr< onnx::ModelProto > ModelPtr
TfLiteParserImpl::SubgraphPtr SubgraphPtr
Definition: LoadModel.cpp:13
BOOST_AUTO_TEST_CASE(LoadNullBinary)
Definition: LoadModel.cpp:221
BOOST_AUTO_TEST_SUITE_END()
fs::path NamedTempFile(const char *fileName)
Construct a temporary file name.
Definition: Filesystem.cpp:23
BOOST_FIXTURE_TEST_CASE(LoadModelFromBinary, LoadModelFixture)
Definition: LoadModel.cpp:186
TfLiteParserImpl::ModelPtr ModelPtr
Definition: LoadModel.cpp:12