ArmNN
 21.02
SplitV.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2020 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <boost/test/unit_test.hpp>
8 #include "../TfLiteParser.hpp"
9 
10 #include <string>
11 #include <iostream>
12 
13 BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
14 
15 struct SplitVFixture : public ParserFlatbuffersFixture
16 {
17  explicit SplitVFixture(const std::string& inputShape,
18  const std::string& splitValues,
19  const std::string& sizeSplitsShape,
20  const std::string& axisShape,
21  const std::string& numSplits,
22  const std::string& outputShape1,
23  const std::string& outputShape2,
24  const std::string& axisData,
25  const std::string& dataType)
26  {
27  m_JsonString = R"(
28  {
29  "version": 3,
30  "operator_codes": [ { "builtin_code": "SPLIT_V" } ],
31  "subgraphs": [ {
32  "tensors": [
33  {
34  "shape": )" + inputShape + R"(,
35  "type": )" + dataType + R"(,
36  "buffer": 0,
37  "name": "inputTensor",
38  "quantization": {
39  "min": [ 0.0 ],
40  "max": [ 255.0 ],
41  "scale": [ 1.0 ],
42  "zero_point": [ 0 ],
43  }
44  },
45  {
46  "shape": )" + sizeSplitsShape + R"(,
47  "type": "INT32",
48  "buffer": 1,
49  "name": "sizeSplits",
50  "quantization": {
51  "min": [ 0.0 ],
52  "max": [ 255.0 ],
53  "scale": [ 1.0 ],
54  "zero_point": [ 0 ],
55  }
56  },
57  {
58  "shape": )" + axisShape + R"(,
59  "type": "INT32",
60  "buffer": 2,
61  "name": "axis",
62  "quantization": {
63  "min": [ 0.0 ],
64  "max": [ 255.0 ],
65  "scale": [ 1.0 ],
66  "zero_point": [ 0 ],
67  }
68  },
69  {
70  "shape": )" + outputShape1 + R"( ,
71  "type":)" + dataType + R"(,
72  "buffer": 3,
73  "name": "outputTensor1",
74  "quantization": {
75  "min": [ 0.0 ],
76  "max": [ 255.0 ],
77  "scale": [ 1.0 ],
78  "zero_point": [ 0 ],
79  }
80  },
81  {
82  "shape": )" + outputShape2 + R"( ,
83  "type":)" + dataType + R"(,
84  "buffer": 4,
85  "name": "outputTensor2",
86  "quantization": {
87  "min": [ 0.0 ],
88  "max": [ 255.0 ],
89  "scale": [ 1.0 ],
90  "zero_point": [ 0 ],
91  }
92  }
93  ],
94  "inputs": [ 0, 1, 2 ],
95  "outputs": [ 3, 4 ],
96  "operators": [
97  {
98  "opcode_index": 0,
99  "inputs": [ 0, 1, 2 ],
100  "outputs": [ 3, 4 ],
101  "builtin_options_type": "SplitVOptions",
102  "builtin_options": {
103  "num_splits": )" + numSplits + R"(
104  },
105  "custom_options_format": "FLEXBUFFERS"
106  }
107  ],
108  } ],
109  "buffers" : [ {}, { "data": )" + splitValues + R"( }, { "data": )" + axisData + R"( }, {}, {}]
110  }
111  )";
112 
113  Setup();
114  }
115 };
116 
117 /*
118  * Tested inferred splitSizes with splitValues [-1, 1] locally.
119  */
120 
121 struct SimpleSplitVAxisOneFixture : SplitVFixture
122 {
123  SimpleSplitVAxisOneFixture()
124  : SplitVFixture( "[ 4, 2, 2, 2 ]", "[ 1, 0, 0, 0, 3, 0, 0, 0 ]", "[ 2 ]","[ ]", "2",
125  "[ 1, 2, 2, 2 ]", "[ 3, 2, 2, 2 ]", "[ 0, 0, 0, 0 ]", "FLOAT32")
126  {}
127 };
128 
129 BOOST_FIXTURE_TEST_CASE(ParseAxisOneSplitVTwo, SimpleSplitVAxisOneFixture)
130 {
131  RunTest<4, armnn::DataType::Float32>(
132  0,
133  { {"inputTensor", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f,
134  9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f,
135  17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f,
136  25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f } } },
137  { {"outputTensor1", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f } },
138  {"outputTensor2", { 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f,
139  17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f,
140  25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f } } } );
141 }
142 
143 struct SimpleSplitVAxisTwoFixture : SplitVFixture
144 {
145  SimpleSplitVAxisTwoFixture()
146  : SplitVFixture( "[ 2, 4, 2, 2 ]", "[ 3, 0, 0, 0, 1, 0, 0, 0 ]", "[ 2 ]","[ ]", "2",
147  "[ 2, 3, 2, 2 ]", "[ 2, 1, 2, 2 ]", "[ 1, 0, 0, 0 ]", "FLOAT32")
148  {}
149 };
150 
151 BOOST_FIXTURE_TEST_CASE(ParseAxisTwoSplitVTwo, SimpleSplitVAxisTwoFixture)
152 {
153  RunTest<4, armnn::DataType::Float32>(
154  0,
155  { {"inputTensor", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f,
156  9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f,
157  17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f,
158  25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f } } },
159  { {"outputTensor1", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f,
160  9.0f, 10.0f, 11.0f, 12.0f, 17.0f, 18.0f, 19.0f, 20.0f,
161  21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f } },
162  {"outputTensor2", { 13.0f, 14.0f, 15.0f, 16.0f, 29.0f, 30.0f, 31.0f, 32.0f } } } );
163 }
164 
165 struct SimpleSplitVAxisThreeFixture : SplitVFixture
166 {
167  SimpleSplitVAxisThreeFixture()
168  : SplitVFixture( "[ 2, 2, 4, 2 ]", "[ 1, 0, 0, 0, 3, 0, 0, 0 ]", "[ 2 ]","[ ]", "2",
169  "[ 2, 2, 1, 2 ]", "[ 2, 2, 3, 2 ]", "[ 2, 0, 0, 0 ]", "FLOAT32")
170  {}
171 };
172 
173 BOOST_FIXTURE_TEST_CASE(ParseAxisThreeSplitVTwo, SimpleSplitVAxisThreeFixture)
174 {
175  RunTest<4, armnn::DataType::Float32>(
176  0,
177  { {"inputTensor", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f,
178  9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f,
179  17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f,
180  25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f } } },
181  { {"outputTensor1", { 1.0f, 2.0f, 9.0f, 10.0f, 17.0f, 18.0f, 25.0f, 26.0f } },
182  {"outputTensor2", { 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 11.0f, 12.0f,
183  13.0f, 14.0f, 15.0f, 16.0f, 19.0f, 20.0f, 21.0f, 22.0f,
184  23.0f, 24.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f } } } );
185 }
186 
187 struct SimpleSplitVAxisFourFixture : SplitVFixture
188 {
189  SimpleSplitVAxisFourFixture()
190  : SplitVFixture( "[ 2, 2, 2, 4 ]", "[ 3, 0, 0, 0, 1, 0, 0, 0 ]", "[ 2 ]","[ ]", "2",
191  "[ 2, 2, 2, 3 ]", "[ 2, 2, 2, 1 ]", "[ 3, 0, 0, 0 ]", "FLOAT32")
192  {}
193 };
194 
195 BOOST_FIXTURE_TEST_CASE(ParseAxisFourSplitVTwo, SimpleSplitVAxisFourFixture)
196 {
197  RunTest<4, armnn::DataType::Float32>(
198  0,
199  { {"inputTensor", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f,
200  9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f,
201  17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f,
202  25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f } } },
203  { {"outputTensor1", { 1.0f, 2.0f, 3.0f, 5.0f, 6.0f, 7.0f, 9.0f, 10.0f,
204  11.0f, 13.0f, 14.0f, 15.0f, 17.0f, 18.0f, 19.0f, 21.0f,
205  22.0f, 23.0f, 25.0f, 26.0f, 27.0f, 29.0f, 30.0f, 31.0f} },
206  {"outputTensor2", { 4.0f, 8.0f, 12.0f, 16.0f, 20.0f, 24.0f, 28.0f, 32.0f } } } );
207 }
208 
BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
BOOST_FIXTURE_TEST_CASE(ParseAxisOneSplitVTwo, SimpleSplitVAxisOneFixture)
Definition: SplitV.cpp:129
BOOST_AUTO_TEST_SUITE_END()