diff options
author | Saoirse Stewart <saoirse.stewart@arm.com> | 2019-02-27 11:07:57 +0000 |
---|---|---|
committer | Saoirse Stewart Arm <saoirse.stewart@arm.com> | 2019-02-27 13:16:23 +0000 |
commit | 91c0eff2ff4ff384e013cb69cac1e07e28b9e2b1 (patch) | |
tree | 9db521e088c92dd7b9d9e2e6dfd917559aab4745 /src/armnnTfParser/TfParser.cpp | |
parent | dbfb8549d4aa80115a7049b3e94788fb7a474d9b (diff) | |
download | armnn-91c0eff2ff4ff384e013cb69cac1e07e28b9e2b1.tar.gz |
IVGCVSW-2598 Fix for constant axis issue for Tensorflow Parser
Change-Id: I8b081012529aed8e434273259c5a5ef7dc3afff7
Signed-off-by: Finn Williams <finn.williams@arm.com>
Signed-off-by: Saoirse Stewart <saoirse.stewart@arm.com>
Diffstat (limited to 'src/armnnTfParser/TfParser.cpp')
-rwxr-xr-x | src/armnnTfParser/TfParser.cpp | 104 |
1 files changed, 51 insertions, 53 deletions
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp index 0410460059..1e304cbfd7 100755 --- a/src/armnnTfParser/TfParser.cpp +++ b/src/armnnTfParser/TfParser.cpp @@ -1158,6 +1158,23 @@ bool TfParser::HasParsedConstTensor(ParsedTfOperation* parsedTfOpPtr) const return dynamic_cast<ParsedConstTfOperation<Type>*>(parsedTfOpPtr) != nullptr; } +unsigned int TfParser::GetConstInputIndex(const std::vector<OutputOfParsedTfOperation>& inputs) +{ + for (unsigned int i = 0; i < inputs.size(); i++) + { + if (HasParsedConstTensor<int32_t>(inputs[i].m_IndexedValue->GetNode().name())) + { + return i; + } + } + throw ParseException( + boost::str( + boost::format( + "ArmNN only supports operators with constant axis. %1%") + % CHECK_LOCATION().AsString())); + +} + ParsedTfOperationPtr TfParser::ParseConv2D(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef) { @@ -2040,22 +2057,12 @@ ParsedTfOperationPtr TfParser::ParseConcat(const tensorflow::NodeDef& nodeDef, std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, numInputs); - // The last input is the axis for concatenation. - if (!HasParsedConstTensor<int32_t>(inputs[numInputs - 1].m_IndexedValue->GetNode().name())) - { - throw ParseException( - boost::str( - boost::format( - "ArmNN only supports Concat with constant axis. " - "Input %1%. Node %2% %3%") - % inputs[numInputs - 1].m_IndexedValue->GetNode().name() - % nodeDef.name() - % CHECK_LOCATION().AsString())); - } + // Constant tensor index + unsigned int index = GetConstInputIndex(inputs); + // Get the axis tensor data ParsedConstTfOperation<int32_t>* shapeNode = - boost::polymorphic_downcast<ParsedConstTfOperation<int32_t>*>(inputs[numInputs - 1].m_IndexedValue); + boost::polymorphic_downcast<ParsedConstTfOperation<int32_t>*>(inputs[index].m_IndexedValue); - // Get the axis tensor data std::vector<int32_t> axisTensorData; shapeNode->GetConstTensor(axisTensorData); @@ -2066,13 +2073,13 @@ ParsedTfOperationPtr TfParser::ParseConcat(const tensorflow::NodeDef& nodeDef, if (concatDim == 0 || concatDim == 2) { throw ParseException( - boost::str( - boost::format( + boost::str( + boost::format( "Dimension %1% for concatenation is not supported by Armnn. " "Node %2% %3%") - % concatDim - % nodeDef.name() - % CHECK_LOCATION().AsString())); + % concatDim + % nodeDef.name() + % CHECK_LOCATION().AsString())); } unsigned int numConcatViews = numInputs - 1; @@ -2090,13 +2097,13 @@ ParsedTfOperationPtr TfParser::ParseConcat(const tensorflow::NodeDef& nodeDef, if (inputTensorInfo.GetNumDimensions() != MaxNumOfTensorDimensions) { throw armnn::ParseException( - boost::str( - boost::format( + boost::str( + boost::format( "The number of dimensions: %1% for input tensors of the " "concatenation op should be %2% %3%") - % inputTensorInfo.GetNumDimensions() - % MaxNumOfTensorDimensions - % CHECK_LOCATION().AsString())); + % inputTensorInfo.GetNumDimensions() + % MaxNumOfTensorDimensions + % CHECK_LOCATION().AsString())); } // Copy the input tensor shape to mergeDimSizes and initialize the view origin coordinates for the current input @@ -2605,22 +2612,12 @@ ParsedTfOperationPtr TfParser::ParseSplit(const tensorflow::NodeDef& nodeDef, unsigned int numInputs = static_cast<unsigned int>(nodes.size()); std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, numInputs); - // The last input is the axis for split operation. - if (!HasParsedConstTensor<int32_t>(inputs[numInputs - 1].m_IndexedValue->GetNode().name())) - { - throw ParseException( - boost::str( - boost::format( - "ArmNN only supports split with constant axis. " - "Input %1%. Node %2% %3%") - % inputs[numInputs - 1].m_IndexedValue->GetNode().name() - % nodeDef.name() - % CHECK_LOCATION().AsString())); - } + // Constant tensor index + unsigned int index = GetConstInputIndex(inputs); + // Get the axis tensor data ParsedConstTfOperation<int32_t>* shapeNode = - boost::polymorphic_downcast<ParsedConstTfOperation<int32_t>*>(inputs[numInputs - 1].m_IndexedValue); + boost::polymorphic_downcast<ParsedConstTfOperation<int32_t>*>(inputs[index].m_IndexedValue); - // Get the axis tensor data std::vector<int32_t> axisTensorData; shapeNode->GetConstTensor(axisTensorData); @@ -2630,34 +2627,35 @@ ParsedTfOperationPtr TfParser::ParseSplit(const tensorflow::NodeDef& nodeDef, // Armnn supports split along the channel dimension for data formats NHWC and NCHW. if (splitDim == 0 || splitDim == 2) { - throw ParseException( - boost::str( - boost::format( + throw armnn::ParseException( + boost::str( + boost::format( "Dimension %1% for split is not supported by Armnn. " "Node %2% %3%") - % splitDim - % nodeDef.name() - % CHECK_LOCATION().AsString())); + % splitDim + % nodeDef.name() + % CHECK_LOCATION().AsString())); } // As Armnn only supports splitter outputs of the same shape, therefore num_splits will be limited to an integer. uint32_t num_split = ReadMandatoryNodeUint32Attribute(nodeDef, "num_or_size_splits"); - IOutputSlot& inputSlot = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index); + IOutputSlot& inputSlot = inputs[1 - index].m_IndexedValue->ResolveArmnnOutputSlot(inputs[1 - index].m_Index); TensorInfo inputTensorInfo = inputSlot.GetTensorInfo(); - if (inputTensorInfo.GetNumDimensions() != MaxNumOfTensorDimensions) + auto inputDimSize = inputTensorInfo.GetNumDimensions(); + + if (inputDimSize != MaxNumOfTensorDimensions) { throw armnn::ParseException( - boost::str( - boost::format( + boost::str( + boost::format( "The number of dimensions: %1% for input tensors of the " - "splitter op should be %2% %3%") - % inputTensorInfo.GetNumDimensions() - % MaxNumOfTensorDimensions - % CHECK_LOCATION().AsString())); + "split op should be %2% %3%") + % inputTensorInfo.GetNumDimensions() + % MaxNumOfTensorDimensions + % CHECK_LOCATION().AsString())); } - auto inputDimSize = inputTensorInfo.GetNumDimensions(); std::vector<unsigned int> splitterDimSizes(inputDimSize); |