From 479045bdcac4faddbf567aa0f73d2899881f341c Mon Sep 17 00:00:00 2001 From: Sadik Armagan Date: Mon, 1 Oct 2018 11:51:37 +0100 Subject: IVGCVSW-1787 Add Support for Concatenation on TfLite parser * Concatenation Parser function added to the TfLite Parser Change-Id: I42a42cd765ea09a30841c66b1942b9e09a876b10 --- src/armnnTfLiteParser/TfLiteParser.cpp | 98 +++++++++++++++++++++++++++++++++- 1 file changed, 96 insertions(+), 2 deletions(-) (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp') diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index 13e4604490..66746e488b 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -10,6 +10,7 @@ #include // armnnUtils: +#include #include #include @@ -452,13 +453,14 @@ TfLiteParser::TfLiteParser() { // register supported operators m_ParserFunctions[tflite::BuiltinOperator_AVERAGE_POOL_2D] = &TfLiteParser::ParseAveragePool2D; + m_ParserFunctions[tflite::BuiltinOperator_CONCATENATION] = &TfLiteParser::ParseConcatenation; m_ParserFunctions[tflite::BuiltinOperator_CONV_2D] = &TfLiteParser::ParseConv2D; m_ParserFunctions[tflite::BuiltinOperator_DEPTHWISE_CONV_2D] = &TfLiteParser::ParseDepthwiseConv2D; - m_ParserFunctions[tflite::BuiltinOperator_SOFTMAX] = &TfLiteParser::ParseSoftmax; - m_ParserFunctions[tflite::BuiltinOperator_SQUEEZE] = &TfLiteParser::ParseSqueeze; m_ParserFunctions[tflite::BuiltinOperator_RELU] = &TfLiteParser::ParseRelu; m_ParserFunctions[tflite::BuiltinOperator_RELU6] = &TfLiteParser::ParseRelu6; m_ParserFunctions[tflite::BuiltinOperator_RESHAPE] = &TfLiteParser::ParseReshape; + m_ParserFunctions[tflite::BuiltinOperator_SOFTMAX] = &TfLiteParser::ParseSoftmax; + m_ParserFunctions[tflite::BuiltinOperator_SQUEEZE] = &TfLiteParser::ParseSqueeze; } void TfLiteParser::ResetParser() @@ -1097,6 +1099,98 @@ void TfLiteParser::ParseReshape(size_t subgraphIndex, size_t operatorIndex) RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]}); } +void TfLiteParser::ParseConcatenation(size_t subgraphIndex, size_t operatorIndex) +{ + CHECK_MODEL(m_Model, subgraphIndex, operatorIndex); + + const auto & operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex]; + const auto * options = operatorPtr->builtin_options.AsConcatenationOptions(); + + CHECK_SUPPORTED_FUSED_ACTIVATION(options, subgraphIndex, operatorIndex); + + auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex); + auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex); + CHECK_VALID_SIZE(outputs.size(), 1); + + unsigned int numInputs = static_cast(inputs.size()); + unsigned int numConcatView = numInputs; + + OriginsDescriptor concatDescriptor(static_cast(numConcatView), MaxNumOfTensorDimensions); + std::vectormergeDimSizes(MaxNumOfTensorDimensions, 0u); + + unsigned int mergeDim = 0; + + // This concatDim indicates the data format: 3 is the NHWC, 1 is the NCHW. + // axis could also be negative numbers. Negative axis are interpreted as counting from the end of the rank, + // i.e., axis + rank(values)-th dimension. + int32_t inputRank = static_cast(ToTensorInfo(inputs[0]).GetNumDimensions()); + const unsigned int concatDimInput = static_cast((inputRank + options->axis) % inputRank); + + // ArmNN supports concatenation along the channel dimension for data formats NHWC and NCHW. + if (concatDimInput == 0 || concatDimInput == 2) + { + throw ParseException( + boost::str( + boost::format( + "Dimension %1% for concatenation is not supported by Armnn. " + "Node %2%") + % concatDimInput + % CHECK_LOCATION().AsString())); + } + + for (unsigned int viewIndex = 0; viewIndex < numConcatView; ++viewIndex) + { + TensorInfo inputTensorInfo = ToTensorInfo(inputs[viewIndex]); + + // process the input tensor info + armnnUtils::ProcessConcatInputTensorInfo(inputTensorInfo, concatDescriptor, + concatDimInput, viewIndex, mergeDimSizes, mergeDim); + } + + auto layerName = boost::str(boost::format("Concatenation:%1%:%2%") % subgraphIndex % operatorIndex); + IConnectableLayer* layer = m_Network->AddMergerLayer(concatDescriptor, layerName.c_str()); + + BOOST_ASSERT(layer != nullptr); + + armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]); + auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex)); + if (concatDimInput == 3) + { + // Adding Fused Activation Layer after this moment.... + for (unsigned int viewIndex = 0; viewIndex < numConcatView; ++viewIndex) + { + // add permute layers to swizzle the inputs + armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[viewIndex]); + IConnectableLayer* const swizzleLayer = SwizzleIn(*m_Network, layer, viewIndex, inputTensorInfo); + + BOOST_ASSERT(swizzleLayer != nullptr); + + // register the input connection slots for the layer + // only the tensors for the inputs are relevant, exclude the const tensors + RegisterInputSlots(subgraphIndex, operatorIndex, swizzleLayer, {inputTensorIndexes[viewIndex]}); + } + + // add permute layer to deswizzle the output + IConnectableLayer* const deswizzleLayer = DeswizzleOut(*m_Network, layer, 0, outputTensorInfo); + + // add fused activation layer after the trailing swizzle layer + layer = AddFusedActivationLayer(deswizzleLayer, 0, options->fused_activation_function); + } + else + { + // set the layer output tensor info + layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + + // register the input connection slots for the layer, connections are made after all layers have been created + // only the tensors for the inputs are relevant, exclude the const tensors + RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes}); + } + + // register the output connection slots for the layer, connections are made after all layers have been created + auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex)); + RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]}); +} + armnn::IConnectableLayer* TfLiteParser::AddFusedActivationLayer(armnn::IConnectableLayer* prevLayer, unsigned int outputSlot, tflite::ActivationFunctionType activationType) -- cgit v1.2.1