ArmNN
 20.02
GetInputsOutputs.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include <boost/test/unit_test.hpp>
6 #include "../OnnxParser.hpp"
8 #include <onnx/onnx.pb.h>
9 #include "google/protobuf/stubs/logging.h"
10 
11 
12 using ModelPtr = std::unique_ptr<onnx::ModelProto>;
13 
14 BOOST_AUTO_TEST_SUITE(OnnxParser)
15 
16 struct GetInputsOutputsMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
17 {
18  explicit GetInputsOutputsMainFixture()
19  {
20  m_Prototext = R"(
21  ir_version: 3
22  producer_name: "CNTK"
23  producer_version: "2.5.1"
24  domain: "ai.cntk"
25  model_version: 1
26  graph {
27  name: "CNTKGraph"
28  input {
29  name: "Input"
30  type {
31  tensor_type {
32  elem_type: 1
33  shape {
34  dim {
35  dim_value: 4
36  }
37  }
38  }
39  }
40  }
41  node {
42  input: "Input"
43  output: "Output"
44  name: "ActivationLayer"
45  op_type: "Relu"
46  }
47  output {
48  name: "Output"
49  type {
50  tensor_type {
51  elem_type: 1
52  shape {
53  dim {
54  dim_value: 4
55  }
56  }
57  }
58  }
59  }
60  }
61  opset_import {
62  version: 7
63  })";
64  Setup();
65  }
66 };
67 
68 
69 BOOST_FIXTURE_TEST_CASE(GetInput, GetInputsOutputsMainFixture)
70 {
72  std::vector<std::string> tensors = armnnOnnxParser::OnnxParser::GetInputs(model);
73  BOOST_CHECK_EQUAL(1, tensors.size());
74  BOOST_CHECK_EQUAL("Input", tensors[0]);
75 
76 }
77 
78 BOOST_FIXTURE_TEST_CASE(GetOutput, GetInputsOutputsMainFixture)
79 {
81  std::vector<std::string> tensors = armnnOnnxParser::OnnxParser::GetOutputs(model);
82  BOOST_CHECK_EQUAL(1, tensors.size());
83  BOOST_CHECK_EQUAL("Output", tensors[0]);
84 }
85 
86 struct GetEmptyInputsOutputsFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
87 {
88  GetEmptyInputsOutputsFixture()
89  {
90  m_Prototext = R"(
91  ir_version: 3
92  producer_name: "CNTK "
93  producer_version: "2.5.1 "
94  domain: "ai.cntk "
95  model_version: 1
96  graph {
97  name: "CNTKGraph "
98  node {
99  output: "Output"
100  attribute {
101  name: "value"
102  t {
103  dims: 7
104  data_type: 1
105  float_data: 0.0
106  float_data: 1.0
107  float_data: 2.0
108  float_data: 3.0
109  float_data: 4.0
110  float_data: 5.0
111  float_data: 6.0
112 
113  }
114  type: 1
115  }
116  name: "constantNode"
117  op_type: "Constant"
118  }
119  output {
120  name: "Output"
121  type {
122  tensor_type {
123  elem_type: 1
124  shape {
125  dim {
126  dim_value: 7
127  }
128  }
129  }
130  }
131  }
132  }
133  opset_import {
134  version: 7
135  })";
136  Setup();
137  }
138 };
139 
140 BOOST_FIXTURE_TEST_CASE(GetEmptyInputs, GetEmptyInputsOutputsFixture)
141 {
142  ModelPtr model = armnnOnnxParser::OnnxParser::LoadModelFromString(m_Prototext.c_str());
143  std::vector<std::string> tensors = armnnOnnxParser::OnnxParser::GetInputs(model);
144  BOOST_CHECK_EQUAL(0, tensors.size());
145 }
146 
147 BOOST_AUTO_TEST_CASE(GetInputsNullModel)
148 {
150 }
151 
152 BOOST_AUTO_TEST_CASE(GetOutputsNullModel)
153 {
154  auto silencer = google::protobuf::LogSilencer(); //get rid of errors from protobuf
156 }
157 
158 struct GetInputsMultipleFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
159 {
160  GetInputsMultipleFixture() {
161 
162  m_Prototext = R"(
163  ir_version: 3
164  producer_name: "CNTK"
165  producer_version: "2.5.1"
166  domain: "ai.cntk"
167  model_version: 1
168  graph {
169  name: "CNTKGraph"
170  input {
171  name: "Input0"
172  type {
173  tensor_type {
174  elem_type: 1
175  shape {
176  dim {
177  dim_value: 1
178  }
179  dim {
180  dim_value: 1
181  }
182  dim {
183  dim_value: 1
184  }
185  dim {
186  dim_value: 4
187  }
188  }
189  }
190  }
191  }
192  input {
193  name: "Input1"
194  type {
195  tensor_type {
196  elem_type: 1
197  shape {
198  dim {
199  dim_value: 4
200  }
201  }
202  }
203  }
204  }
205  node {
206  input: "Input0"
207  input: "Input1"
208  output: "Output"
209  name: "addition"
210  op_type: "Add"
211  doc_string: ""
212  domain: ""
213  }
214  output {
215  name: "Output"
216  type {
217  tensor_type {
218  elem_type: 1
219  shape {
220  dim {
221  dim_value: 1
222  }
223  dim {
224  dim_value: 1
225  }
226  dim {
227  dim_value: 1
228  }
229  dim {
230  dim_value: 4
231  }
232  }
233  }
234  }
235  }
236  }
237  opset_import {
238  version: 7
239  })";
240  Setup();
241  }
242 };
243 
244 BOOST_FIXTURE_TEST_CASE(GetInputsMultipleInputs, GetInputsMultipleFixture)
245 {
246  ModelPtr model = armnnOnnxParser::OnnxParser::LoadModelFromString(m_Prototext.c_str());
247  std::vector<std::string> tensors = armnnOnnxParser::OnnxParser::GetInputs(model);
248  BOOST_CHECK_EQUAL(2, tensors.size());
249  BOOST_CHECK_EQUAL("Input0", tensors[0]);
250  BOOST_CHECK_EQUAL("Input1", tensors[1]);
251 }
252 
253 
254 
BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
BOOST_AUTO_TEST_CASE(GetInputsNullModel)
BOOST_FIXTURE_TEST_CASE(GetInput, GetInputsOutputsMainFixture)
std::unique_ptr< onnx::ModelProto > ModelPtr
static std::vector< std::string > GetInputs(ModelPtr &model)
Retrieve inputs names.
static std::vector< std::string > GetOutputs(ModelPtr &model)
Retrieve outputs names.
BOOST_AUTO_TEST_SUITE_END()
static ModelPtr LoadModelFromString(const std::string &inputString)
Definition: OnnxParser.cpp:564