ArmNN
 21.02
Concat.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <boost/test/unit_test.hpp>
9 
10 BOOST_AUTO_TEST_SUITE(TensorflowParser)
11 
12 struct ConcatFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
13 {
14  explicit ConcatFixture(const armnn::TensorShape& inputShape0, const armnn::TensorShape& inputShape1,
15  unsigned int concatDim)
16  {
17  m_Prototext = R"(
18  node {
19  name: "graphInput0"
20  op: "Placeholder"
21  attr {
22  key: "dtype"
23  value {
24  type: DT_FLOAT
25  }
26  }
27  attr {
28  key: "shape"
29  value {
30  shape {
31  }
32  }
33  }
34  }
35  node {
36  name: "graphInput1"
37  op: "Placeholder"
38  attr {
39  key: "dtype"
40  value {
41  type: DT_FLOAT
42  }
43  }
44  attr {
45  key: "shape"
46  value {
47  shape {
48  }
49  }
50  }
51  }
52  node {
53  name: "concat/axis"
54  op: "Const"
55  attr {
56  key: "dtype"
57  value {
58  type: DT_INT32
59  }
60  }
61  attr {
62  key: "value"
63  value {
64  tensor {
65  dtype: DT_INT32
66  tensor_shape {
67  }
68  int_val: )";
69 
70  m_Prototext += std::to_string(concatDim);
71 
72  m_Prototext += R"(
73  }
74  }
75  }
76  }
77  node {
78  name: "concat"
79  op: "ConcatV2"
80  input: "graphInput0"
81  input: "graphInput1"
82  input: "concat/axis"
83  attr {
84  key: "N"
85  value {
86  i: 2
87  }
88  }
89  attr {
90  key: "T"
91  value {
92  type: DT_FLOAT
93  }
94  }
95  attr {
96  key: "Tidx"
97  value {
98  type: DT_FLOAT
99  }
100  }
101  }
102  )";
103 
104  Setup({{"graphInput0", inputShape0 },
105  {"graphInput1", inputShape1 }}, {"concat"});
106  }
107 };
108 
109 struct ConcatFixtureNCHW : ConcatFixture
110 {
111  ConcatFixtureNCHW() : ConcatFixture({ 1, 1, 2, 2 }, { 1, 1, 2, 2 }, 1 ) {}
112 };
113 
114 struct ConcatFixtureNHWC : ConcatFixture
115 {
116  ConcatFixtureNHWC() : ConcatFixture({ 1, 1, 2, 2 }, { 1, 1, 2, 2 }, 3 ) {}
117 };
118 
119 BOOST_FIXTURE_TEST_CASE(ParseConcatNCHW, ConcatFixtureNCHW)
120 {
121  RunTest<4>({{"graphInput0", {0.0, 1.0, 2.0, 3.0}},
122  {"graphInput1", {4.0, 5.0, 6.0, 7.0}}},
123  {{"concat", { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0 }}});
124 }
125 
126 BOOST_FIXTURE_TEST_CASE(ParseConcatNHWC, ConcatFixtureNHWC)
127 {
128  RunTest<4>({{"graphInput0", {0.0, 1.0, 2.0, 3.0}},
129  {"graphInput1", {4.0, 5.0, 6.0, 7.0}}},
130  {{"concat", { 0.0, 1.0, 4.0, 5.0, 2.0, 3.0, 6.0, 7.0 }}});
131 }
132 
133 struct ConcatFixtureDim1 : ConcatFixture
134 {
135  ConcatFixtureDim1() : ConcatFixture({ 1, 2, 3, 4 }, { 1, 2, 3, 4 }, 1) {}
136 };
137 
138 struct ConcatFixtureDim3 : ConcatFixture
139 {
140  ConcatFixtureDim3() : ConcatFixture({ 1, 2, 3, 4 }, { 1, 2, 3, 4 }, 3) {}
141 };
142 
143 BOOST_FIXTURE_TEST_CASE(ParseConcatDim1, ConcatFixtureDim1)
144 {
145  RunTest<4>({ { "graphInput0", { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0,
146  12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0 } },
147  { "graphInput1", { 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0, 61.0,
148  62.0, 63.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, 72.0, 73.0 } } },
149  { { "concat", { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0,
150  12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0,
151  50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0, 61.0,
152  62.0, 63.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, 72.0, 73.0 } } });
153 }
154 
155 BOOST_FIXTURE_TEST_CASE(ParseConcatDim3, ConcatFixtureDim3)
156 {
157  RunTest<4>({ { "graphInput0", { 0.0, 1.0, 2.0, 3.0,
158  4.0, 5.0, 6.0, 7.0,
159  8.0, 9.0, 10.0, 11.0,
160  12.0, 13.0, 14.0, 15.0,
161  16.0, 17.0, 18.0, 19.0,
162  20.0, 21.0, 22.0, 23.0 } },
163  { "graphInput1", { 50.0, 51.0, 52.0, 53.0,
164  54.0, 55.0, 56.0, 57.0,
165  58.0, 59.0, 60.0, 61.0,
166  62.0, 63.0, 64.0, 65.0,
167  66.0, 67.0, 68.0, 69.0,
168  70.0, 71.0, 72.0, 73.0 } } },
169  { { "concat", { 0.0, 1.0, 2.0, 3.0,
170  50.0, 51.0, 52.0, 53.0,
171  4.0, 5.0, 6.0, 7.0,
172  54.0, 55.0, 56.0, 57.0,
173  8.0, 9.0, 10.0, 11.0,
174  58.0, 59.0, 60.0, 61.0,
175  12.0, 13.0, 14.0, 15.0,
176  62.0, 63.0, 64.0, 65.0,
177  16.0, 17.0, 18.0, 19.0,
178  66.0, 67.0, 68.0, 69.0,
179  20.0, 21.0, 22.0, 23.0,
180  70.0, 71.0, 72.0, 73.0 } } });
181 }
182 
BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
BOOST_FIXTURE_TEST_CASE(ParseConcatNCHW, ConcatFixtureNCHW)
Definition: Concat.cpp:119
BOOST_AUTO_TEST_SUITE_END()