ArmNN
 21.02
StridedSlice.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <boost/test/unit_test.hpp>
8 #include "../TfLiteParser.hpp"
9 
10 #include <string>
11 
12 BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
13 
14 struct StridedSliceFixture : public ParserFlatbuffersFixture
15 {
16  explicit StridedSliceFixture(const std::string & inputShape,
17  const std::string & outputShape,
18  const std::string & beginData,
19  const std::string & endData,
20  const std::string & stridesData,
21  int beginMask = 0,
22  int endMask = 0)
23  {
24  m_JsonString = R"(
25  {
26  "version": 3,
27  "operator_codes": [ { "builtin_code": "STRIDED_SLICE" } ],
28  "subgraphs": [ {
29  "tensors": [
30  {
31  "shape": )" + inputShape + R"(,
32  "type": "FLOAT32",
33  "buffer": 0,
34  "name": "inputTensor",
35  "quantization": {
36  "min": [ 0.0 ],
37  "max": [ 255.0 ],
38  "scale": [ 1.0 ],
39  "zero_point": [ 0 ],
40  }
41  },
42  {
43  "shape": [ 4 ],
44  "type": "INT32",
45  "buffer": 1,
46  "name": "beginTensor",
47  "quantization": {
48  }
49  },
50  {
51  "shape": [ 4 ],
52  "type": "INT32",
53  "buffer": 2,
54  "name": "endTensor",
55  "quantization": {
56  }
57  },
58  {
59  "shape": [ 4 ],
60  "type": "INT32",
61  "buffer": 3,
62  "name": "stridesTensor",
63  "quantization": {
64  }
65  },
66  {
67  "shape": )" + outputShape + R"( ,
68  "type": "FLOAT32",
69  "buffer": 4,
70  "name": "outputTensor",
71  "quantization": {
72  "min": [ 0.0 ],
73  "max": [ 255.0 ],
74  "scale": [ 1.0 ],
75  "zero_point": [ 0 ],
76  }
77  }
78  ],
79  "inputs": [ 0, 1, 2, 3 ],
80  "outputs": [ 4 ],
81  "operators": [
82  {
83  "opcode_index": 0,
84  "inputs": [ 0, 1, 2, 3 ],
85  "outputs": [ 4 ],
86  "builtin_options_type": "StridedSliceOptions",
87  "builtin_options": {
88  "begin_mask": )" + std::to_string(beginMask) + R"(,
89  "end_mask": )" + std::to_string(endMask) + R"(
90  },
91  "custom_options_format": "FLEXBUFFERS"
92  }
93  ],
94  } ],
95  "buffers" : [
96  { },
97  { "data": )" + beginData + R"(, },
98  { "data": )" + endData + R"(, },
99  { "data": )" + stridesData + R"(, },
100  { }
101  ]
102  }
103  )";
104  Setup();
105  }
106 };
107 
108 struct StridedSlice4DFixture : StridedSliceFixture
109 {
110  StridedSlice4DFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]", // inputShape
111  "[ 1, 2, 3, 1 ]", // outputShape
112  "[ 1,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0 ]", // beginData
113  "[ 2,0,0,0, 2,0,0,0, 3,0,0,0, 1,0,0,0 ]", // endData
114  "[ 1,0,0,0, 1,0,0,0, 1,0,0,0, 1,0,0,0 ]" // stridesData
115  ) {}
116 };
117 
118 BOOST_FIXTURE_TEST_CASE(StridedSlice4D, StridedSlice4DFixture)
119 {
120  RunTest<4, armnn::DataType::Float32>(
121  0,
122  {{"inputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
123 
124  3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
125 
126  5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
127 
128  {{"outputTensor", { 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f }}});
129 }
130 
131 struct StridedSlice4DReverseFixture : StridedSliceFixture
132 {
133  StridedSlice4DReverseFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]", // inputShape
134  "[ 1, 2, 3, 1 ]", // outputShape
135  "[ 1,0,0,0, "
136  "255,255,255,255, "
137  "0,0,0,0, "
138  "0,0,0,0 ]", // beginData [ 1 -1 0 0 ]
139  "[ 2,0,0,0, "
140  "253,255,255,255, "
141  "3,0,0,0, "
142  "1,0,0,0 ]", // endData [ 2 -3 3 1 ]
143  "[ 1,0,0,0, "
144  "255,255,255,255, "
145  "1,0,0,0, "
146  "1,0,0,0 ]" // stridesData [ 1 -1 1 1 ]
147  ) {}
148 };
149 
150 BOOST_FIXTURE_TEST_CASE(StridedSlice4DReverse, StridedSlice4DReverseFixture)
151 {
152  RunTest<4, armnn::DataType::Float32>(
153  0,
154  {{"inputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
155 
156  3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
157 
158  5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
159 
160  {{"outputTensor", { 4.0f, 4.0f, 4.0f, 3.0f, 3.0f, 3.0f }}});
161 }
162 
163 struct StridedSliceSimpleStrideFixture : StridedSliceFixture
164 {
165  StridedSliceSimpleStrideFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]", // inputShape
166  "[ 2, 1, 2, 1 ]", // outputShape
167  "[ 0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0 ]", // beginData
168  "[ 3,0,0,0, 2,0,0,0, 3,0,0,0, 1,0,0,0 ]", // endData
169  "[ 2,0,0,0, 2,0,0,0, 2,0,0,0, 1,0,0,0 ]" // stridesData
170  ) {}
171 };
172 
173 BOOST_FIXTURE_TEST_CASE(StridedSliceSimpleStride, StridedSliceSimpleStrideFixture)
174 {
175  RunTest<4, armnn::DataType::Float32>(
176  0,
177  {{"inputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
178 
179  3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
180 
181  5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
182 
183  {{"outputTensor", { 1.0f, 1.0f,
184 
185  5.0f, 5.0f }}});
186 }
187 
188 struct StridedSliceSimpleRangeMaskFixture : StridedSliceFixture
189 {
190  StridedSliceSimpleRangeMaskFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]", // inputShape
191  "[ 3, 2, 3, 1 ]", // outputShape
192  "[ 1,0,0,0, 1,0,0,0, 1,0,0,0, 1,0,0,0 ]", // beginData
193  "[ 1,0,0,0, 1,0,0,0, 1,0,0,0, 1,0,0,0 ]", // endData
194  "[ 1,0,0,0, 1,0,0,0, 1,0,0,0, 1,0,0,0 ]", // stridesData
195  (1 << 4) - 1, // beginMask
196  (1 << 4) - 1 // endMask
197  ) {}
198 };
199 
200 BOOST_FIXTURE_TEST_CASE(StridedSliceSimpleRangeMask, StridedSliceSimpleRangeMaskFixture)
201 {
202  RunTest<4, armnn::DataType::Float32>(
203  0,
204  {{"inputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
205 
206  3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
207 
208  5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
209 
210  {{"outputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
211 
212  3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
213 
214  5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}});
215 }
216 
BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
BOOST_FIXTURE_TEST_CASE(StridedSlice4D, StridedSlice4DFixture)
BOOST_AUTO_TEST_SUITE_END()