ArmNN
 21.02
TestDependencies.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <boost/test/unit_test.hpp>
9 
10 BOOST_AUTO_TEST_SUITE(TensorflowParser)
11 
12 // Graph which tests that nodes are re-ordered in the queue when they are encountered a second time.
13 // In this case R0 will be encountered first via R1 and then via R2. At that time
14 // we need to make sure that R0 (and the I on which it is dependent) is moved to the front again
15 // so that it is before both R1 and R2.
16 // I
17 // |
18 // R0
19 // / \'
20 // R1 R2
21 // \ |
22 // \ R3
23 // \|
24 // O
25 struct RediscoveredDependenciesFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
26 {
27  RediscoveredDependenciesFixture()
28  {
29  // Input = tf.placeholder(tf.float32, 1, "input")
30  // Relu0 = tf.nn.relu(input, "relu0")
31  // Relu1 = tf.nn.relu(relu0, "relu1")
32  // Relu2 = tf.nn.relu(relu0, "relu2")
33  // Relu3 = tf.nn.relu(relu2, "relu3")
34  // Output = tf.add(relu1, relu3, "output")
35  m_Prototext = R"(
36  node {
37  name: "input"
38  op: "Placeholder"
39  attr {
40  key: "dtype"
41  value {
42  type: DT_FLOAT
43  }
44  }
45  attr {
46  key: "shape"
47  value {
48  shape {
49  dim {
50  size: 1
51  }
52  }
53  }
54  }
55  }
56  node {
57  name: "relu0"
58  op: "Relu"
59  input: "input"
60  attr {
61  key: "T"
62  value {
63  type: DT_FLOAT
64  }
65  }
66  }
67  node {
68  name: "relu1"
69  op: "Relu"
70  input: "relu0"
71  attr {
72  key: "T"
73  value {
74  type: DT_FLOAT
75  }
76  }
77  }
78  node {
79  name: "relu2"
80  op: "Relu"
81  input: "relu0"
82  attr {
83  key: "T"
84  value {
85  type: DT_FLOAT
86  }
87  }
88  }
89  node {
90  name: "relu3"
91  op: "Relu"
92  input: "relu2"
93  attr {
94  key: "T"
95  value {
96  type: DT_FLOAT
97  }
98  }
99  }
100  node {
101  name: "output"
102  op: "Add"
103  input: "relu1"
104  input: "relu3"
105  attr {
106  key: "T"
107  value {
108  type: DT_FLOAT
109  }
110  }
111  }
112  )";
113  SetupSingleInputSingleOutput({ 1 }, "input", "output");
114  }
115 };
116 
117 BOOST_FIXTURE_TEST_CASE(RediscoveredDependencies, RediscoveredDependenciesFixture)
118 {
119  RunTest<1>({1}, {2});
120 }
121 
122 // Tests that a simple cycle in the tensorflow graph will be detected and an exception thrown, rather than the TfParser
123 // getting stuck in an infinite loop.
125 {
126  const char* prototext = R"(
127 node {
128  name: "r1"
129  op: "Relu"
130  input: "r2"
131  attr {
132  key: "T"
133  value {
134  type: DT_FLOAT
135  }
136  }
137 }
138 node {
139  name: "r2"
140  op: "Relu"
141  input: "r1"
142  attr {
143  key: "T"
144  value {
145  type: DT_FLOAT
146  }
147  }
148 }
149  )";
151  BOOST_CHECK_THROW(parser->CreateNetworkFromString(prototext, {}, { "r2" }), armnn::ParseException);
152 }
153 
154 // Similar to the above SimpleCycle test, but has a single node which connects to itself.
155 BOOST_AUTO_TEST_CASE(SingleNodeCycle)
156 {
157  const char* prototext = R"(
158 node {
159  name: "r1"
160  op: "Relu"
161  input: "r1"
162  attr {
163  key: "T"
164  value {
165  type: DT_FLOAT
166  }
167  }
168 }
169  )";
171  BOOST_CHECK_THROW(parser->CreateNetworkFromString(prototext, {}, { "r1" }), armnn::ParseException);
172 }
173 
174 // Similar to the above SimpleCycle test, but with a more complicated graph.
175 // I
176 // |
177 // A2---<---<-
178 // / \' |
179 // R1 R2 |
180 // \ | |
181 // \ R3 |
182 // \| |
183 // A1-->--->|
184 //
185 BOOST_AUTO_TEST_CASE(ComplexCycle)
186 {
187  // Input = tf.placeholder(tf.float32, 1, "input")
188  // Add2 = tf.nn.relu(input, add1, "add2") // This line won't actually run in TF, because add1 is not yet defined
189  // Relu1 = tf.nn.relu(relu0, "relu1")
190  // Relu2 = tf.nn.relu(relu0, "relu2")
191  // Relu3 = tf.nn.relu(relu2, "relu3")
192  // Add1 = tf.add(relu1, relu3, "add1")
193  const char* prototext = R"(
194  node {
195  name: "input"
196  op: "Placeholder"
197  attr {
198  key: "dtype"
199  value {
200  type: DT_FLOAT
201  }
202  }
203  attr {
204  key: "shape"
205  value {
206  shape {
207  dim {
208  size: 1
209  }
210  }
211  }
212  }
213  }
214  node {
215  name: "add2"
216  op: "Add"
217  input: "input"
218  input: "add1"
219  attr {
220  key: "T"
221  value {
222  type: DT_FLOAT
223  }
224  }
225  }
226  node {
227  name: "relu1"
228  op: "Relu"
229  input: "add2"
230  attr {
231  key: "T"
232  value {
233  type: DT_FLOAT
234  }
235  }
236  }
237  node {
238  name: "relu2"
239  op: "Relu"
240  input: "add2"
241  attr {
242  key: "T"
243  value {
244  type: DT_FLOAT
245  }
246  }
247  }
248  node {
249  name: "relu3"
250  op: "Relu"
251  input: "relu2"
252  attr {
253  key: "T"
254  value {
255  type: DT_FLOAT
256  }
257  }
258  }
259  node {
260  name: "add1"
261  op: "Add"
262  input: "relu1"
263  input: "relu3"
264  attr {
265  key: "T"
266  value {
267  type: DT_FLOAT
268  }
269  }
270  }
271  )";
273  BOOST_CHECK_THROW(parser->CreateNetworkFromString(prototext, {}, { "add1" }), armnn::ParseException);
274 }
275 
276 // Tests that a graph with an input that is not present throws a ParseException.
277 BOOST_AUTO_TEST_CASE(InvalidInput)
278 {
279  const char* prototext = R"(
280 node {
281  name: "r1"
282  op: "Relu"
283  input: "a-node-that-does-not-exist"
284  attr {
285  key: "T"
286  value {
287  type: DT_FLOAT
288  }
289  }
290 }
291  )";
293  BOOST_CHECK_THROW(parser->CreateNetworkFromString(prototext, {}, { "r1" }), armnn::ParseException);
294 }
295 
BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
std::unique_ptr< ITfParser, void(*)(ITfParser *parser)> ITfParserPtr
Definition: ITfParser.hpp:22
BOOST_AUTO_TEST_CASE(SimpleCycle)
BOOST_FIXTURE_TEST_CASE(RediscoveredDependencies, RediscoveredDependenciesFixture)
BOOST_AUTO_TEST_SUITE_END()
static ITfParserPtr Create()
Definition: TfParser.cpp:48
void SetupSingleInputSingleOutput(const std::string &inputName, const std::string &outputName)
Parses and loads the network defined by the m_Prototext string.