ArmNN
 21.08
DeserializeStridedSlice.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
8 
9 #include <string>
10 
11 TEST_SUITE("Deserializer_StridedSlice")
12 {
13 struct StridedSliceFixture : public ParserFlatbuffersSerializeFixture
14 {
15  explicit StridedSliceFixture(const std::string& inputShape,
16  const std::string& begin,
17  const std::string& end,
18  const std::string& stride,
19  const std::string& beginMask,
20  const std::string& endMask,
21  const std::string& shrinkAxisMask,
22  const std::string& ellipsisMask,
23  const std::string& newAxisMask,
24  const std::string& dataLayout,
25  const std::string& outputShape,
26  const std::string& dataType)
27  {
28  m_JsonString = R"(
29  {
30  inputIds: [0],
31  outputIds: [2],
32  layers: [
33  {
34  layer_type: "InputLayer",
35  layer: {
36  base: {
37  layerBindingId: 0,
38  base: {
39  index: 0,
40  layerName: "InputLayer",
41  layerType: "Input",
42  inputSlots: [{
43  index: 0,
44  connection: {sourceLayerIndex:0, outputSlotIndex:0 },
45  }],
46  outputSlots: [{
47  index: 0,
48  tensorInfo: {
49  dimensions: )" + inputShape + R"(,
50  dataType: )" + dataType + R"(
51  }
52  }]
53  }
54  }
55  }
56  },
57  {
58  layer_type: "StridedSliceLayer",
59  layer: {
60  base: {
61  index: 1,
62  layerName: "StridedSliceLayer",
63  layerType: "StridedSlice",
64  inputSlots: [{
65  index: 0,
66  connection: {sourceLayerIndex:0, outputSlotIndex:0 },
67  }],
68  outputSlots: [{
69  index: 0,
70  tensorInfo: {
71  dimensions: )" + outputShape + R"(,
72  dataType: )" + dataType + R"(
73  }
74  }]
75  },
76  descriptor: {
77  begin: )" + begin + R"(,
78  end: )" + end + R"(,
79  stride: )" + stride + R"(,
80  beginMask: )" + beginMask + R"(,
81  endMask: )" + endMask + R"(,
82  shrinkAxisMask: )" + shrinkAxisMask + R"(,
83  ellipsisMask: )" + ellipsisMask + R"(,
84  newAxisMask: )" + newAxisMask + R"(,
85  dataLayout: )" + dataLayout + R"(,
86  }
87  }
88  },
89  {
90  layer_type: "OutputLayer",
91  layer: {
92  base:{
93  layerBindingId: 2,
94  base: {
95  index: 2,
96  layerName: "OutputLayer",
97  layerType: "Output",
98  inputSlots: [{
99  index: 0,
100  connection: {sourceLayerIndex:1, outputSlotIndex:0 },
101  }],
102  outputSlots: [{
103  index: 0,
104  tensorInfo: {
105  dimensions: )" + outputShape + R"(,
106  dataType: )" + dataType + R"(
107  },
108  }],
109  }
110  }
111  },
112  }
113  ]
114  }
115  )";
116  SetupSingleInputSingleOutput("InputLayer", "OutputLayer");
117  }
118 };
119 
120 struct SimpleStridedSliceFixture : StridedSliceFixture
121 {
122  SimpleStridedSliceFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]",
123  "[ 0, 0, 0, 0 ]",
124  "[ 3, 2, 3, 1 ]",
125  "[ 2, 2, 2, 1 ]",
126  "0",
127  "0",
128  "0",
129  "0",
130  "0",
131  "NCHW",
132  "[ 2, 1, 2, 1 ]",
133  "Float32") {}
134 };
135 
136 TEST_CASE_FIXTURE(SimpleStridedSliceFixture, "SimpleStridedSliceFloat32")
137 {
138  RunTest<4, armnn::DataType::Float32>(0,
139  {
140  1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
141  3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
142  5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f
143  },
144  {
145  1.0f, 1.0f, 5.0f, 5.0f
146  });
147 }
148 
149 struct StridedSliceMaskFixture : StridedSliceFixture
150 {
151  StridedSliceMaskFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]",
152  "[ 1, 1, 1, 1 ]",
153  "[ 1, 1, 1, 1 ]",
154  "[ 1, 1, 1, 1 ]",
155  "15",
156  "15",
157  "0",
158  "0",
159  "0",
160  "NCHW",
161  "[ 3, 2, 3, 1 ]",
162  "Float32") {}
163 };
164 
165 TEST_CASE_FIXTURE(StridedSliceMaskFixture, "StridedSliceMaskFloat32")
166 {
167  RunTest<4, armnn::DataType::Float32>(0,
168  {
169  1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
170  3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
171  5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f
172  },
173  {
174  1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
175  3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
176  5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f
177  });
178 }
179 
180 }
void SetupSingleInputSingleOutput(const std::string &inputName, const std::string &outputName)
TEST_CASE_FIXTURE(ClContextControlFixture, "CopyBetweenNeonAndGpu")
TEST_SUITE("Deserializer_StridedSlice")