aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--include/armnn/Tensor.hpp2
-rw-r--r--src/armnn/Tensor.cpp16
-rw-r--r--src/armnnTfParser/TfParser.cpp79
3 files changed, 58 insertions, 39 deletions
diff --git a/include/armnn/Tensor.hpp b/include/armnn/Tensor.hpp
index e69d1a956c..f4d7f9f984 100644
--- a/include/armnn/Tensor.hpp
+++ b/include/armnn/Tensor.hpp
@@ -21,6 +21,8 @@ public:
/// Empty (invalid) constructor.
TensorShape();
+ TensorShape(unsigned int numDimensions);
+
TensorShape(unsigned int numDimensions, const unsigned int* dimensionSizes);
TensorShape(std::initializer_list<unsigned int> dimensionSizeList);
diff --git a/src/armnn/Tensor.cpp b/src/armnn/Tensor.cpp
index 8e72d4694c..6e09e3bc59 100644
--- a/src/armnn/Tensor.cpp
+++ b/src/armnn/Tensor.cpp
@@ -23,6 +23,22 @@ TensorShape::TensorShape()
{
}
+TensorShape::TensorShape(unsigned int numDimensions)
+ : m_NumDimensions(numDimensions)
+{
+ if (numDimensions < 1)
+ {
+ throw InvalidArgumentException("Tensor numDimensions must be greater than 0");
+ }
+
+ if (numDimensions > MaxNumOfTensorDimensions)
+ {
+ throw InvalidArgumentException("Tensor numDimensions must be less than or equal to MaxNumOfTensorDimensions");
+ }
+
+ std::fill(m_Dimensions.begin(), m_Dimensions.begin() + m_NumDimensions, 0);
+}
+
TensorShape::TensorShape(const unsigned int numDimensions, const unsigned int* const dimensionSizes)
: m_NumDimensions(numDimensions)
{
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp
index 8f6352c6e7..210b825e43 100644
--- a/src/armnnTfParser/TfParser.cpp
+++ b/src/armnnTfParser/TfParser.cpp
@@ -1781,14 +1781,10 @@ ParsedTfOperationPtr TfParser::ParseConcat(const tensorflow::NodeDef& nodeDef,
const tensorflow::GraphDef& graphDef)
{
std::vector<OutputOfConstNodeDef> nodes = GetTfInputNodes(nodeDef);
+
// In tensorflow, we have the last input of the Concat layer as the axis for concatenation.
unsigned int numInputs = static_cast<unsigned int>(nodes.size());
- unsigned int numConcatView = numInputs - 1;
-
- OriginsDescriptor concatDescriptor(static_cast<uint32_t>(numConcatView), MaxNumOfTensorDimensions);
- std::vector<unsigned int>mergeDimSizes(MaxNumOfTensorDimensions, 0u);
- unsigned int mergeDim = 0;
std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, numInputs);
// The last input is the axis for concatenation.
@@ -1806,65 +1802,70 @@ ParsedTfOperationPtr TfParser::ParseConcat(const tensorflow::NodeDef& nodeDef,
ParsedConstTfOperation<int32_t>* shapeNode =
boost::polymorphic_downcast<ParsedConstTfOperation<int32_t>*>(inputs[numInputs - 1].m_IndexedValue);
+ // Get the axis tensor data
std::vector<int32_t> axisTensorData;
- ConstTensor axisTensor = shapeNode->GetConstTensor(false, axisTensorData);
+ shapeNode->GetConstTensor(false, axisTensorData);
// This concatDim indicates the data format: 3 is the NHWC, 1 is the NCHW.
- const unsigned int concatDimInput = static_cast<unsigned int>(axisTensorData[0]);
+ const unsigned int concatDim = static_cast<unsigned int>(axisTensorData[0]);
// Armnn supports concatenation along the channel dimension for data formats NHWC and NCHW.
- if (concatDimInput == 0 || concatDimInput == 2)
+ if (concatDim == 0 || concatDim == 2)
{
throw ParseException(
boost::str(
boost::format(
"Dimension %1% for concatenation is not supported by Armnn. "
"Node %2% %3%")
- % concatDimInput
+ % concatDim
% nodeDef.name()
% CHECK_LOCATION().AsString()));
}
- // This is the only concatDim we support in armnn.
- const unsigned int concatDim = 1;
- for (unsigned int viewIndex = 0; viewIndex < numConcatView; ++viewIndex)
+ unsigned int numConcatViews = numInputs - 1;
+ OriginsDescriptor concatDescriptor(static_cast<uint32_t>(numConcatViews), MaxNumOfTensorDimensions);
+ concatDescriptor.SetConcatAxis(concatDim);
+ TensorShape mergeDims(MaxNumOfTensorDimensions);
+ unsigned int mergeDim = 0;
+ for (unsigned int viewIndex = 0; viewIndex < numConcatViews; ++viewIndex)
{
// Need to double check whether it should be
- IOutputSlot& inputSlot =
- inputs[viewIndex].m_IndexedValue->ResolveArmnnOutputSlot(inputs[viewIndex].m_Index);
+ IOutputSlot& inputSlot = inputs[viewIndex].m_IndexedValue->ResolveArmnnOutputSlot(inputs[viewIndex].m_Index);
TensorInfo inputTensorInfo = inputSlot.GetTensorInfo();
- // process the input tensor info
- armnnUtils::ProcessConcatInputTensorInfo(inputTensorInfo, concatDescriptor,
- concatDimInput, viewIndex, mergeDimSizes, mergeDim);
+ // Double check dimensions of the tensors
+ if (inputTensorInfo.GetNumDimensions() != MaxNumOfTensorDimensions)
+ {
+ throw armnn::ParseException(
+ boost::str(
+ boost::format(
+ "The number of dimensions: %1% for input tensors of the "
+ "concatenation op should be %2% %3%")
+ % inputTensorInfo.GetNumDimensions()
+ % MaxNumOfTensorDimensions
+ % CHECK_LOCATION().AsString()));
+ }
+
+ // Copy the input tensor shape to mergeDimSizes and initialize the view origin coordinates for the current input
+ mergeDims = inputTensorInfo.GetShape();
+ unsigned int* viewOrigin = const_cast<unsigned int*>(concatDescriptor.GetViewOrigin(viewIndex));
+ std::fill(viewOrigin, viewOrigin + MaxNumOfTensorDimensions, 0);
+
+ // Update the view origin coordinates and the merge dimension value
+ concatDescriptor.SetViewOriginCoord(viewIndex, concatDim, mergeDim);
+ mergeDim += mergeDims[concatDim];
}
- mergeDimSizes[concatDim] = mergeDim;
+ // Update the output shape
+ mergeDims[concatDim] = mergeDim;
armnn::IConnectableLayer *layer = m_Network->AddMergerLayer(concatDescriptor, nodeDef.name().c_str());
- layer->GetOutputSlot(0).SetTensorInfo(armnn::TensorInfo(MaxNumOfTensorDimensions, mergeDimSizes.data(),
- DataType::Float32));
-
- for (unsigned int v = 0; v < numConcatView; ++v)
- {
- IOutputSlot& inputSlot = inputs[v].m_IndexedValue->ResolveArmnnOutputSlot(inputs[v].m_Index);
- if (concatDimInput == 3)
- {
- IConnectableLayer* const swizzleLayer = AddSwizzleLayer(*m_Network, inputSlot, NHWCToArmNN,
- "swizzle_for-" + nodeDef.name());
- swizzleLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(v));
- }
- else
- {
- inputSlot.Connect(layer->GetInputSlot(v));
- }
- }
+ layer->GetOutputSlot(0).SetTensorInfo(armnn::TensorInfo(mergeDims, DataType::Float32));
- if (concatDimInput == 3)
+ for (unsigned int viewIndex = 0; viewIndex < numConcatViews; ++viewIndex)
{
- IConnectableLayer* const deswizzleLayer = AddSwizzleLayer(*m_Network, layer->GetOutputSlot(0), ArmNNToNHWC,
- "deswizzle_for-" + nodeDef.name());
- layer = deswizzleLayer;
+ IOutputSlot& inputSlot = inputs[viewIndex].m_IndexedValue->ResolveArmnnOutputSlot(inputs[viewIndex].m_Index);
+ inputSlot.Connect(layer->GetInputSlot(viewIndex));
}
return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);