ArmNN
 22.05.01
Gather.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
9 
10 TEST_SUITE("OnnxParser_Gather")
11 {
12 
13 struct GatherMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
14 {
15  GatherMainFixture(const std::vector<int>& indicesShape,
16  const std::vector<int>& indices,
17  const std::vector<int>& inputShape,
18  const std::vector<int>& outputShape)
19  {
20  m_Prototext = R"(
21  ir_version: 8
22  producer_name: "onnx-example"
23  graph {
24  node {
25  output: "indices"
26  op_type: "Constant"
27  attribute {
28  name: "value"
29  t {
30  data_type: 7
31  )" + ConstructIndicesString(indicesShape, indices) + R"(
32  name: "value"
33  }
34  type: TENSOR
35  }
36  }
37  node {
38  input: "input"
39  input: "indices"
40  output: "output"
41  op_type: "Gather"
42  attribute {
43  name: "axis"
44  i: 0
45  type: INT
46  }
47  }
48  name: "gather-model"
49  input {
50  name: "input"
51  type {
52  tensor_type {
53  elem_type: 1
54  shape {
55  )" + armnnUtils::ConstructTensorShapeString(inputShape) + R"(
56  }
57  }
58  }
59  }
60  output {
61  name: "output"
62  type {
63  tensor_type {
64  elem_type: 1
65  shape {
66  )" + armnnUtils::ConstructTensorShapeString(outputShape) + R"(
67  }
68  }
69  }
70  }
71  })";
72  }
73  std::string ConstructIndicesString(const std::vector<int>& indicesShape, const std::vector<int>& indices)
74  {
75  std::string shapeStr;
76  for (int i : indicesShape)
77  {
78  shapeStr = fmt::format(" {} dims: {}", shapeStr, i);
79  }
80  for (int i : indices)
81  {
82  shapeStr = fmt::format(" {} int64_data: {}", shapeStr, i);
83  }
84  return shapeStr;
85  }
86 };
87 
88 struct GatherScalarFixture : GatherMainFixture
89 {
90  GatherScalarFixture() : GatherMainFixture({ }, { 0 }, { 8 }, { })
91  {
92  Setup();
93  }
94 };
95 
96 struct Gather1dFixture : GatherMainFixture
97 {
98  Gather1dFixture() : GatherMainFixture({ 4 }, { 0, 2, 1, 5 }, { 8 }, { 4 })
99  {
100  Setup();
101  }
102 };
103 
104 struct Gather2dFixture : GatherMainFixture
105 {
106  Gather2dFixture() : GatherMainFixture({ 3 }, { 1, 3, 4 }, { 5, 2 }, { 3, 2 })
107  {
108  Setup();
109  }
110 };
111 
112 struct Gather3dMultiIndicesFixture : GatherMainFixture
113 {
114  Gather3dMultiIndicesFixture() : GatherMainFixture({ 2, 3 }, { 1, 2, 1, 2, 1, 0 }, { 3, 2, 3 }, { 2, 3, 2, 3 })
115  {
116  Setup();
117  }
118 };
119 
120 struct Gather4dFixture : GatherMainFixture
121 {
122  Gather4dFixture() : GatherMainFixture({ 3 }, { 0, 1, 3 }, { 5, 4, 3, 2 }, { 3, 4, 3, 2 })
123  {
124  Setup();
125  }
126 };
127 
128 TEST_CASE_FIXTURE(GatherScalarFixture, "GatherScalarTest")
129 {
130  RunTest<1, float>({{"input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f }}},
131  {{"output", { 1.0f }}});
132 }
133 
134 TEST_CASE_FIXTURE(Gather1dFixture, "Gather1dTest")
135 {
136  RunTest<1, float>({{"input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f }}},
137  {{"output", { 1.0f, 3.0f, 2.0f, 6.0f }}});
138 }
139 
140 TEST_CASE_FIXTURE(Gather2dFixture, "Gather2dTest")
141 {
142  RunTest<2, float>({{"input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f }}},
143  {{"output", { 3.0f, 4.0f, 7.0f, 8.0f, 9.0f, 10.0f }}});
144 }
145 
146 TEST_CASE_FIXTURE(Gather3dMultiIndicesFixture, "Gather3dMultiIndicesTest")
147 {
148  RunTest<3, float>({{"input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
149  7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f,
150  13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f }}},
151  {{"output", { 7.0f, 8.0f, 9.0f,
152  10.0f, 11.0f, 12.0f,
153  13.0f, 14.0f, 15.0f,
154  16.0f, 17.0f, 18.0f,
155  7.0f, 8.0f, 9.0f,
156  10.0f, 11.0f, 12.0f,
157  13.0f, 14.0f, 15.0f,
158  16.0f, 17.0f, 18.0f,
159  7.0f, 8.0f, 9.0f,
160  10.0f, 11.0f, 12.0f,
161  1.0f, 2.0f, 3.0f,
162  4.0f, 5.0f, 6.0f }}});
163 }
164 
165 TEST_CASE_FIXTURE(Gather4dFixture, "Gather4dTest")
166 {
167  RunTest<4, float>({{"input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
168  6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
169  11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
170  16.0f, 17.0f, 18.0f, 19.0f, 20.0f,
171  21.0f, 22.0f, 23.0f, 24.0f, 25.0f,
172  26.0f, 27.0f, 28.0f, 29.0f, 30.0f,
173  31.0f, 32.0f, 33.0f, 34.0f, 35.0f,
174  36.0f, 37.0f, 38.0f, 39.0f, 40.0f,
175  41.0f, 42.0f, 43.0f, 44.0f, 45.0f,
176  46.0f, 47.0f, 48.0f, 49.0f, 50.0f,
177  51.0f, 52.0f, 53.0f, 54.0f, 55.0f,
178  56.0f, 57.0f, 58.0f, 59.0f, 60.0f,
179  61.0f, 62.0f, 63.0f, 64.0f, 65.0f,
180  66.0f, 67.0f, 68.0f, 69.0f, 70.0f,
181  71.0f, 72.0f, 73.0f, 74.0f, 75.0f,
182  76.0f, 77.0f, 78.0f, 79.0f, 80.0f,
183  81.0f, 82.0f, 83.0f, 84.0f, 85.0f,
184  86.0f, 87.0f, 88.0f, 89.0f, 90.0f,
185  91.0f, 92.0f, 93.0f, 94.0f, 95.0f,
186  96.0f, 97.0f, 98.0f, 99.0f, 100.0f,
187  101.0f, 102.0f, 103.0f, 104.0f, 105.0f,
188  106.0f, 107.0f, 108.0f, 109.0f, 110.0f,
189  111.0f, 112.0f, 113.0f, 114.0f, 115.0f,
190  116.0f, 117.0f, 118.0f, 119.0f, 120.0f }}},
191  {{"output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
192  7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f,
193  13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f,
194  19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f,
195  25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f,
196  31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f,
197  37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f,
198  43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f,
199  73.0f, 74.0f, 75.0f, 76.0f, 77.0f, 78.0f,
200  79.0f, 80.0f, 81.0f, 82.0f, 83.0f, 84.0f,
201  85.0f, 86.0f, 87.0f, 88.0f, 89.0f, 90.0f,
202  91.0f, 92.0f, 93.0f, 94.0f, 95.0f, 96.0f }}});
203 }
204 
205 struct GatherRawDataFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
206 {
207  GatherRawDataFixture()
208  {
209  m_Prototext = R"(
210  ir_version: 8
211  producer_name: "onnx-example"
212  graph {
213  node {
214  output: "indices"
215  op_type: "Constant"
216  attribute {
217  name: "value"
218  t {
219  dims: 3
220  data_type: 7
221  raw_data:
222  "\000\000\000\000\000\000\000\000\001\000\000\000\000\000\000\000\003\000\000\000\000\000\000\000"
223  name: "value"
224  }
225  type: TENSOR
226  }
227  }
228  node {
229  input: "input"
230  input: "indices"
231  output: "output"
232  op_type: "Gather"
233  attribute {
234  name: "axis"
235  i: 0
236  type: INT
237  }
238  }
239  name: "gather-model"
240  input {
241  name: "input"
242  type {
243  tensor_type {
244  elem_type: 1
245  shape {
246  dim {
247  dim_value: 5
248  }
249  dim {
250  dim_value: 4
251  }
252  dim {
253  dim_value: 3
254  }
255  dim {
256  dim_value: 2
257  }
258  }
259  }
260  }
261  }
262  output {
263  name: "output"
264  type {
265  tensor_type {
266  elem_type: 1
267  shape {
268  dim {
269  dim_value: 3
270  }
271  dim {
272  dim_value: 4
273  }
274  dim {
275  dim_value: 3
276  }
277  dim {
278  dim_value: 2
279  }
280  }
281  }
282  }
283  }
284  })";
285  Setup();
286  }
287 };
288 
289 TEST_CASE_FIXTURE(GatherRawDataFixture, "GatherRawDataTest")
290 {
291  RunTest<4, float>({{"input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
292  6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
293  11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
294  16.0f, 17.0f, 18.0f, 19.0f, 20.0f,
295  21.0f, 22.0f, 23.0f, 24.0f, 25.0f,
296  26.0f, 27.0f, 28.0f, 29.0f, 30.0f,
297  31.0f, 32.0f, 33.0f, 34.0f, 35.0f,
298  36.0f, 37.0f, 38.0f, 39.0f, 40.0f,
299  41.0f, 42.0f, 43.0f, 44.0f, 45.0f,
300  46.0f, 47.0f, 48.0f, 49.0f, 50.0f,
301  51.0f, 52.0f, 53.0f, 54.0f, 55.0f,
302  56.0f, 57.0f, 58.0f, 59.0f, 60.0f,
303  61.0f, 62.0f, 63.0f, 64.0f, 65.0f,
304  66.0f, 67.0f, 68.0f, 69.0f, 70.0f,
305  71.0f, 72.0f, 73.0f, 74.0f, 75.0f,
306  76.0f, 77.0f, 78.0f, 79.0f, 80.0f,
307  81.0f, 82.0f, 83.0f, 84.0f, 85.0f,
308  86.0f, 87.0f, 88.0f, 89.0f, 90.0f,
309  91.0f, 92.0f, 93.0f, 94.0f, 95.0f,
310  96.0f, 97.0f, 98.0f, 99.0f, 100.0f,
311  101.0f, 102.0f, 103.0f, 104.0f, 105.0f,
312  106.0f, 107.0f, 108.0f, 109.0f, 110.0f,
313  111.0f, 112.0f, 113.0f, 114.0f, 115.0f,
314  116.0f, 117.0f, 118.0f, 119.0f, 120.0f }}},
315  {{"output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
316  7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f,
317  13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f,
318  19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f,
319  25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f,
320  31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f,
321  37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f,
322  43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f,
323  73.0f, 74.0f, 75.0f, 76.0f, 77.0f, 78.0f,
324  79.0f, 80.0f, 81.0f, 82.0f, 83.0f, 84.0f,
325  85.0f, 86.0f, 87.0f, 88.0f, 89.0f, 90.0f,
326  91.0f, 92.0f, 93.0f, 94.0f, 95.0f, 96.0f }}});
327 }
328 
329 }
std::string ConstructTensorShapeString(const std::vector< int > &shape)
TEST_SUITE("OnnxParser_Gather")
Definition: Gather.cpp:10
TEST_CASE_FIXTURE(ClContextControlFixture, "CopyBetweenNeonAndGpu")