ArmNN  NotReleased
Split.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <boost/test/unit_test.hpp>
9 
10 BOOST_AUTO_TEST_SUITE(TensorflowParser)
11 
12 struct SplitFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
13 {
14  SplitFixture(bool withDimZero=false) {
15  m_Prototext = R"(
16  node {
17  name: "graphInput"
18  op: "Placeholder"
19  attr {
20  key: "dtype"
21  value {
22  type: DT_FLOAT
23  }
24  }
25  attr {
26  key: "shape"
27  value {
28  shape {
29  }
30  }
31  }
32  }
33  node {
34  name: "graphInput2"
35  op: "Placeholder"
36  attr {
37  key: "dtype"
38  value {
39  type: DT_FLOAT
40  }
41  }
42  attr {
43  key: "shape"
44  value {
45  shape {
46  }
47  }
48  }
49  }
50  node {
51  name: "multiplication"
52  op : "Mul"
53  input: "graphInput"
54  input: "graphInput2"
55  attr {
56  key: "T"
57  value {
58  type: DT_FLOAT
59  }
60  }
61  }
62  node {
63  name: "SplitInput"
64  op: "Const"
65  attr {
66  key: "dtype"
67  value {
68  type: DT_INT32
69  }
70  }
71  attr {
72  key: "value"
73  value {
74  tensor {
75  dtype: DT_INT32
76  tensor_shape {
77  }
78  int_val: )";
79 
80  if(withDimZero)
81  {
82  m_Prototext += std::to_string(3);
83  }
84  else
85  {
86  m_Prototext += std::to_string(1);
87  }
88 
89  m_Prototext += R"(
90  }
91  }
92  }
93  }
94  node {
95  name: "Split"
96  op: "Split" )";
97  if(withDimZero)
98  {
99  m_Prototext += "input: \"SplitInput\"\n";
100  m_Prototext += "input: \"multiplication\"\n";
101  }
102  else
103  {
104  m_Prototext += "input: \"graphInput\"\n";
105  m_Prototext += "input: \"SplitInput\"\n";
106  }
107  m_Prototext += R"(
108  attr {
109  key: "num_split"
110  value {
111  i: 2
112  }
113  }
114  }
115  node {
116  name: "Relu_1"
117  op: "Relu"
118  input: "Split:0"
119  attr {
120  key: "T"
121  value {
122  type: DT_FLOAT
123  }
124  }
125  }
126  node {
127  name: "Relu_2"
128  op: "Relu"
129  input:"Split:1"
130  attr {
131  key: "T"
132  value {
133  type: DT_FLOAT
134  }
135  }
136  } )";
137 
138  Setup( { { "graphInput", { 1, 2, 2 , 2} } , { "graphInput2", { 1, 2, 2 , 2} }},
139  { "Relu_1", "Relu_2" });
140  }
141 };
142 
143 struct InputFirstSplitFixture : SplitFixture
144 {
145  InputFirstSplitFixture() : SplitFixture(true) {}
146 };
147 
148 BOOST_FIXTURE_TEST_CASE(ParseAxisOneSplitTwo, SplitFixture)
149 {
150  BOOST_TEST(
151  (m_Parser->GetNetworkOutputBindingInfo("Relu_1").second.GetShape() == armnn::TensorShape({ 1, 1, 2, 2 })));
152 
153  BOOST_TEST(
154  (m_Parser->GetNetworkOutputBindingInfo("Relu_2").second.GetShape() == armnn::TensorShape({ 1, 1, 2, 2 })));
155 
156  RunTest<4>({ { "graphInput", { -1.0f, -0.5f, 1.25f, -3.0f, 0.0f, 0.5f, -0.75f, 1.75f } } },
157  { { "Relu_1", { 0.0f, 0.0f, 1.25f, 0.0f } },
158  { "Relu_2", { 0.0f, 0.5f, 0.0f, 1.75f } } });
159 }
160 
161 BOOST_FIXTURE_TEST_CASE(ParseSplit, InputFirstSplitFixture)
162 {
163 
164  BOOST_TEST(
165  (m_Parser->GetNetworkOutputBindingInfo("Relu_1").second.GetShape() == armnn::TensorShape({ 1, 2, 2, 1 })));
166 
167  BOOST_TEST(
168  (m_Parser->GetNetworkOutputBindingInfo("Relu_2").second.GetShape() == armnn::TensorShape({ 1, 2, 2, 1 })));
169 
170  RunTest<4>({ { "graphInput", { -1.0f, -0.5f, 1.25f, -3.0f, 0.0f, 0.5f, -0.75f , 1.75f } } ,
171  { "graphInput2", { -1.0f, -0.5f, 1.25f, -3.0f, 0.0f, 0.5f, -0.75f , 1.75f } } },
172  { { "Relu_1", { 1.0f, 1.5625f, 0, 0.5625f } },
173  { "Relu_2", { 0.25, 9.0f, 0.25f, 3.0625f } } });
174 }
175 
176 struct SplitLastDimFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
177 {
178  SplitLastDimFixture(bool withDimZero=false) {
179  boost::ignore_unused(withDimZero);
180  m_Prototext = R"(
181  node {
182  name: "Placeholder"
183  op: "Placeholder"
184  attr {
185  key: "dtype"
186  value {
187  type: DT_FLOAT
188  }
189  }
190  attr {
191  key: "shape"
192  value {
193  shape {
194  dim {
195  size: 1
196  }
197  dim {
198  size: 2
199  }
200  dim {
201  size: 2
202  }
203  dim {
204  size: 3
205  }
206  }
207  }
208  }
209  }
210  node {
211  name: "Const"
212  op: "Const"
213  attr {
214  key: "dtype"
215  value {
216  type: DT_INT32
217  }
218  }
219  attr {
220  key: "value"
221  value {
222  tensor {
223  dtype: DT_INT32
224  tensor_shape {
225  }
226  int_val: 3
227  }
228  }
229  }
230  }
231  node {
232  name: "split/split_dim"
233  op: "Const"
234  attr {
235  key: "dtype"
236  value {
237  type: DT_INT32
238  }
239  }
240  attr {
241  key: "value"
242  value {
243  tensor {
244  dtype: DT_INT32
245  tensor_shape {
246  }
247  int_val: 3
248  }
249  }
250  }
251  }
252  node {
253  name: "split"
254  op: "Split"
255  input: "split/split_dim"
256  input: "Placeholder"
257  attr {
258  key: "T"
259  value {
260  type: DT_FLOAT
261  }
262  }
263  attr {
264  key: "num_split"
265  value {
266  i: 3
267  }
268  }
269  }
270  node {
271  name: "sub0/y"
272  op: "Const"
273  attr {
274  key: "dtype"
275  value {
276  type: DT_FLOAT
277  }
278  }
279  attr {
280  key: "value"
281  value {
282  tensor {
283  dtype: DT_FLOAT
284  tensor_shape {
285  }
286  float_val: 3.0
287  }
288  }
289  }
290  }
291  node {
292  name: "sub0"
293  op: "Sub"
294  input: "split"
295  input: "sub0/y"
296  attr {
297  key: "T"
298  value {
299  type: DT_FLOAT
300  }
301  }
302  }
303  node {
304  name: "sub1/y"
305  op: "Const"
306  attr {
307  key: "dtype"
308  value {
309  type: DT_FLOAT
310  }
311  }
312  attr {
313  key: "value"
314  value {
315  tensor {
316  dtype: DT_FLOAT
317  tensor_shape {
318  }
319  float_val: 2.0
320  }
321  }
322  }
323  }
324  node {
325  name: "sub1"
326  op: "Sub"
327  input: "split:1"
328  input: "sub1/y"
329  attr {
330  key: "T"
331  value {
332  type: DT_FLOAT
333  }
334  }
335  }
336  node {
337  name: "sub2/y"
338  op: "Const"
339  attr {
340  key: "dtype"
341  value {
342  type: DT_FLOAT
343  }
344  }
345  attr {
346  key: "value"
347  value {
348  tensor {
349  dtype: DT_FLOAT
350  tensor_shape {
351  }
352  float_val: 1.0
353  }
354  }
355  }
356  }
357  node {
358  name: "sub2"
359  op: "Sub"
360  input: "split:2"
361  input: "sub2/y"
362  attr {
363  key: "T"
364  value {
365  type: DT_FLOAT
366  }
367  }
368  }
369  versions {
370  producer: 27
371  } )";
372 
373  Setup( { { "Placeholder", { 1, 2, 2 , 3} } },
374  { "sub0", "sub1", "sub2" });
375  }
376 };
377 
378 BOOST_FIXTURE_TEST_CASE(SplitLastDimTest, SplitLastDimFixture)
379 {
380  BOOST_TEST(
381  (m_Parser->GetNetworkOutputBindingInfo("sub0").second.GetShape() == armnn::TensorShape({ 1, 2, 2, 1 })));
382 
383  BOOST_TEST(
384  (m_Parser->GetNetworkOutputBindingInfo("sub1").second.GetShape() == armnn::TensorShape({ 1, 2, 2, 1 })));
385 
386  BOOST_TEST(
387  (m_Parser->GetNetworkOutputBindingInfo("sub2").second.GetShape() == armnn::TensorShape({ 1, 2, 2, 1 })));
388 
389  RunTest<4>({ { "Placeholder", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f } } },
390  { { "sub0", { -2.0f, 1.0f, 4.0f, 7.0f } },
391  { "sub1", { 0.0f, 3.0f, 6.0f, 9.0f } },
392  { "sub2", { 2.0f, 5.0f, 8.0f, 11.0f } } });
393 }
394 
BOOST_FIXTURE_TEST_CASE(ParseAxisOneSplitTwoFloat32, SimpleSplitFixtureFloat32)
Definition: Split.cpp:111
DataLayout::NCHW DataLayout::NCHW DataLayout::NHWC DataLayout::NHWC true
BOOST_AUTO_TEST_SUITE_END()
BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)