diff options
Diffstat (limited to 'src/armnnTfParser/test/Equal.cpp')
-rw-r--r-- | src/armnnTfParser/test/Equal.cpp | 139 |
1 files changed, 139 insertions, 0 deletions
diff --git a/src/armnnTfParser/test/Equal.cpp b/src/armnnTfParser/test/Equal.cpp new file mode 100644 index 0000000000..43a1c6abb5 --- /dev/null +++ b/src/armnnTfParser/test/Equal.cpp @@ -0,0 +1,139 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include <boost/test/unit_test.hpp> +#include "armnnTfParser/ITfParser.hpp" +#include "ParserPrototxtFixture.hpp" + +BOOST_AUTO_TEST_SUITE(TensorflowParser) + + struct EqualFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser> + { + EqualFixture() + { + m_Prototext = R"( +node { + name: "input0" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } +} +node { + name: "input1" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } +} +node { + name: "output" + op: "Equal" + input: "input0" + input: "input1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} + )"; + } + }; + +BOOST_FIXTURE_TEST_CASE(ParseEqualUnsupportedBroadcast, EqualFixture) +{ + BOOST_REQUIRE_THROW(Setup({ { "input0", {2, 3} }, + { "input1", {1, 2, 2, 3} } }, + { "output" }), + armnn::ParseException); +} + +struct EqualFixtureAutoSetup : public EqualFixture +{ + EqualFixtureAutoSetup(const armnn::TensorShape& input0Shape, + const armnn::TensorShape& input1Shape) + : EqualFixture() + { + Setup({ { "input0", input0Shape }, + { "input1", input1Shape } }, + { "output" }); + } +}; + +struct EqualTwoByTwo : public EqualFixtureAutoSetup +{ + EqualTwoByTwo() : EqualFixtureAutoSetup({2,2}, {2,2}) {} +}; + +BOOST_FIXTURE_TEST_CASE(ParseEqualTwoByTwo, EqualTwoByTwo) +{ + RunTest<2>({ { "input0", { 1.0f, 2.0f, 3.0f, 2.0f } }, + { "input1", { 1.0f, 5.0f, 2.0f, 2.0f } } }, + { { "output", { 1.0f, 0.0f, 0.0f, 1.0f } } }); +} + +struct EqualBroadcast1DAnd4D : public EqualFixtureAutoSetup +{ + EqualBroadcast1DAnd4D() : EqualFixtureAutoSetup({1}, {1,1,2,2}) {} +}; + +BOOST_FIXTURE_TEST_CASE(ParseEqualBroadcast1DToTwoByTwo, EqualBroadcast1DAnd4D) +{ + RunTest<4>({ { "input0", { 2.0f } }, + { "input1", { 1.0f, 2.0f, 3.0f, 2.0f } } }, + { { "output", { 0.0f, 1.0f, 0.0f, 1.0f } } }); +} + +struct EqualBroadcast4DAnd1D : public EqualFixtureAutoSetup +{ + EqualBroadcast4DAnd1D() : EqualFixtureAutoSetup({1,1,2,2}, {1}) {} +}; + +BOOST_FIXTURE_TEST_CASE(ParseEqualBroadcast4DAnd1D, EqualBroadcast4DAnd1D) +{ + RunTest<4>({ { "input0", { 1.0f, 2.0f, 3.0f, 2.0f } }, + { "input1", { 3.0f } } }, + { { "output", { 0.0f, 0.0f, 1.0f, 0.0f } } }); +} + +struct EqualMultiDimBroadcast : public EqualFixtureAutoSetup +{ + EqualMultiDimBroadcast() : EqualFixtureAutoSetup({1,1,2,1}, {1,2,1,3}) {} +}; + +BOOST_FIXTURE_TEST_CASE(ParseEqualMultiDimBroadcast, EqualMultiDimBroadcast) +{ + RunTest<4>({ { "input0", { 1.0f, 2.0f } }, + { "input1", { 1.0f, 2.0f, 3.0f, + 3.0f, 2.0f, 2.0f } } }, + { { "output", { 1.0f, 0.0f, 0.0f, + 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 0.0f, + 0.0f, 1.0f, 1.0f } } }); +} + +BOOST_AUTO_TEST_SUITE_END() |