ArmNN
 20.11
ConcatOfConcats.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <boost/test/unit_test.hpp>
9 
10 BOOST_AUTO_TEST_SUITE(TensorflowParser)
11 
12 struct ConcatOfConcatsFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
13 {
14  explicit ConcatOfConcatsFixture(const armnn::TensorShape& inputShape0, const armnn::TensorShape& inputShape1,
15  const armnn::TensorShape& inputShape2, const armnn::TensorShape& inputShape3,
16  unsigned int concatDim)
17  {
18  m_Prototext = R"(
19  node {
20  name: "graphInput0"
21  op: "Placeholder"
22  attr {
23  key: "dtype"
24  value {
25  type: DT_FLOAT
26  }
27  }
28  attr {
29  key: "shape"
30  value {
31  shape {
32  }
33  }
34  }
35  }
36  node {
37  name: "graphInput1"
38  op: "Placeholder"
39  attr {
40  key: "dtype"
41  value {
42  type: DT_FLOAT
43  }
44  }
45  attr {
46  key: "shape"
47  value {
48  shape {
49  }
50  }
51  }
52  }
53  node {
54  name: "graphInput2"
55  op: "Placeholder"
56  attr {
57  key: "dtype"
58  value {
59  type: DT_FLOAT
60  }
61  }
62  attr {
63  key: "shape"
64  value {
65  shape {
66  }
67  }
68  }
69  }
70  node {
71  name: "graphInput3"
72  op: "Placeholder"
73  attr {
74  key: "dtype"
75  value {
76  type: DT_FLOAT
77  }
78  }
79  attr {
80  key: "shape"
81  value {
82  shape {
83  }
84  }
85  }
86  }
87  node {
88  name: "Relu"
89  op: "Relu"
90  input: "graphInput0"
91  attr {
92  key: "T"
93  value {
94  type: DT_FLOAT
95  }
96  }
97  }
98  node {
99  name: "Relu_1"
100  op: "Relu"
101  input: "graphInput1"
102  attr {
103  key: "T"
104  value {
105  type: DT_FLOAT
106  }
107  }
108  }
109  node {
110  name: "Relu_2"
111  op: "Relu"
112  input: "graphInput2"
113  attr {
114  key: "T"
115  value {
116  type: DT_FLOAT
117  }
118  }
119  }
120  node {
121  name: "Relu_3"
122  op: "Relu"
123  input: "graphInput3"
124  attr {
125  key: "T"
126  value {
127  type: DT_FLOAT
128  }
129  }
130  }
131  node {
132  name: "concat/axis"
133  op: "Const"
134  attr {
135  key: "dtype"
136  value {
137  type: DT_INT32
138  }
139  }
140  attr {
141  key: "value"
142  value {
143  tensor {
144  dtype: DT_INT32
145  tensor_shape {
146  }
147  int_val: )";
148  m_Prototext += std::to_string(concatDim);
149  m_Prototext += R"(
150  }
151  }
152  }
153  }
154  node {
155  name: "concat"
156  op: "ConcatV2"
157  input: "Relu"
158  input: "Relu_1"
159  input: "concat/axis"
160  attr {
161  key: "N"
162  value {
163  i: 2
164  }
165  }
166  attr {
167  key: "T"
168  value {
169  type: DT_FLOAT
170  }
171  }
172  attr {
173  key: "Tidx"
174  value {
175  type: DT_INT32
176  }
177  }
178  }
179  node {
180  name: "concat_1/axis"
181  op: "Const"
182  attr {
183  key: "dtype"
184  value {
185  type: DT_INT32
186  }
187  }
188  attr {
189  key: "value"
190  value {
191  tensor {
192  dtype: DT_INT32
193  tensor_shape {
194  }
195  int_val: )";
196  m_Prototext += std::to_string(concatDim);
197  m_Prototext += R"(
198  }
199  }
200  }
201  }
202  node {
203  name: "concat_1"
204  op: "ConcatV2"
205  input: "Relu_2"
206  input: "Relu_3"
207  input: "concat_1/axis"
208  attr {
209  key: "N"
210  value {
211  i: 2
212  }
213  }
214  attr {
215  key: "T"
216  value {
217  type: DT_FLOAT
218  }
219  }
220  attr {
221  key: "Tidx"
222  value {
223  type: DT_INT32
224  }
225  }
226  }
227  node {
228  name: "concat_2/axis"
229  op: "Const"
230  attr {
231  key: "dtype"
232  value {
233  type: DT_INT32
234  }
235  }
236  attr {
237  key: "value"
238  value {
239  tensor {
240  dtype: DT_INT32
241  tensor_shape {
242  }
243  int_val: )";
244  m_Prototext += std::to_string(concatDim);
245  m_Prototext += R"(
246  }
247  }
248  }
249  }
250  node {
251  name: "concat_2"
252  op: "ConcatV2"
253  input: "concat"
254  input: "concat_1"
255  input: "concat_2/axis"
256  attr {
257  key: "N"
258  value {
259  i: 2
260  }
261  }
262  attr {
263  key: "T"
264  value {
265  type: DT_FLOAT
266  }
267  }
268  attr {
269  key: "Tidx"
270  value {
271  type: DT_INT32
272  }
273  }
274  }
275  )";
276 
277  Setup({{ "graphInput0", inputShape0 },
278  { "graphInput1", inputShape1 },
279  { "graphInput2", inputShape2 },
280  { "graphInput3", inputShape3}}, {"concat_2"});
281  }
282 };
283 
284 struct ConcatOfConcatsFixtureNCHW : ConcatOfConcatsFixture
285 {
286  ConcatOfConcatsFixtureNCHW() : ConcatOfConcatsFixture({ 1, 1, 2, 2 }, { 1, 1, 2, 2 }, { 1, 1, 2, 2 },
287  { 1, 1, 2, 2 }, 1 ) {}
288 };
289 
290 struct ConcatOfConcatsFixtureNHWC : ConcatOfConcatsFixture
291 {
292  ConcatOfConcatsFixtureNHWC() : ConcatOfConcatsFixture({ 1, 1, 2, 2 }, { 1, 1, 2, 2 }, { 1, 1, 2, 2 },
293  { 1, 1, 2, 2 }, 3 ) {}
294 };
295 
296 BOOST_FIXTURE_TEST_CASE(ParseConcatOfConcatsNCHW, ConcatOfConcatsFixtureNCHW)
297 {
298  RunTest<4>({{"graphInput0", {0.0, 1.0, 2.0, 3.0}},
299  {"graphInput1", {4.0, 5.0, 6.0, 7.0}},
300  {"graphInput2", {8.0, 9.0, 10.0, 11.0}},
301  {"graphInput3", {12.0, 13.0, 14.0, 15.0}}},
302  {{"concat_2", { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0,
303  8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0 }}});
304 }
305 
306 BOOST_FIXTURE_TEST_CASE(ParseConcatOfConcatsNHWC, ConcatOfConcatsFixtureNHWC)
307 {
308  RunTest<4>({{"graphInput0", {0.0, 1.0, 2.0, 3.0}},
309  {"graphInput1", {4.0, 5.0, 6.0, 7.0}},
310  {"graphInput2", {8.0, 9.0, 10.0, 11.0}},
311  {"graphInput3", {12.0, 13.0, 14.0, 15.0}}},
312  {{"concat_2", { 0.0, 1.0, 4.0, 5.0, 8.0, 9.0, 12.0, 13.0,
313  2.0, 3.0, 6.0, 7.0, 10.0, 11.0, 14.0, 15.0 }}});
314 }
315 
BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
BOOST_AUTO_TEST_SUITE_END()
BOOST_FIXTURE_TEST_CASE(ParseConcatOfConcatsNCHW, ConcatOfConcatsFixtureNCHW)