diff options
Diffstat (limited to 'src/armnnTfParser/test/TestDependencies.cpp')
-rw-r--r-- | src/armnnTfParser/test/TestDependencies.cpp | 296 |
1 files changed, 296 insertions, 0 deletions
diff --git a/src/armnnTfParser/test/TestDependencies.cpp b/src/armnnTfParser/test/TestDependencies.cpp new file mode 100644 index 0000000000..13ab17c5b6 --- /dev/null +++ b/src/armnnTfParser/test/TestDependencies.cpp @@ -0,0 +1,296 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// + +#include <boost/test/unit_test.hpp> +#include "armnnTfParser/ITfParser.hpp" +#include "ParserPrototxtFixture.hpp" + +BOOST_AUTO_TEST_SUITE(TensorflowParser) + +// Graph which tests that nodes are re-ordered in the queue when they are encountered a second time. +// In this case R0 will be encountered first via R1 and then via R2. At that time +// we need to make sure that R0 (and the I on which it is dependent) is moved to the front again +// so that it is before both R1 and R2. +// I +// | +// R0 +// / \' +// R1 R2 +// \ | +// \ R3 +// \| +// O +struct RediscoveredDependenciesFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser> +{ + RediscoveredDependenciesFixture() + { + // input = tf.placeholder(tf.float32, 1, "input") + // relu0 = tf.nn.relu(input, "relu0") + // relu1 = tf.nn.relu(relu0, "relu1") + // relu2 = tf.nn.relu(relu0, "relu2") + // relu3 = tf.nn.relu(relu2, "relu3") + // output = tf.add(relu1, relu3, "output") + m_Prototext = R"( + node { + name: "input" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 1 + } + } + } + } + } + node { + name: "relu0" + op: "Relu" + input: "input" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node { + name: "relu1" + op: "Relu" + input: "relu0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node { + name: "relu2" + op: "Relu" + input: "relu0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node { + name: "relu3" + op: "Relu" + input: "relu2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node { + name: "output" + op: "Add" + input: "relu1" + input: "relu3" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + )"; + SetupSingleInputSingleOutput({ 1 }, "input", "output"); + } +}; + +BOOST_FIXTURE_TEST_CASE(RediscoveredDependencies, RediscoveredDependenciesFixture) +{ + RunTest<1>({1}, {2}); +} + +// Tests that a simple cycle in the tensorflow graph will be detected and an exception thrown, rather than the TfParser +// getting stuck in an infinite loop. +BOOST_AUTO_TEST_CASE(SimpleCycle) +{ + const char* prototext = R"( +node { + name: "r1" + op: "Relu" + input: "r2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "r2" + op: "Relu" + input: "r1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} + )"; + armnnTfParser::ITfParserPtr parser = armnnTfParser::ITfParser::Create(); + BOOST_CHECK_THROW(parser->CreateNetworkFromString(prototext, {}, { "r2" }), armnn::ParseException); +} + +// Similar to the above SimpleCycle test, but has a single node which connects to itself. +BOOST_AUTO_TEST_CASE(SingleNodeCycle) +{ + const char* prototext = R"( +node { + name: "r1" + op: "Relu" + input: "r1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} + )"; + armnnTfParser::ITfParserPtr parser = armnnTfParser::ITfParser::Create(); + BOOST_CHECK_THROW(parser->CreateNetworkFromString(prototext, {}, { "r1" }), armnn::ParseException); +} + +// Similar to the above SimpleCycle test, but with a more complicated graph. +// I +// | +// A2---<---<- +// / \' | +// R1 R2 | +// \ | | +// \ R3 | +// \| | +// A1-->--->| +// +BOOST_AUTO_TEST_CASE(ComplexCycle) +{ + // input = tf.placeholder(tf.float32, 1, "input") + // add2 = tf.nn.relu(input, add1, "add2") // This line won't actually run in TF, because add1 is not yet defined + // relu1 = tf.nn.relu(relu0, "relu1") + // relu2 = tf.nn.relu(relu0, "relu2") + // relu3 = tf.nn.relu(relu2, "relu3") + // add1 = tf.add(relu1, relu3, "add1") + const char* prototext = R"( + node { + name: "input" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 1 + } + } + } + } + } + node { + name: "add2" + op: "Add" + input: "input" + input: "add1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node { + name: "relu1" + op: "Relu" + input: "add2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node { + name: "relu2" + op: "Relu" + input: "add2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node { + name: "relu3" + op: "Relu" + input: "relu2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node { + name: "add1" + op: "Add" + input: "relu1" + input: "relu3" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + )"; + armnnTfParser::ITfParserPtr parser = armnnTfParser::ITfParser::Create(); + BOOST_CHECK_THROW(parser->CreateNetworkFromString(prototext, {}, { "add1" }), armnn::ParseException); +} + +// Tests that a graph with an input that is not present throws a ParseException. +BOOST_AUTO_TEST_CASE(InvalidInput) +{ + const char* prototext = R"( +node { + name: "r1" + op: "Relu" + input: "a-node-that-does-not-exist" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} + )"; + armnnTfParser::ITfParserPtr parser = armnnTfParser::ITfParser::Create(); + BOOST_CHECK_THROW(parser->CreateNetworkFromString(prototext, {}, { "r1" }), armnn::ParseException); +} + +BOOST_AUTO_TEST_SUITE_END() |