diff options
Diffstat (limited to 'src/armnnTfParser/TfParser.cpp')
-rw-r--r-- | src/armnnTfParser/TfParser.cpp | 39 |
1 files changed, 4 insertions, 35 deletions
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp index eca393663b..52ba92cad5 100644 --- a/src/armnnTfParser/TfParser.cpp +++ b/src/armnnTfParser/TfParser.cpp @@ -11,6 +11,7 @@ #include <armnn/Descriptors.hpp> #include <GraphTopologicalSort.hpp> +#include <ParserHelper.hpp> #include <Permute.hpp> #include <VerificationHelpers.hpp> @@ -1478,41 +1479,9 @@ ParsedTfOperationPtr TfParser::ParseConcat(const tensorflow::NodeDef& nodeDef, inputs[viewIndex].m_IndexedValue->ResolveArmnnOutputSlot(inputs[viewIndex].m_Index); TensorInfo inputTensorInfo = inputSlot.GetTensorInfo(); - if (inputTensorInfo.GetNumDimensions() != MaxNumOfTensorDimensions) - { - throw ParseException( - boost::str( - boost::format( - "The number of dimensions: %1% for input tensors of the " - "concatenation op should be %2% for Node %3% %4%") - % inputTensorInfo.GetNumDimensions() - % MaxNumOfTensorDimensions - % nodeDef.name() - % CHECK_LOCATION().AsString())); - } - - if (concatDimInput == 3) - { - inputTensorInfo = armnnUtils::Permuted(inputTensorInfo, NHWCToArmNN); - } - - for (unsigned int dim = 0; dim < MaxNumOfTensorDimensions; ++dim) - { - mergeDimSizes[dim] = inputTensorInfo.GetShape()[dim]; - } - - for (unsigned int j = 0; j < concatDim; ++j) - { - concatDescriptor.SetViewOriginCoord(viewIndex, j, 0); - } - - concatDescriptor.SetViewOriginCoord(viewIndex, concatDim, mergeDim); - mergeDim += mergeDimSizes[concatDim]; - - for (unsigned int j = concatDim+1; j < MaxNumOfTensorDimensions; ++j) - { - concatDescriptor.SetViewOriginCoord(viewIndex, j, 0); - } + // process the input tensor info + armnnUtils::ProcessConcatInputTensorInfo(inputTensorInfo, concatDescriptor, + concatDimInput, viewIndex, mergeDimSizes, mergeDim); } mergeDimSizes[concatDim] = mergeDim; |