aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfParser/test/TestDependencies.cpp
diff options
context:
space:
mode:
authorsurmeh01 <surabhi.mehta@arm.com>2018-03-29 16:29:27 +0100
committersurmeh01 <surabhi.mehta@arm.com>2018-03-29 16:29:27 +0100
commitbceff2fb3fc68bb0aa88b886900c34b77340c826 (patch)
treed867d3e090d58d3012dfbbac456e9ea8c7f789bc /src/armnnTfParser/test/TestDependencies.cpp
parent4fcda0101ec3d110c1d6d7bee5c83416b645528a (diff)
downloadarmnn-bceff2fb3fc68bb0aa88b886900c34b77340c826.tar.gz
Release 18.03
Diffstat (limited to 'src/armnnTfParser/test/TestDependencies.cpp')
-rw-r--r--src/armnnTfParser/test/TestDependencies.cpp296
1 files changed, 296 insertions, 0 deletions
diff --git a/src/armnnTfParser/test/TestDependencies.cpp b/src/armnnTfParser/test/TestDependencies.cpp
new file mode 100644
index 0000000000..13ab17c5b6
--- /dev/null
+++ b/src/armnnTfParser/test/TestDependencies.cpp
@@ -0,0 +1,296 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// See LICENSE file in the project root for full license information.
+//
+
+#include <boost/test/unit_test.hpp>
+#include "armnnTfParser/ITfParser.hpp"
+#include "ParserPrototxtFixture.hpp"
+
+BOOST_AUTO_TEST_SUITE(TensorflowParser)
+
+// Graph which tests that nodes are re-ordered in the queue when they are encountered a second time.
+// In this case R0 will be encountered first via R1 and then via R2. At that time
+// we need to make sure that R0 (and the I on which it is dependent) is moved to the front again
+// so that it is before both R1 and R2.
+// I
+// |
+// R0
+// / \'
+// R1 R2
+// \ |
+// \ R3
+// \|
+// O
+struct RediscoveredDependenciesFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ RediscoveredDependenciesFixture()
+ {
+ // input = tf.placeholder(tf.float32, 1, "input")
+ // relu0 = tf.nn.relu(input, "relu0")
+ // relu1 = tf.nn.relu(relu0, "relu1")
+ // relu2 = tf.nn.relu(relu0, "relu2")
+ // relu3 = tf.nn.relu(relu2, "relu3")
+ // output = tf.add(relu1, relu3, "output")
+ m_Prototext = R"(
+ node {
+ name: "input"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "relu0"
+ op: "Relu"
+ input: "input"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ node {
+ name: "relu1"
+ op: "Relu"
+ input: "relu0"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ node {
+ name: "relu2"
+ op: "Relu"
+ input: "relu0"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ node {
+ name: "relu3"
+ op: "Relu"
+ input: "relu2"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ node {
+ name: "output"
+ op: "Add"
+ input: "relu1"
+ input: "relu3"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ )";
+ SetupSingleInputSingleOutput({ 1 }, "input", "output");
+ }
+};
+
+BOOST_FIXTURE_TEST_CASE(RediscoveredDependencies, RediscoveredDependenciesFixture)
+{
+ RunTest<1>({1}, {2});
+}
+
+// Tests that a simple cycle in the tensorflow graph will be detected and an exception thrown, rather than the TfParser
+// getting stuck in an infinite loop.
+BOOST_AUTO_TEST_CASE(SimpleCycle)
+{
+ const char* prototext = R"(
+node {
+ name: "r1"
+ op: "Relu"
+ input: "r2"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "r2"
+ op: "Relu"
+ input: "r1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+ )";
+ armnnTfParser::ITfParserPtr parser = armnnTfParser::ITfParser::Create();
+ BOOST_CHECK_THROW(parser->CreateNetworkFromString(prototext, {}, { "r2" }), armnn::ParseException);
+}
+
+// Similar to the above SimpleCycle test, but has a single node which connects to itself.
+BOOST_AUTO_TEST_CASE(SingleNodeCycle)
+{
+ const char* prototext = R"(
+node {
+ name: "r1"
+ op: "Relu"
+ input: "r1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+ )";
+ armnnTfParser::ITfParserPtr parser = armnnTfParser::ITfParser::Create();
+ BOOST_CHECK_THROW(parser->CreateNetworkFromString(prototext, {}, { "r1" }), armnn::ParseException);
+}
+
+// Similar to the above SimpleCycle test, but with a more complicated graph.
+// I
+// |
+// A2---<---<-
+// / \' |
+// R1 R2 |
+// \ | |
+// \ R3 |
+// \| |
+// A1-->--->|
+//
+BOOST_AUTO_TEST_CASE(ComplexCycle)
+{
+ // input = tf.placeholder(tf.float32, 1, "input")
+ // add2 = tf.nn.relu(input, add1, "add2") // This line won't actually run in TF, because add1 is not yet defined
+ // relu1 = tf.nn.relu(relu0, "relu1")
+ // relu2 = tf.nn.relu(relu0, "relu2")
+ // relu3 = tf.nn.relu(relu2, "relu3")
+ // add1 = tf.add(relu1, relu3, "add1")
+ const char* prototext = R"(
+ node {
+ name: "input"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "add2"
+ op: "Add"
+ input: "input"
+ input: "add1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ node {
+ name: "relu1"
+ op: "Relu"
+ input: "add2"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ node {
+ name: "relu2"
+ op: "Relu"
+ input: "add2"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ node {
+ name: "relu3"
+ op: "Relu"
+ input: "relu2"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ node {
+ name: "add1"
+ op: "Add"
+ input: "relu1"
+ input: "relu3"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ )";
+ armnnTfParser::ITfParserPtr parser = armnnTfParser::ITfParser::Create();
+ BOOST_CHECK_THROW(parser->CreateNetworkFromString(prototext, {}, { "add1" }), armnn::ParseException);
+}
+
+// Tests that a graph with an input that is not present throws a ParseException.
+BOOST_AUTO_TEST_CASE(InvalidInput)
+{
+ const char* prototext = R"(
+node {
+ name: "r1"
+ op: "Relu"
+ input: "a-node-that-does-not-exist"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+ )";
+ armnnTfParser::ITfParserPtr parser = armnnTfParser::ITfParser::Create();
+ BOOST_CHECK_THROW(parser->CreateNetworkFromString(prototext, {}, { "r1" }), armnn::ParseException);
+}
+
+BOOST_AUTO_TEST_SUITE_END()