ArmNN
 21.02
StridedSlice.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2020 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
7 
10 
11 #include <boost/test/unit_test.hpp>
12 
13 BOOST_AUTO_TEST_SUITE(TensorflowParser)
14 
15 namespace {
16 // helper for setting the dimensions in prototxt
17 void shapeHelper(const armnn::TensorShape& shape, std::string& text){
18  for(unsigned int i = 0; i < shape.GetNumDimensions(); ++i) {
19  text.append(R"(dim {
20  size: )");
21  text.append(std::to_string(shape[i]));
22  text.append(R"(
23  })");
24  }
25 }
26 
27 // helper for converting from integer to octal representation
28 void octalHelper(const std::vector<int>& content, std::string& text){
29  for (unsigned int i = 0; i < content.size(); ++i)
30  {
31  text.append(armnnUtils::ConvertInt32ToOctalString(static_cast<int>(content[i])));
32  }
33 }
34 } // namespace
35 
36 struct StridedSliceFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
37 {
38  StridedSliceFixture(const armnn::TensorShape& inputShape,
39  const std::vector<int>& beginData,
40  const std::vector<int>& endData,
41  const std::vector<int>& stridesData,
42  int beginMask = 0,
43  int endMask = 0,
44  int ellipsisMask = 0,
45  int newAxisMask = 0,
46  int shrinkAxisMask = 0)
47  {
48  m_Prototext = R"(
49  node {
50  name: "input"
51  op: "Placeholder"
52  attr {
53  key: "dtype"
54  value {
55  type: DT_FLOAT
56  }
57  }
58  attr {
59  key: "shape"
60  value {
61  shape {)";
62  shapeHelper(inputShape, m_Prototext);
63  m_Prototext.append(R"(
64  }
65  }
66  }
67  }
68  node {
69  name: "begin"
70  op: "Const"
71  attr {
72  key: "dtype"
73  value {
74  type: DT_INT32
75  }
76  }
77  attr {
78  key: "value"
79  value {
80  tensor {
81  dtype: DT_INT32
82  tensor_shape {
83  dim {
84  size: )");
85  m_Prototext += std::to_string(beginData.size());
86  m_Prototext.append(R"(
87  }
88  }
89  tensor_content: ")");
90  octalHelper(beginData, m_Prototext);
91  m_Prototext.append(R"("
92  }
93  }
94  }
95  }
96  node {
97  name: "end"
98  op: "Const"
99  attr {
100  key: "dtype"
101  value {
102  type: DT_INT32
103  }
104  }
105  attr {
106  key: "value"
107  value {
108  tensor {
109  dtype: DT_INT32
110  tensor_shape {
111  dim {
112  size: )");
113  m_Prototext += std::to_string(endData.size());
114  m_Prototext.append(R"(
115  }
116  }
117  tensor_content: ")");
118  octalHelper(endData, m_Prototext);
119  m_Prototext.append(R"("
120  }
121  }
122  }
123  }
124  node {
125  name: "strides"
126  op: "Const"
127  attr {
128  key: "dtype"
129  value {
130  type: DT_INT32
131  }
132  }
133  attr {
134  key: "value"
135  value {
136  tensor {
137  dtype: DT_INT32
138  tensor_shape {
139  dim {
140  size: )");
141  m_Prototext += std::to_string(stridesData.size());
142  m_Prototext.append(R"(
143  }
144  }
145  tensor_content: ")");
146  octalHelper(stridesData, m_Prototext);
147  m_Prototext.append(R"("
148  }
149  }
150  }
151  }
152  node {
153  name: "output"
154  op: "StridedSlice"
155  input: "input"
156  input: "begin"
157  input: "end"
158  input: "strides"
159  attr {
160  key: "begin_mask"
161  value {
162  i: )");
163  m_Prototext += std::to_string(beginMask);
164  m_Prototext.append(R"(
165  }
166  }
167  attr {
168  key: "end_mask"
169  value {
170  i: )");
171  m_Prototext += std::to_string(endMask);
172  m_Prototext.append(R"(
173  }
174  }
175  attr {
176  key: "ellipsis_mask"
177  value {
178  i: )");
179  m_Prototext += std::to_string(ellipsisMask);
180  m_Prototext.append(R"(
181  }
182  }
183  attr {
184  key: "new_axis_mask"
185  value {
186  i: )");
187  m_Prototext += std::to_string(newAxisMask);
188  m_Prototext.append(R"(
189  }
190  }
191  attr {
192  key: "shrink_axis_mask"
193  value {
194  i: )");
195  m_Prototext += std::to_string(shrinkAxisMask);
196  m_Prototext.append(R"(
197  }
198  }
199  })");
200 
201  Setup({ { "input", inputShape } }, { "output" });
202  }
203 };
204 
205 struct StridedSlice4DFixture : StridedSliceFixture
206 {
207  StridedSlice4DFixture() : StridedSliceFixture({ 3, 2, 3, 1 }, // inputShape
208  { 1, 0, 0, 0 }, // beginData
209  { 2, 2, 3, 1 }, // endData
210  { 1, 1, 1, 1 } // stridesData
211  ) {}
212 };
213 
214 BOOST_FIXTURE_TEST_CASE(StridedSlice4D, StridedSlice4DFixture)
215 {
216  RunTest<4>(
217  {{"input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
218  3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
219  5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
220  {{"output", { 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f }}});
221 }
222 
223 struct StridedSlice4DReverseFixture : StridedSliceFixture
224 {
225 
226  StridedSlice4DReverseFixture() : StridedSliceFixture({ 3, 2, 3, 1 }, // inputShape
227  { 1, -1, 0, 0 }, // beginData
228  { 2, -3, 3, 1 }, // endData
229  { 1, -1, 1, 1 } // stridesData
230  ) {}
231 };
232 
233 BOOST_FIXTURE_TEST_CASE(StridedSlice4DReverse, StridedSlice4DReverseFixture)
234 {
235  RunTest<4>(
236  {{"input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
237  3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
238  5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
239  {{"output", { 4.0f, 4.0f, 4.0f, 3.0f, 3.0f, 3.0f }}});
240 }
241 
242 struct StridedSliceSimpleStrideFixture : StridedSliceFixture
243 {
244  StridedSliceSimpleStrideFixture() : StridedSliceFixture({ 3, 2, 3, 1 }, // inputShape
245  { 0, 0, 0, 0 }, // beginData
246  { 3, 2, 3, 1 }, // endData
247  { 2, 2, 2, 1 } // stridesData
248  ) {}
249 };
250 
251 BOOST_FIXTURE_TEST_CASE(StridedSliceSimpleStride, StridedSliceSimpleStrideFixture)
252 {
253  RunTest<4>(
254  {{"input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
255  3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
256  5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
257  {{"output", { 1.0f, 1.0f,
258  5.0f, 5.0f }}});
259 }
260 
261 struct StridedSliceSimpleRangeMaskFixture : StridedSliceFixture
262 {
263  StridedSliceSimpleRangeMaskFixture() : StridedSliceFixture({ 3, 2, 3, 1 }, // inputShape
264  { 1, 1, 1, 1 }, // beginData
265  { 1, 1, 1, 1 }, // endData
266  { 1, 1, 1, 1 }, // stridesData
267  (1 << 4) - 1, // beginMask
268  (1 << 4) - 1 // endMask
269  ) {}
270 };
271 
272 BOOST_FIXTURE_TEST_CASE(StridedSliceSimpleRangeMask, StridedSliceSimpleRangeMaskFixture)
273 {
274  RunTest<4>(
275  {{"input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
276  3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
277  5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
278  {{"output", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
279  3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
280  5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}});
281 }
282 
BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
BOOST_FIXTURE_TEST_CASE(StridedSlice4D, StridedSlice4DFixture)
std::string ConvertInt32ToOctalString(int value)
Converts an int value into the Prototxt octal representation.
BOOST_AUTO_TEST_SUITE_END()
unsigned int GetNumDimensions() const
Function that returns the tensor rank.
Definition: Tensor.cpp:174