diff options
author | Matthew Jackson <matthew.jackson@arm.com> | 2019-08-15 15:14:18 +0100 |
---|---|---|
committer | Áron Virginás-Tar <aron.virginas-tar@arm.com> | 2019-08-16 11:43:00 +0000 |
commit | dba634fd6a66a9e033a1925b0b26c80b270bbf21 (patch) | |
tree | c89740a40b9c109582635b7c40b17a16dd6c0649 /src/armnnTfParser/TfParser.cpp | |
parent | 11f99b4e72a92051329b23af7ded759463380086 (diff) | |
download | armnn-dba634fd6a66a9e033a1925b0b26c80b270bbf21.tar.gz |
IVGCVSW-3639 Add 5d tensor support
* Increased MaxNumOfTensorDimensions and fixed issues related to its use
* Fixed issues caused by assuming 5d tensors are invalid
* Updated ArmComputeTensorUtils for 5d tensors
* Added 5d tensor unit tests for add, mul, stack and reshape (needed by IVGCVSW-3527)
Signed-off-by: Matthew Jackson <matthew.jackson@arm.com>
Change-Id: I5bcd64942d0d04efcc6c5acb240ad4b88e010743
Diffstat (limited to 'src/armnnTfParser/TfParser.cpp')
-rwxr-xr-x | src/armnnTfParser/TfParser.cpp | 16 |
1 files changed, 9 insertions, 7 deletions
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp index 39e6971ab5..76d25d1d05 100755 --- a/src/armnnTfParser/TfParser.cpp +++ b/src/armnnTfParser/TfParser.cpp @@ -2090,10 +2090,11 @@ ParsedTfOperationPtr TfParser::ParseConcat(const tensorflow::NodeDef& nodeDef, % CHECK_LOCATION().AsString())); } + const unsigned int supportedNumDims = 4; unsigned int numConcatViews = numInputs - 1; - OriginsDescriptor concatDescriptor(static_cast<uint32_t>(numConcatViews), MaxNumOfTensorDimensions); + OriginsDescriptor concatDescriptor(static_cast<uint32_t>(numConcatViews), supportedNumDims); concatDescriptor.SetConcatAxis(concatDim); - TensorShape mergeDims(MaxNumOfTensorDimensions); + TensorShape mergeDims(supportedNumDims); unsigned int mergeDim = 0; for (unsigned int viewIndex = 0; viewIndex < numConcatViews; ++viewIndex) { @@ -2102,7 +2103,7 @@ ParsedTfOperationPtr TfParser::ParseConcat(const tensorflow::NodeDef& nodeDef, TensorInfo inputTensorInfo = inputSlot.GetTensorInfo(); // Double check dimensions of the tensors - if (inputTensorInfo.GetNumDimensions() != MaxNumOfTensorDimensions) + if (inputTensorInfo.GetNumDimensions() != supportedNumDims) { throw armnn::ParseException( boost::str( @@ -2110,14 +2111,14 @@ ParsedTfOperationPtr TfParser::ParseConcat(const tensorflow::NodeDef& nodeDef, "The number of dimensions: %1% for input tensors of the " "concatenation op should be %2% %3%") % inputTensorInfo.GetNumDimensions() - % MaxNumOfTensorDimensions + % supportedNumDims % CHECK_LOCATION().AsString())); } // Copy the input tensor shape to mergeDimSizes and initialize the view origin coordinates for the current input mergeDims = inputTensorInfo.GetShape(); unsigned int* viewOrigin = const_cast<unsigned int*>(concatDescriptor.GetViewOrigin(viewIndex)); - std::fill(viewOrigin, viewOrigin + MaxNumOfTensorDimensions, 0); + std::fill(viewOrigin, viewOrigin + supportedNumDims, 0); // Update the view origin coordinates and the merge dimension value concatDescriptor.SetViewOriginCoord(viewIndex, concatDim, mergeDim); @@ -2652,9 +2653,10 @@ ParsedTfOperationPtr TfParser::ParseSplit(const tensorflow::NodeDef& nodeDef, IOutputSlot& inputSlot = inputs[1 - index].m_IndexedValue->ResolveArmnnOutputSlot(inputs[1 - index].m_Index); TensorInfo inputTensorInfo = inputSlot.GetTensorInfo(); + const unsigned int supportedNumDims = 4; auto inputDimSize = inputTensorInfo.GetNumDimensions(); - if (inputDimSize != MaxNumOfTensorDimensions) + if (inputDimSize != supportedNumDims) { throw armnn::ParseException( boost::str( @@ -2662,7 +2664,7 @@ ParsedTfOperationPtr TfParser::ParseSplit(const tensorflow::NodeDef& nodeDef, "The number of dimensions: %1% for input tensors of the " "split op should be %2% %3%") % inputTensorInfo.GetNumDimensions() - % MaxNumOfTensorDimensions + % supportedNumDims % CHECK_LOCATION().AsString())); } |