ArmNN
 21.02
DeserializeStridedSlice.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <boost/test/unit_test.hpp>
9 
10 #include <string>
11 
12 BOOST_AUTO_TEST_SUITE(Deserializer)
13 
14 struct StridedSliceFixture : public ParserFlatbuffersSerializeFixture
15 {
16  explicit StridedSliceFixture(const std::string& inputShape,
17  const std::string& begin,
18  const std::string& end,
19  const std::string& stride,
20  const std::string& beginMask,
21  const std::string& endMask,
22  const std::string& shrinkAxisMask,
23  const std::string& ellipsisMask,
24  const std::string& newAxisMask,
25  const std::string& dataLayout,
26  const std::string& outputShape,
27  const std::string& dataType)
28  {
29  m_JsonString = R"(
30  {
31  inputIds: [0],
32  outputIds: [2],
33  layers: [
34  {
35  layer_type: "InputLayer",
36  layer: {
37  base: {
38  layerBindingId: 0,
39  base: {
40  index: 0,
41  layerName: "InputLayer",
42  layerType: "Input",
43  inputSlots: [{
44  index: 0,
45  connection: {sourceLayerIndex:0, outputSlotIndex:0 },
46  }],
47  outputSlots: [{
48  index: 0,
49  tensorInfo: {
50  dimensions: )" + inputShape + R"(,
51  dataType: )" + dataType + R"(
52  }
53  }]
54  }
55  }
56  }
57  },
58  {
59  layer_type: "StridedSliceLayer",
60  layer: {
61  base: {
62  index: 1,
63  layerName: "StridedSliceLayer",
64  layerType: "StridedSlice",
65  inputSlots: [{
66  index: 0,
67  connection: {sourceLayerIndex:0, outputSlotIndex:0 },
68  }],
69  outputSlots: [{
70  index: 0,
71  tensorInfo: {
72  dimensions: )" + outputShape + R"(,
73  dataType: )" + dataType + R"(
74  }
75  }]
76  },
77  descriptor: {
78  begin: )" + begin + R"(,
79  end: )" + end + R"(,
80  stride: )" + stride + R"(,
81  beginMask: )" + beginMask + R"(,
82  endMask: )" + endMask + R"(,
83  shrinkAxisMask: )" + shrinkAxisMask + R"(,
84  ellipsisMask: )" + ellipsisMask + R"(,
85  newAxisMask: )" + newAxisMask + R"(,
86  dataLayout: )" + dataLayout + R"(,
87  }
88  }
89  },
90  {
91  layer_type: "OutputLayer",
92  layer: {
93  base:{
94  layerBindingId: 2,
95  base: {
96  index: 2,
97  layerName: "OutputLayer",
98  layerType: "Output",
99  inputSlots: [{
100  index: 0,
101  connection: {sourceLayerIndex:1, outputSlotIndex:0 },
102  }],
103  outputSlots: [{
104  index: 0,
105  tensorInfo: {
106  dimensions: )" + outputShape + R"(,
107  dataType: )" + dataType + R"(
108  },
109  }],
110  }
111  }
112  },
113  }
114  ]
115  }
116  )";
117  SetupSingleInputSingleOutput("InputLayer", "OutputLayer");
118  }
119 };
120 
121 struct SimpleStridedSliceFixture : StridedSliceFixture
122 {
123  SimpleStridedSliceFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]",
124  "[ 0, 0, 0, 0 ]",
125  "[ 3, 2, 3, 1 ]",
126  "[ 2, 2, 2, 1 ]",
127  "0",
128  "0",
129  "0",
130  "0",
131  "0",
132  "NCHW",
133  "[ 2, 1, 2, 1 ]",
134  "Float32") {}
135 };
136 
137 BOOST_FIXTURE_TEST_CASE(SimpleStridedSliceFloat32, SimpleStridedSliceFixture)
138 {
139  RunTest<4, armnn::DataType::Float32>(0,
140  {
141  1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
142  3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
143  5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f
144  },
145  {
146  1.0f, 1.0f, 5.0f, 5.0f
147  });
148 }
149 
150 struct StridedSliceMaskFixture : StridedSliceFixture
151 {
152  StridedSliceMaskFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]",
153  "[ 1, 1, 1, 1 ]",
154  "[ 1, 1, 1, 1 ]",
155  "[ 1, 1, 1, 1 ]",
156  "15",
157  "15",
158  "0",
159  "0",
160  "0",
161  "NCHW",
162  "[ 3, 2, 3, 1 ]",
163  "Float32") {}
164 };
165 
166 BOOST_FIXTURE_TEST_CASE(StridedSliceMaskFloat32, StridedSliceMaskFixture)
167 {
168  RunTest<4, armnn::DataType::Float32>(0,
169  {
170  1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
171  3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
172  5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f
173  },
174  {
175  1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
176  3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
177  5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f
178  });
179 }
180 
BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
void SetupSingleInputSingleOutput(const std::string &inputName, const std::string &outputName)
BOOST_AUTO_TEST_SUITE_END()
BOOST_FIXTURE_TEST_CASE(SimpleStridedSliceFloat32, SimpleStridedSliceFixture)