aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorColm Donelan <colm.donelan@arm.com>2023-11-08 21:11:05 +0000
committerColm Donelan <colm.donelan@arm.com>2023-11-09 15:59:58 +0000
commit5812504a07caf6d3b8d7b5179b34e4f8fb31cdb0 (patch)
treeb2aa714b4327a3fd9b6d87cb3578affff2e39fd4
parent2199017ce9645503207b1803c0a1689bf59427f6 (diff)
downloadarmnn-5812504a07caf6d3b8d7b5179b34e4f8fb31cdb0.tar.gz
IVGCVSW-7861 Updating tflite parser to ignore VALIDATION: subgraphs.
Signed-off-by: Colm Donelan <colm.donelan@arm.com> Change-Id: If156b975d37868db77d7c6bb75d884652278e02a
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp75
1 files changed, 41 insertions, 34 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 67591b9339..049d6049a7 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -1005,57 +1005,64 @@ INetworkPtr TfLiteParserImpl::CreateNetworkFromModel()
throw ParseException(fmt::format("Tflite Model pointer is null {}", CHECK_LOCATION().AsString()));
}
- if (m_Model->subgraphs.size() != 1)
+ // Identify which subgraph we are going to parse. We only support one subgraph but there may be validation
+ // subgraphs still stored in the model. We'll ignore these. In the tflite code base they are identified by
+ // their name beginning with "VALIDATION:".
+ size_t subgraphIndex = 0;
+ uint8_t usableSubgraphs = 0;
+ for (size_t i = 0; i < m_Model->subgraphs.size(); i++)
+ {
+ if (m_Model->subgraphs[i]->name.rfind("VALIDATION:", 0) != 0)
+ {
+ usableSubgraphs++;
+ subgraphIndex = i;
+ }
+ }
+
+ if (usableSubgraphs > 1)
{
throw ParseException(
- fmt::format("Current TfLite parser only supports 1 subgraph. Current one has: {} {}",
- m_Model->subgraphs.size(),
- CHECK_LOCATION().AsString()));
+ fmt::format("Current TfLite parser only supports 1 non validation subgraph. This model has: {} {}",
+ usableSubgraphs, CHECK_LOCATION().AsString()));
}
- size_t subgraphIndex = 0;
size_t operatorIndex = 0;
try
{
- for (SubgraphPtr const& subgraph : m_Model->subgraphs)
- {
- SetupInputLayerTensorInfos(subgraphIndex);
- SetupConstantLayerTensorInfos(subgraphIndex);
+ const SubgraphPtr& subgraph = m_Model->subgraphs[subgraphIndex];
+ SetupInputLayerTensorInfos(subgraphIndex);
+ SetupConstantLayerTensorInfos(subgraphIndex);
- m_SubgraphConnections.emplace_back(subgraph->tensors.size());
- for (OperatorPtr const& op : subgraph->operators)
- {
- auto const& opCodePtr = m_Model->operator_codes[op->opcode_index];
+ m_SubgraphConnections.emplace_back(subgraph->tensors.size());
+ for (const OperatorPtr& op : subgraph->operators)
+ {
+ const auto& opCodePtr = m_Model->operator_codes[op->opcode_index];
// work around the introduction of the deprecated_builtin_code introduced in 2.4 in a backwards compatible manner
#if defined(ARMNN_POST_TFLITE_2_3)
- auto builtinCode = std::max(opCodePtr->builtin_code,
- static_cast<tflite::BuiltinOperator>(opCodePtr->deprecated_builtin_code));
+ auto builtinCode = std::max(opCodePtr->builtin_code,
+ static_cast<tflite::BuiltinOperator>(opCodePtr->deprecated_builtin_code));
#else
- auto builtinCode = opCodePtr->builtin_code;
+ auto builtinCode = opCodePtr->builtin_code;
#endif
- if (builtinCode > tflite::BuiltinOperator_MAX)
- {
- throw ParseException(fmt::format("Operator code {} is out of range 0-{}. "
- "subgraph:{} operator idx:{}. {}",
- builtinCode, tflite::BuiltinOperator_MAX, subgraphIndex,
- operatorIndex, CHECK_LOCATION().AsString()));
- }
-
- // lookup and call the parser function
- auto& parserFunction = m_ParserFunctions[builtinCode];
- (this->*parserFunction)(subgraphIndex, operatorIndex);
- ++operatorIndex;
+ if (builtinCode > tflite::BuiltinOperator_MAX)
+ {
+ throw ParseException(fmt::format("Operator code {} is out of range 0-{}. "
+ "subgraph:{} operator idx:{}. {}",
+ builtinCode, tflite::BuiltinOperator_MAX, subgraphIndex,
+ operatorIndex, CHECK_LOCATION().AsString()));
}
- SetupInputLayers(subgraphIndex);
- SetupOutputLayers(subgraphIndex);
- SetupConstantLayers(subgraphIndex);
-
- ++subgraphIndex;
- operatorIndex = 0;
+ // lookup and call the parser function
+ auto& parserFunction = m_ParserFunctions[builtinCode];
+ (this->*parserFunction)(subgraphIndex, operatorIndex);
+ ++operatorIndex;
}
+
+ SetupInputLayers(subgraphIndex);
+ SetupOutputLayers(subgraphIndex);
+ SetupConstantLayers(subgraphIndex);
}
catch (const ParseException& e)
{