ArmNN
 20.02
Transpose.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2020 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
8 
9 #include <boost/test/unit_test.hpp>
10 #include <PrototxtConversions.hpp>
11 
12 BOOST_AUTO_TEST_SUITE(TensorflowParser)
13 
14 namespace
15 {
16  std::string ConvertInt32VectorToOctalString(const std::vector<unsigned int>& data)
17  {
18  std::stringstream ss;
19  ss << "\"";
20  std::for_each(data.begin(), data.end(), [&ss](unsigned int d) {
21  ss << armnnUtils::ConvertInt32ToOctalString(static_cast<int>(d));
22  });
23  ss << "\"";
24  return ss.str();
25  }
26 } // namespace
27 
28 struct TransposeFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
29 {
30  TransposeFixture(const armnn::TensorShape& inputShape,
31  const std::vector<unsigned int>& permuteVectorData)
32  {
34  armnn::TensorShape permuteVectorShape({static_cast<unsigned int>(permuteVectorData.size())});
35 
36  m_Prototext = "node {\n"
37 " name: \"input\"\n"
38 " op: \"Placeholder\"\n"
39 " attr {\n"
40 " key: \"dtype\"\n"
41 " value {\n"
42 " type: DT_FLOAT\n"
43 " }\n"
44 " }\n"
45 " attr {\n"
46 " key: \"shape\"\n"
47 " value {\n"
48 " shape {\n";
49  m_Prototext.append(ConvertTensorShapeToString(inputShape));
50  m_Prototext.append(
51 " }\n"
52 " }\n"
53 " }\n"
54 "}\n"
55 "node {\n"
56 " name: \"transpose/perm\"\n"
57 " op: \"Const\"\n"
58 " attr {\n"
59 " key: \"dtype\"\n"
60 " value {\n"
61 " type: DT_INT32\n"
62 " }\n"
63 " }\n"
64 " attr {\n"
65 " key: \"value\"\n"
66 " value {\n"
67 " tensor {\n"
68 " dtype: DT_INT32\n"
69 " tensor_shape {\n"
70  );
71  m_Prototext.append(ConvertTensorShapeToString(permuteVectorShape));
72  m_Prototext.append(
73 " }\n"
74 " tensor_content: "
75  );
76  m_Prototext.append(ConvertInt32VectorToOctalString(permuteVectorData) + "\n");
77  m_Prototext.append(
78 " }\n"
79 " }\n"
80 " }\n"
81 "}\n"
82  );
83  m_Prototext.append(
84 "node {\n"
85 " name: \"output\"\n"
86 " op: \"Transpose\"\n"
87 " input: \"input\"\n"
88 " input: \"transpose/perm\"\n"
89 " attr {\n"
90 " key: \"T\"\n"
91 " value {\n"
92 " type: DT_FLOAT\n"
93 " }\n"
94 " }\n"
95 " attr {\n"
96 " key: \"Tperm\"\n"
97 " value {\n"
98 " type: DT_INT32\n"
99 " }\n"
100 " }\n"
101 "}\n"
102  );
103  Setup({{"input", inputShape}}, {"output"});
104  }
105 };
106 
107 struct TransposeFixtureWithPermuteData : TransposeFixture
108 {
109  TransposeFixtureWithPermuteData()
110  : TransposeFixture({2, 2, 3, 4},
111  std::vector<unsigned int>({1, 3, 2, 0})) {}
112 };
113 
114 BOOST_FIXTURE_TEST_CASE(TransposeWithPermuteData, TransposeFixtureWithPermuteData)
115 {
116  RunTest<4>(
117  {{"input", {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
118  16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
119  32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}}},
120  {{"output", {0, 24, 4, 28, 8, 32, 1, 25, 5, 29, 9, 33, 2, 26, 6,
121  30, 10, 34, 3, 27, 7, 31, 11, 35, 12, 36, 16, 40, 20, 44, 13, 37,
122  17, 41, 21, 45, 14, 38, 18, 42, 22, 46, 15, 39, 19, 43, 23, 47}}});
123 
124  BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("output").second.GetShape()
125  == armnn::TensorShape({2, 4, 3, 2})));
126 }
127 
128 struct TransposeFixtureWithoutPermuteData : TransposeFixture
129 {
130  // In case permute data is not given, it assumes (n-1,...,0) is given
131  // where n is the rank of input tensor.
132  TransposeFixtureWithoutPermuteData()
133  : TransposeFixture({2, 2, 3, 4},
134  std::vector<unsigned int>({3, 2, 1, 0})) {}
135 };
136 
137 BOOST_FIXTURE_TEST_CASE(TransposeWithoutPermuteData, TransposeFixtureWithoutPermuteData)
138 {
139  RunTest<4>(
140  {{"input", {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
141  16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
142  32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}}},
143  {{"output", {0, 24, 12, 36, 4, 28, 16, 40, 8, 32, 20, 44, 1, 25,
144  13, 37, 5, 29, 17, 41, 9, 33, 21, 45, 2, 26, 14, 38, 6, 30, 18,
145  42,10, 34, 22, 46, 3, 27, 15, 39, 7, 31, 19, 43, 11, 35, 23, 47}}});
146 
147  BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("output").second.GetShape()
148  == armnn::TensorShape({4, 3, 2, 2})));
149 }
150 
BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
BOOST_FIXTURE_TEST_CASE(TransposeWithPermuteData, TransposeFixtureWithPermuteData)
Definition: Transpose.cpp:121
BOOST_AUTO_TEST_SUITE_END()
std::string ConvertTensorShapeToString(const armnn::TensorShape &shape)
Converts an TensorShape into Prototxt representation.