aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfLiteParser/TfLiteParser.cpp
diff options
context:
space:
mode:
authorSadik Armagan <sadik.armagan@arm.com>2018-10-01 11:51:37 +0100
committerMatthew Bentham <matthew.bentham@arm.com>2018-10-10 16:16:57 +0100
commit479045bdcac4faddbf567aa0f73d2899881f341c (patch)
tree5cd11ee39543ceda2730d21c1e7036543712a335 /src/armnnTfLiteParser/TfLiteParser.cpp
parente4ba53a85c559d4fe574305276ac815cf7995762 (diff)
downloadarmnn-479045bdcac4faddbf567aa0f73d2899881f341c.tar.gz
IVGCVSW-1787 Add Support for Concatenation on TfLite parser
* Concatenation Parser function added to the TfLite Parser Change-Id: I42a42cd765ea09a30841c66b1942b9e09a876b10
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp')
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp98
1 files changed, 96 insertions, 2 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 13e4604490..66746e488b 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -10,6 +10,7 @@
#include <boost/filesystem.hpp>
// armnnUtils:
+#include <ParserHelper.hpp>
#include <Permute.hpp>
#include <VerificationHelpers.hpp>
@@ -452,13 +453,14 @@ TfLiteParser::TfLiteParser()
{
// register supported operators
m_ParserFunctions[tflite::BuiltinOperator_AVERAGE_POOL_2D] = &TfLiteParser::ParseAveragePool2D;
+ m_ParserFunctions[tflite::BuiltinOperator_CONCATENATION] = &TfLiteParser::ParseConcatenation;
m_ParserFunctions[tflite::BuiltinOperator_CONV_2D] = &TfLiteParser::ParseConv2D;
m_ParserFunctions[tflite::BuiltinOperator_DEPTHWISE_CONV_2D] = &TfLiteParser::ParseDepthwiseConv2D;
- m_ParserFunctions[tflite::BuiltinOperator_SOFTMAX] = &TfLiteParser::ParseSoftmax;
- m_ParserFunctions[tflite::BuiltinOperator_SQUEEZE] = &TfLiteParser::ParseSqueeze;
m_ParserFunctions[tflite::BuiltinOperator_RELU] = &TfLiteParser::ParseRelu;
m_ParserFunctions[tflite::BuiltinOperator_RELU6] = &TfLiteParser::ParseRelu6;
m_ParserFunctions[tflite::BuiltinOperator_RESHAPE] = &TfLiteParser::ParseReshape;
+ m_ParserFunctions[tflite::BuiltinOperator_SOFTMAX] = &TfLiteParser::ParseSoftmax;
+ m_ParserFunctions[tflite::BuiltinOperator_SQUEEZE] = &TfLiteParser::ParseSqueeze;
}
void TfLiteParser::ResetParser()
@@ -1097,6 +1099,98 @@ void TfLiteParser::ParseReshape(size_t subgraphIndex, size_t operatorIndex)
RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
}
+void TfLiteParser::ParseConcatenation(size_t subgraphIndex, size_t operatorIndex)
+{
+ CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
+
+ const auto & operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
+ const auto * options = operatorPtr->builtin_options.AsConcatenationOptions();
+
+ CHECK_SUPPORTED_FUSED_ACTIVATION(options, subgraphIndex, operatorIndex);
+
+ auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
+ auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
+ CHECK_VALID_SIZE(outputs.size(), 1);
+
+ unsigned int numInputs = static_cast<unsigned int>(inputs.size());
+ unsigned int numConcatView = numInputs;
+
+ OriginsDescriptor concatDescriptor(static_cast<uint32_t>(numConcatView), MaxNumOfTensorDimensions);
+ std::vector<unsigned int>mergeDimSizes(MaxNumOfTensorDimensions, 0u);
+
+ unsigned int mergeDim = 0;
+
+ // This concatDim indicates the data format: 3 is the NHWC, 1 is the NCHW.
+ // axis could also be negative numbers. Negative axis are interpreted as counting from the end of the rank,
+ // i.e., axis + rank(values)-th dimension.
+ int32_t inputRank = static_cast<int32_t>(ToTensorInfo(inputs[0]).GetNumDimensions());
+ const unsigned int concatDimInput = static_cast<unsigned int>((inputRank + options->axis) % inputRank);
+
+ // ArmNN supports concatenation along the channel dimension for data formats NHWC and NCHW.
+ if (concatDimInput == 0 || concatDimInput == 2)
+ {
+ throw ParseException(
+ boost::str(
+ boost::format(
+ "Dimension %1% for concatenation is not supported by Armnn. "
+ "Node %2%")
+ % concatDimInput
+ % CHECK_LOCATION().AsString()));
+ }
+
+ for (unsigned int viewIndex = 0; viewIndex < numConcatView; ++viewIndex)
+ {
+ TensorInfo inputTensorInfo = ToTensorInfo(inputs[viewIndex]);
+
+ // process the input tensor info
+ armnnUtils::ProcessConcatInputTensorInfo(inputTensorInfo, concatDescriptor,
+ concatDimInput, viewIndex, mergeDimSizes, mergeDim);
+ }
+
+ auto layerName = boost::str(boost::format("Concatenation:%1%:%2%") % subgraphIndex % operatorIndex);
+ IConnectableLayer* layer = m_Network->AddMergerLayer(concatDescriptor, layerName.c_str());
+
+ BOOST_ASSERT(layer != nullptr);
+
+ armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
+ auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
+ if (concatDimInput == 3)
+ {
+ // Adding Fused Activation Layer after this moment....
+ for (unsigned int viewIndex = 0; viewIndex < numConcatView; ++viewIndex)
+ {
+ // add permute layers to swizzle the inputs
+ armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[viewIndex]);
+ IConnectableLayer* const swizzleLayer = SwizzleIn(*m_Network, layer, viewIndex, inputTensorInfo);
+
+ BOOST_ASSERT(swizzleLayer != nullptr);
+
+ // register the input connection slots for the layer
+ // only the tensors for the inputs are relevant, exclude the const tensors
+ RegisterInputSlots(subgraphIndex, operatorIndex, swizzleLayer, {inputTensorIndexes[viewIndex]});
+ }
+
+ // add permute layer to deswizzle the output
+ IConnectableLayer* const deswizzleLayer = DeswizzleOut(*m_Network, layer, 0, outputTensorInfo);
+
+ // add fused activation layer after the trailing swizzle layer
+ layer = AddFusedActivationLayer(deswizzleLayer, 0, options->fused_activation_function);
+ }
+ else
+ {
+ // set the layer output tensor info
+ layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+
+ // register the input connection slots for the layer, connections are made after all layers have been created
+ // only the tensors for the inputs are relevant, exclude the const tensors
+ RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes});
+ }
+
+ // register the output connection slots for the layer, connections are made after all layers have been created
+ auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
+ RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
+}
+
armnn::IConnectableLayer* TfLiteParser::AddFusedActivationLayer(armnn::IConnectableLayer* prevLayer,
unsigned int outputSlot,
tflite::ActivationFunctionType activationType)