6 #include <boost/test/unit_test.hpp> 8 #include "../TfLiteParser.hpp" 17 explicit SplitFixture(
const std::string& inputShape,
18 const std::string& axisShape,
19 const std::string& numSplits,
20 const std::string& outputShape1,
21 const std::string& outputShape2,
22 const std::string& axisData,
23 const std::string& dataType)
28 "operator_codes": [ { "builtin_code": "SPLIT" } ], 32 "shape": )" + inputShape + R"(, 33 "type": )" + dataType + R"(, 35 "name": "inputTensor", 44 "shape": )" + axisShape + R"(, 56 "shape": )" + outputShape1 + R"( , 57 "type":)" + dataType + R"(, 59 "name": "outputTensor1", 68 "shape": )" + outputShape2 + R"( , 69 "type":)" + dataType + R"(, 71 "name": "outputTensor2", 87 "builtin_options_type": "SplitOptions", 89 "num_splits": )" + numSplits + R"( 91 "custom_options_format": "FLEXBUFFERS" 95 "buffers" : [ {}, {"data": )" + axisData + R"( }, {}, {} ] 104 struct SimpleSplitFixtureFloat32 : SplitFixture
106 SimpleSplitFixtureFloat32()
107 : SplitFixture(
"[ 2, 2, 2, 2 ]",
"[ ]",
"2",
"[ 2, 1, 2, 2 ]",
"[ 2, 1, 2, 2 ]",
"[ 1, 0, 0, 0 ]",
"FLOAT32")
114 RunTest<4, armnn::DataType::Float32>(
116 { {
"inputTensor", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f,
117 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f } } },
118 { {
"outputTensor1", { 1.0f, 2.0f, 3.0f, 4.0f, 9.0f, 10.0f, 11.0f, 12.0f } },
119 {
"outputTensor2", { 5.0f, 6.0f, 7.0f, 8.0f, 13.0f, 14.0f, 15.0f, 16.0f } } });
122 struct SimpleSplitAxisThreeFixtureFloat32 : SplitFixture
124 SimpleSplitAxisThreeFixtureFloat32()
125 : SplitFixture(
"[ 2, 2, 2, 2 ]",
"[ ]",
"2",
"[ 2, 2, 2, 1 ]",
"[ 2, 2, 2, 1 ]",
"[ 3, 0, 0, 0 ]",
"FLOAT32")
131 RunTest<4, armnn::DataType::Float32>(
133 { {
"inputTensor", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f,
134 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f } } },
135 { {
"outputTensor1", { 1.0f, 3.0f, 5.0f, 7.0f, 9.0f, 11.0f, 13.0f, 15.0f } },
136 {
"outputTensor2", { 2.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f, 14.0f, 16.0f } } } );
139 struct SimpleSplit2DFixtureFloat32 : SplitFixture
141 SimpleSplit2DFixtureFloat32()
142 : SplitFixture(
"[ 1, 8 ]",
"[ ]",
"2",
"[ 1, 4 ]",
"[ 1, 4 ]",
"[ 1, 0, 0, 0 ]",
"FLOAT32")
148 RunTest<2, armnn::DataType::Float32>(
150 { {
"inputTensor", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f } } },
151 { {
"outputTensor1", { 1.0f, 2.0f, 3.0f, 4.0f } },
152 {
"outputTensor2", { 5.0f, 6.0f, 7.0f, 8.0f } } } );
155 struct SimpleSplit3DFixtureFloat32 : SplitFixture
157 SimpleSplit3DFixtureFloat32()
158 : SplitFixture(
"[ 1, 8, 2 ]",
"[ ]",
"2",
"[ 1, 4, 2 ]",
"[ 1, 4, 2 ]",
"[ 1, 0, 0, 0 ]",
"FLOAT32")
164 RunTest<3, armnn::DataType::Float32>(
166 { {
"inputTensor", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f,
167 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f } } },
168 { {
"outputTensor1", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f } },
169 {
"outputTensor2", { 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f } } } );
172 struct SimpleSplitFixtureUint8 : SplitFixture
174 SimpleSplitFixtureUint8()
175 : SplitFixture(
"[ 2, 2, 2, 2 ]",
"[ ]",
"2",
"[ 2, 1, 2, 2 ]",
"[ 2, 1, 2, 2 ]",
"[ 1, 0, 0, 0 ]",
"UINT8")
182 RunTest<4, armnn::DataType::QAsymmU8>(
184 { {
"inputTensor", { 1, 2, 3, 4, 5, 6, 7, 8,
185 9, 10, 11, 12, 13, 14, 15, 16 } } },
186 { {
"outputTensor1", { 1, 2, 3, 4, 9, 10, 11, 12 } },
187 {
"outputTensor2", { 5, 6, 7, 8, 13, 14, 15, 16 } } });
190 struct SimpleSplitAxisThreeFixtureUint8 : SplitFixture
192 SimpleSplitAxisThreeFixtureUint8()
193 : SplitFixture(
"[ 2, 2, 2, 2 ]",
"[ ]",
"2",
"[ 2, 2, 2, 1 ]",
"[ 2, 2, 2, 1 ]",
"[ 3, 0, 0, 0 ]",
"UINT8")
199 RunTest<4, armnn::DataType::QAsymmU8>(
201 { {
"inputTensor", { 1, 2, 3, 4, 5, 6, 7, 8,
202 9, 10, 11, 12, 13, 14, 15, 16 } } },
203 { {
"outputTensor1", { 1, 3, 5, 7, 9, 11, 13, 15 } },
204 {
"outputTensor2", { 2, 4, 6, 8, 10, 12, 14, 16 } } } );
207 struct SimpleSplit2DFixtureUint8 : SplitFixture
209 SimpleSplit2DFixtureUint8()
210 : SplitFixture(
"[ 1, 8 ]",
"[ ]",
"2",
"[ 1, 4 ]",
"[ 1, 4 ]",
"[ 1, 0, 0, 0 ]",
"UINT8")
216 RunTest<2, armnn::DataType::QAsymmU8>(
218 { {
"inputTensor", { 1, 2, 3, 4, 5, 6, 7, 8 } } },
219 { {
"outputTensor1", { 1, 2, 3, 4 } },
220 {
"outputTensor2", { 5, 6, 7, 8 } } } );
223 struct SimpleSplit3DFixtureUint8 : SplitFixture
225 SimpleSplit3DFixtureUint8()
226 : SplitFixture(
"[ 1, 8, 2 ]",
"[ ]",
"2",
"[ 1, 4, 2 ]",
"[ 1, 4, 2 ]",
"[ 1, 0, 0, 0 ]",
"UINT8")
232 RunTest<3, armnn::DataType::QAsymmU8>(
234 { {
"inputTensor", { 1, 2, 3, 4, 5, 6, 7, 8,
235 9, 10, 11, 12, 13, 14, 15, 16 } } },
236 { {
"outputTensor1", { 1, 2, 3, 4, 5, 6, 7, 8 } },
237 {
"outputTensor2", { 9, 10, 11, 12, 13, 14, 15, 16 } } } );
BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
BOOST_AUTO_TEST_SUITE_END()
BOOST_FIXTURE_TEST_CASE(ParseAxisOneSplitTwoFloat32, SimpleSplitFixtureFloat32)