From 91c0eff2ff4ff384e013cb69cac1e07e28b9e2b1 Mon Sep 17 00:00:00 2001 From: Saoirse Stewart Date: Wed, 27 Feb 2019 11:07:57 +0000 Subject: IVGCVSW-2598 Fix for constant axis issue for Tensorflow Parser Change-Id: I8b081012529aed8e434273259c5a5ef7dc3afff7 Signed-off-by: Finn Williams Signed-off-by: Saoirse Stewart --- src/armnnTfParser/TfParser.cpp | 104 +++++++++--------- src/armnnTfParser/TfParser.hpp | 2 + src/armnnTfParser/test/Split.cpp | 226 +++++++++++++++++++++++++-------------- 3 files changed, 197 insertions(+), 135 deletions(-) diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp index 0410460059..1e304cbfd7 100755 --- a/src/armnnTfParser/TfParser.cpp +++ b/src/armnnTfParser/TfParser.cpp @@ -1158,6 +1158,23 @@ bool TfParser::HasParsedConstTensor(ParsedTfOperation* parsedTfOpPtr) const return dynamic_cast*>(parsedTfOpPtr) != nullptr; } +unsigned int TfParser::GetConstInputIndex(const std::vector& inputs) +{ + for (unsigned int i = 0; i < inputs.size(); i++) + { + if (HasParsedConstTensor(inputs[i].m_IndexedValue->GetNode().name())) + { + return i; + } + } + throw ParseException( + boost::str( + boost::format( + "ArmNN only supports operators with constant axis. %1%") + % CHECK_LOCATION().AsString())); + +} + ParsedTfOperationPtr TfParser::ParseConv2D(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef) { @@ -2040,22 +2057,12 @@ ParsedTfOperationPtr TfParser::ParseConcat(const tensorflow::NodeDef& nodeDef, std::vector inputs = GetInputParsedTfOperationsChecked(nodeDef, numInputs); - // The last input is the axis for concatenation. - if (!HasParsedConstTensor(inputs[numInputs - 1].m_IndexedValue->GetNode().name())) - { - throw ParseException( - boost::str( - boost::format( - "ArmNN only supports Concat with constant axis. " - "Input %1%. Node %2% %3%") - % inputs[numInputs - 1].m_IndexedValue->GetNode().name() - % nodeDef.name() - % CHECK_LOCATION().AsString())); - } + // Constant tensor index + unsigned int index = GetConstInputIndex(inputs); + // Get the axis tensor data ParsedConstTfOperation* shapeNode = - boost::polymorphic_downcast*>(inputs[numInputs - 1].m_IndexedValue); + boost::polymorphic_downcast*>(inputs[index].m_IndexedValue); - // Get the axis tensor data std::vector axisTensorData; shapeNode->GetConstTensor(axisTensorData); @@ -2066,13 +2073,13 @@ ParsedTfOperationPtr TfParser::ParseConcat(const tensorflow::NodeDef& nodeDef, if (concatDim == 0 || concatDim == 2) { throw ParseException( - boost::str( - boost::format( + boost::str( + boost::format( "Dimension %1% for concatenation is not supported by Armnn. " "Node %2% %3%") - % concatDim - % nodeDef.name() - % CHECK_LOCATION().AsString())); + % concatDim + % nodeDef.name() + % CHECK_LOCATION().AsString())); } unsigned int numConcatViews = numInputs - 1; @@ -2090,13 +2097,13 @@ ParsedTfOperationPtr TfParser::ParseConcat(const tensorflow::NodeDef& nodeDef, if (inputTensorInfo.GetNumDimensions() != MaxNumOfTensorDimensions) { throw armnn::ParseException( - boost::str( - boost::format( + boost::str( + boost::format( "The number of dimensions: %1% for input tensors of the " "concatenation op should be %2% %3%") - % inputTensorInfo.GetNumDimensions() - % MaxNumOfTensorDimensions - % CHECK_LOCATION().AsString())); + % inputTensorInfo.GetNumDimensions() + % MaxNumOfTensorDimensions + % CHECK_LOCATION().AsString())); } // Copy the input tensor shape to mergeDimSizes and initialize the view origin coordinates for the current input @@ -2605,22 +2612,12 @@ ParsedTfOperationPtr TfParser::ParseSplit(const tensorflow::NodeDef& nodeDef, unsigned int numInputs = static_cast(nodes.size()); std::vector inputs = GetInputParsedTfOperationsChecked(nodeDef, numInputs); - // The last input is the axis for split operation. - if (!HasParsedConstTensor(inputs[numInputs - 1].m_IndexedValue->GetNode().name())) - { - throw ParseException( - boost::str( - boost::format( - "ArmNN only supports split with constant axis. " - "Input %1%. Node %2% %3%") - % inputs[numInputs - 1].m_IndexedValue->GetNode().name() - % nodeDef.name() - % CHECK_LOCATION().AsString())); - } + // Constant tensor index + unsigned int index = GetConstInputIndex(inputs); + // Get the axis tensor data ParsedConstTfOperation* shapeNode = - boost::polymorphic_downcast*>(inputs[numInputs - 1].m_IndexedValue); + boost::polymorphic_downcast*>(inputs[index].m_IndexedValue); - // Get the axis tensor data std::vector axisTensorData; shapeNode->GetConstTensor(axisTensorData); @@ -2630,34 +2627,35 @@ ParsedTfOperationPtr TfParser::ParseSplit(const tensorflow::NodeDef& nodeDef, // Armnn supports split along the channel dimension for data formats NHWC and NCHW. if (splitDim == 0 || splitDim == 2) { - throw ParseException( - boost::str( - boost::format( + throw armnn::ParseException( + boost::str( + boost::format( "Dimension %1% for split is not supported by Armnn. " "Node %2% %3%") - % splitDim - % nodeDef.name() - % CHECK_LOCATION().AsString())); + % splitDim + % nodeDef.name() + % CHECK_LOCATION().AsString())); } // As Armnn only supports splitter outputs of the same shape, therefore num_splits will be limited to an integer. uint32_t num_split = ReadMandatoryNodeUint32Attribute(nodeDef, "num_or_size_splits"); - IOutputSlot& inputSlot = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index); + IOutputSlot& inputSlot = inputs[1 - index].m_IndexedValue->ResolveArmnnOutputSlot(inputs[1 - index].m_Index); TensorInfo inputTensorInfo = inputSlot.GetTensorInfo(); - if (inputTensorInfo.GetNumDimensions() != MaxNumOfTensorDimensions) + auto inputDimSize = inputTensorInfo.GetNumDimensions(); + + if (inputDimSize != MaxNumOfTensorDimensions) { throw armnn::ParseException( - boost::str( - boost::format( + boost::str( + boost::format( "The number of dimensions: %1% for input tensors of the " - "splitter op should be %2% %3%") - % inputTensorInfo.GetNumDimensions() - % MaxNumOfTensorDimensions - % CHECK_LOCATION().AsString())); + "split op should be %2% %3%") + % inputTensorInfo.GetNumDimensions() + % MaxNumOfTensorDimensions + % CHECK_LOCATION().AsString())); } - auto inputDimSize = inputTensorInfo.GetNumDimensions(); std::vector splitterDimSizes(inputDimSize); diff --git a/src/armnnTfParser/TfParser.hpp b/src/armnnTfParser/TfParser.hpp index 46da55f1d1..95ccf397c1 100644 --- a/src/armnnTfParser/TfParser.hpp +++ b/src/armnnTfParser/TfParser.hpp @@ -129,6 +129,8 @@ private: template bool HasParsedConstTensor(ParsedTfOperation* parsedTfOpPtr) const; + unsigned int GetConstInputIndex(const std::vector& inputs); + ParsedTfOperationPtr ParseAdd(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); ParsedTfOperationPtr ParseAddN(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); ParsedTfOperationPtr ParseBiasAdd(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); diff --git a/src/armnnTfParser/test/Split.cpp b/src/armnnTfParser/test/Split.cpp index de6b5d861e..87cd6544c9 100644 --- a/src/armnnTfParser/test/Split.cpp +++ b/src/armnnTfParser/test/Split.cpp @@ -11,93 +11,140 @@ BOOST_AUTO_TEST_SUITE(TensorflowParser) struct SplitFixture : public armnnUtils::ParserPrototxtFixture { - SplitFixture() { - m_Prototext = - "node { \n" - " name: \"graphInput\" \n" - " op: \"Placeholder\" \n" - " attr { \n" - " key: \"dtype\" \n" - " value { \n" - " type: DT_FLOAT \n" - " } \n" - " } \n" - " attr { \n" - " key: \"shape\" \n" - " value { \n" - " shape { \n" - " } \n" - " } \n" - " } \n" - " } \n" - " node {" - " name: \"splitInput\" \n" - " op: \"Const\" \n" - "attr {\n" - " key: \"dtype\" \n" - " value {" - " type: DT_INT32" - " }" - "}" - "attr {" - " key: \"value\"\n" - " value { " - " tensor {" - " dtype: DT_INT32" - " tensor_shape {" - "}" - "int_val: 1" - "}" - "}" - "}" - "}" - "node { \n" - " name: \"Split\" \n" - " op: \"Split\" \n" - "input: \"graphInput\"\n" - "input: \"splitInput\"\n" - "attr { \n " - "key: \"T\"\n" - "value {\n" - "type: DT_FLOAT\n" - " }\n" - "}\n" - "\n" - " attr { \n" - " key: \"num_or_size_splits\" \n" - " value { \n" - " i:2 \n " - " } \n" - " } \n" - "} \n" - "node { \n" - "name: \"Relu_1\"\n" - "op: \"Relu\"\n" - "input: \"Split:0\"\n" - "attr { \n " - "key: \"T\"\n" - "value {\n" - "type: DT_FLOAT\n" - " }\n" - "}\n" - "}\n" - "node { \n" - "name: \"Relu_2\"\n" - "op: \"Relu\"\n" - "input: \"Split:1\"\n" - "attr { \n " - "key: \"T\"\n" - "value {\n" - "type: DT_FLOAT\n" - " }\n" - "}\n" - "}\n"; + SplitFixture(bool withDimZero=false) { + m_Prototext = R"( + node { + name: "graphInput" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + } + node { + name: "graphInput2" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + } + node { + name: "multiplication" + op : "Mul" + input: "graphInput" + input: "graphInput2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node { + name: "SplitInput" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: )"; - Setup( { { "graphInput", { 1, 2, 2 , 2} } }, + if(withDimZero) + { + m_Prototext += std::to_string(3); + } + else + { + m_Prototext += std::to_string(1); + } + + m_Prototext += R"( + } + } + } + } + node { + name: "Split" + op: "Split" )"; + if(withDimZero) + { + m_Prototext += "input: \"SplitInput\"\n"; + m_Prototext += "input: \"multiplication\"\n"; + } + else + { + m_Prototext += "input: \"graphInput\"\n"; + m_Prototext += "input: \"SplitInput\"\n"; + } + m_Prototext += R"( + attr { + key: "num_or_size_splits" + value { + i: 2 + } + } + } + node { + name: "Relu_1" + op: "Relu" + input: "Split:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node { + name: "Relu_2" + op: "Relu" + input:"Split:1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } )"; + + Setup( { { "graphInput", { 1, 2, 2 , 2} } , { "graphInput2", { 1, 2, 2 , 2} }}, { "Relu_1", "Relu_2" }); } }; +struct InputFirstSplitFixture : SplitFixture +{ + InputFirstSplitFixture() : SplitFixture(true) {} +}; + BOOST_FIXTURE_TEST_CASE(ParseAxisOneSplitTwo, SplitFixture) { BOOST_TEST( @@ -111,4 +158,19 @@ BOOST_FIXTURE_TEST_CASE(ParseAxisOneSplitTwo, SplitFixture) { "Relu_2", { 0.0f, 0.5f, 0.0f, 1.75f } } }); } +BOOST_FIXTURE_TEST_CASE(ParseSplit, InputFirstSplitFixture) +{ + + BOOST_TEST( + (m_Parser->GetNetworkOutputBindingInfo("Relu_1").second.GetShape() == armnn::TensorShape({ 1, 2, 2, 1 }))); + + BOOST_TEST( + (m_Parser->GetNetworkOutputBindingInfo("Relu_2").second.GetShape() == armnn::TensorShape({ 1, 2, 2, 1 }))); + + RunTest<4>({ { "graphInput", { -1.0f, -0.5f, 1.25f, -3.0f, 0.0f, 0.5f, -0.75f , 1.75f } } , + { "graphInput2", { -1.0f, -0.5f, 1.25f, -3.0f, 0.0f, 0.5f, -0.75f , 1.75f } } }, + { { "Relu_1", { 1.0f, 1.5625f, 0, 0.5625f } }, + { "Relu_2", { 0.25, 9.0f, 0.25f, 3.0625f } } }); +} + BOOST_AUTO_TEST_SUITE_END() -- cgit v1.2.1