From a2493a0483f19fe9654be63a15badfb0834aaff6 Mon Sep 17 00:00:00 2001 From: Narumol Prangnawarat Date: Wed, 19 Aug 2020 14:39:07 +0100 Subject: IVGCVSW-5012 Add importEnabled option for OptimizerOptions * Default importEnabled to false * Improve error messages Signed-off-by: Narumol Prangnawarat Change-Id: I17f78986aa1d23e48b0844297a52029b1a9bbe3e --- src/armnn/Network.cpp | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) (limited to 'src/armnn/Network.cpp') diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp index 94a9961a81..dec9468d7b 100644 --- a/src/armnn/Network.cpp +++ b/src/armnn/Network.cpp @@ -861,7 +861,8 @@ EdgeStrategy CalculateEdgeStrategy(BackendsMap& backends, ITensorHandleFactory::FactoryId srcFactoryId, const Layer& layer, const Layer& connectedLayer, - TensorHandleFactoryRegistry& registry) + TensorHandleFactoryRegistry& registry, + bool importEnabled) { auto toBackend = backends.find(connectedLayer.GetBackendId()); ARMNN_ASSERT_MSG(toBackend != backends.end(), "Backend id not found for the connected layer"); @@ -899,7 +900,7 @@ EdgeStrategy CalculateEdgeStrategy(BackendsMap& backends, // Search for export/import options ITensorHandleFactory* srcFactory = registry.GetFactory(srcFactoryId); - if (srcFactory->GetExportFlags() != 0) + if (srcFactory->GetExportFlags() != 0 && importEnabled) { for (auto&& pref : dstPrefs) { @@ -945,11 +946,12 @@ EdgeStrategy CalculateEdgeStrategy(BackendsMap& backends, OptimizationResult SelectTensorHandleStrategy(Graph& optGraph, BackendsMap& backends, TensorHandleFactoryRegistry& registry, + bool importEnabled, Optional&> errMessages) { OptimizationResult result; - optGraph.ForEachLayer([&backends, ®istry, &result, &errMessages](Layer* layer) + optGraph.ForEachLayer([&backends, ®istry, &result, &errMessages, importEnabled](Layer* layer) { ARMNN_ASSERT(layer); @@ -985,7 +987,8 @@ OptimizationResult SelectTensorHandleStrategy(Graph& optGraph, { const Layer& connectedLayer = connection->GetOwningLayer(); - EdgeStrategy strategy = CalculateEdgeStrategy(backends, slotOption, *layer, connectedLayer, registry); + EdgeStrategy strategy = CalculateEdgeStrategy(backends, slotOption, *layer, connectedLayer, + registry, importEnabled); if (strategy == EdgeStrategy::Undefined) { @@ -1122,6 +1125,7 @@ IOptimizedNetworkPtr Optimize(const INetwork& inNetwork, OptimizationResult strategyResult = SelectTensorHandleStrategy(optGraph, backends, tensorHandleFactoryRegistry, + options.m_ImportEnabled, messages); if (strategyResult.m_Error) { -- cgit v1.2.1