From 91c0eff2ff4ff384e013cb69cac1e07e28b9e2b1 Mon Sep 17 00:00:00 2001 From: Saoirse Stewart Date: Wed, 27 Feb 2019 11:07:57 +0000 Subject: IVGCVSW-2598 Fix for constant axis issue for Tensorflow Parser Change-Id: I8b081012529aed8e434273259c5a5ef7dc3afff7 Signed-off-by: Finn Williams Signed-off-by: Saoirse Stewart --- src/armnnTfParser/TfParser.cpp | 104 ++++++++++++++++++++--------------------- 1 file changed, 51 insertions(+), 53 deletions(-) (limited to 'src/armnnTfParser/TfParser.cpp') diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp index 0410460059..1e304cbfd7 100755 --- a/src/armnnTfParser/TfParser.cpp +++ b/src/armnnTfParser/TfParser.cpp @@ -1158,6 +1158,23 @@ bool TfParser::HasParsedConstTensor(ParsedTfOperation* parsedTfOpPtr) const return dynamic_cast*>(parsedTfOpPtr) != nullptr; } +unsigned int TfParser::GetConstInputIndex(const std::vector& inputs) +{ + for (unsigned int i = 0; i < inputs.size(); i++) + { + if (HasParsedConstTensor(inputs[i].m_IndexedValue->GetNode().name())) + { + return i; + } + } + throw ParseException( + boost::str( + boost::format( + "ArmNN only supports operators with constant axis. %1%") + % CHECK_LOCATION().AsString())); + +} + ParsedTfOperationPtr TfParser::ParseConv2D(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef) { @@ -2040,22 +2057,12 @@ ParsedTfOperationPtr TfParser::ParseConcat(const tensorflow::NodeDef& nodeDef, std::vector inputs = GetInputParsedTfOperationsChecked(nodeDef, numInputs); - // The last input is the axis for concatenation. - if (!HasParsedConstTensor(inputs[numInputs - 1].m_IndexedValue->GetNode().name())) - { - throw ParseException( - boost::str( - boost::format( - "ArmNN only supports Concat with constant axis. " - "Input %1%. Node %2% %3%") - % inputs[numInputs - 1].m_IndexedValue->GetNode().name() - % nodeDef.name() - % CHECK_LOCATION().AsString())); - } + // Constant tensor index + unsigned int index = GetConstInputIndex(inputs); + // Get the axis tensor data ParsedConstTfOperation* shapeNode = - boost::polymorphic_downcast*>(inputs[numInputs - 1].m_IndexedValue); + boost::polymorphic_downcast*>(inputs[index].m_IndexedValue); - // Get the axis tensor data std::vector axisTensorData; shapeNode->GetConstTensor(axisTensorData); @@ -2066,13 +2073,13 @@ ParsedTfOperationPtr TfParser::ParseConcat(const tensorflow::NodeDef& nodeDef, if (concatDim == 0 || concatDim == 2) { throw ParseException( - boost::str( - boost::format( + boost::str( + boost::format( "Dimension %1% for concatenation is not supported by Armnn. " "Node %2% %3%") - % concatDim - % nodeDef.name() - % CHECK_LOCATION().AsString())); + % concatDim + % nodeDef.name() + % CHECK_LOCATION().AsString())); } unsigned int numConcatViews = numInputs - 1; @@ -2090,13 +2097,13 @@ ParsedTfOperationPtr TfParser::ParseConcat(const tensorflow::NodeDef& nodeDef, if (inputTensorInfo.GetNumDimensions() != MaxNumOfTensorDimensions) { throw armnn::ParseException( - boost::str( - boost::format( + boost::str( + boost::format( "The number of dimensions: %1% for input tensors of the " "concatenation op should be %2% %3%") - % inputTensorInfo.GetNumDimensions() - % MaxNumOfTensorDimensions - % CHECK_LOCATION().AsString())); + % inputTensorInfo.GetNumDimensions() + % MaxNumOfTensorDimensions + % CHECK_LOCATION().AsString())); } // Copy the input tensor shape to mergeDimSizes and initialize the view origin coordinates for the current input @@ -2605,22 +2612,12 @@ ParsedTfOperationPtr TfParser::ParseSplit(const tensorflow::NodeDef& nodeDef, unsigned int numInputs = static_cast(nodes.size()); std::vector inputs = GetInputParsedTfOperationsChecked(nodeDef, numInputs); - // The last input is the axis for split operation. - if (!HasParsedConstTensor(inputs[numInputs - 1].m_IndexedValue->GetNode().name())) - { - throw ParseException( - boost::str( - boost::format( - "ArmNN only supports split with constant axis. " - "Input %1%. Node %2% %3%") - % inputs[numInputs - 1].m_IndexedValue->GetNode().name() - % nodeDef.name() - % CHECK_LOCATION().AsString())); - } + // Constant tensor index + unsigned int index = GetConstInputIndex(inputs); + // Get the axis tensor data ParsedConstTfOperation* shapeNode = - boost::polymorphic_downcast*>(inputs[numInputs - 1].m_IndexedValue); + boost::polymorphic_downcast*>(inputs[index].m_IndexedValue); - // Get the axis tensor data std::vector axisTensorData; shapeNode->GetConstTensor(axisTensorData); @@ -2630,34 +2627,35 @@ ParsedTfOperationPtr TfParser::ParseSplit(const tensorflow::NodeDef& nodeDef, // Armnn supports split along the channel dimension for data formats NHWC and NCHW. if (splitDim == 0 || splitDim == 2) { - throw ParseException( - boost::str( - boost::format( + throw armnn::ParseException( + boost::str( + boost::format( "Dimension %1% for split is not supported by Armnn. " "Node %2% %3%") - % splitDim - % nodeDef.name() - % CHECK_LOCATION().AsString())); + % splitDim + % nodeDef.name() + % CHECK_LOCATION().AsString())); } // As Armnn only supports splitter outputs of the same shape, therefore num_splits will be limited to an integer. uint32_t num_split = ReadMandatoryNodeUint32Attribute(nodeDef, "num_or_size_splits"); - IOutputSlot& inputSlot = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index); + IOutputSlot& inputSlot = inputs[1 - index].m_IndexedValue->ResolveArmnnOutputSlot(inputs[1 - index].m_Index); TensorInfo inputTensorInfo = inputSlot.GetTensorInfo(); - if (inputTensorInfo.GetNumDimensions() != MaxNumOfTensorDimensions) + auto inputDimSize = inputTensorInfo.GetNumDimensions(); + + if (inputDimSize != MaxNumOfTensorDimensions) { throw armnn::ParseException( - boost::str( - boost::format( + boost::str( + boost::format( "The number of dimensions: %1% for input tensors of the " - "splitter op should be %2% %3%") - % inputTensorInfo.GetNumDimensions() - % MaxNumOfTensorDimensions - % CHECK_LOCATION().AsString())); + "split op should be %2% %3%") + % inputTensorInfo.GetNumDimensions() + % MaxNumOfTensorDimensions + % CHECK_LOCATION().AsString())); } - auto inputDimSize = inputTensorInfo.GetNumDimensions(); std::vector splitterDimSizes(inputDimSize); -- cgit v1.2.1