// // Copyright © 2017 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // #include #include "armnnTfParser/ITfParser.hpp" #include "ParserPrototxtFixture.hpp" BOOST_AUTO_TEST_SUITE(TensorflowParser) // Graph which tests that nodes are re-ordered in the queue when they are encountered a second time. // In this case R0 will be encountered first via R1 and then via R2. At that time // we need to make sure that R0 (and the I on which it is dependent) is moved to the front again // so that it is before both R1 and R2. // I // | // R0 // / \' // R1 R2 // \ | // \ R3 // \| // O struct RediscoveredDependenciesFixture : public armnnUtils::ParserPrototxtFixture { RediscoveredDependenciesFixture() { // Input = tf.placeholder(tf.float32, 1, "input") // Relu0 = tf.nn.relu(input, "relu0") // Relu1 = tf.nn.relu(relu0, "relu1") // Relu2 = tf.nn.relu(relu0, "relu2") // Relu3 = tf.nn.relu(relu2, "relu3") // Output = tf.add(relu1, relu3, "output") m_Prototext = R"( node { name: "input" op: "Placeholder" attr { key: "dtype" value { type: DT_FLOAT } } attr { key: "shape" value { shape { dim { size: 1 } } } } } node { name: "relu0" op: "Relu" input: "input" attr { key: "T" value { type: DT_FLOAT } } } node { name: "relu1" op: "Relu" input: "relu0" attr { key: "T" value { type: DT_FLOAT } } } node { name: "relu2" op: "Relu" input: "relu0" attr { key: "T" value { type: DT_FLOAT } } } node { name: "relu3" op: "Relu" input: "relu2" attr { key: "T" value { type: DT_FLOAT } } } node { name: "output" op: "Add" input: "relu1" input: "relu3" attr { key: "T" value { type: DT_FLOAT } } } )"; SetupSingleInputSingleOutput({ 1 }, "input", "output"); } }; BOOST_FIXTURE_TEST_CASE(RediscoveredDependencies, RediscoveredDependenciesFixture) { RunTest<1>({1}, {2}); } // Tests that a simple cycle in the tensorflow graph will be detected and an exception thrown, rather than the TfParser // getting stuck in an infinite loop. BOOST_AUTO_TEST_CASE(SimpleCycle) { const char* prototext = R"( node { name: "r1" op: "Relu" input: "r2" attr { key: "T" value { type: DT_FLOAT } } } node { name: "r2" op: "Relu" input: "r1" attr { key: "T" value { type: DT_FLOAT } } } )"; armnnTfParser::ITfParserPtr parser = armnnTfParser::ITfParser::Create(); BOOST_CHECK_THROW(parser->CreateNetworkFromString(prototext, {}, { "r2" }), armnn::ParseException); } // Similar to the above SimpleCycle test, but has a single node which connects to itself. BOOST_AUTO_TEST_CASE(SingleNodeCycle) { const char* prototext = R"( node { name: "r1" op: "Relu" input: "r1" attr { key: "T" value { type: DT_FLOAT } } } )"; armnnTfParser::ITfParserPtr parser = armnnTfParser::ITfParser::Create(); BOOST_CHECK_THROW(parser->CreateNetworkFromString(prototext, {}, { "r1" }), armnn::ParseException); } // Similar to the above SimpleCycle test, but with a more complicated graph. // I // | // A2---<---<- // / \' | // R1 R2 | // \ | | // \ R3 | // \| | // A1-->--->| // BOOST_AUTO_TEST_CASE(ComplexCycle) { // Input = tf.placeholder(tf.float32, 1, "input") // Add2 = tf.nn.relu(input, add1, "add2") // This line won't actually run in TF, because add1 is not yet defined // Relu1 = tf.nn.relu(relu0, "relu1") // Relu2 = tf.nn.relu(relu0, "relu2") // Relu3 = tf.nn.relu(relu2, "relu3") // Add1 = tf.add(relu1, relu3, "add1") const char* prototext = R"( node { name: "input" op: "Placeholder" attr { key: "dtype" value { type: DT_FLOAT } } attr { key: "shape" value { shape { dim { size: 1 } } } } } node { name: "add2" op: "Add" input: "input" input: "add1" attr { key: "T" value { type: DT_FLOAT } } } node { name: "relu1" op: "Relu" input: "add2" attr { key: "T" value { type: DT_FLOAT } } } node { name: "relu2" op: "Relu" input: "add2" attr { key: "T" value { type: DT_FLOAT } } } node { name: "relu3" op: "Relu" input: "relu2" attr { key: "T" value { type: DT_FLOAT } } } node { name: "add1" op: "Add" input: "relu1" input: "relu3" attr { key: "T" value { type: DT_FLOAT } } } )"; armnnTfParser::ITfParserPtr parser = armnnTfParser::ITfParser::Create(); BOOST_CHECK_THROW(parser->CreateNetworkFromString(prototext, {}, { "add1" }), armnn::ParseException); } // Tests that a graph with an input that is not present throws a ParseException. BOOST_AUTO_TEST_CASE(InvalidInput) { const char* prototext = R"( node { name: "r1" op: "Relu" input: "a-node-that-does-not-exist" attr { key: "T" value { type: DT_FLOAT } } } )"; armnnTfParser::ITfParserPtr parser = armnnTfParser::ITfParser::Create(); BOOST_CHECK_THROW(parser->CreateNetworkFromString(prototext, {}, { "r1" }), armnn::ParseException); } BOOST_AUTO_TEST_SUITE_END()