diff options
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp')
-rw-r--r-- | src/armnnTfLiteParser/TfLiteParser.cpp | 89 |
1 files changed, 89 insertions, 0 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index b9a3522736..c00c2188a9 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -464,6 +464,7 @@ TfLiteParser::TfLiteParser() m_ParserFunctions[tflite::BuiltinOperator_MUL] = &TfLiteParser::ParseMul; m_ParserFunctions[tflite::BuiltinOperator_MEAN] = &TfLiteParser::ParseMean; m_ParserFunctions[tflite::BuiltinOperator_PAD] = &TfLiteParser::ParsePad; + m_ParserFunctions[tflite::BuiltinOperator_SPLIT] = &TfLiteParser::ParseSplit; } void TfLiteParser::ResetParser() @@ -1851,6 +1852,94 @@ void TfLiteParser::ParseDetectionPostProcess(size_t subgraphIndex, size_t operat outputTensorIndexes[3]}); } +void TfLiteParser::ParseSplit(size_t subgraphIndex, size_t operatorIndex) +{ + CHECK_MODEL(m_Model, subgraphIndex, operatorIndex); + + const auto & operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex]; + const auto * options = operatorPtr->builtin_options.AsSplitOptions(); + + const unsigned int numSplits = CHECKED_NON_NEGATIVE(options->num_splits); + + auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex); + CHECK_VALID_SIZE(inputs.size(), 2); + auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex); + CHECK_VALID_SIZE(outputs.size(), numSplits); + + armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]); + armnn::TensorInfo axisTensorInfo = ToTensorInfo(inputs[1]); + + // This splitDim indicates the data format: 3 is the NHWC, 1 is the NCHW. + const unsigned int splitDim = static_cast<unsigned int>(axisTensorInfo.GetShape()[0]); + + // Armnn supports split along the channel dimension for data formats NHWC and NCHW. + if (splitDim == 0 || splitDim == 2) + { + throw ParseException( + boost::str( + boost::format( + "Dimension %1% for split is not supported by Armnn. %2%") + % splitDim + % CHECK_LOCATION().AsString())); + } + + auto inputDimSize = inputTensorInfo.GetNumDimensions(); + if (inputDimSize != MaxNumOfTensorDimensions) + { + throw ParseException( + boost::str( + boost::format( + "The number of dimensions: %1% for input tensors of the " + "split op should be %2% %3%") + % inputTensorInfo.GetNumDimensions() + % MaxNumOfTensorDimensions + % CHECK_LOCATION().AsString())); + } + + std::vector<unsigned int> splitterDimSizes(inputDimSize); + + // Add current input shape to splitterDimSizes + for (unsigned int i = 0; i < inputDimSize; ++i) + { + splitterDimSizes[i] = inputTensorInfo.GetShape()[i]; + } + + if (splitterDimSizes[splitDim] % numSplits != 0) + { + throw ParseException("Number of splits must evenly divide the dimension"); + } + splitterDimSizes[splitDim] /= numSplits; + + SplitterDescriptor splitDesc(numSplits); + for (unsigned int j = 0; j < numSplits; ++j) + { + // Set the size of the views. + for (unsigned int dimIdx = 0; dimIdx < splitterDimSizes.size(); ++dimIdx) + { + splitDesc.SetViewSize(j, dimIdx, splitterDimSizes[dimIdx]); + } + splitDesc.SetViewOriginCoord(j, splitDim, splitterDimSizes[splitDim] * j); + } + + auto layerName = boost::str(boost::format("Split:%1%:%2%") % subgraphIndex % operatorIndex); + IConnectableLayer* layer = m_Network->AddSplitterLayer(splitDesc, layerName.c_str()); + + auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex)); + RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]}); + + TensorShape outShape = TensorShape(static_cast<unsigned int>(splitterDimSizes.size()), + splitterDimSizes.data()); + + for (unsigned int k = 0; k < layer->GetNumOutputSlots(); ++k) + { + layer->GetOutputSlot(k).SetTensorInfo(armnn::TensorInfo(outShape, + inputTensorInfo.GetDataType())); + } + + auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex)); + RegisterOutputSlots(subgraphIndex, operatorIndex, layer, outputTensorIndexes); +} + armnn::IConnectableLayer* TfLiteParser::AddFusedActivationLayer(armnn::IConnectableLayer* prevLayer, unsigned int outputSlot, tflite::ActivationFunctionType activationType) |