ArmNN
 20.02
Squeeze.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <boost/test/unit_test.hpp>
9 
10 BOOST_AUTO_TEST_SUITE(TensorflowParser)
11 
12 template <bool withDimZero, bool withDimOne>
13 struct SqueezeFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
14 {
15  SqueezeFixture()
16  {
17  m_Prototext =
18  "node { \n"
19  " name: \"graphInput\" \n"
20  " op: \"Placeholder\" \n"
21  " attr { \n"
22  " key: \"dtype\" \n"
23  " value { \n"
24  " type: DT_FLOAT \n"
25  " } \n"
26  " } \n"
27  " attr { \n"
28  " key: \"shape\" \n"
29  " value { \n"
30  " shape { \n"
31  " } \n"
32  " } \n"
33  " } \n"
34  " } \n"
35  "node { \n"
36  " name: \"Squeeze\" \n"
37  " op: \"Squeeze\" \n"
38  " input: \"graphInput\" \n"
39  " attr { \n"
40  " key: \"T\" \n"
41  " value { \n"
42  " type: DT_FLOAT \n"
43  " } \n"
44  " } \n"
45  " attr { \n"
46  " key: \"squeeze_dims\" \n"
47  " value { \n"
48  " list {\n";
49 
50  if (withDimZero)
51  {
52  m_Prototext += "i:0\n";
53  }
54 
55  if (withDimOne)
56  {
57  m_Prototext += "i:1\n";
58  }
59 
60  m_Prototext +=
61  " } \n"
62  " } \n"
63  " } \n"
64  "} \n";
65 
66  SetupSingleInputSingleOutput({ 1, 1, 2, 2 }, "graphInput", "Squeeze");
67  }
68 };
69 
70 typedef SqueezeFixture<false, false> ImpliedDimensionsSqueezeFixture;
71 typedef SqueezeFixture<true, false> ExplicitDimensionZeroSqueezeFixture;
72 typedef SqueezeFixture<false, true> ExplicitDimensionOneSqueezeFixture;
73 typedef SqueezeFixture<true, true> ExplicitDimensionsSqueezeFixture;
74 
75 BOOST_FIXTURE_TEST_CASE(ParseImplicitSqueeze, ImpliedDimensionsSqueezeFixture)
76 {
77  BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("Squeeze").second.GetShape() ==
78  armnn::TensorShape({2,2})));
79  RunTest<2>({ 1.0f, 2.0f, 3.0f, 4.0f },
80  { 1.0f, 2.0f, 3.0f, 4.0f });
81 }
82 
83 BOOST_FIXTURE_TEST_CASE(ParseDimensionZeroSqueeze, ExplicitDimensionZeroSqueezeFixture)
84 {
85  BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("Squeeze").second.GetShape() ==
86  armnn::TensorShape({1,2,2})));
87  RunTest<3>({ 1.0f, 2.0f, 3.0f, 4.0f },
88  { 1.0f, 2.0f, 3.0f, 4.0f });
89 }
90 
91 BOOST_FIXTURE_TEST_CASE(ParseDimensionOneSqueeze, ExplicitDimensionOneSqueezeFixture)
92 {
93  BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("Squeeze").second.GetShape() ==
94  armnn::TensorShape({1,2,2})));
95  RunTest<3>({ 1.0f, 2.0f, 3.0f, 4.0f },
96  { 1.0f, 2.0f, 3.0f, 4.0f });
97 }
98 
99 BOOST_FIXTURE_TEST_CASE(ParseExplicitDimensionsSqueeze, ExplicitDimensionsSqueezeFixture)
100 {
101  BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("Squeeze").second.GetShape() ==
102  armnn::TensorShape({2,2})));
103  RunTest<2>({ 1.0f, 2.0f, 3.0f, 4.0f },
104  { 1.0f, 2.0f, 3.0f, 4.0f });
105 }
106 
BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
SqueezeFixture< true, true > ExplicitDimensionsSqueezeFixture
Definition: Squeeze.cpp:73
BOOST_FIXTURE_TEST_CASE(ParseSqueezeWithSqueezeDims, SqueezeFixtureWithSqueezeDims)
Definition: Squeeze.cpp:85
SqueezeFixture< false, false > ImpliedDimensionsSqueezeFixture
Definition: Squeeze.cpp:70
SqueezeFixture< true, false > ExplicitDimensionZeroSqueezeFixture
Definition: Squeeze.cpp:71
SqueezeFixture< false, true > ExplicitDimensionOneSqueezeFixture
Definition: Squeeze.cpp:72
BOOST_AUTO_TEST_SUITE_END()
void SetupSingleInputSingleOutput(const std::string &inputName, const std::string &outputName)
Parses and loads the network defined by the m_Prototext string.