ArmNN
 21.02
Constant.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <boost/test/unit_test.hpp>
7 
9 
11 
12 BOOST_AUTO_TEST_SUITE(TensorflowParser)
13 
14 // Tests that a Const node in Tensorflow can be converted to a ConstLayer in armnn (as opposed to most
15 // Const nodes which are used as weight inputs for convolutions etc. and are therefore not converted to
16 // armnn ConstLayers).
17 struct ConstantFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
18 {
19  ConstantFixture()
20  {
21  // Input = tf.placeholder(tf.float32, name = "input")
22  // Const = tf.constant([17], tf.float32, [1])
23  // Output = tf.add(input, const, name = "output")
24  m_Prototext =
25  R"(
26 node {
27  name: "input"
28  op: "Placeholder"
29  attr {
30  key: "dtype"
31  value {
32  type: DT_FLOAT
33  }
34  }
35  attr {
36  key: "shape"
37  value {
38  shape {
39  unknown_rank: true
40  }
41  }
42  }
43 }
44 node {
45  name: "Const"
46  op: "Const"
47  attr {
48  key: "dtype"
49  value {
50  type: DT_FLOAT
51  }
52  }
53  attr {
54  key: "value"
55  value {
56  tensor {
57  dtype: DT_FLOAT
58  tensor_shape {
59  dim {
60  size: 1
61  }
62  }
63  float_val: 17.0
64  }
65  }
66  }
67 }
68 node {
69  name: "output"
70  op: "Add"
71  input: "input"
72  input: "Const"
73  attr {
74  key: "T"
75  value {
76  type: DT_FLOAT
77  }
78  }
79 }
80  )";
81  SetupSingleInputSingleOutput({ 1 }, "input", "output");
82  }
83 };
84 
85 BOOST_FIXTURE_TEST_CASE(Constant, ConstantFixture)
86 {
87  RunTest<1>({1}, {18});
88 }
89 
90 
91 // Tests that a single Const node in Tensorflow can be used twice by a dependant node. This should result in only
92 // a single armnn ConstLayer being created.
93 struct ConstantReusedFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
94 {
95  ConstantReusedFixture()
96  {
97  // Const = tf.constant([17], tf.float32, [1])
98  // Output = tf.add(const, const, name = "output")
99  m_Prototext =
100  R"(
101 node {
102  name: "Const"
103  op: "Const"
104  attr {
105  key: "dtype"
106  value {
107  type: DT_FLOAT
108  }
109  }
110  attr {
111  key: "value"
112  value {
113  tensor {
114  dtype: DT_FLOAT
115  tensor_shape {
116  dim {
117  size: 1
118  }
119  }
120  float_val: 17.0
121  }
122  }
123  }
124 }
125 node {
126  name: "output"
127  op: "Add"
128  input: "Const"
129  input: "Const"
130  attr {
131  key: "T"
132  value {
133  type: DT_FLOAT
134  }
135  }
136 }
137  )";
138  Setup({}, { "output" });
139  }
140 };
141 
142 BOOST_FIXTURE_TEST_CASE(ConstantReused, ConstantReusedFixture)
143 {
144  RunTest<1>({}, { { "output", { 34 } } });
145 }
146 
147 template <int ListSize>
148 struct ConstantValueListFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
149 {
150  ConstantValueListFixture()
151  {
152  m_Prototext =
153  R"(
154 node {
155  name: "output"
156  op: "Const"
157  attr {
158  key: "dtype"
159  value {
160  type: DT_FLOAT
161  }
162  }
163  attr {
164  key: "value"
165  value {
166  tensor {
167  dtype: DT_FLOAT
168  tensor_shape {
169  dim {
170  size: 2
171  }
172  dim {
173  size: 3
174  }
175  })";
176 
177  double value = 0.75;
178  for (int i = 0; i < ListSize; i++, value += 0.25)
179  {
180  m_Prototext += std::string("float_val : ") + std::to_string(value) + "\n";
181  }
182 
183  m_Prototext +=
184  R"(
185  }
186  }
187  }
188 }
189  )";
190  Setup({}, { "output" });
191  }
192 };
193 
194 using ConstantSingleValueListFixture = ConstantValueListFixture<1>;
195 using ConstantMultipleValueListFixture = ConstantValueListFixture<4>;
196 using ConstantMaxValueListFixture = ConstantValueListFixture<6>;
197 
199 {
200  RunTest<2>({}, { { "output", { 0.75f, 0.75f, 0.75f, 0.75f, 0.75f, 0.75f } } });
201 }
203 {
204  RunTest<2>({}, { { "output", { 0.75f, 1.f, 1.25f, 1.5f, 1.5f, 1.5f } } });
205 }
207 {
208  RunTest<2>({}, { { "output", { 0.75f, 1.f, 1.25f, 1.50f, 1.75f, 2.f } } });
209 }
210 
211 template <bool WithShape, bool WithContent, bool WithValueList>
212 struct ConstantCreateFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
213 {
214  ConstantCreateFixture()
215  {
216  m_Prototext =
217  R"(
218 node {
219  name: "output"
220  op: "Const"
221  attr {
222  key: "dtype"
223  value {
224  type: DT_FLOAT
225  }
226  }
227  attr {
228  key: "value"
229  value {
230  tensor {
231  dtype: DT_FLOAT
232  )";
233 
234  if (WithShape)
235  {
236  m_Prototext +=
237  R"(
238 tensor_shape {
239  dim {
240  size: 2
241  }
242  dim {
243  size: 2
244  }
245 }
246  )";
247  }
248  else
249  {
250  m_Prototext +=
251  R"(
252 tensor_shape {
253 }
254  )";
255  }
256 
257  if (WithContent)
258  {
259  m_Prototext +=
260  R"(
261 tensor_content: "\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?"
262  )";
263  }
264 
265  if (WithValueList)
266  {
267  m_Prototext +=
268  R"(
269 float_val: 1.0
270 float_val: 1.0
271 float_val: 1.0
272 float_val: 1.0
273 float_val: 1.0
274  )";
275  }
276 
277  m_Prototext +=
278  R"(
279  }
280  }
281  }
282 }
283  )";
284  }
285 };
286 
287 using ConstantCreateNoValueListFixture = ConstantCreateFixture<true, false, true>;
288 using ConstantCreateNoValueList2Fixture = ConstantCreateFixture<true, false, false>;
289 using ConstantCreateNoContentFixture = ConstantCreateFixture<true, true, false>;
290 using ConstantCreateNoContent2Fixture = ConstantCreateFixture<true, false, false>;
291 using ConstantCreateNoShapeFixture = ConstantCreateFixture<false, false, false>;
292 using ConstantCreateNoShape2Fixture = ConstantCreateFixture<false, true, false>;
293 using ConstantCreateNoShape3Fixture = ConstantCreateFixture<false, false, true>;
294 
296 {
297  BOOST_REQUIRE_THROW(Setup({}, { "output" }), armnn::ParseException);
298 }
300 {
301  BOOST_REQUIRE_THROW(Setup({}, { "output" }), armnn::ParseException);
302 }
304 {
305  BOOST_REQUIRE_THROW(Setup({}, { "output" }), armnn::ParseException);
306 }
308 {
309  BOOST_REQUIRE_THROW(Setup({}, { "output" }), armnn::ParseException);
310 }
312 {
313  BOOST_REQUIRE_THROW(Setup({}, { "output" }), armnn::ParseException);
314 }
316 {
317  Setup({}, { "output" });
318  RunTest<1>({}, { { "output", { 1.f, 1.f, 1.f, 1.f, 1.f } } });
319 }
320 
BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
ConstantCreateFixture< true, true, false > ConstantCreateNoContentFixture
Definition: Constant.cpp:289
ConstantCreateFixture< true, false, false > ConstantCreateNoContent2Fixture
Definition: Constant.cpp:290
ConstantCreateFixture< false, true, false > ConstantCreateNoShape2Fixture
Definition: Constant.cpp:292
ConstantValueListFixture< 6 > ConstantMaxValueListFixture
Definition: Constant.cpp:196
ConstantCreateFixture< true, false, true > ConstantCreateNoValueListFixture
Definition: Constant.cpp:287
ConstantValueListFixture< 1 > ConstantSingleValueListFixture
Definition: Constant.cpp:194
ConstantCreateFixture< false, false, true > ConstantCreateNoShape3Fixture
Definition: Constant.cpp:293
ConstantCreateFixture< false, false, false > ConstantCreateNoShapeFixture
Definition: Constant.cpp:291
BOOST_FIXTURE_TEST_CASE(SimpleConstantAdd, SimpleConstantAddFixture)
Definition: Constant.cpp:104
BOOST_AUTO_TEST_SUITE_END()
void SetupSingleInputSingleOutput(const std::string &inputName, const std::string &outputName)
Parses and loads the network defined by the m_Prototext string.
ConstantValueListFixture< 4 > ConstantMultipleValueListFixture
Definition: Constant.cpp:195
ConstantCreateFixture< true, false, false > ConstantCreateNoValueList2Fixture
Definition: Constant.cpp:288