From 5812504a07caf6d3b8d7b5179b34e4f8fb31cdb0 Mon Sep 17 00:00:00 2001 From: Colm Donelan Date: Wed, 8 Nov 2023 21:11:05 +0000 Subject: IVGCVSW-7861 Updating tflite parser to ignore VALIDATION: subgraphs. Signed-off-by: Colm Donelan Change-Id: If156b975d37868db77d7c6bb75d884652278e02a --- src/armnnTfLiteParser/TfLiteParser.cpp | 75 +++++++++++++++++++--------------- 1 file changed, 41 insertions(+), 34 deletions(-) diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index 67591b9339..049d6049a7 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -1005,57 +1005,64 @@ INetworkPtr TfLiteParserImpl::CreateNetworkFromModel() throw ParseException(fmt::format("Tflite Model pointer is null {}", CHECK_LOCATION().AsString())); } - if (m_Model->subgraphs.size() != 1) + // Identify which subgraph we are going to parse. We only support one subgraph but there may be validation + // subgraphs still stored in the model. We'll ignore these. In the tflite code base they are identified by + // their name beginning with "VALIDATION:". + size_t subgraphIndex = 0; + uint8_t usableSubgraphs = 0; + for (size_t i = 0; i < m_Model->subgraphs.size(); i++) + { + if (m_Model->subgraphs[i]->name.rfind("VALIDATION:", 0) != 0) + { + usableSubgraphs++; + subgraphIndex = i; + } + } + + if (usableSubgraphs > 1) { throw ParseException( - fmt::format("Current TfLite parser only supports 1 subgraph. Current one has: {} {}", - m_Model->subgraphs.size(), - CHECK_LOCATION().AsString())); + fmt::format("Current TfLite parser only supports 1 non validation subgraph. This model has: {} {}", + usableSubgraphs, CHECK_LOCATION().AsString())); } - size_t subgraphIndex = 0; size_t operatorIndex = 0; try { - for (SubgraphPtr const& subgraph : m_Model->subgraphs) - { - SetupInputLayerTensorInfos(subgraphIndex); - SetupConstantLayerTensorInfos(subgraphIndex); + const SubgraphPtr& subgraph = m_Model->subgraphs[subgraphIndex]; + SetupInputLayerTensorInfos(subgraphIndex); + SetupConstantLayerTensorInfos(subgraphIndex); - m_SubgraphConnections.emplace_back(subgraph->tensors.size()); - for (OperatorPtr const& op : subgraph->operators) - { - auto const& opCodePtr = m_Model->operator_codes[op->opcode_index]; + m_SubgraphConnections.emplace_back(subgraph->tensors.size()); + for (const OperatorPtr& op : subgraph->operators) + { + const auto& opCodePtr = m_Model->operator_codes[op->opcode_index]; // work around the introduction of the deprecated_builtin_code introduced in 2.4 in a backwards compatible manner #if defined(ARMNN_POST_TFLITE_2_3) - auto builtinCode = std::max(opCodePtr->builtin_code, - static_cast(opCodePtr->deprecated_builtin_code)); + auto builtinCode = std::max(opCodePtr->builtin_code, + static_cast(opCodePtr->deprecated_builtin_code)); #else - auto builtinCode = opCodePtr->builtin_code; + auto builtinCode = opCodePtr->builtin_code; #endif - if (builtinCode > tflite::BuiltinOperator_MAX) - { - throw ParseException(fmt::format("Operator code {} is out of range 0-{}. " - "subgraph:{} operator idx:{}. {}", - builtinCode, tflite::BuiltinOperator_MAX, subgraphIndex, - operatorIndex, CHECK_LOCATION().AsString())); - } - - // lookup and call the parser function - auto& parserFunction = m_ParserFunctions[builtinCode]; - (this->*parserFunction)(subgraphIndex, operatorIndex); - ++operatorIndex; + if (builtinCode > tflite::BuiltinOperator_MAX) + { + throw ParseException(fmt::format("Operator code {} is out of range 0-{}. " + "subgraph:{} operator idx:{}. {}", + builtinCode, tflite::BuiltinOperator_MAX, subgraphIndex, + operatorIndex, CHECK_LOCATION().AsString())); } - SetupInputLayers(subgraphIndex); - SetupOutputLayers(subgraphIndex); - SetupConstantLayers(subgraphIndex); - - ++subgraphIndex; - operatorIndex = 0; + // lookup and call the parser function + auto& parserFunction = m_ParserFunctions[builtinCode]; + (this->*parserFunction)(subgraphIndex, operatorIndex); + ++operatorIndex; } + + SetupInputLayers(subgraphIndex); + SetupOutputLayers(subgraphIndex); + SetupConstantLayers(subgraphIndex); } catch (const ParseException& e) { -- cgit v1.2.1