aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfParser/TfParser.cpp
diff options
context:
space:
mode:
authorMatthew Jackson <matthew.jackson@arm.com>2019-08-15 15:14:18 +0100
committerÁron Virginás-Tar <aron.virginas-tar@arm.com>2019-08-16 11:43:00 +0000
commitdba634fd6a66a9e033a1925b0b26c80b270bbf21 (patch)
treec89740a40b9c109582635b7c40b17a16dd6c0649 /src/armnnTfParser/TfParser.cpp
parent11f99b4e72a92051329b23af7ded759463380086 (diff)
downloadarmnn-dba634fd6a66a9e033a1925b0b26c80b270bbf21.tar.gz
IVGCVSW-3639 Add 5d tensor support
* Increased MaxNumOfTensorDimensions and fixed issues related to its use * Fixed issues caused by assuming 5d tensors are invalid * Updated ArmComputeTensorUtils for 5d tensors * Added 5d tensor unit tests for add, mul, stack and reshape (needed by IVGCVSW-3527) Signed-off-by: Matthew Jackson <matthew.jackson@arm.com> Change-Id: I5bcd64942d0d04efcc6c5acb240ad4b88e010743
Diffstat (limited to 'src/armnnTfParser/TfParser.cpp')
-rwxr-xr-xsrc/armnnTfParser/TfParser.cpp16
1 files changed, 9 insertions, 7 deletions
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp
index 39e6971ab5..76d25d1d05 100755
--- a/src/armnnTfParser/TfParser.cpp
+++ b/src/armnnTfParser/TfParser.cpp
@@ -2090,10 +2090,11 @@ ParsedTfOperationPtr TfParser::ParseConcat(const tensorflow::NodeDef& nodeDef,
% CHECK_LOCATION().AsString()));
}
+ const unsigned int supportedNumDims = 4;
unsigned int numConcatViews = numInputs - 1;
- OriginsDescriptor concatDescriptor(static_cast<uint32_t>(numConcatViews), MaxNumOfTensorDimensions);
+ OriginsDescriptor concatDescriptor(static_cast<uint32_t>(numConcatViews), supportedNumDims);
concatDescriptor.SetConcatAxis(concatDim);
- TensorShape mergeDims(MaxNumOfTensorDimensions);
+ TensorShape mergeDims(supportedNumDims);
unsigned int mergeDim = 0;
for (unsigned int viewIndex = 0; viewIndex < numConcatViews; ++viewIndex)
{
@@ -2102,7 +2103,7 @@ ParsedTfOperationPtr TfParser::ParseConcat(const tensorflow::NodeDef& nodeDef,
TensorInfo inputTensorInfo = inputSlot.GetTensorInfo();
// Double check dimensions of the tensors
- if (inputTensorInfo.GetNumDimensions() != MaxNumOfTensorDimensions)
+ if (inputTensorInfo.GetNumDimensions() != supportedNumDims)
{
throw armnn::ParseException(
boost::str(
@@ -2110,14 +2111,14 @@ ParsedTfOperationPtr TfParser::ParseConcat(const tensorflow::NodeDef& nodeDef,
"The number of dimensions: %1% for input tensors of the "
"concatenation op should be %2% %3%")
% inputTensorInfo.GetNumDimensions()
- % MaxNumOfTensorDimensions
+ % supportedNumDims
% CHECK_LOCATION().AsString()));
}
// Copy the input tensor shape to mergeDimSizes and initialize the view origin coordinates for the current input
mergeDims = inputTensorInfo.GetShape();
unsigned int* viewOrigin = const_cast<unsigned int*>(concatDescriptor.GetViewOrigin(viewIndex));
- std::fill(viewOrigin, viewOrigin + MaxNumOfTensorDimensions, 0);
+ std::fill(viewOrigin, viewOrigin + supportedNumDims, 0);
// Update the view origin coordinates and the merge dimension value
concatDescriptor.SetViewOriginCoord(viewIndex, concatDim, mergeDim);
@@ -2652,9 +2653,10 @@ ParsedTfOperationPtr TfParser::ParseSplit(const tensorflow::NodeDef& nodeDef,
IOutputSlot& inputSlot = inputs[1 - index].m_IndexedValue->ResolveArmnnOutputSlot(inputs[1 - index].m_Index);
TensorInfo inputTensorInfo = inputSlot.GetTensorInfo();
+ const unsigned int supportedNumDims = 4;
auto inputDimSize = inputTensorInfo.GetNumDimensions();
- if (inputDimSize != MaxNumOfTensorDimensions)
+ if (inputDimSize != supportedNumDims)
{
throw armnn::ParseException(
boost::str(
@@ -2662,7 +2664,7 @@ ParsedTfOperationPtr TfParser::ParseSplit(const tensorflow::NodeDef& nodeDef,
"The number of dimensions: %1% for input tensors of the "
"split op should be %2% %3%")
% inputTensorInfo.GetNumDimensions()
- % MaxNumOfTensorDimensions
+ % supportedNumDims
% CHECK_LOCATION().AsString()));
}