ArmNN
 20.02
Split.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
8 
10 
11 #include <boost/test/unit_test.hpp>
12 
13 BOOST_AUTO_TEST_SUITE(TensorflowParser)
14 
15 struct SplitFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
16 {
17  SplitFixture(bool withDimZero=false) {
18  m_Prototext = R"(
19  node {
20  name: "graphInput"
21  op: "Placeholder"
22  attr {
23  key: "dtype"
24  value {
25  type: DT_FLOAT
26  }
27  }
28  attr {
29  key: "shape"
30  value {
31  shape {
32  }
33  }
34  }
35  }
36  node {
37  name: "graphInput2"
38  op: "Placeholder"
39  attr {
40  key: "dtype"
41  value {
42  type: DT_FLOAT
43  }
44  }
45  attr {
46  key: "shape"
47  value {
48  shape {
49  }
50  }
51  }
52  }
53  node {
54  name: "multiplication"
55  op : "Mul"
56  input: "graphInput"
57  input: "graphInput2"
58  attr {
59  key: "T"
60  value {
61  type: DT_FLOAT
62  }
63  }
64  }
65  node {
66  name: "SplitInput"
67  op: "Const"
68  attr {
69  key: "dtype"
70  value {
71  type: DT_INT32
72  }
73  }
74  attr {
75  key: "value"
76  value {
77  tensor {
78  dtype: DT_INT32
79  tensor_shape {
80  }
81  int_val: )";
82 
83  if(withDimZero)
84  {
85  m_Prototext += std::to_string(3);
86  }
87  else
88  {
89  m_Prototext += std::to_string(1);
90  }
91 
92  m_Prototext += R"(
93  }
94  }
95  }
96  }
97  node {
98  name: "Split"
99  op: "Split" )";
100  if(withDimZero)
101  {
102  m_Prototext += "input: \"SplitInput\"\n";
103  m_Prototext += "input: \"multiplication\"\n";
104  }
105  else
106  {
107  m_Prototext += "input: \"graphInput\"\n";
108  m_Prototext += "input: \"SplitInput\"\n";
109  }
110  m_Prototext += R"(
111  attr {
112  key: "num_split"
113  value {
114  i: 2
115  }
116  }
117  }
118  node {
119  name: "Relu_1"
120  op: "Relu"
121  input: "Split:0"
122  attr {
123  key: "T"
124  value {
125  type: DT_FLOAT
126  }
127  }
128  }
129  node {
130  name: "Relu_2"
131  op: "Relu"
132  input:"Split:1"
133  attr {
134  key: "T"
135  value {
136  type: DT_FLOAT
137  }
138  }
139  } )";
140 
141  Setup( { { "graphInput", { 1, 2, 2 , 2} } , { "graphInput2", { 1, 2, 2 , 2} }},
142  { "Relu_1", "Relu_2" });
143  }
144 };
145 
146 struct InputFirstSplitFixture : SplitFixture
147 {
148  InputFirstSplitFixture() : SplitFixture(true) {}
149 };
150 
151 BOOST_FIXTURE_TEST_CASE(ParseAxisOneSplitTwo, SplitFixture)
152 {
153  BOOST_TEST(
154  (m_Parser->GetNetworkOutputBindingInfo("Relu_1").second.GetShape() == armnn::TensorShape({ 1, 1, 2, 2 })));
155 
156  BOOST_TEST(
157  (m_Parser->GetNetworkOutputBindingInfo("Relu_2").second.GetShape() == armnn::TensorShape({ 1, 1, 2, 2 })));
158 
159  RunTest<4>({ { "graphInput", { -1.0f, -0.5f, 1.25f, -3.0f, 0.0f, 0.5f, -0.75f, 1.75f } } },
160  { { "Relu_1", { 0.0f, 0.0f, 1.25f, 0.0f } },
161  { "Relu_2", { 0.0f, 0.5f, 0.0f, 1.75f } } });
162 }
163 
164 BOOST_FIXTURE_TEST_CASE(ParseSplit, InputFirstSplitFixture)
165 {
166 
167  BOOST_TEST(
168  (m_Parser->GetNetworkOutputBindingInfo("Relu_1").second.GetShape() == armnn::TensorShape({ 1, 2, 2, 1 })));
169 
170  BOOST_TEST(
171  (m_Parser->GetNetworkOutputBindingInfo("Relu_2").second.GetShape() == armnn::TensorShape({ 1, 2, 2, 1 })));
172 
173  RunTest<4>({ { "graphInput", { -1.0f, -0.5f, 1.25f, -3.0f, 0.0f, 0.5f, -0.75f , 1.75f } } ,
174  { "graphInput2", { -1.0f, -0.5f, 1.25f, -3.0f, 0.0f, 0.5f, -0.75f , 1.75f } } },
175  { { "Relu_1", { 1.0f, 1.5625f, 0, 0.5625f } },
176  { "Relu_2", { 0.25, 9.0f, 0.25f, 3.0625f } } });
177 }
178 
179 struct SplitLastDimFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
180 {
181  SplitLastDimFixture(bool withDimZero=false) {
182  armnn::IgnoreUnused(withDimZero);
183  m_Prototext = R"(
184  node {
185  name: "Placeholder"
186  op: "Placeholder"
187  attr {
188  key: "dtype"
189  value {
190  type: DT_FLOAT
191  }
192  }
193  attr {
194  key: "shape"
195  value {
196  shape {
197  dim {
198  size: 1
199  }
200  dim {
201  size: 2
202  }
203  dim {
204  size: 2
205  }
206  dim {
207  size: 3
208  }
209  }
210  }
211  }
212  }
213  node {
214  name: "Const"
215  op: "Const"
216  attr {
217  key: "dtype"
218  value {
219  type: DT_INT32
220  }
221  }
222  attr {
223  key: "value"
224  value {
225  tensor {
226  dtype: DT_INT32
227  tensor_shape {
228  }
229  int_val: 3
230  }
231  }
232  }
233  }
234  node {
235  name: "split/split_dim"
236  op: "Const"
237  attr {
238  key: "dtype"
239  value {
240  type: DT_INT32
241  }
242  }
243  attr {
244  key: "value"
245  value {
246  tensor {
247  dtype: DT_INT32
248  tensor_shape {
249  }
250  int_val: 3
251  }
252  }
253  }
254  }
255  node {
256  name: "split"
257  op: "Split"
258  input: "split/split_dim"
259  input: "Placeholder"
260  attr {
261  key: "T"
262  value {
263  type: DT_FLOAT
264  }
265  }
266  attr {
267  key: "num_split"
268  value {
269  i: 3
270  }
271  }
272  }
273  node {
274  name: "sub0/y"
275  op: "Const"
276  attr {
277  key: "dtype"
278  value {
279  type: DT_FLOAT
280  }
281  }
282  attr {
283  key: "value"
284  value {
285  tensor {
286  dtype: DT_FLOAT
287  tensor_shape {
288  }
289  float_val: 3.0
290  }
291  }
292  }
293  }
294  node {
295  name: "sub0"
296  op: "Sub"
297  input: "split"
298  input: "sub0/y"
299  attr {
300  key: "T"
301  value {
302  type: DT_FLOAT
303  }
304  }
305  }
306  node {
307  name: "sub1/y"
308  op: "Const"
309  attr {
310  key: "dtype"
311  value {
312  type: DT_FLOAT
313  }
314  }
315  attr {
316  key: "value"
317  value {
318  tensor {
319  dtype: DT_FLOAT
320  tensor_shape {
321  }
322  float_val: 2.0
323  }
324  }
325  }
326  }
327  node {
328  name: "sub1"
329  op: "Sub"
330  input: "split:1"
331  input: "sub1/y"
332  attr {
333  key: "T"
334  value {
335  type: DT_FLOAT
336  }
337  }
338  }
339  node {
340  name: "sub2/y"
341  op: "Const"
342  attr {
343  key: "dtype"
344  value {
345  type: DT_FLOAT
346  }
347  }
348  attr {
349  key: "value"
350  value {
351  tensor {
352  dtype: DT_FLOAT
353  tensor_shape {
354  }
355  float_val: 1.0
356  }
357  }
358  }
359  }
360  node {
361  name: "sub2"
362  op: "Sub"
363  input: "split:2"
364  input: "sub2/y"
365  attr {
366  key: "T"
367  value {
368  type: DT_FLOAT
369  }
370  }
371  }
372  versions {
373  producer: 27
374  } )";
375 
376  Setup( { { "Placeholder", { 1, 2, 2 , 3} } },
377  { "sub0", "sub1", "sub2" });
378  }
379 };
380 
381 BOOST_FIXTURE_TEST_CASE(SplitLastDimTest, SplitLastDimFixture)
382 {
383  BOOST_TEST(
384  (m_Parser->GetNetworkOutputBindingInfo("sub0").second.GetShape() == armnn::TensorShape({ 1, 2, 2, 1 })));
385 
386  BOOST_TEST(
387  (m_Parser->GetNetworkOutputBindingInfo("sub1").second.GetShape() == armnn::TensorShape({ 1, 2, 2, 1 })));
388 
389  BOOST_TEST(
390  (m_Parser->GetNetworkOutputBindingInfo("sub2").second.GetShape() == armnn::TensorShape({ 1, 2, 2, 1 })));
391 
392  RunTest<4>({ { "Placeholder", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f } } },
393  { { "sub0", { -2.0f, 1.0f, 4.0f, 7.0f } },
394  { "sub1", { 0.0f, 3.0f, 6.0f, 9.0f } },
395  { "sub2", { 2.0f, 5.0f, 8.0f, 11.0f } } });
396 }
397 
BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
void IgnoreUnused(Ts &&...)
DataLayout::NCHW DataLayout::NCHW DataLayout::NHWC DataLayout::NHWC true
BOOST_AUTO_TEST_SUITE_END()
BOOST_FIXTURE_TEST_CASE(ParseAxisOneSplitTwoFloat32, SimpleSplitFixtureFloat32)
Definition: Split.cpp:111