ArmNN
 20.08
FusedBatchNorm.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <boost/test/unit_test.hpp>
9 
10 #include <array>
11 
12 BOOST_AUTO_TEST_SUITE(TensorflowParser)
13 
14 struct FusedBatchNormFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
15 {
16  explicit FusedBatchNormFixture(const std::string& dataLayout)
17  {
18  m_Prototext = "node { \n"
19  " name: \"graphInput\" \n"
20  " op: \"Placeholder\" \n"
21  " attr { \n"
22  " key: \"dtype\" \n"
23  " value { \n"
24  " type: DT_FLOAT \n"
25  " } \n"
26  " } \n"
27  " attr { \n"
28  " key: \"shape\" \n"
29  " value { \n"
30  " shape { \n"
31  " } \n"
32  " } \n"
33  " } \n"
34  "} \n"
35  "node { \n"
36  " name: \"Const_1\" \n"
37  " op: \"Const\" \n"
38  " attr { \n"
39  " key: \"dtype\" \n"
40  " value { \n"
41  " type: DT_FLOAT \n"
42  " } \n"
43  " } \n"
44  " attr { \n"
45  " key: \"value\" \n"
46  " value { \n"
47  " tensor { \n"
48  " dtype: DT_FLOAT \n"
49  " tensor_shape { \n"
50  " dim { \n"
51  " size: 1 \n"
52  " } \n"
53  " } \n"
54  " float_val: 1.0 \n"
55  " } \n"
56  " } \n"
57  " } \n"
58  "} \n"
59  "node { \n"
60  " name: \"Const_2\" \n"
61  " op: \"Const\" \n"
62  " attr { \n"
63  " key: \"dtype\" \n"
64  " value { \n"
65  " type: DT_FLOAT \n"
66  " } \n"
67  " } \n"
68  " attr { \n"
69  " key: \"value\" \n"
70  " value { \n"
71  " tensor { \n"
72  " dtype: DT_FLOAT \n"
73  " tensor_shape { \n"
74  " dim { \n"
75  " size: 1 \n"
76  " } \n"
77  " } \n"
78  " float_val: 0.0 \n"
79  " } \n"
80  " } \n"
81  " } \n"
82  "} \n"
83  "node { \n"
84  " name: \"FusedBatchNormLayer/mean\" \n"
85  " op: \"Const\" \n"
86  " attr { \n"
87  " key: \"dtype\" \n"
88  " value { \n"
89  " type: DT_FLOAT \n"
90  " } \n"
91  " } \n"
92  " attr { \n"
93  " key: \"value\" \n"
94  " value { \n"
95  " tensor { \n"
96  " dtype: DT_FLOAT \n"
97  " tensor_shape { \n"
98  " dim { \n"
99  " size: 1 \n"
100  " } \n"
101  " } \n"
102  " float_val: 5.0 \n"
103  " } \n"
104  " } \n"
105  " } \n"
106  "} \n"
107  "node { \n"
108  " name: \"FusedBatchNormLayer/variance\" \n"
109  " op: \"Const\" \n"
110  " attr { \n"
111  " key: \"dtype\" \n"
112  " value { \n"
113  " type: DT_FLOAT \n"
114  " } \n"
115  " } \n"
116  " attr { \n"
117  " key: \"value\" \n"
118  " value { \n"
119  " tensor { \n"
120  " dtype: DT_FLOAT \n"
121  " tensor_shape { \n"
122  " dim { \n"
123  " size: 1 \n"
124  " } \n"
125  " } \n"
126  " float_val: 2.0 \n"
127  " } \n"
128  " } \n"
129  " } \n"
130  "} \n"
131  "node { \n"
132  " name: \"output\" \n"
133  " op: \"FusedBatchNorm\" \n"
134  " input: \"graphInput\" \n"
135  " input: \"Const_1\" \n"
136  " input: \"Const_2\" \n"
137  " input: \"FusedBatchNormLayer/mean\" \n"
138  " input: \"FusedBatchNormLayer/variance\" \n"
139  " attr { \n"
140  " key: \"T\" \n"
141  " value { \n"
142  " type: DT_FLOAT \n"
143  " } \n"
144  " } \n";
145 
146  // NOTE: we only explicitly set data_format when it is not the default NHWC
147  if (dataLayout != "NHWC")
148  {
149  m_Prototext.append(" attr { \n"
150  " key: \"data_format\" \n"
151  " value { \n"
152  " s: \"");
153  m_Prototext.append(dataLayout);
154  m_Prototext.append("\" \n"
155  " } \n"
156  " } \n");
157  }
158 
159  m_Prototext.append(" attr { \n"
160  " key: \"epsilon\" \n"
161  " value { \n"
162  " f: 0.0010000000475 \n"
163  " } \n"
164  " } \n"
165  " attr { \n"
166  " key: \"is_training\" \n"
167  " value { \n"
168  " b: false \n"
169  " } \n"
170  " } \n"
171  "} \n");
172 
173  // Set the input shape according to the data layout
174  std::array<unsigned int, 4> dims;
175  if (dataLayout == "NHWC")
176  {
177  dims = { 1u, 3u, 3u, 1u };
178  }
179  else // dataLayout == "NCHW"
180  {
181  dims = { 1u, 1u, 3u, 3u };
182  }
183 
184  SetupSingleInputSingleOutput(armnn::TensorShape(4, dims.data()), "graphInput", "output");
185  }
186 };
187 
188 struct FusedBatchNormNhwcFixture : FusedBatchNormFixture
189 {
190  FusedBatchNormNhwcFixture() : FusedBatchNormFixture("NHWC"){}
191 };
192 BOOST_FIXTURE_TEST_CASE(ParseFusedBatchNormNhwc, FusedBatchNormNhwcFixture)
193 {
194  RunTest<4>({ 1, 2, 3, 4, 5, 6, 7, 8, 9 }, // Input data.
195  { -2.8277204f, -2.12079024f, -1.4138602f,
196  -0.7069301f, 0.0f, 0.7069301f,
197  1.4138602f, 2.12079024f, 2.8277204f }); // Expected output data.
198 }
199 
200 struct FusedBatchNormNchwFixture : FusedBatchNormFixture
201 {
202  FusedBatchNormNchwFixture() : FusedBatchNormFixture("NCHW"){}
203 };
204 BOOST_FIXTURE_TEST_CASE(ParseFusedBatchNormNchw, FusedBatchNormNchwFixture)
205 {
206  RunTest<4>({ 1, 2, 3, 4, 5, 6, 7, 8, 9 }, // Input data.
207  { -2.8277204f, -2.12079024f, -1.4138602f,
208  -0.7069301f, 0.0f, 0.7069301f,
209  1.4138602f, 2.12079024f, 2.8277204f }); // Expected output data.
210 }
211 
BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
BOOST_FIXTURE_TEST_CASE(ParseFusedBatchNormNhwc, FusedBatchNormNhwcFixture)
BOOST_AUTO_TEST_SUITE_END()
void SetupSingleInputSingleOutput(const std::string &inputName, const std::string &outputName)
Parses and loads the network defined by the m_Prototext string.