diff options
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp')
-rw-r--r-- | src/armnnTfLiteParser/TfLiteParser.cpp | 153 |
1 files changed, 103 insertions, 50 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index e04f9adc9f..937131ccd7 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -427,42 +427,46 @@ bool CheckShape(const armnn::TensorShape& actual, const std::vector<int32_t>& ex } // <anonymous> -TfLiteParser::TfLiteParser() -: m_Network(nullptr, nullptr) +TfLiteParser::TfLiteParser(const Optional<ITfLiteParser::TfLiteParserOptions>& options) +: m_Options(options) +, m_Network(nullptr, nullptr) , m_ParserFunctions(tflite::BuiltinOperator_MAX+1, &TfLiteParser::ParseUnsupportedOperator) { // register supported operators - m_ParserFunctions[tflite::BuiltinOperator_AVERAGE_POOL_2D] = &TfLiteParser::ParseAveragePool2D; - m_ParserFunctions[tflite::BuiltinOperator_BATCH_TO_SPACE_ND] = &TfLiteParser::ParseBatchToSpaceND; - m_ParserFunctions[tflite::BuiltinOperator_CONCATENATION] = &TfLiteParser::ParseConcatenation; - m_ParserFunctions[tflite::BuiltinOperator_CONV_2D] = &TfLiteParser::ParseConv2D; - m_ParserFunctions[tflite::BuiltinOperator_DEPTHWISE_CONV_2D] = &TfLiteParser::ParseDepthwiseConv2D; - m_ParserFunctions[tflite::BuiltinOperator_CUSTOM] = &TfLiteParser::ParseDetectionPostProcess; - m_ParserFunctions[tflite::BuiltinOperator_FULLY_CONNECTED] = &TfLiteParser::ParseFullyConnected; - m_ParserFunctions[tflite::BuiltinOperator_LOGISTIC] = &TfLiteParser::ParseLogistic; - m_ParserFunctions[tflite::BuiltinOperator_L2_NORMALIZATION] = &TfLiteParser::ParseL2Normalization; - m_ParserFunctions[tflite::BuiltinOperator_MAX_POOL_2D] = &TfLiteParser::ParseMaxPool2D; - m_ParserFunctions[tflite::BuiltinOperator_MAXIMUM] = &TfLiteParser::ParseMaximum; - m_ParserFunctions[tflite::BuiltinOperator_MINIMUM] = &TfLiteParser::ParseMinimum; - m_ParserFunctions[tflite::BuiltinOperator_RELU] = &TfLiteParser::ParseRelu; - m_ParserFunctions[tflite::BuiltinOperator_RELU6] = &TfLiteParser::ParseRelu6; - m_ParserFunctions[tflite::BuiltinOperator_RESHAPE] = &TfLiteParser::ParseReshape; - m_ParserFunctions[tflite::BuiltinOperator_RESIZE_BILINEAR] = &TfLiteParser::ParseResizeBilinear; - m_ParserFunctions[tflite::BuiltinOperator_SOFTMAX] = &TfLiteParser::ParseSoftmax; - m_ParserFunctions[tflite::BuiltinOperator_SPACE_TO_BATCH_ND] = &TfLiteParser::ParseSpaceToBatchND; - m_ParserFunctions[tflite::BuiltinOperator_SQUEEZE] = &TfLiteParser::ParseSqueeze; - m_ParserFunctions[tflite::BuiltinOperator_STRIDED_SLICE] = &TfLiteParser::ParseStridedSlice; - m_ParserFunctions[tflite::BuiltinOperator_SUB] = &TfLiteParser::ParseSub; - m_ParserFunctions[tflite::BuiltinOperator_ADD] = &TfLiteParser::ParseAdd; - m_ParserFunctions[tflite::BuiltinOperator_MUL] = &TfLiteParser::ParseMul; - m_ParserFunctions[tflite::BuiltinOperator_MEAN] = &TfLiteParser::ParseMean; - m_ParserFunctions[tflite::BuiltinOperator_PACK] = &TfLiteParser::ParsePack; - m_ParserFunctions[tflite::BuiltinOperator_PAD] = &TfLiteParser::ParsePad; - m_ParserFunctions[tflite::BuiltinOperator_SPLIT] = &TfLiteParser::ParseSplit; - m_ParserFunctions[tflite::BuiltinOperator_TANH] = &TfLiteParser::ParseTanH; - m_ParserFunctions[tflite::BuiltinOperator_TRANSPOSE] = &TfLiteParser::ParseTranspose; - m_ParserFunctions[tflite::BuiltinOperator_TRANSPOSE_CONV] = &TfLiteParser::ParseTransposeConv; - m_ParserFunctions[tflite::BuiltinOperator_UNPACK] = &TfLiteParser::ParseUnpack; + m_ParserFunctions[tflite::BuiltinOperator_AVERAGE_POOL_2D] = &TfLiteParser::ParseAveragePool2D; + m_ParserFunctions[tflite::BuiltinOperator_BATCH_TO_SPACE_ND] = &TfLiteParser::ParseBatchToSpaceND; + m_ParserFunctions[tflite::BuiltinOperator_CONCATENATION] = &TfLiteParser::ParseConcatenation; + m_ParserFunctions[tflite::BuiltinOperator_CONV_2D] = &TfLiteParser::ParseConv2D; + m_ParserFunctions[tflite::BuiltinOperator_DEPTHWISE_CONV_2D] = &TfLiteParser::ParseDepthwiseConv2D; + m_ParserFunctions[tflite::BuiltinOperator_CUSTOM] = &TfLiteParser::ParseCustomOperator; + m_ParserFunctions[tflite::BuiltinOperator_FULLY_CONNECTED] = &TfLiteParser::ParseFullyConnected; + m_ParserFunctions[tflite::BuiltinOperator_LOGISTIC] = &TfLiteParser::ParseLogistic; + m_ParserFunctions[tflite::BuiltinOperator_L2_NORMALIZATION] = &TfLiteParser::ParseL2Normalization; + m_ParserFunctions[tflite::BuiltinOperator_MAX_POOL_2D] = &TfLiteParser::ParseMaxPool2D; + m_ParserFunctions[tflite::BuiltinOperator_MAXIMUM] = &TfLiteParser::ParseMaximum; + m_ParserFunctions[tflite::BuiltinOperator_MINIMUM] = &TfLiteParser::ParseMinimum; + m_ParserFunctions[tflite::BuiltinOperator_RELU] = &TfLiteParser::ParseRelu; + m_ParserFunctions[tflite::BuiltinOperator_RELU6] = &TfLiteParser::ParseRelu6; + m_ParserFunctions[tflite::BuiltinOperator_RESHAPE] = &TfLiteParser::ParseReshape; + m_ParserFunctions[tflite::BuiltinOperator_RESIZE_BILINEAR] = &TfLiteParser::ParseResizeBilinear; + m_ParserFunctions[tflite::BuiltinOperator_SOFTMAX] = &TfLiteParser::ParseSoftmax; + m_ParserFunctions[tflite::BuiltinOperator_SPACE_TO_BATCH_ND] = &TfLiteParser::ParseSpaceToBatchND; + m_ParserFunctions[tflite::BuiltinOperator_SQUEEZE] = &TfLiteParser::ParseSqueeze; + m_ParserFunctions[tflite::BuiltinOperator_STRIDED_SLICE] = &TfLiteParser::ParseStridedSlice; + m_ParserFunctions[tflite::BuiltinOperator_SUB] = &TfLiteParser::ParseSub; + m_ParserFunctions[tflite::BuiltinOperator_ADD] = &TfLiteParser::ParseAdd; + m_ParserFunctions[tflite::BuiltinOperator_MUL] = &TfLiteParser::ParseMul; + m_ParserFunctions[tflite::BuiltinOperator_MEAN] = &TfLiteParser::ParseMean; + m_ParserFunctions[tflite::BuiltinOperator_PACK] = &TfLiteParser::ParsePack; + m_ParserFunctions[tflite::BuiltinOperator_PAD] = &TfLiteParser::ParsePad; + m_ParserFunctions[tflite::BuiltinOperator_SPLIT] = &TfLiteParser::ParseSplit; + m_ParserFunctions[tflite::BuiltinOperator_TANH] = &TfLiteParser::ParseTanH; + m_ParserFunctions[tflite::BuiltinOperator_TRANSPOSE] = &TfLiteParser::ParseTranspose; + m_ParserFunctions[tflite::BuiltinOperator_TRANSPOSE_CONV] = &TfLiteParser::ParseTransposeConv; + m_ParserFunctions[tflite::BuiltinOperator_UNPACK] = &TfLiteParser::ParseUnpack; + + // register supported custom operators + m_CustomParserFunctions["TFLite_Detection_PostProcess"] = &TfLiteParser::ParseDetectionPostProcess; } void TfLiteParser::ResetParser() @@ -675,25 +679,74 @@ void TfLiteParser::RegisterConsumerOfTensor(size_t subgraphIndex, tensorSlots.inputSlots.push_back(slot); } +void TfLiteParser::ParseCustomOperator(size_t subgraphIndex, size_t operatorIndex) +{ + CHECK_MODEL(m_Model, subgraphIndex, operatorIndex); + + // NOTE: By default we presume the custom operator is not supported + auto customParserFunction = &TfLiteParser::ParseUnsupportedOperator; + + // Identify custom code defined for custom operator + const auto& operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex]; + const auto& customCode = m_Model->operator_codes[operatorPtr->opcode_index]->custom_code; + + // Find parser function that correspondes to custom code (if any) + auto iterator = m_CustomParserFunctions.find(customCode); + if (iterator != m_CustomParserFunctions.end()) + { + customParserFunction = iterator->second; + } + + // Run parser function + (this->*customParserFunction)(subgraphIndex, operatorIndex); +} + void TfLiteParser::ParseUnsupportedOperator(size_t subgraphIndex, size_t operatorIndex) { CHECK_MODEL(m_Model, subgraphIndex, operatorIndex); + const auto & operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex]; - // + auto opcodeIndex = operatorPtr->opcode_index; - auto opcode = m_Model->operator_codes[opcodeIndex]->builtin_code; + auto opcode = m_Model->operator_codes[opcodeIndex]->builtin_code; - throw ParseException( - boost::str( - boost::format("Operator not supported. " - "subgraph:%1% operator:%2% " - "opcode_index:%3% opcode:%4% / %5% %6%") % - subgraphIndex % - operatorIndex % - opcodeIndex % - opcode % - tflite::EnumNameBuiltinOperator(opcode) % - CHECK_LOCATION().AsString())); + if (!m_Options || !m_Options.value().m_StandInLayerForUnsupported) + { + // Do not add StandInLayer, throw ParseException instead + throw ParseException( + boost::str( + boost::format("Operator not supported. " + "subgraph:%1% operator:%2% " + "opcode_index:%3% opcode:%4% / %5% %6%") % + subgraphIndex % + operatorIndex % + opcodeIndex % + opcode % + tflite::EnumNameBuiltinOperator(opcode) % + CHECK_LOCATION().AsString())); + } + + auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex); + auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex); + + const unsigned int numInputs = boost::numeric_cast<unsigned int>(inputs.size()); + const unsigned int numOutputs = boost::numeric_cast<unsigned int>(outputs.size()); + + StandInDescriptor descriptor(numInputs, numOutputs); + auto layerName = boost::str(boost::format("StandIn:%1%:%2%:%3%") % subgraphIndex % operatorIndex % opcode); + + // Add a non-executable StandInLayer as a placeholder for any unsupported operator + IConnectableLayer* layer = m_Network->AddStandInLayer(descriptor, layerName.c_str()); + for (unsigned int i = 0u; i < numOutputs; ++i) + { + layer->GetOutputSlot(i).SetTensorInfo(ToTensorInfo(outputs[i])); + } + + auto inputTensorIds = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex)); + auto outputTensorIds = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex)); + + RegisterInputSlots(subgraphIndex, operatorIndex, layer, inputTensorIds); + RegisterOutputSlots(subgraphIndex, operatorIndex, layer, outputTensorIds); } void TfLiteParser::ParseConv2D(size_t subgraphIndex, size_t operatorIndex) @@ -2761,14 +2814,14 @@ std::vector<std::string> TfLiteParser::GetSubgraphOutputTensorNames(size_t subgr return result; } -ITfLiteParser* ITfLiteParser::CreateRaw() +ITfLiteParser* ITfLiteParser::CreateRaw(const Optional<ITfLiteParser::TfLiteParserOptions>& options) { - return new TfLiteParser(); + return new TfLiteParser(options); } -ITfLiteParserPtr ITfLiteParser::Create() +ITfLiteParserPtr ITfLiteParser::Create(const Optional<ITfLiteParser::TfLiteParserOptions>& options) { - return ITfLiteParserPtr(CreateRaw(), &ITfLiteParser::Destroy); + return ITfLiteParserPtr(CreateRaw(options), &ITfLiteParser::Destroy); } void ITfLiteParser::Destroy(ITfLiteParser* parser) |