ArmNN
 20.02
Reshape.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <boost/test/unit_test.hpp>
8 #include "../TfLiteParser.hpp"
9 
10 #include <string>
11 #include <iostream>
12 
13 BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
14 
15 struct ReshapeFixture : public ParserFlatbuffersFixture
16 {
17  explicit ReshapeFixture(const std::string& inputShape,
18  const std::string& outputShape,
19  const std::string& newShape)
20  {
21  m_JsonString = R"(
22  {
23  "version": 3,
24  "operator_codes": [ { "builtin_code": "RESHAPE" } ],
25  "subgraphs": [ {
26  "tensors": [
27  {)";
28  m_JsonString += R"(
29  "shape" : )" + inputShape + ",";
30  m_JsonString += R"(
31  "type": "UINT8",
32  "buffer": 0,
33  "name": "inputTensor",
34  "quantization": {
35  "min": [ 0.0 ],
36  "max": [ 255.0 ],
37  "scale": [ 1.0 ],
38  "zero_point": [ 0 ],
39  }
40  },
41  {)";
42  m_JsonString += R"(
43  "shape" : )" + outputShape;
44  m_JsonString += R"(,
45  "type": "UINT8",
46  "buffer": 1,
47  "name": "outputTensor",
48  "quantization": {
49  "min": [ 0.0 ],
50  "max": [ 255.0 ],
51  "scale": [ 1.0 ],
52  "zero_point": [ 0 ],
53  }
54  }
55  ],
56  "inputs": [ 0 ],
57  "outputs": [ 1 ],
58  "operators": [
59  {
60  "opcode_index": 0,
61  "inputs": [ 0 ],
62  "outputs": [ 1 ],
63  "builtin_options_type": "ReshapeOptions",
64  "builtin_options": {)";
65  if (!newShape.empty())
66  {
67  m_JsonString += R"("new_shape" : )" + newShape;
68  }
69  m_JsonString += R"(},
70  "custom_options_format": "FLEXBUFFERS"
71  }
72  ],
73  } ],
74  "buffers" : [ {}, {} ]
75  }
76  )";
77 
78  }
79 };
80 
81 struct ReshapeFixtureWithReshapeDims : ReshapeFixture
82 {
83  ReshapeFixtureWithReshapeDims() : ReshapeFixture("[ 1, 9 ]", "[ 3, 3 ]", "[ 3, 3 ]") {}
84 };
85 
86 BOOST_FIXTURE_TEST_CASE(ParseReshapeWithReshapeDims, ReshapeFixtureWithReshapeDims)
87 {
88  SetupSingleInputSingleOutput("inputTensor", "outputTensor");
89  RunTest<2, armnn::DataType::QAsymmU8>(0,
90  { 1, 2, 3, 4, 5, 6, 7, 8, 9 },
91  { 1, 2, 3, 4, 5, 6, 7, 8, 9 });
92  BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
93  == armnn::TensorShape({3,3})));
94 }
95 
96 struct ReshapeFixtureWithReshapeDimsFlatten : ReshapeFixture
97 {
98  ReshapeFixtureWithReshapeDimsFlatten() : ReshapeFixture("[ 3, 3 ]", "[ 9 ]", "[ -1 ]") {}
99 };
100 
101 BOOST_FIXTURE_TEST_CASE(ParseReshapeWithReshapeDimsFlatten, ReshapeFixtureWithReshapeDimsFlatten)
102 {
103  SetupSingleInputSingleOutput("inputTensor", "outputTensor");
104  RunTest<1, armnn::DataType::QAsymmU8>(0,
105  { 1, 2, 3, 4, 5, 6, 7, 8, 9 },
106  { 1, 2, 3, 4, 5, 6, 7, 8, 9 });
107  BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
108  == armnn::TensorShape({9})));
109 }
110 
111 struct ReshapeFixtureWithReshapeDimsFlattenTwoDims : ReshapeFixture
112 {
113  ReshapeFixtureWithReshapeDimsFlattenTwoDims() : ReshapeFixture("[ 3, 2, 3 ]", "[ 2, 9 ]", "[ 2, -1 ]") {}
114 };
115 
116 BOOST_FIXTURE_TEST_CASE(ParseReshapeWithReshapeDimsFlattenTwoDims, ReshapeFixtureWithReshapeDimsFlattenTwoDims)
117 {
118  SetupSingleInputSingleOutput("inputTensor", "outputTensor");
119  RunTest<2, armnn::DataType::QAsymmU8>(0,
120  { 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6 },
121  { 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6 });
122  BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
123  == armnn::TensorShape({2,9})));
124 }
125 
126 struct ReshapeFixtureWithReshapeDimsFlattenOneDim : ReshapeFixture
127 {
128  ReshapeFixtureWithReshapeDimsFlattenOneDim() : ReshapeFixture("[ 2, 9 ]", "[ 2, 3, 3 ]", "[ 2, -1, 3 ]") {}
129 };
130 
131 BOOST_FIXTURE_TEST_CASE(ParseReshapeWithReshapeDimsFlattenOneDim, ReshapeFixtureWithReshapeDimsFlattenOneDim)
132 {
133  SetupSingleInputSingleOutput("inputTensor", "outputTensor");
134  RunTest<3, armnn::DataType::QAsymmU8>(0,
135  { 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6 },
136  { 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6 });
137  BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
138  == armnn::TensorShape({2,3,3})));
139 }
140 
BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
BOOST_FIXTURE_TEST_CASE(ValidReshapeTest, ReshapeValidFixture)
Definition: Reshape.cpp:100
BOOST_AUTO_TEST_SUITE_END()