aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfParser/TfParser.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnTfParser/TfParser.cpp')
-rwxr-xr-xsrc/armnnTfParser/TfParser.cpp16
1 files changed, 9 insertions, 7 deletions
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp
index 39e6971ab5..76d25d1d05 100755
--- a/src/armnnTfParser/TfParser.cpp
+++ b/src/armnnTfParser/TfParser.cpp
@@ -2090,10 +2090,11 @@ ParsedTfOperationPtr TfParser::ParseConcat(const tensorflow::NodeDef& nodeDef,
% CHECK_LOCATION().AsString()));
}
+ const unsigned int supportedNumDims = 4;
unsigned int numConcatViews = numInputs - 1;
- OriginsDescriptor concatDescriptor(static_cast<uint32_t>(numConcatViews), MaxNumOfTensorDimensions);
+ OriginsDescriptor concatDescriptor(static_cast<uint32_t>(numConcatViews), supportedNumDims);
concatDescriptor.SetConcatAxis(concatDim);
- TensorShape mergeDims(MaxNumOfTensorDimensions);
+ TensorShape mergeDims(supportedNumDims);
unsigned int mergeDim = 0;
for (unsigned int viewIndex = 0; viewIndex < numConcatViews; ++viewIndex)
{
@@ -2102,7 +2103,7 @@ ParsedTfOperationPtr TfParser::ParseConcat(const tensorflow::NodeDef& nodeDef,
TensorInfo inputTensorInfo = inputSlot.GetTensorInfo();
// Double check dimensions of the tensors
- if (inputTensorInfo.GetNumDimensions() != MaxNumOfTensorDimensions)
+ if (inputTensorInfo.GetNumDimensions() != supportedNumDims)
{
throw armnn::ParseException(
boost::str(
@@ -2110,14 +2111,14 @@ ParsedTfOperationPtr TfParser::ParseConcat(const tensorflow::NodeDef& nodeDef,
"The number of dimensions: %1% for input tensors of the "
"concatenation op should be %2% %3%")
% inputTensorInfo.GetNumDimensions()
- % MaxNumOfTensorDimensions
+ % supportedNumDims
% CHECK_LOCATION().AsString()));
}
// Copy the input tensor shape to mergeDimSizes and initialize the view origin coordinates for the current input
mergeDims = inputTensorInfo.GetShape();
unsigned int* viewOrigin = const_cast<unsigned int*>(concatDescriptor.GetViewOrigin(viewIndex));
- std::fill(viewOrigin, viewOrigin + MaxNumOfTensorDimensions, 0);
+ std::fill(viewOrigin, viewOrigin + supportedNumDims, 0);
// Update the view origin coordinates and the merge dimension value
concatDescriptor.SetViewOriginCoord(viewIndex, concatDim, mergeDim);
@@ -2652,9 +2653,10 @@ ParsedTfOperationPtr TfParser::ParseSplit(const tensorflow::NodeDef& nodeDef,
IOutputSlot& inputSlot = inputs[1 - index].m_IndexedValue->ResolveArmnnOutputSlot(inputs[1 - index].m_Index);
TensorInfo inputTensorInfo = inputSlot.GetTensorInfo();
+ const unsigned int supportedNumDims = 4;
auto inputDimSize = inputTensorInfo.GetNumDimensions();
- if (inputDimSize != MaxNumOfTensorDimensions)
+ if (inputDimSize != supportedNumDims)
{
throw armnn::ParseException(
boost::str(
@@ -2662,7 +2664,7 @@ ParsedTfOperationPtr TfParser::ParseSplit(const tensorflow::NodeDef& nodeDef,
"The number of dimensions: %1% for input tensors of the "
"split op should be %2% %3%")
% inputTensorInfo.GetNumDimensions()
- % MaxNumOfTensorDimensions
+ % supportedNumDims
% CHECK_LOCATION().AsString()));
}