ArmNN
 21.02
Split.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <boost/test/unit_test.hpp>
8 #include "../TfLiteParser.hpp"
9 
10 #include <string>
11 #include <iostream>
12 
13 BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
14 
15 struct SplitFixture : public ParserFlatbuffersFixture
16 {
17  explicit SplitFixture(const std::string& inputShape,
18  const std::string& axisShape,
19  const std::string& numSplits,
20  const std::string& outputShape1,
21  const std::string& outputShape2,
22  const std::string& axisData,
23  const std::string& dataType)
24  {
25  m_JsonString = R"(
26  {
27  "version": 3,
28  "operator_codes": [ { "builtin_code": "SPLIT" } ],
29  "subgraphs": [ {
30  "tensors": [
31  {
32  "shape": )" + inputShape + R"(,
33  "type": )" + dataType + R"(,
34  "buffer": 0,
35  "name": "inputTensor",
36  "quantization": {
37  "min": [ 0.0 ],
38  "max": [ 255.0 ],
39  "scale": [ 1.0 ],
40  "zero_point": [ 0 ],
41  }
42  },
43  {
44  "shape": )" + axisShape + R"(,
45  "type": "INT32",
46  "buffer": 1,
47  "name": "axis",
48  "quantization": {
49  "min": [ 0.0 ],
50  "max": [ 255.0 ],
51  "scale": [ 1.0 ],
52  "zero_point": [ 0 ],
53  }
54  },
55  {
56  "shape": )" + outputShape1 + R"( ,
57  "type":)" + dataType + R"(,
58  "buffer": 2,
59  "name": "outputTensor1",
60  "quantization": {
61  "min": [ 0.0 ],
62  "max": [ 255.0 ],
63  "scale": [ 1.0 ],
64  "zero_point": [ 0 ],
65  }
66  },
67  {
68  "shape": )" + outputShape2 + R"( ,
69  "type":)" + dataType + R"(,
70  "buffer": 3,
71  "name": "outputTensor2",
72  "quantization": {
73  "min": [ 0.0 ],
74  "max": [ 255.0 ],
75  "scale": [ 1.0 ],
76  "zero_point": [ 0 ],
77  }
78  }
79  ],
80  "inputs": [ 0 ],
81  "outputs": [ 2, 3 ],
82  "operators": [
83  {
84  "opcode_index": 0,
85  "inputs": [ 1, 0 ],
86  "outputs": [ 2, 3 ],
87  "builtin_options_type": "SplitOptions",
88  "builtin_options": {
89  "num_splits": )" + numSplits + R"(
90  },
91  "custom_options_format": "FLEXBUFFERS"
92  }
93  ],
94  } ],
95  "buffers" : [ {}, {"data": )" + axisData + R"( }, {}, {} ]
96  }
97  )";
98 
99  Setup();
100  }
101 };
102 
103 
104 struct SimpleSplitFixtureFloat32 : SplitFixture
105 {
106  SimpleSplitFixtureFloat32()
107  : SplitFixture( "[ 2, 2, 2, 2 ]", "[ ]", "2", "[ 2, 1, 2, 2 ]", "[ 2, 1, 2, 2 ]", "[ 1, 0, 0, 0 ]", "FLOAT32")
108  {}
109 };
110 
111 BOOST_FIXTURE_TEST_CASE(ParseAxisOneSplitTwoFloat32, SimpleSplitFixtureFloat32)
112 {
113 
114  RunTest<4, armnn::DataType::Float32>(
115  0,
116  { {"inputTensor", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f,
117  9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f } } },
118  { {"outputTensor1", { 1.0f, 2.0f, 3.0f, 4.0f, 9.0f, 10.0f, 11.0f, 12.0f } },
119  {"outputTensor2", { 5.0f, 6.0f, 7.0f, 8.0f, 13.0f, 14.0f, 15.0f, 16.0f } } });
120 }
121 
122 struct SimpleSplitAxisThreeFixtureFloat32 : SplitFixture
123 {
124  SimpleSplitAxisThreeFixtureFloat32()
125  : SplitFixture( "[ 2, 2, 2, 2 ]", "[ ]", "2", "[ 2, 2, 2, 1 ]", "[ 2, 2, 2, 1 ]", "[ 3, 0, 0, 0 ]", "FLOAT32")
126  {}
127 };
128 
129 BOOST_FIXTURE_TEST_CASE(ParseAxisThreeSplitTwoFloat32, SimpleSplitAxisThreeFixtureFloat32)
130 {
131  RunTest<4, armnn::DataType::Float32>(
132  0,
133  { {"inputTensor", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f,
134  9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f } } },
135  { {"outputTensor1", { 1.0f, 3.0f, 5.0f, 7.0f, 9.0f, 11.0f, 13.0f, 15.0f } },
136  {"outputTensor2", { 2.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f, 14.0f, 16.0f } } } );
137 }
138 
139 struct SimpleSplit2DFixtureFloat32 : SplitFixture
140 {
141  SimpleSplit2DFixtureFloat32()
142  : SplitFixture( "[ 1, 8 ]", "[ ]", "2", "[ 1, 4 ]", "[ 1, 4 ]", "[ 1, 0, 0, 0 ]", "FLOAT32")
143  {}
144 };
145 
146 BOOST_FIXTURE_TEST_CASE(SimpleSplit2DFloat32, SimpleSplit2DFixtureFloat32)
147 {
148  RunTest<2, armnn::DataType::Float32>(
149  0,
150  { {"inputTensor", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f } } },
151  { {"outputTensor1", { 1.0f, 2.0f, 3.0f, 4.0f } },
152  {"outputTensor2", { 5.0f, 6.0f, 7.0f, 8.0f } } } );
153 }
154 
155 struct SimpleSplit3DFixtureFloat32 : SplitFixture
156 {
157  SimpleSplit3DFixtureFloat32()
158  : SplitFixture( "[ 1, 8, 2 ]", "[ ]", "2", "[ 1, 4, 2 ]", "[ 1, 4, 2 ]", "[ 1, 0, 0, 0 ]", "FLOAT32")
159  {}
160 };
161 
162 BOOST_FIXTURE_TEST_CASE(SimpleSplit3DFloat32, SimpleSplit3DFixtureFloat32)
163 {
164  RunTest<3, armnn::DataType::Float32>(
165  0,
166  { {"inputTensor", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f,
167  9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f } } },
168  { {"outputTensor1", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f } },
169  {"outputTensor2", { 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f } } } );
170 }
171 
172 struct SimpleSplitFixtureUint8 : SplitFixture
173 {
174  SimpleSplitFixtureUint8()
175  : SplitFixture( "[ 2, 2, 2, 2 ]", "[ ]", "2", "[ 2, 1, 2, 2 ]", "[ 2, 1, 2, 2 ]", "[ 1, 0, 0, 0 ]", "UINT8")
176  {}
177 };
178 
179 BOOST_FIXTURE_TEST_CASE(ParseAxisOneSplitTwoUint8, SimpleSplitFixtureUint8)
180 {
181 
182  RunTest<4, armnn::DataType::QAsymmU8>(
183  0,
184  { {"inputTensor", { 1, 2, 3, 4, 5, 6, 7, 8,
185  9, 10, 11, 12, 13, 14, 15, 16 } } },
186  { {"outputTensor1", { 1, 2, 3, 4, 9, 10, 11, 12 } },
187  {"outputTensor2", { 5, 6, 7, 8, 13, 14, 15, 16 } } });
188 }
189 
190 struct SimpleSplitAxisThreeFixtureUint8 : SplitFixture
191 {
192  SimpleSplitAxisThreeFixtureUint8()
193  : SplitFixture( "[ 2, 2, 2, 2 ]", "[ ]", "2", "[ 2, 2, 2, 1 ]", "[ 2, 2, 2, 1 ]", "[ 3, 0, 0, 0 ]", "UINT8")
194  {}
195 };
196 
197 BOOST_FIXTURE_TEST_CASE(ParseAxisThreeSplitTwoUint8, SimpleSplitAxisThreeFixtureUint8)
198 {
199  RunTest<4, armnn::DataType::QAsymmU8>(
200  0,
201  { {"inputTensor", { 1, 2, 3, 4, 5, 6, 7, 8,
202  9, 10, 11, 12, 13, 14, 15, 16 } } },
203  { {"outputTensor1", { 1, 3, 5, 7, 9, 11, 13, 15 } },
204  {"outputTensor2", { 2, 4, 6, 8, 10, 12, 14, 16 } } } );
205 }
206 
207 struct SimpleSplit2DFixtureUint8 : SplitFixture
208 {
209  SimpleSplit2DFixtureUint8()
210  : SplitFixture( "[ 1, 8 ]", "[ ]", "2", "[ 1, 4 ]", "[ 1, 4 ]", "[ 1, 0, 0, 0 ]", "UINT8")
211  {}
212 };
213 
214 BOOST_FIXTURE_TEST_CASE(SimpleSplit2DUint8, SimpleSplit2DFixtureUint8)
215 {
216  RunTest<2, armnn::DataType::QAsymmU8>(
217  0,
218  { {"inputTensor", { 1, 2, 3, 4, 5, 6, 7, 8 } } },
219  { {"outputTensor1", { 1, 2, 3, 4 } },
220  {"outputTensor2", { 5, 6, 7, 8 } } } );
221 }
222 
223 struct SimpleSplit3DFixtureUint8 : SplitFixture
224 {
225  SimpleSplit3DFixtureUint8()
226  : SplitFixture( "[ 1, 8, 2 ]", "[ ]", "2", "[ 1, 4, 2 ]", "[ 1, 4, 2 ]", "[ 1, 0, 0, 0 ]", "UINT8")
227  {}
228 };
229 
230 BOOST_FIXTURE_TEST_CASE(SimpleSplit3DUint8, SimpleSplit3DFixtureUint8)
231 {
232  RunTest<3, armnn::DataType::QAsymmU8>(
233  0,
234  { {"inputTensor", { 1, 2, 3, 4, 5, 6, 7, 8,
235  9, 10, 11, 12, 13, 14, 15, 16 } } },
236  { {"outputTensor1", { 1, 2, 3, 4, 5, 6, 7, 8 } },
237  {"outputTensor2", { 9, 10, 11, 12, 13, 14, 15, 16 } } } );
238 }
239 
BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
BOOST_AUTO_TEST_SUITE_END()
BOOST_FIXTURE_TEST_CASE(ParseAxisOneSplitTwoFloat32, SimpleSplitFixtureFloat32)
Definition: Split.cpp:111