ArmNN
 22.05.01
Gemm.cpp File Reference

Go to the source code of this file.

Functions

 TEST_SUITE ("OnnxParser_Gemm")
 

Function Documentation

◆ TEST_SUITE()

TEST_SUITE ( "OnnxParser_Gemm"  )

Definition at line 10 of file Gemm.cpp.

References armnnUtils::ConstructTensorShapeString(), and TEST_CASE_FIXTURE().

11 {
12 
13 struct GemmFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
14 {
15  GemmFixture(const std::string& alpha,
16  const std::string& beta,
17  const std::string& transA,
18  const std::string& transB,
19  const std::vector<int>& inputAShape,
20  const std::vector<int>& inputBShape,
21  const std::vector<int>& inputCShape,
22  const std::vector<int>& outputShape)
23  {
24  m_Prototext = R"(
25  ir_version: 8
26  producer_name: "onnx-example"
27  graph {
28  node {
29  input: "A"
30  input: "B"
31  input: "C"
32  output: "Output"
33  op_type: "Gemm"
34  attribute {
35  name: "alpha"
36  f: )" + alpha + R"(
37  type: FLOAT
38  }
39  attribute {
40  name: "beta"
41  f: )" + beta + R"(
42  type: FLOAT
43  }
44  attribute {
45  name: "transA"
46  i: )" + transA + R"(
47  type: INT
48  }
49  attribute {
50  name: "transB"
51  i: )" + transB + R"(
52  type: INT
53  }
54  }
55  name: "gem-model"
56  input {
57  name: "A"
58  type {
59  tensor_type {
60  elem_type: 1
61  shape {
62  )" + armnnUtils::ConstructTensorShapeString(inputAShape) + R"(
63  }
64  }
65  }
66  }
67  input {
68  name: "B"
69  type {
70  tensor_type {
71  elem_type: 1
72  shape {
73  )" + armnnUtils::ConstructTensorShapeString(inputBShape) + R"(
74  }
75  }
76  }
77  }
78  input {
79  name: "C"
80  type {
81  tensor_type {
82  elem_type: 1
83  shape {
84  )" + armnnUtils::ConstructTensorShapeString(inputCShape) + R"(
85  }
86  }
87  }
88  }
89  output {
90  name: "Output"
91  type {
92  tensor_type {
93  elem_type: 1
94  shape {
95  )" + armnnUtils::ConstructTensorShapeString(outputShape) + R"(
96  }
97  }
98  }
99  }
100  })";
101  }
102 };
103 
104 struct GemmAllAttributesFixture : GemmFixture
105 {
106  GemmAllAttributesFixture() : GemmFixture("0.25", "0.35", "1", "1", { 4, 3 }, { 5, 4 }, { 5 }, { 3, 5 })
107  {
108  Setup();
109  }
110 };
111 
112 struct GemmSimpleFixture : GemmFixture
113 {
114  GemmSimpleFixture() : GemmFixture("1", "1", "0", "0", { 3, 4 }, { 4, 5 }, { 5 }, { 3, 5 })
115  {
116  Setup();
117  }
118 };
119 
120 struct GemmTransAFixture : GemmFixture
121 {
122  GemmTransAFixture() : GemmFixture("1", "1", "1", "0", { 4, 3 }, { 4, 5 }, { 5 }, { 3, 5 })
123  {
124  Setup();
125  }
126 };
127 
128 struct GemmTransBFixture : GemmFixture
129 {
130  GemmTransBFixture() : GemmFixture("1", "1", "0", "1", { 3, 4 }, { 5, 4 }, { 5 }, { 3, 5 })
131  {
132  Setup();
133  }
134 };
135 
136 struct GemmParseExceptionFixture : GemmFixture
137 {
138  GemmParseExceptionFixture() : GemmFixture("1", "1", "0", "1", { 3, 4 }, { 5, 4 }, { 3, 5 }, { 3, 5 }) {}
139 };
140 
141 TEST_CASE_FIXTURE(GemmAllAttributesFixture, "GemmTest")
142 {
143  RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f,
144  6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }},
145  {"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
146  6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
147  11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
148  16.0f, 17.0f, 18.0f, 19.0f, 20.0f }},
149  {"C", { 0.10f, 0.20f, 0.30f, 0.40f, 0.50f }}},
150  {{"Output", { 15.035f, 45.07f, 75.105f, 105.14f, 135.175f,
151  12.535f, 38.57f, 64.605f, 90.64f, 116.675f,
152  10.035f, 32.07f, 54.105f, 76.14f, 98.175f }}});
153 }
154 
155 TEST_CASE_FIXTURE(GemmSimpleFixture, "GemmSimpleTest")
156 {
157  RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f,
158  6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }},
159  {"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
160  6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
161  11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
162  16.0f, 17.0f, 18.0f, 19.0f, 20.0f }},
163  {"C", { 0.10f, 0.20f, 0.30f, 0.40f, 0.50f }}},
164  {{"Output", { 332.1f, 374.2f, 416.3f, 458.4f, 500.5f,
165  196.1f, 222.2f, 248.3f, 274.4f, 300.5f,
166  60.1f, 70.2f, 80.3f, 90.4f, 100.5f }}});
167 }
168 
169 TEST_CASE_FIXTURE(GemmTransAFixture, "GemmTransposeATest")
170 {
171  RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f,
172  6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }},
173  {"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
174  6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
175  11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
176  16.0f, 17.0f, 18.0f, 19.0f, 20.0f }},
177  {"C", { 0.10f, 0.20f, 0.30f, 0.40f, 0.50f }}},
178  {{"Output", { 180.1f, 210.2f, 240.3f, 270.4f, 300.5f,
179  146.1f, 172.2f, 198.3f, 224.4f, 250.5f,
180  112.1f, 134.2f, 156.3f, 178.4f, 200.5f }}});
181 }
182 
183 TEST_CASE_FIXTURE(GemmTransBFixture, "GemmTransposeBTest")
184 {
185  RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f,
186  6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }},
187  {"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
188  6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
189  11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
190  16.0f, 17.0f, 18.0f, 19.0f, 20.0f }},
191  {"C", { 0.10f, 0.20f, 0.30f, 0.40f, 0.50f }}},
192  {{"Output", { 100.1f, 268.2f, 436.3f, 604.4f, 772.5f,
193  60.1f, 164.2f, 268.3f, 372.4f, 476.5f,
194  20.1f, 60.2f, 100.3f, 140.4f, 180.5f }}});
195 }
196 
197 TEST_CASE_FIXTURE(GemmParseExceptionFixture, "GemmParseExceptionTest")
198 {
199  // ParseException because Input C is non-constant and has 2 dimension (should be 1 dimension)
200  CHECK_THROWS_AS(Setup(), armnn::ParseException);
201 }
202 
203 struct GemmConstantFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
204 {
205  GemmConstantFixture()
206  {
207  m_Prototext = R"(
208  ir_version: 8
209  producer_name: "onnx-example"
210  graph {
211  node {
212  input: "A"
213  input: "B"
214  input: "C"
215  output: "Output"
216  op_type: "Gemm"
217  attribute {
218  name: "alpha"
219  f: 0.25
220  type: FLOAT
221  }
222  attribute {
223  name: "beta"
224  f: 0.35
225  type: FLOAT
226  }
227  attribute {
228  name: "transA"
229  i: 1
230  type: INT
231  }
232  attribute {
233  name: "transB"
234  i: 1
235  type: INT
236  }
237  }
238  name: "gem-model"
239  initializer {
240  dims: 5
241  dims: 4
242  data_type: 1
243  float_data: 1.0
244  float_data: 2.0
245  float_data: 3.0
246  float_data: 4.0
247  float_data: 5.0
248  float_data: 6.0
249  float_data: 7.0
250  float_data: 8.0
251  float_data: 9.0
252  float_data: 10.0
253  float_data: 11.0
254  float_data: 12.0
255  float_data: 13.0
256  float_data: 14.0
257  float_data: 15.0
258  float_data: 16.0
259  float_data: 17.0
260  float_data: 18.0
261  float_data: 19.0
262  float_data: 20.0
263  name: "B"
264  }
265  initializer {
266  dims: 1
267  dims: 5
268  data_type: 1
269  float_data: 0.1
270  float_data: 0.2
271  float_data: 0.3
272  float_data: 0.4
273  float_data: 0.5
274  name: "C"
275  }
276  input {
277  name: "A"
278  type {
279  tensor_type {
280  elem_type: 1
281  shape {
282  dim {
283  dim_value: 4
284  }
285  dim {
286  dim_value: 3
287  }
288  }
289  }
290  }
291  }
292  output {
293  name: "Output"
294  type {
295  tensor_type {
296  elem_type: 1
297  shape {
298  dim {
299  dim_value: 3
300  }
301  dim {
302  dim_value: 5
303  }
304  }
305  }
306  }
307  }
308  })";
309  Setup();
310  }
311 };
312 
313 TEST_CASE_FIXTURE(GemmConstantFixture, "GemmConstantTest")
314 {
315  RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f,
316  6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }}},
317  {{"Output", { 15.035f, 45.07f, 75.105f, 105.14f, 135.175f,
318  12.535f, 38.57f, 64.605f, 90.64f, 116.675f,
319  10.035f, 32.07f, 54.105f, 76.14f, 98.175f }}});
320 }
321 
322 struct GemmConstantSimpleFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
323 {
324  GemmConstantSimpleFixture()
325  {
326  m_Prototext = R"(
327  ir_version: 8
328  producer_name: "onnx-example"
329  graph {
330  node {
331  input: "A"
332  input: "B"
333  input: "C"
334  output: "Output"
335  op_type: "Gemm"
336  attribute {
337  name: "alpha"
338  f: 1
339  type: FLOAT
340  }
341  attribute {
342  name: "beta"
343  f: 1
344  type: FLOAT
345  }
346  attribute {
347  name: "transA"
348  i: 0
349  type: INT
350  }
351  attribute {
352  name: "transB"
353  i: 0
354  type: INT
355  }
356  }
357  name: "gem-model"
358  initializer {
359  dims: 4
360  dims: 5
361  data_type: 1
362  float_data: 1.0
363  float_data: 2.0
364  float_data: 3.0
365  float_data: 4.0
366  float_data: 5.0
367  float_data: 6.0
368  float_data: 7.0
369  float_data: 8.0
370  float_data: 9.0
371  float_data: 10.0
372  float_data: 11.0
373  float_data: 12.0
374  float_data: 13.0
375  float_data: 14.0
376  float_data: 15.0
377  float_data: 16.0
378  float_data: 17.0
379  float_data: 18.0
380  float_data: 19.0
381  float_data: 20.0
382  name: "B"
383  }
384  initializer {
385  dims: 1
386  dims: 5
387  data_type: 1
388  float_data: 0.1
389  float_data: 0.2
390  float_data: 0.3
391  float_data: 0.4
392  float_data: 0.5
393  name: "C"
394  }
395  input {
396  name: "A"
397  type {
398  tensor_type {
399  elem_type: 1
400  shape {
401  dim {
402  dim_value: 3
403  }
404  dim {
405  dim_value: 4
406  }
407  }
408  }
409  }
410  }
411  output {
412  name: "Output"
413  type {
414  tensor_type {
415  elem_type: 1
416  shape {
417  dim {
418  dim_value: 3
419  }
420  dim {
421  dim_value: 5
422  }
423  }
424  }
425  }
426  }
427  })";
428  Setup();
429  }
430 };
431 
432 TEST_CASE_FIXTURE(GemmConstantSimpleFixture, "GemmConstantSimpleTest")
433 {
434  RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f,
435  6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }}},
436  {{"Output", { 332.1f, 374.2f, 416.3f, 458.4f, 500.5f,
437  196.1f, 222.2f, 248.3f, 274.4f, 300.5f,
438  60.1f, 70.2f, 80.3f, 90.4f, 100.5f }}});
439 }
440 
441 struct GemmABFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
442 {
443  GemmABFixture(const std::string& alpha,
444  const std::string& beta,
445  const std::string& transA,
446  const std::string& transB,
447  const std::vector<int>& inputAShape,
448  const std::vector<int>& inputBShape,
449  const std::vector<int>& outputShape)
450  {
451  m_Prototext = R"(
452  ir_version: 8
453  producer_name: "onnx-example"
454  graph {
455  node {
456  input: "A"
457  input: "B"
458  output: "Output"
459  op_type: "Gemm"
460  attribute {
461  name: "alpha"
462  f: )" + alpha + R"(
463  type: FLOAT
464  }
465  attribute {
466  name: "beta"
467  f: )" + beta + R"(
468  type: FLOAT
469  }
470  attribute {
471  name: "transA"
472  i: )" + transA + R"(
473  type: INT
474  }
475  attribute {
476  name: "transB"
477  i: )" + transB + R"(
478  type: INT
479  }
480  }
481  name: "gem-model"
482  input {
483  name: "A"
484  type {
485  tensor_type {
486  elem_type: 1
487  shape {
488  )" + armnnUtils::ConstructTensorShapeString(inputAShape) + R"(
489  }
490  }
491  }
492  }
493  input {
494  name: "B"
495  type {
496  tensor_type {
497  elem_type: 1
498  shape {
499  )" + armnnUtils::ConstructTensorShapeString(inputBShape) + R"(
500  }
501  }
502  }
503  }
504  output {
505  name: "Output"
506  type {
507  tensor_type {
508  elem_type: 1
509  shape {
510  )" + armnnUtils::ConstructTensorShapeString(outputShape) + R"(
511  }
512  }
513  }
514  }
515  })";
516  Setup();
517  }
518 };
519 
520 struct GemmAlphaTransAFixture : GemmABFixture
521 {
522  GemmAlphaTransAFixture() : GemmABFixture("0.25", "0.35", "1", "0", { 4, 3 }, { 4, 5 }, { 3, 5 }) {}
523 };
524 
525 struct GemmAlphaTransBFixture : GemmABFixture
526 {
527  GemmAlphaTransBFixture() : GemmABFixture("0.25", "0.35", "0", "1", { 3, 4 }, { 5, 4 }, { 3, 5 }) {}
528 };
529 
530 TEST_CASE_FIXTURE(GemmAlphaTransAFixture, "GemmAlphaTransATest")
531 {
532  RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f,
533  6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }},
534  {"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
535  6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
536  11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
537  16.0f, 17.0f, 18.0f, 19.0f, 20.0f }}},
538  {{"Output", { 45.0f, 52.5f, 60.0f, 67.5f, 75.0f,
539  36.5f, 43.0f, 49.5f, 56.0f, 62.5f,
540  28.0f, 33.5f, 39.0f, 44.5f, 50.0f }}});
541 }
542 
543 TEST_CASE_FIXTURE(GemmAlphaTransBFixture, "GemmAlphaTransBTest")
544 {
545  RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f,
546  6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }},
547  {"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
548  6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
549  11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
550  16.0f, 17.0f, 18.0f, 19.0f, 20.0f }}},
551  {{"Output", { 25.0f, 67.0f, 109.0f, 151.0f, 193.0f,
552  15.0f, 41.0f, 67.0f, 93.0f, 119.0f,
553  5.0f, 15.0f, 25.0f, 35.0f, 45.0f }}});
554 }
555 
556 }
std::string ConstructTensorShapeString(const std::vector< int > &shape)
TEST_CASE_FIXTURE(ClContextControlFixture, "CopyBetweenNeonAndGpu")