ArmNN
 20.02
ExpandDims.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <boost/test/unit_test.hpp>
9 
10 BOOST_AUTO_TEST_SUITE(TensorflowParser)
11 
12 struct ExpandDimsFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
13 {
14  ExpandDimsFixture(const std::string& expandDim)
15  {
16  m_Prototext =
17  "node { \n"
18  " name: \"graphInput\" \n"
19  " op: \"Placeholder\" \n"
20  " attr { \n"
21  " key: \"dtype\" \n"
22  " value { \n"
23  " type: DT_FLOAT \n"
24  " } \n"
25  " } \n"
26  " attr { \n"
27  " key: \"shape\" \n"
28  " value { \n"
29  " shape { \n"
30  " } \n"
31  " } \n"
32  " } \n"
33  " } \n"
34  "node { \n"
35  " name: \"ExpandDims\" \n"
36  " op: \"ExpandDims\" \n"
37  " input: \"graphInput\" \n"
38  " attr { \n"
39  " key: \"T\" \n"
40  " value { \n"
41  " type: DT_FLOAT \n"
42  " } \n"
43  " } \n"
44  " attr { \n"
45  " key: \"Tdim\" \n"
46  " value { \n";
47  m_Prototext += "i:" + expandDim;
48  m_Prototext +=
49  " } \n"
50  " } \n"
51  "} \n";
52 
53  SetupSingleInputSingleOutput({ 2, 3, 5 }, "graphInput", "ExpandDims");
54  }
55 };
56 
57 struct ExpandZeroDim : ExpandDimsFixture
58 {
59  ExpandZeroDim() : ExpandDimsFixture("0") {}
60 };
61 
62 struct ExpandTwoDim : ExpandDimsFixture
63 {
64  ExpandTwoDim() : ExpandDimsFixture("2") {}
65 };
66 
67 struct ExpandThreeDim : ExpandDimsFixture
68 {
69  ExpandThreeDim() : ExpandDimsFixture("3") {}
70 };
71 
72 struct ExpandMinusOneDim : ExpandDimsFixture
73 {
74  ExpandMinusOneDim() : ExpandDimsFixture("-1") {}
75 };
76 
77 struct ExpandMinusThreeDim : ExpandDimsFixture
78 {
79  ExpandMinusThreeDim() : ExpandDimsFixture("-3") {}
80 };
81 
82 BOOST_FIXTURE_TEST_CASE(ParseExpandZeroDim, ExpandZeroDim)
83 {
84  BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() ==
85  armnn::TensorShape({1, 2, 3, 5})));
86 }
87 
88 BOOST_FIXTURE_TEST_CASE(ParseExpandTwoDim, ExpandTwoDim)
89 {
90  BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() ==
91  armnn::TensorShape({2, 3, 1, 5})));
92 }
93 
94 BOOST_FIXTURE_TEST_CASE(ParseExpandThreeDim, ExpandThreeDim)
95 {
96  BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() ==
97  armnn::TensorShape({2, 3, 5, 1})));
98 }
99 
100 BOOST_FIXTURE_TEST_CASE(ParseExpandMinusOneDim, ExpandMinusOneDim)
101 {
102  BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() ==
103  armnn::TensorShape({2, 3, 5, 1})));
104 }
105 
106 BOOST_FIXTURE_TEST_CASE(ParseExpandMinusThreeDim, ExpandMinusThreeDim)
107 {
108  BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() ==
109  armnn::TensorShape({2, 1, 3, 5})));
110 }
111 
BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
BOOST_AUTO_TEST_SUITE_END()
void SetupSingleInputSingleOutput(const std::string &inputName, const std::string &outputName)
Parses and loads the network defined by the m_Prototext string.
BOOST_FIXTURE_TEST_CASE(ParseExpandZeroDim, ExpandZeroDim)
Definition: ExpandDims.cpp:82