ArmNN
 21.02
ExpandDims.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <boost/test/unit_test.hpp>
9 
10 BOOST_AUTO_TEST_SUITE(TensorflowParser)
11 
12 struct ExpandDimsFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
13 {
14  ExpandDimsFixture(const std::string& expandDim)
15  {
16  m_Prototext =
17  "node { \n"
18  " name: \"graphInput\" \n"
19  " op: \"Placeholder\" \n"
20  " attr { \n"
21  " key: \"dtype\" \n"
22  " value { \n"
23  " type: DT_FLOAT \n"
24  " } \n"
25  " } \n"
26  " attr { \n"
27  " key: \"shape\" \n"
28  " value { \n"
29  " shape { \n"
30  " } \n"
31  " } \n"
32  " } \n"
33  " } \n"
34  "node { \n"
35  " name: \"ExpandDims\" \n"
36  " op: \"ExpandDims\" \n"
37  " input: \"graphInput\" \n"
38  " attr { \n"
39  " key: \"T\" \n"
40  " value { \n"
41  " type: DT_FLOAT \n"
42  " } \n"
43  " } \n"
44  " attr { \n"
45  " key: \"Tdim\" \n"
46  " value { \n";
47  m_Prototext += "i:" + expandDim;
48  m_Prototext +=
49  " } \n"
50  " } \n"
51  "} \n";
52 
53  SetupSingleInputSingleOutput({ 2, 3, 5 }, "graphInput", "ExpandDims");
54  }
55 };
56 
57 struct ExpandZeroDim : ExpandDimsFixture
58 {
59  ExpandZeroDim() : ExpandDimsFixture("0") {}
60 };
61 
62 struct ExpandTwoDim : ExpandDimsFixture
63 {
64  ExpandTwoDim() : ExpandDimsFixture("2") {}
65 };
66 
67 struct ExpandThreeDim : ExpandDimsFixture
68 {
69  ExpandThreeDim() : ExpandDimsFixture("3") {}
70 };
71 
72 struct ExpandMinusOneDim : ExpandDimsFixture
73 {
74  ExpandMinusOneDim() : ExpandDimsFixture("-1") {}
75 };
76 
77 struct ExpandMinusThreeDim : ExpandDimsFixture
78 {
79  ExpandMinusThreeDim() : ExpandDimsFixture("-3") {}
80 };
81 
82 BOOST_FIXTURE_TEST_CASE(ParseExpandZeroDim, ExpandZeroDim)
83 {
84  BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() ==
85  armnn::TensorShape({1, 2, 3, 5})));
86 }
87 
88 BOOST_FIXTURE_TEST_CASE(ParseExpandTwoDim, ExpandTwoDim)
89 {
90  BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() ==
91  armnn::TensorShape({2, 3, 1, 5})));
92 }
93 
94 BOOST_FIXTURE_TEST_CASE(ParseExpandThreeDim, ExpandThreeDim)
95 {
96  BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() ==
97  armnn::TensorShape({2, 3, 5, 1})));
98 }
99 
100 BOOST_FIXTURE_TEST_CASE(ParseExpandMinusOneDim, ExpandMinusOneDim)
101 {
102  BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() ==
103  armnn::TensorShape({2, 3, 5, 1})));
104 }
105 
106 BOOST_FIXTURE_TEST_CASE(ParseExpandMinusThreeDim, ExpandMinusThreeDim)
107 {
108  BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() ==
109  armnn::TensorShape({2, 1, 3, 5})));
110 }
111 
112 struct ExpandDimsAsInputFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
113 {
114  ExpandDimsAsInputFixture(const std::string& expandDim,
115  const bool wrongDataType = false,
116  const std::string& numElements = "1")
117  {
118  std::string dataType = (wrongDataType) ? "DT_FLOAT" : "DT_INT32";
119  std::string val = (wrongDataType) ? ("float_val: " + expandDim + ".0") : ("int_val: "+ expandDim);
120 
121  m_Prototext = R"(
122  node {
123  name: "a"
124  op: "Placeholder"
125  attr {
126  key: "dtype"
127  value {
128  type: DT_FLOAT
129  }
130  }
131  attr {
132  key: "shape"
133  value {
134  shape {
135  dim {
136  size: 1
137  }
138  dim {
139  size: 4
140  }
141  }
142  }
143  }
144  }
145  node {
146  name: "b"
147  op: "Const"
148  attr {
149  key: "dtype"
150  value {
151  type: )" + dataType + R"(
152  }
153  }
154  attr {
155  key: "value"
156  value {
157  tensor {
158  dtype: )" + dataType + R"(
159  tensor_shape {
160  dim {
161  size: )" + numElements + R"(
162  }
163  }
164  )" + val + R"(
165  }
166  }
167  }
168  }
169  node {
170  name: "ExpandDims"
171  op: "ExpandDims"
172  input: "a"
173  input: "b"
174  attr {
175  key: "T"
176  value {
177  type: DT_FLOAT
178  }
179  }
180  attr {
181  key: "Tdim"
182  value {
183  type: DT_INT32
184  }
185  }
186  }
187  versions {
188  producer: 134
189  })";
190  }
191 };
192 
193 struct ExpandDimAsInput : ExpandDimsAsInputFixture
194 {
195  ExpandDimAsInput() : ExpandDimsAsInputFixture("0")
196  {
197  Setup({{"a", {1,4}} ,{"b",{1,1}}}, { "ExpandDims" });
198  }
199 };
200 
201 
202 BOOST_FIXTURE_TEST_CASE(ParseExpandDimAsInput, ExpandDimAsInput)
203 {
204  // Axis parameter that describes which axis/dim should be expanded is passed as a second input
205  BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() ==
206  armnn::TensorShape({1, 1, 4})));
207 }
208 
209 struct ExpandDimAsInputWrongDataType : ExpandDimsAsInputFixture
210 {
211  ExpandDimAsInputWrongDataType() : ExpandDimsAsInputFixture("0", true, "1") {}
212 };
213 
214 BOOST_FIXTURE_TEST_CASE(ParseExpandDimAsInputWrongDataType, ExpandDimAsInputWrongDataType)
215 {
216  // Axis parameter that describes which axis/dim should be expanded is passed as a second input
217  // Axis parameter is of wrong data type (float instead of int32)
218  BOOST_REQUIRE_THROW(Setup({{"a", {1,4}} ,{"b",{1,1}}}, { "ExpandDims" }), armnn::ParseException);
219 }
220 
221 struct ExpandDimAsInputWrongShape : ExpandDimsAsInputFixture
222 {
223  ExpandDimAsInputWrongShape() : ExpandDimsAsInputFixture("0", false, "2") {}
224 };
225 
226 BOOST_FIXTURE_TEST_CASE(ParseExpandDimAsInputWrongShape, ExpandDimAsInputWrongShape)
227 {
228  // Axis parameter that describes which axis/dim should be expanded is passed as a second input
229  // Axis parameter is of wrong shape
230  BOOST_REQUIRE_THROW(Setup({{"a", {1,4}} ,{"b",{1,1}}}, { "ExpandDims" }), armnn::ParseException);
231 }
232 
233 struct ExpandDimsAsNotConstInputFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
234 {
235  ExpandDimsAsNotConstInputFixture()
236  {
237  m_Prototext = R"(
238  node {
239  name: "a"
240  op: "Placeholder"
241  attr {
242  key: "dtype"
243  value {
244  type: DT_FLOAT
245  }
246  }
247  attr {
248  key: "shape"
249  value {
250  shape {
251  dim {
252  size: 1
253  }
254  dim {
255  size: 4
256  }
257  }
258  }
259  }
260  }
261  node {
262  name: "b"
263  op: "Placeholder"
264  attr {
265  key: "dtype"
266  value {
267  type: DT_INT32
268  }
269  }
270  attr {
271  key: "shape"
272  value {
273  shape {
274  dim {
275  size: 1
276  }
277  }
278  }
279  }
280  }
281  node {
282  name: "ExpandDims"
283  op: "ExpandDims"
284  input: "a"
285  input: "b"
286  attr {
287  key: "T"
288  value {
289  type: DT_FLOAT
290  }
291  }
292  attr {
293  key: "Tdim"
294  value {
295  type: DT_INT32
296  }
297  }
298  }
299  versions {
300  producer: 134
301  })";
302  }
303 };
304 
305 BOOST_FIXTURE_TEST_CASE(ParseExpandDimAsNotConstInput, ExpandDimsAsNotConstInputFixture)
306 {
307  // Axis parameter that describes which axis/dim should be expanded is passed as a second input.
308  // But is not a constant tensor --> not supported
309  BOOST_REQUIRE_THROW(Setup({{"a", {1,4}} ,{"b",{1,1}}}, { "ExpandDims" }),
311 }
312 
BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
BOOST_AUTO_TEST_SUITE_END()
void SetupSingleInputSingleOutput(const std::string &inputName, const std::string &outputName)
Parses and loads the network defined by the m_Prototext string.
BOOST_FIXTURE_TEST_CASE(ParseExpandZeroDim, ExpandZeroDim)
Definition: ExpandDims.cpp:82