From bceff2fb3fc68bb0aa88b886900c34b77340c826 Mon Sep 17 00:00:00 2001 From: surmeh01 Date: Thu, 29 Mar 2018 16:29:27 +0100 Subject: Release 18.03 --- src/armnnTfParser/test/BiasAdd.cpp | 104 +++++++++++++++++++++++++++++++++++++ 1 file changed, 104 insertions(+) create mode 100644 src/armnnTfParser/test/BiasAdd.cpp (limited to 'src/armnnTfParser/test/BiasAdd.cpp') diff --git a/src/armnnTfParser/test/BiasAdd.cpp b/src/armnnTfParser/test/BiasAdd.cpp new file mode 100644 index 0000000000..e29aeb1057 --- /dev/null +++ b/src/armnnTfParser/test/BiasAdd.cpp @@ -0,0 +1,104 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// + +#include +#include "armnnTfParser/ITfParser.hpp" +#include "ParserPrototxtFixture.hpp" + +BOOST_AUTO_TEST_SUITE(TensorflowParser) + +struct BiasAddFixture : public ParserPrototxtFixture +{ + explicit BiasAddFixture(const std::string& dataFormat) + { + m_Prototext = R"( +node { + name: "graphInput" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } +} +node { + name: "bias" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 3 + } + } + float_val: 1 + float_val: 2 + float_val: 3 + } + } + } +} +node { + name: "biasAdd" + op : "BiasAdd" + input: "graphInput" + input: "bias" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: ")" + dataFormat + R"(" + } + } +} +)"; + + SetupSingleInputSingleOutput({ 1, 3, 1, 3 }, "graphInput", "biasAdd"); + } +}; + +struct BiasAddFixtureNCHW : BiasAddFixture +{ + BiasAddFixtureNCHW() : BiasAddFixture("NCHW") {} +}; + +struct BiasAddFixtureNHWC : BiasAddFixture +{ + BiasAddFixtureNHWC() : BiasAddFixture("NHWC") {} +}; + +BOOST_FIXTURE_TEST_CASE(ParseBiasAddNCHW, BiasAddFixtureNCHW) +{ + RunTest<4>(std::vector(9), { 1, 1, 1, 2, 2, 2, 3, 3, 3 }); +} + +BOOST_FIXTURE_TEST_CASE(ParseBiasAddNHWC, BiasAddFixtureNHWC) +{ + RunTest<4>(std::vector(9), { 1, 2, 3, 1, 2, 3, 1, 2, 3 }); +} + +BOOST_AUTO_TEST_SUITE_END() -- cgit v1.2.1