aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfParser/TfParser.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnTfParser/TfParser.cpp')
-rw-r--r--src/armnnTfParser/TfParser.cpp39
1 files changed, 4 insertions, 35 deletions
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp
index eca393663b..52ba92cad5 100644
--- a/src/armnnTfParser/TfParser.cpp
+++ b/src/armnnTfParser/TfParser.cpp
@@ -11,6 +11,7 @@
#include <armnn/Descriptors.hpp>
#include <GraphTopologicalSort.hpp>
+#include <ParserHelper.hpp>
#include <Permute.hpp>
#include <VerificationHelpers.hpp>
@@ -1478,41 +1479,9 @@ ParsedTfOperationPtr TfParser::ParseConcat(const tensorflow::NodeDef& nodeDef,
inputs[viewIndex].m_IndexedValue->ResolveArmnnOutputSlot(inputs[viewIndex].m_Index);
TensorInfo inputTensorInfo = inputSlot.GetTensorInfo();
- if (inputTensorInfo.GetNumDimensions() != MaxNumOfTensorDimensions)
- {
- throw ParseException(
- boost::str(
- boost::format(
- "The number of dimensions: %1% for input tensors of the "
- "concatenation op should be %2% for Node %3% %4%")
- % inputTensorInfo.GetNumDimensions()
- % MaxNumOfTensorDimensions
- % nodeDef.name()
- % CHECK_LOCATION().AsString()));
- }
-
- if (concatDimInput == 3)
- {
- inputTensorInfo = armnnUtils::Permuted(inputTensorInfo, NHWCToArmNN);
- }
-
- for (unsigned int dim = 0; dim < MaxNumOfTensorDimensions; ++dim)
- {
- mergeDimSizes[dim] = inputTensorInfo.GetShape()[dim];
- }
-
- for (unsigned int j = 0; j < concatDim; ++j)
- {
- concatDescriptor.SetViewOriginCoord(viewIndex, j, 0);
- }
-
- concatDescriptor.SetViewOriginCoord(viewIndex, concatDim, mergeDim);
- mergeDim += mergeDimSizes[concatDim];
-
- for (unsigned int j = concatDim+1; j < MaxNumOfTensorDimensions; ++j)
- {
- concatDescriptor.SetViewOriginCoord(viewIndex, j, 0);
- }
+ // process the input tensor info
+ armnnUtils::ProcessConcatInputTensorInfo(inputTensorInfo, concatDescriptor,
+ concatDimInput, viewIndex, mergeDimSizes, mergeDim);
}
mergeDimSizes[concatDim] = mergeDim;