From 91c0eff2ff4ff384e013cb69cac1e07e28b9e2b1 Mon Sep 17 00:00:00 2001 From: Saoirse Stewart Date: Wed, 27 Feb 2019 11:07:57 +0000 Subject: IVGCVSW-2598 Fix for constant axis issue for Tensorflow Parser Change-Id: I8b081012529aed8e434273259c5a5ef7dc3afff7 Signed-off-by: Finn Williams Signed-off-by: Saoirse Stewart --- src/armnnTfParser/test/Split.cpp | 226 +++++++++++++++++++++++++-------------- 1 file changed, 144 insertions(+), 82 deletions(-) (limited to 'src/armnnTfParser/test/Split.cpp') diff --git a/src/armnnTfParser/test/Split.cpp b/src/armnnTfParser/test/Split.cpp index de6b5d861e..87cd6544c9 100644 --- a/src/armnnTfParser/test/Split.cpp +++ b/src/armnnTfParser/test/Split.cpp @@ -11,93 +11,140 @@ BOOST_AUTO_TEST_SUITE(TensorflowParser) struct SplitFixture : public armnnUtils::ParserPrototxtFixture { - SplitFixture() { - m_Prototext = - "node { \n" - " name: \"graphInput\" \n" - " op: \"Placeholder\" \n" - " attr { \n" - " key: \"dtype\" \n" - " value { \n" - " type: DT_FLOAT \n" - " } \n" - " } \n" - " attr { \n" - " key: \"shape\" \n" - " value { \n" - " shape { \n" - " } \n" - " } \n" - " } \n" - " } \n" - " node {" - " name: \"splitInput\" \n" - " op: \"Const\" \n" - "attr {\n" - " key: \"dtype\" \n" - " value {" - " type: DT_INT32" - " }" - "}" - "attr {" - " key: \"value\"\n" - " value { " - " tensor {" - " dtype: DT_INT32" - " tensor_shape {" - "}" - "int_val: 1" - "}" - "}" - "}" - "}" - "node { \n" - " name: \"Split\" \n" - " op: \"Split\" \n" - "input: \"graphInput\"\n" - "input: \"splitInput\"\n" - "attr { \n " - "key: \"T\"\n" - "value {\n" - "type: DT_FLOAT\n" - " }\n" - "}\n" - "\n" - " attr { \n" - " key: \"num_or_size_splits\" \n" - " value { \n" - " i:2 \n " - " } \n" - " } \n" - "} \n" - "node { \n" - "name: \"Relu_1\"\n" - "op: \"Relu\"\n" - "input: \"Split:0\"\n" - "attr { \n " - "key: \"T\"\n" - "value {\n" - "type: DT_FLOAT\n" - " }\n" - "}\n" - "}\n" - "node { \n" - "name: \"Relu_2\"\n" - "op: \"Relu\"\n" - "input: \"Split:1\"\n" - "attr { \n " - "key: \"T\"\n" - "value {\n" - "type: DT_FLOAT\n" - " }\n" - "}\n" - "}\n"; + SplitFixture(bool withDimZero=false) { + m_Prototext = R"( + node { + name: "graphInput" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + } + node { + name: "graphInput2" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + } + node { + name: "multiplication" + op : "Mul" + input: "graphInput" + input: "graphInput2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node { + name: "SplitInput" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: )"; - Setup( { { "graphInput", { 1, 2, 2 , 2} } }, + if(withDimZero) + { + m_Prototext += std::to_string(3); + } + else + { + m_Prototext += std::to_string(1); + } + + m_Prototext += R"( + } + } + } + } + node { + name: "Split" + op: "Split" )"; + if(withDimZero) + { + m_Prototext += "input: \"SplitInput\"\n"; + m_Prototext += "input: \"multiplication\"\n"; + } + else + { + m_Prototext += "input: \"graphInput\"\n"; + m_Prototext += "input: \"SplitInput\"\n"; + } + m_Prototext += R"( + attr { + key: "num_or_size_splits" + value { + i: 2 + } + } + } + node { + name: "Relu_1" + op: "Relu" + input: "Split:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node { + name: "Relu_2" + op: "Relu" + input:"Split:1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } )"; + + Setup( { { "graphInput", { 1, 2, 2 , 2} } , { "graphInput2", { 1, 2, 2 , 2} }}, { "Relu_1", "Relu_2" }); } }; +struct InputFirstSplitFixture : SplitFixture +{ + InputFirstSplitFixture() : SplitFixture(true) {} +}; + BOOST_FIXTURE_TEST_CASE(ParseAxisOneSplitTwo, SplitFixture) { BOOST_TEST( @@ -111,4 +158,19 @@ BOOST_FIXTURE_TEST_CASE(ParseAxisOneSplitTwo, SplitFixture) { "Relu_2", { 0.0f, 0.5f, 0.0f, 1.75f } } }); } +BOOST_FIXTURE_TEST_CASE(ParseSplit, InputFirstSplitFixture) +{ + + BOOST_TEST( + (m_Parser->GetNetworkOutputBindingInfo("Relu_1").second.GetShape() == armnn::TensorShape({ 1, 2, 2, 1 }))); + + BOOST_TEST( + (m_Parser->GetNetworkOutputBindingInfo("Relu_2").second.GetShape() == armnn::TensorShape({ 1, 2, 2, 1 }))); + + RunTest<4>({ { "graphInput", { -1.0f, -0.5f, 1.25f, -3.0f, 0.0f, 0.5f, -0.75f , 1.75f } } , + { "graphInput2", { -1.0f, -0.5f, 1.25f, -3.0f, 0.0f, 0.5f, -0.75f , 1.75f } } }, + { { "Relu_1", { 1.0f, 1.5625f, 0, 0.5625f } }, + { "Relu_2", { 0.25, 9.0f, 0.25f, 3.0625f } } }); +} + BOOST_AUTO_TEST_SUITE_END() -- cgit v1.2.1