From dba634fd6a66a9e033a1925b0b26c80b270bbf21 Mon Sep 17 00:00:00 2001 From: Matthew Jackson Date: Thu, 15 Aug 2019 15:14:18 +0100 Subject: IVGCVSW-3639 Add 5d tensor support * Increased MaxNumOfTensorDimensions and fixed issues related to its use * Fixed issues caused by assuming 5d tensors are invalid * Updated ArmComputeTensorUtils for 5d tensors * Added 5d tensor unit tests for add, mul, stack and reshape (needed by IVGCVSW-3527) Signed-off-by: Matthew Jackson Change-Id: I5bcd64942d0d04efcc6c5acb240ad4b88e010743 --- src/armnnTfParser/TfParser.cpp | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) (limited to 'src/armnnTfParser/TfParser.cpp') diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp index 39e6971ab5..76d25d1d05 100755 --- a/src/armnnTfParser/TfParser.cpp +++ b/src/armnnTfParser/TfParser.cpp @@ -2090,10 +2090,11 @@ ParsedTfOperationPtr TfParser::ParseConcat(const tensorflow::NodeDef& nodeDef, % CHECK_LOCATION().AsString())); } + const unsigned int supportedNumDims = 4; unsigned int numConcatViews = numInputs - 1; - OriginsDescriptor concatDescriptor(static_cast(numConcatViews), MaxNumOfTensorDimensions); + OriginsDescriptor concatDescriptor(static_cast(numConcatViews), supportedNumDims); concatDescriptor.SetConcatAxis(concatDim); - TensorShape mergeDims(MaxNumOfTensorDimensions); + TensorShape mergeDims(supportedNumDims); unsigned int mergeDim = 0; for (unsigned int viewIndex = 0; viewIndex < numConcatViews; ++viewIndex) { @@ -2102,7 +2103,7 @@ ParsedTfOperationPtr TfParser::ParseConcat(const tensorflow::NodeDef& nodeDef, TensorInfo inputTensorInfo = inputSlot.GetTensorInfo(); // Double check dimensions of the tensors - if (inputTensorInfo.GetNumDimensions() != MaxNumOfTensorDimensions) + if (inputTensorInfo.GetNumDimensions() != supportedNumDims) { throw armnn::ParseException( boost::str( @@ -2110,14 +2111,14 @@ ParsedTfOperationPtr TfParser::ParseConcat(const tensorflow::NodeDef& nodeDef, "The number of dimensions: %1% for input tensors of the " "concatenation op should be %2% %3%") % inputTensorInfo.GetNumDimensions() - % MaxNumOfTensorDimensions + % supportedNumDims % CHECK_LOCATION().AsString())); } // Copy the input tensor shape to mergeDimSizes and initialize the view origin coordinates for the current input mergeDims = inputTensorInfo.GetShape(); unsigned int* viewOrigin = const_cast(concatDescriptor.GetViewOrigin(viewIndex)); - std::fill(viewOrigin, viewOrigin + MaxNumOfTensorDimensions, 0); + std::fill(viewOrigin, viewOrigin + supportedNumDims, 0); // Update the view origin coordinates and the merge dimension value concatDescriptor.SetViewOriginCoord(viewIndex, concatDim, mergeDim); @@ -2652,9 +2653,10 @@ ParsedTfOperationPtr TfParser::ParseSplit(const tensorflow::NodeDef& nodeDef, IOutputSlot& inputSlot = inputs[1 - index].m_IndexedValue->ResolveArmnnOutputSlot(inputs[1 - index].m_Index); TensorInfo inputTensorInfo = inputSlot.GetTensorInfo(); + const unsigned int supportedNumDims = 4; auto inputDimSize = inputTensorInfo.GetNumDimensions(); - if (inputDimSize != MaxNumOfTensorDimensions) + if (inputDimSize != supportedNumDims) { throw armnn::ParseException( boost::str( @@ -2662,7 +2664,7 @@ ParsedTfOperationPtr TfParser::ParseSplit(const tensorflow::NodeDef& nodeDef, "The number of dimensions: %1% for input tensors of the " "split op should be %2% %3%") % inputTensorInfo.GetNumDimensions() - % MaxNumOfTensorDimensions + % supportedNumDims % CHECK_LOCATION().AsString())); } -- cgit v1.2.1