aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfParser/TfParser.cpp
diff options
context:
space:
mode:
authorMatteo Martincigh <matteo.martincigh@arm.com>2018-12-06 12:03:17 +0000
committerMatteo Martincigh <matteo.martincigh@arm.com>2018-12-07 16:25:45 +0000
commitf9afc791662f9ffd639a9500de8c4e33394c8c39 (patch)
tree26b3203c189f0a0f280e8fab4fd081f2d6112da3 /src/armnnTfParser/TfParser.cpp
parentba563c6d81bfb20f01b1b54e27fe1ac4a494ece1 (diff)
downloadarmnn-f9afc791662f9ffd639a9500de8c4e33394c8c39.tar.gz
IVGCVSW-2268 Remove the input swizzling from ParseConcat
* Removed the input swizzling when the concatenation dimension is 3 in ParseConcat in the TF parser * No longer using the helper ProcessConcatInputTensorInfo, where the input was being swizzled if the concatenation dimension was 3 * Added a new convenience constuctor to TensorShape that initializes a shape to all zeros given only the number of dimensions Change-Id: I82a207e41bddc5fea21a0b5a38eafa24ad75d1c2
Diffstat (limited to 'src/armnnTfParser/TfParser.cpp')
-rw-r--r--src/armnnTfParser/TfParser.cpp79
1 files changed, 40 insertions, 39 deletions
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp
index 8f6352c6e7..210b825e43 100644
--- a/src/armnnTfParser/TfParser.cpp
+++ b/src/armnnTfParser/TfParser.cpp
@@ -1781,14 +1781,10 @@ ParsedTfOperationPtr TfParser::ParseConcat(const tensorflow::NodeDef& nodeDef,
const tensorflow::GraphDef& graphDef)
{
std::vector<OutputOfConstNodeDef> nodes = GetTfInputNodes(nodeDef);
+
// In tensorflow, we have the last input of the Concat layer as the axis for concatenation.
unsigned int numInputs = static_cast<unsigned int>(nodes.size());
- unsigned int numConcatView = numInputs - 1;
-
- OriginsDescriptor concatDescriptor(static_cast<uint32_t>(numConcatView), MaxNumOfTensorDimensions);
- std::vector<unsigned int>mergeDimSizes(MaxNumOfTensorDimensions, 0u);
- unsigned int mergeDim = 0;
std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, numInputs);
// The last input is the axis for concatenation.
@@ -1806,65 +1802,70 @@ ParsedTfOperationPtr TfParser::ParseConcat(const tensorflow::NodeDef& nodeDef,
ParsedConstTfOperation<int32_t>* shapeNode =
boost::polymorphic_downcast<ParsedConstTfOperation<int32_t>*>(inputs[numInputs - 1].m_IndexedValue);
+ // Get the axis tensor data
std::vector<int32_t> axisTensorData;
- ConstTensor axisTensor = shapeNode->GetConstTensor(false, axisTensorData);
+ shapeNode->GetConstTensor(false, axisTensorData);
// This concatDim indicates the data format: 3 is the NHWC, 1 is the NCHW.
- const unsigned int concatDimInput = static_cast<unsigned int>(axisTensorData[0]);
+ const unsigned int concatDim = static_cast<unsigned int>(axisTensorData[0]);
// Armnn supports concatenation along the channel dimension for data formats NHWC and NCHW.
- if (concatDimInput == 0 || concatDimInput == 2)
+ if (concatDim == 0 || concatDim == 2)
{
throw ParseException(
boost::str(
boost::format(
"Dimension %1% for concatenation is not supported by Armnn. "
"Node %2% %3%")
- % concatDimInput
+ % concatDim
% nodeDef.name()
% CHECK_LOCATION().AsString()));
}
- // This is the only concatDim we support in armnn.
- const unsigned int concatDim = 1;
- for (unsigned int viewIndex = 0; viewIndex < numConcatView; ++viewIndex)
+ unsigned int numConcatViews = numInputs - 1;
+ OriginsDescriptor concatDescriptor(static_cast<uint32_t>(numConcatViews), MaxNumOfTensorDimensions);
+ concatDescriptor.SetConcatAxis(concatDim);
+ TensorShape mergeDims(MaxNumOfTensorDimensions);
+ unsigned int mergeDim = 0;
+ for (unsigned int viewIndex = 0; viewIndex < numConcatViews; ++viewIndex)
{
// Need to double check whether it should be
- IOutputSlot& inputSlot =
- inputs[viewIndex].m_IndexedValue->ResolveArmnnOutputSlot(inputs[viewIndex].m_Index);
+ IOutputSlot& inputSlot = inputs[viewIndex].m_IndexedValue->ResolveArmnnOutputSlot(inputs[viewIndex].m_Index);
TensorInfo inputTensorInfo = inputSlot.GetTensorInfo();
- // process the input tensor info
- armnnUtils::ProcessConcatInputTensorInfo(inputTensorInfo, concatDescriptor,
- concatDimInput, viewIndex, mergeDimSizes, mergeDim);
+ // Double check dimensions of the tensors
+ if (inputTensorInfo.GetNumDimensions() != MaxNumOfTensorDimensions)
+ {
+ throw armnn::ParseException(
+ boost::str(
+ boost::format(
+ "The number of dimensions: %1% for input tensors of the "
+ "concatenation op should be %2% %3%")
+ % inputTensorInfo.GetNumDimensions()
+ % MaxNumOfTensorDimensions
+ % CHECK_LOCATION().AsString()));
+ }
+
+ // Copy the input tensor shape to mergeDimSizes and initialize the view origin coordinates for the current input
+ mergeDims = inputTensorInfo.GetShape();
+ unsigned int* viewOrigin = const_cast<unsigned int*>(concatDescriptor.GetViewOrigin(viewIndex));
+ std::fill(viewOrigin, viewOrigin + MaxNumOfTensorDimensions, 0);
+
+ // Update the view origin coordinates and the merge dimension value
+ concatDescriptor.SetViewOriginCoord(viewIndex, concatDim, mergeDim);
+ mergeDim += mergeDims[concatDim];
}
- mergeDimSizes[concatDim] = mergeDim;
+ // Update the output shape
+ mergeDims[concatDim] = mergeDim;
armnn::IConnectableLayer *layer = m_Network->AddMergerLayer(concatDescriptor, nodeDef.name().c_str());
- layer->GetOutputSlot(0).SetTensorInfo(armnn::TensorInfo(MaxNumOfTensorDimensions, mergeDimSizes.data(),
- DataType::Float32));
-
- for (unsigned int v = 0; v < numConcatView; ++v)
- {
- IOutputSlot& inputSlot = inputs[v].m_IndexedValue->ResolveArmnnOutputSlot(inputs[v].m_Index);
- if (concatDimInput == 3)
- {
- IConnectableLayer* const swizzleLayer = AddSwizzleLayer(*m_Network, inputSlot, NHWCToArmNN,
- "swizzle_for-" + nodeDef.name());
- swizzleLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(v));
- }
- else
- {
- inputSlot.Connect(layer->GetInputSlot(v));
- }
- }
+ layer->GetOutputSlot(0).SetTensorInfo(armnn::TensorInfo(mergeDims, DataType::Float32));
- if (concatDimInput == 3)
+ for (unsigned int viewIndex = 0; viewIndex < numConcatViews; ++viewIndex)
{
- IConnectableLayer* const deswizzleLayer = AddSwizzleLayer(*m_Network, layer->GetOutputSlot(0), ArmNNToNHWC,
- "deswizzle_for-" + nodeDef.name());
- layer = deswizzleLayer;
+ IOutputSlot& inputSlot = inputs[viewIndex].m_IndexedValue->ResolveArmnnOutputSlot(inputs[viewIndex].m_Index);
+ inputSlot.Connect(layer->GetInputSlot(viewIndex));
}
return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);