6 #include <boost/test/unit_test.hpp> 16 explicit FusedBatchNormFixture(
const std::string& dataLayout)
19 " name: \"graphInput\" \n" 20 " op: \"Placeholder\" \n" 36 " name: \"Const_1\" \n" 60 " name: \"Const_2\" \n" 84 " name: \"FusedBatchNormLayer/mean\" \n" 108 " name: \"FusedBatchNormLayer/variance\" \n" 120 " dtype: DT_FLOAT \n" 132 " name: \"output\" \n" 133 " op: \"FusedBatchNorm\" \n" 134 " input: \"graphInput\" \n" 135 " input: \"Const_1\" \n" 136 " input: \"Const_2\" \n" 137 " input: \"FusedBatchNormLayer/mean\" \n" 138 " input: \"FusedBatchNormLayer/variance\" \n" 147 if (dataLayout !=
"NHWC")
150 " key: \"data_format\" \n" 160 " key: \"epsilon\" \n" 162 " f: 0.0010000000475 \n" 166 " key: \"is_training\" \n" 174 std::array<unsigned int, 4> dims;
175 if (dataLayout ==
"NHWC")
177 dims = { 1u, 3u, 3u, 1u };
181 dims = { 1u, 1u, 3u, 3u };
188 struct FusedBatchNormNhwcFixture : FusedBatchNormFixture
190 FusedBatchNormNhwcFixture() : FusedBatchNormFixture(
"NHWC"){}
194 RunTest<4>({ 1, 2, 3, 4, 5, 6, 7, 8, 9 },
195 { -2.8277204f, -2.12079024f, -1.4138602f,
196 -0.7069301f, 0.0f, 0.7069301f,
197 1.4138602f, 2.12079024f, 2.8277204f });
200 struct FusedBatchNormNchwFixture : FusedBatchNormFixture
202 FusedBatchNormNchwFixture() : FusedBatchNormFixture(
"NCHW"){}
206 RunTest<4>({ 1, 2, 3, 4, 5, 6, 7, 8, 9 },
207 { -2.8277204f, -2.12079024f, -1.4138602f,
208 -0.7069301f, 0.0f, 0.7069301f,
209 1.4138602f, 2.12079024f, 2.8277204f });
BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
BOOST_FIXTURE_TEST_CASE(ParseFusedBatchNormNhwc, FusedBatchNormNhwcFixture)
BOOST_AUTO_TEST_SUITE_END()
void SetupSingleInputSingleOutput(const std::string &inputName, const std::string &outputName)
Parses and loads the network defined by the m_Prototext string.