diff options
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp')
-rw-r--r-- | src/armnnTfLiteParser/TfLiteParser.cpp | 84 |
1 files changed, 84 insertions, 0 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index f689deedf6..86688add9d 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -468,6 +468,7 @@ TfLiteParser::TfLiteParser() m_ParserFunctions[tflite::BuiltinOperator_PAD] = &TfLiteParser::ParsePad; m_ParserFunctions[tflite::BuiltinOperator_SPLIT] = &TfLiteParser::ParseSplit; m_ParserFunctions[tflite::BuiltinOperator_TANH] = &TfLiteParser::ParseTanH; + m_ParserFunctions[tflite::BuiltinOperator_UNPACK] = &TfLiteParser::ParseUnpack; } void TfLiteParser::ResetParser() @@ -1867,6 +1868,83 @@ void TfLiteParser::ParseDetectionPostProcess(size_t subgraphIndex, size_t operat outputTensorIndexes[3]}); } +void TfLiteParser::ParseUnpack(size_t subgraphIndex, size_t operatorIndex) +{ + CHECK_MODEL(m_Model, subgraphIndex, operatorIndex); + + const auto & operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex]; + const auto * options = operatorPtr->builtin_options.AsUnpackOptions(); + + // This unpackAxis indicates the axis to unpack + const unsigned int unpackAxis = CHECKED_NON_NEGATIVE(options->axis); + + auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex); + CHECK_VALID_SIZE(inputs.size(), 1); + + armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]); + unsigned int unpackNum = CHECKED_NON_NEGATIVE(options->num); + // If num is not defined, automatically infer from the length of the dimension axis. + if(unpackNum == 0) + { + unpackNum = inputTensorInfo.GetShape()[unpackAxis]; + } + + // If unpack number cannot be inferred and is still zero, throw ParseException. + if(unpackNum == 0) + { + throw ParseException("Number to unpack must greater than zero."); + } + + auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex); + CHECK_VALID_SIZE(outputs.size(), unpackNum); + + auto inputDimSize = inputTensorInfo.GetNumDimensions(); + std::vector<unsigned int> unpackDimSizes(inputDimSize); + + // Add current input shape to unpackDimSizes + for (unsigned int i = 0; i < inputDimSize; ++i) + { + unpackDimSizes[i] = inputTensorInfo.GetShape()[i]; + } + + if (unpackDimSizes[unpackAxis] != unpackNum) + { + throw ParseException("Number to unpack must be the same as length of the dimension to " + "unpack along."); + } + + unpackDimSizes[unpackAxis] /= unpackNum; + + SplitterDescriptor splitDesc(unpackNum, static_cast<unsigned int>(unpackDimSizes.size())); + for (unsigned int j = 0; j < unpackNum; ++j) + { + // Set the size of the views. + for (unsigned int dimIdx = 0; dimIdx < unpackDimSizes.size(); ++dimIdx) + { + splitDesc.SetViewSize(j, dimIdx, unpackDimSizes[dimIdx]); + } + splitDesc.SetViewOriginCoord(j, unpackAxis, unpackDimSizes[unpackAxis] * j); + } + + auto layerName = boost::str(boost::format("Unpack:%1%:%2%") % subgraphIndex % operatorIndex); + IConnectableLayer* layer = m_Network->AddSplitterLayer(splitDesc, layerName.c_str()); + + auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex)); + RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]}); + + TensorShape outShape = TensorShape(static_cast<unsigned int>(unpackDimSizes.size()), + unpackDimSizes.data()); + + for (unsigned int k = 0; k < layer->GetNumOutputSlots(); ++k) + { + layer->GetOutputSlot(k).SetTensorInfo(armnn::TensorInfo(outShape, + inputTensorInfo.GetDataType())); + } + + auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex)); + RegisterOutputSlots(subgraphIndex, operatorIndex, layer, outputTensorIndexes); +} + void TfLiteParser::ParseSplit(size_t subgraphIndex, size_t operatorIndex) { CHECK_MODEL(m_Model, subgraphIndex, operatorIndex); @@ -1876,6 +1954,12 @@ void TfLiteParser::ParseSplit(size_t subgraphIndex, size_t operatorIndex) const unsigned int numSplits = CHECKED_NON_NEGATIVE(options->num_splits); + // If number of splits cannot be inferred and is zero, throw ParseException. + if(numSplits == 0) + { + throw ParseException("Number to splits must greater than zero."); + } + auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex); CHECK_VALID_SIZE(inputs.size(), 2); auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex); |