ArmNN
 20.02
DeserializePooling2d.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <boost/test/unit_test.hpp>
8 #include "../Deserializer.hpp"
9 
10 #include <string>
11 #include <iostream>
12 
13 BOOST_AUTO_TEST_SUITE(Deserializer)
14 
15 struct Pooling2dFixture : public ParserFlatbuffersSerializeFixture
16 {
17  explicit Pooling2dFixture(const std::string &inputShape,
18  const std::string &outputShape,
19  const std::string &dataType,
20  const std::string &dataLayout,
21  const std::string &poolingAlgorithm)
22  {
23  m_JsonString = R"(
24  {
25  inputIds: [0],
26  outputIds: [2],
27  layers: [
28  {
29  layer_type: "InputLayer",
30  layer: {
31  base: {
32  layerBindingId: 0,
33  base: {
34  index: 0,
35  layerName: "InputLayer",
36  layerType: "Input",
37  inputSlots: [{
38  index: 0,
39  connection: {sourceLayerIndex:0, outputSlotIndex:0 },
40  }],
41  outputSlots: [ {
42  index: 0,
43  tensorInfo: {
44  dimensions: )" + inputShape + R"(,
45  dataType: )" + dataType + R"(
46  }}]
47  }
48  }}},
49  {
50  layer_type: "Pooling2dLayer",
51  layer: {
52  base: {
53  index: 1,
54  layerName: "Pooling2dLayer",
55  layerType: "Pooling2d",
56  inputSlots: [{
57  index: 0,
58  connection: {sourceLayerIndex:0, outputSlotIndex:0 },
59  }],
60  outputSlots: [ {
61  index: 0,
62  tensorInfo: {
63  dimensions: )" + outputShape + R"(,
64  dataType: )" + dataType + R"(
65 
66  }}]},
67  descriptor: {
68  poolType: )" + poolingAlgorithm + R"(,
69  outputShapeRounding: "Floor",
70  paddingMethod: Exclude,
71  dataLayout: )" + dataLayout + R"(,
72  padLeft: 0,
73  padRight: 0,
74  padTop: 0,
75  padBottom: 0,
76  poolWidth: 2,
77  poolHeight: 2,
78  strideX: 2,
79  strideY: 2
80  }
81  }},
82  {
83  layer_type: "OutputLayer",
84  layer: {
85  base:{
86  layerBindingId: 0,
87  base: {
88  index: 2,
89  layerName: "OutputLayer",
90  layerType: "Output",
91  inputSlots: [{
92  index: 0,
93  connection: {sourceLayerIndex:1, outputSlotIndex:0 },
94  }],
95  outputSlots: [ {
96  index: 0,
97  tensorInfo: {
98  dimensions: )" + outputShape + R"(,
99  dataType: )" + dataType + R"(
100  },
101  }],
102  }}},
103  }]
104  }
105  )";
106  SetupSingleInputSingleOutput("InputLayer", "OutputLayer");
107  }
108 };
109 
110 struct SimpleAvgPooling2dFixture : Pooling2dFixture
111 {
112  SimpleAvgPooling2dFixture() : Pooling2dFixture("[ 1, 2, 2, 1 ]", "[ 1, 1, 1, 1 ]",
113  "Float32", "NHWC", "Average") {}
114 };
115 
116 struct SimpleAvgPooling2dFixture2 : Pooling2dFixture
117 {
118  SimpleAvgPooling2dFixture2() : Pooling2dFixture("[ 1, 2, 2, 1 ]",
119  "[ 1, 1, 1, 1 ]",
120  "QuantisedAsymm8", "NHWC", "Average") {}
121 };
122 
123 struct SimpleMaxPooling2dFixture : Pooling2dFixture
124 {
125  SimpleMaxPooling2dFixture() : Pooling2dFixture("[ 1, 1, 2, 2 ]",
126  "[ 1, 1, 1, 1 ]",
127  "Float32", "NCHW", "Max") {}
128 };
129 
130 struct SimpleMaxPooling2dFixture2 : Pooling2dFixture
131 {
132  SimpleMaxPooling2dFixture2() : Pooling2dFixture("[ 1, 1, 2, 2 ]",
133  "[ 1, 1, 1, 1 ]",
134  "QuantisedAsymm8", "NCHW", "Max") {}
135 };
136 
137 BOOST_FIXTURE_TEST_CASE(Pooling2dFloat32Avg, SimpleAvgPooling2dFixture)
138 {
139  RunTest<4, armnn::DataType::Float32>(0, { 2, 3, 5, 2 }, { 3 });
140 }
141 
142 BOOST_FIXTURE_TEST_CASE(Pooling2dQuantisedAsymm8Avg, SimpleAvgPooling2dFixture2)
143 {
144  RunTest<4, armnn::DataType::QAsymmU8>(0,
145  { 20, 40, 60, 80 },
146  { 50 });
147 }
148 
149 BOOST_FIXTURE_TEST_CASE(Pooling2dFloat32Max, SimpleMaxPooling2dFixture)
150 {
151  RunTest<4, armnn::DataType::Float32>(0, { 2, 5, 5, 2 }, { 5 });
152 }
153 
154 BOOST_FIXTURE_TEST_CASE(Pooling2dQuantisedAsymm8Max, SimpleMaxPooling2dFixture2)
155 {
156  RunTest<4, armnn::DataType::QAsymmU8>(0,
157  { 20, 40, 60, 80 },
158  { 80 });
159 }
160 
162 
BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
void SetupSingleInputSingleOutput(const std::string &inputName, const std::string &outputName)
BOOST_FIXTURE_TEST_CASE(Pooling2dFloat32Avg, SimpleAvgPooling2dFixture)
BOOST_AUTO_TEST_SUITE_END()