ArmNN
 22.05.01
Concat.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
9 
10 TEST_SUITE("OnnxParser_Concat")
11 {
12 
13 struct ConcatFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
14 {
15  ConcatFixture(const std::string& axis,
16  const std::vector<int>& input0Shape,
17  const std::vector<int>& input1Shape,
18  const std::vector<int>& outputShape)
19  {
20  m_Prototext = R"(
21  ir_version: 8
22  producer_name: "onnx-example"
23  graph {
24  node {
25  input: "Input0"
26  input: "Input1"
27  output: "Output"
28  op_type: "Concat"
29  attribute {
30  name: "axis"
31  i: )" + axis + R"(
32  type: INT
33  }
34  }
35  name: "concat-model"
36  input {
37  name: "Input0"
38  type {
39  tensor_type {
40  elem_type: 1
41  shape {
42  )" + armnnUtils::ConstructTensorShapeString(input0Shape) + R"(
43  }
44  }
45  }
46  }
47  input {
48  name: "Input1"
49  type {
50  tensor_type {
51  elem_type: 1
52  shape {
53  )" + armnnUtils::ConstructTensorShapeString(input1Shape) + R"(
54  }
55  }
56  }
57  }
58  output {
59  name: "Output"
60  type {
61  tensor_type {
62  elem_type: 1
63  shape {
64  )" + armnnUtils::ConstructTensorShapeString(outputShape) + R"(
65  }
66  }
67  }
68  }
69  })";
70  Setup();
71  }
72 };
73 
74 struct ConcatAxis0Fixture : ConcatFixture
75 {
76  ConcatAxis0Fixture() : ConcatFixture("0", { 1, 3, 2, 5 }, { 1, 3, 2, 5 }, { 2, 3, 2, 5 }) {}
77 };
78 
79 struct ConcatAxis1Fixture : ConcatFixture
80 {
81  ConcatAxis1Fixture() : ConcatFixture("1", { 2, 2, 1, 3 }, { 2, 1, 1, 3 }, { 2, 3, 1, 3 }) {}
82 };
83 
84 struct ConcatAxis2Fixture : ConcatFixture
85 {
86  ConcatAxis2Fixture() : ConcatFixture("2", { 2, 3, 1, 1 }, { 2, 3, 2, 1 }, { 2, 3, 3, 1 }) {}
87 };
88 
89 struct ConcatAxis3Fixture : ConcatFixture
90 {
91  ConcatAxis3Fixture() : ConcatFixture("3", { 1, 3, 2, 2 }, { 1, 3, 2, 2 }, { 1, 3, 2, 4 }) {}
92 };
93 
94 struct ConcatNegativeAxisFixture : ConcatFixture
95 {
96  ConcatNegativeAxisFixture() : ConcatFixture("-1", { 1, 2, 5 }, { 1, 2, 3 }, { 1, 2, 8 }) {}
97 };
98 
99 TEST_CASE_FIXTURE(ConcatAxis0Fixture, "ConcatAxis0Test")
100 {
101  RunTest<4, float>({{"Input0", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
102  6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
103  11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
104  16.0f, 17.0f, 18.0f, 19.0f, 20.0f,
105  21.0f, 22.0f, 23.0f, 24.0f, 25.0f,
106  26.0f, 27.0f, 28.0f, 29.0f, 30.0f }},
107  {"Input1", { 31.0f, 32.0f, 33.0f, 34.0f, 35.0f,
108  36.0f, 37.0f, 38.0f, 39.0f, 40.0f,
109  41.0f, 42.0f, 43.0f, 44.0f, 45.0f,
110  46.0f, 47.0f, 48.0f, 49.0f, 50.0f,
111  51.0f, 52.0f, 53.0f, 54.0f, 55.0f,
112  56.0f, 57.0f, 58.0f, 59.0f, 60.0f }}},
113  {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
114  6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
115  11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
116  16.0f, 17.0f, 18.0f, 19.0f, 20.0f,
117  21.0f, 22.0f, 23.0f, 24.0f, 25.0f,
118  26.0f, 27.0f, 28.0f, 29.0f, 30.0f,
119  31.0f, 32.0f, 33.0f, 34.0f, 35.0f,
120  36.0f, 37.0f, 38.0f, 39.0f, 40.0f,
121  41.0f, 42.0f, 43.0f, 44.0f, 45.0f,
122  46.0f, 47.0f, 48.0f, 49.0f, 50.0f,
123  51.0f, 52.0f, 53.0f, 54.0f, 55.0f,
124  56.0f, 57.0f, 58.0f, 59.0f, 60.0f }}});
125 }
126 
127 TEST_CASE_FIXTURE(ConcatAxis1Fixture, "ConcatAxis1est")
128 {
129  RunTest<4, float>({{"Input0", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }},
130  {"Input1", { 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f }}},
131  {{"Output", { 1.0f, 2.0f, 3.0f,
132  4.0f, 5.0f, 6.0f,
133  13.0f, 14.0f, 15.0f,
134  7.0f, 8.0f, 9.0f,
135  10.0f, 11.0f, 12.0f,
136  16.0f, 17.0f, 18.0f }}});
137 }
138 
139 TEST_CASE_FIXTURE(ConcatAxis2Fixture, "ConcatAxis2Test")
140 {
141  RunTest<4, float>({{"Input0", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }},
142  {"Input1", { 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f }}},
143  {{"Output", { 1.0f, 7.0f, 8.0f,
144  2.0f, 9.0f, 10.0f,
145  3.0f, 11.0f, 12.0f,
146  4.0f, 13.0f, 14.0f,
147  5.0f, 15.0f, 16.0f,
148  6.0f, 17.0f, 18.0f }}});
149 }
150 
151 TEST_CASE_FIXTURE(ConcatAxis3Fixture, "ConcatAxis3Test")
152 {
153  RunTest<4, float>({{"Input0", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
154  7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }},
155  {"Input1", { 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f,
156  19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f }}},
157  {{"Output", { 1.0f, 2.0f, 13.0f, 14.0f,
158  3.0f, 4.0f, 15.0f, 16.0f,
159  5.0f, 6.0f, 17.0f, 18.0f,
160  7.0f, 8.0f, 19.0f, 20.0f,
161  9.0f, 10.0f, 21.0f, 22.0f,
162  11.0f, 12.0f, 23.0f, 24.0f }}});
163 }
164 
165 TEST_CASE_FIXTURE(ConcatNegativeAxisFixture, "ConcatNegativeAxisTest")
166 {
167  RunTest<3, float>({{"Input0", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
168  6.0f, 7.0f, 8.0f, 9.0f, 10.0f }},
169  {"Input1", { 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f }}},
170  {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 11.0f, 12.0f, 13.0f,
171  6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 14.0f, 15.0f, 16.0f }}});
172 }
173 
174 struct ConcatMultipleInputsFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
175 {
176  ConcatMultipleInputsFixture()
177  {
178  m_Prototext = R"(
179  ir_version: 8
180  producer_name: "onnx-example"
181  graph {
182  node {
183  input: "Input0"
184  input: "Input1"
185  input: "Input2"
186  output: "Output"
187  op_type: "Concat"
188  attribute {
189  name: "axis"
190  i: 1
191  type: INT
192  }
193  }
194  name: "concat-model"
195  input {
196  name: "Input0"
197  type {
198  tensor_type {
199  elem_type: 1
200  shape {
201  dim {
202  dim_value: 3
203  }
204  dim {
205  dim_value: 2
206  }
207  }
208  }
209  }
210  }
211  input {
212  name: "Input1"
213  type {
214  tensor_type {
215  elem_type: 1
216  shape {
217  dim {
218  dim_value: 3
219  }
220  dim {
221  dim_value: 3
222  }
223  }
224  }
225  }
226  }
227  input {
228  name: "Input2"
229  type {
230  tensor_type {
231  elem_type: 1
232  shape {
233  dim {
234  dim_value: 3
235  }
236  dim {
237  dim_value: 1
238  }
239  }
240  }
241  }
242  }
243  output {
244  name: "Output"
245  type {
246  tensor_type {
247  elem_type: 1
248  shape {
249  dim {
250  dim_value: 3
251  }
252  dim {
253  dim_value: 6
254  }
255  }
256  }
257  }
258  }
259  })";
260  Setup();
261  }
262 };
263 
264 TEST_CASE_FIXTURE(ConcatMultipleInputsFixture, "ConcatMultipleInputsTest")
265 {
266  RunTest<2, float>({{"Input0", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }},
267  {"Input1", { 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f }},
268  {"Input2", { 16.0f, 17.0f, 18.0f }}},
269  {{"Output", { 1.0f, 2.0f, 7.0f, 8.0f, 9.0f, 16.0f,
270  3.0f, 4.0f, 10.0f, 11.0f, 12.0f, 17.0f,
271  5.0f, 6.0f, 13.0f, 14.0f, 15.0f, 18.0f }}});
272 }
273 
274 }
std::string ConstructTensorShapeString(const std::vector< int > &shape)
TEST_SUITE("OnnxParser_Concat")
Definition: Concat.cpp:10
TEST_CASE_FIXTURE(ClContextControlFixture, "CopyBetweenNeonAndGpu")