aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfParser/test
diff options
context:
space:
mode:
authorjimfly01 <jim.flynn@arm.com>2018-12-19 13:14:46 +0000
committerjimfly01 <jim.flynn@arm.com>2018-12-20 14:19:25 +0000
commit84c70e65a193aa5faa959d305af82783fa8f78b5 (patch)
tree4e22d949b863c21c6bb7aa1bb92e8465a337dd79 /src/armnnTfParser/test
parent4fa0916386e720e254bee6b9fd1576e90ba6a42f (diff)
downloadarmnn-84c70e65a193aa5faa959d305af82783fa8f78b5.tar.gz
IVGCVSW-2367 Add Equal Operator to TfParser
* Unit tests in Equal.cpp * Fixed error in Network::AddEqualLayer * Refactored TfParser::Minimum/Equal to get rid of duplicate code Change-Id: I0ed6f888eb391c995b88be20dc0c1b916dd14c3c
Diffstat (limited to 'src/armnnTfParser/test')
-rw-r--r--src/armnnTfParser/test/Equal.cpp139
1 files changed, 139 insertions, 0 deletions
diff --git a/src/armnnTfParser/test/Equal.cpp b/src/armnnTfParser/test/Equal.cpp
new file mode 100644
index 0000000000..43a1c6abb5
--- /dev/null
+++ b/src/armnnTfParser/test/Equal.cpp
@@ -0,0 +1,139 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include <boost/test/unit_test.hpp>
+#include "armnnTfParser/ITfParser.hpp"
+#include "ParserPrototxtFixture.hpp"
+
+BOOST_AUTO_TEST_SUITE(TensorflowParser)
+
+ struct EqualFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
+ {
+ EqualFixture()
+ {
+ m_Prototext = R"(
+node {
+ name: "input0"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ }
+ }
+ }
+}
+node {
+ name: "input1"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ }
+ }
+ }
+}
+node {
+ name: "output"
+ op: "Equal"
+ input: "input0"
+ input: "input1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+ )";
+ }
+ };
+
+BOOST_FIXTURE_TEST_CASE(ParseEqualUnsupportedBroadcast, EqualFixture)
+{
+ BOOST_REQUIRE_THROW(Setup({ { "input0", {2, 3} },
+ { "input1", {1, 2, 2, 3} } },
+ { "output" }),
+ armnn::ParseException);
+}
+
+struct EqualFixtureAutoSetup : public EqualFixture
+{
+ EqualFixtureAutoSetup(const armnn::TensorShape& input0Shape,
+ const armnn::TensorShape& input1Shape)
+ : EqualFixture()
+ {
+ Setup({ { "input0", input0Shape },
+ { "input1", input1Shape } },
+ { "output" });
+ }
+};
+
+struct EqualTwoByTwo : public EqualFixtureAutoSetup
+{
+ EqualTwoByTwo() : EqualFixtureAutoSetup({2,2}, {2,2}) {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseEqualTwoByTwo, EqualTwoByTwo)
+{
+ RunTest<2>({ { "input0", { 1.0f, 2.0f, 3.0f, 2.0f } },
+ { "input1", { 1.0f, 5.0f, 2.0f, 2.0f } } },
+ { { "output", { 1.0f, 0.0f, 0.0f, 1.0f } } });
+}
+
+struct EqualBroadcast1DAnd4D : public EqualFixtureAutoSetup
+{
+ EqualBroadcast1DAnd4D() : EqualFixtureAutoSetup({1}, {1,1,2,2}) {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseEqualBroadcast1DToTwoByTwo, EqualBroadcast1DAnd4D)
+{
+ RunTest<4>({ { "input0", { 2.0f } },
+ { "input1", { 1.0f, 2.0f, 3.0f, 2.0f } } },
+ { { "output", { 0.0f, 1.0f, 0.0f, 1.0f } } });
+}
+
+struct EqualBroadcast4DAnd1D : public EqualFixtureAutoSetup
+{
+ EqualBroadcast4DAnd1D() : EqualFixtureAutoSetup({1,1,2,2}, {1}) {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseEqualBroadcast4DAnd1D, EqualBroadcast4DAnd1D)
+{
+ RunTest<4>({ { "input0", { 1.0f, 2.0f, 3.0f, 2.0f } },
+ { "input1", { 3.0f } } },
+ { { "output", { 0.0f, 0.0f, 1.0f, 0.0f } } });
+}
+
+struct EqualMultiDimBroadcast : public EqualFixtureAutoSetup
+{
+ EqualMultiDimBroadcast() : EqualFixtureAutoSetup({1,1,2,1}, {1,2,1,3}) {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseEqualMultiDimBroadcast, EqualMultiDimBroadcast)
+{
+ RunTest<4>({ { "input0", { 1.0f, 2.0f } },
+ { "input1", { 1.0f, 2.0f, 3.0f,
+ 3.0f, 2.0f, 2.0f } } },
+ { { "output", { 1.0f, 0.0f, 0.0f,
+ 0.0f, 1.0f, 0.0f,
+ 0.0f, 0.0f, 0.0f,
+ 0.0f, 1.0f, 1.0f } } });
+}
+
+BOOST_AUTO_TEST_SUITE_END()