diff options
author | Narumol Prangnawarat <narumol.prangnawarat@arm.com> | 2020-08-19 14:39:07 +0100 |
---|---|---|
committer | Narumol Prangnawarat <narumol.prangnawarat@arm.com> | 2020-08-19 14:43:09 +0100 |
commit | a2493a0483f19fe9654be63a15badfb0834aaff6 (patch) | |
tree | a526ee366356b173947d13162ba4b34a965f23ed /src/armnn | |
parent | 37c8197c9153924ebe934d5b521c0985eab9e477 (diff) | |
download | armnn-a2493a0483f19fe9654be63a15badfb0834aaff6.tar.gz |
IVGCVSW-5012 Add importEnabled option for OptimizerOptions
* Default importEnabled to false
* Improve error messages
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com>
Change-Id: I17f78986aa1d23e48b0844297a52029b1a9bbe3e
Diffstat (limited to 'src/armnn')
-rw-r--r-- | src/armnn/Network.cpp | 12 | ||||
-rw-r--r-- | src/armnn/Network.hpp | 1 | ||||
-rw-r--r-- | src/armnn/test/TensorHandleStrategyTest.cpp | 2 |
3 files changed, 10 insertions, 5 deletions
diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp index 94a9961a81..dec9468d7b 100644 --- a/src/armnn/Network.cpp +++ b/src/armnn/Network.cpp @@ -861,7 +861,8 @@ EdgeStrategy CalculateEdgeStrategy(BackendsMap& backends, ITensorHandleFactory::FactoryId srcFactoryId, const Layer& layer, const Layer& connectedLayer, - TensorHandleFactoryRegistry& registry) + TensorHandleFactoryRegistry& registry, + bool importEnabled) { auto toBackend = backends.find(connectedLayer.GetBackendId()); ARMNN_ASSERT_MSG(toBackend != backends.end(), "Backend id not found for the connected layer"); @@ -899,7 +900,7 @@ EdgeStrategy CalculateEdgeStrategy(BackendsMap& backends, // Search for export/import options ITensorHandleFactory* srcFactory = registry.GetFactory(srcFactoryId); - if (srcFactory->GetExportFlags() != 0) + if (srcFactory->GetExportFlags() != 0 && importEnabled) { for (auto&& pref : dstPrefs) { @@ -945,11 +946,12 @@ EdgeStrategy CalculateEdgeStrategy(BackendsMap& backends, OptimizationResult SelectTensorHandleStrategy(Graph& optGraph, BackendsMap& backends, TensorHandleFactoryRegistry& registry, + bool importEnabled, Optional<std::vector<std::string>&> errMessages) { OptimizationResult result; - optGraph.ForEachLayer([&backends, ®istry, &result, &errMessages](Layer* layer) + optGraph.ForEachLayer([&backends, ®istry, &result, &errMessages, importEnabled](Layer* layer) { ARMNN_ASSERT(layer); @@ -985,7 +987,8 @@ OptimizationResult SelectTensorHandleStrategy(Graph& optGraph, { const Layer& connectedLayer = connection->GetOwningLayer(); - EdgeStrategy strategy = CalculateEdgeStrategy(backends, slotOption, *layer, connectedLayer, registry); + EdgeStrategy strategy = CalculateEdgeStrategy(backends, slotOption, *layer, connectedLayer, + registry, importEnabled); if (strategy == EdgeStrategy::Undefined) { @@ -1122,6 +1125,7 @@ IOptimizedNetworkPtr Optimize(const INetwork& inNetwork, OptimizationResult strategyResult = SelectTensorHandleStrategy(optGraph, backends, tensorHandleFactoryRegistry, + options.m_ImportEnabled, messages); if (strategyResult.m_Error) { diff --git a/src/armnn/Network.hpp b/src/armnn/Network.hpp index 77d6b04919..7136ee4d32 100644 --- a/src/armnn/Network.hpp +++ b/src/armnn/Network.hpp @@ -323,6 +323,7 @@ BackendsMap CreateSupportedBackends(TensorHandleFactoryRegistry& handleFactoryRe OptimizationResult SelectTensorHandleStrategy(Graph& optGraph, BackendsMap& backends, TensorHandleFactoryRegistry& registry, + bool importEnabled, Optional<std::vector<std::string>&> errMessages); OptimizationResult AssignBackends(OptimizedNetwork* optNetObjPtr, diff --git a/src/armnn/test/TensorHandleStrategyTest.cpp b/src/armnn/test/TensorHandleStrategyTest.cpp index 976e58eb50..c7aa30f701 100644 --- a/src/armnn/test/TensorHandleStrategyTest.cpp +++ b/src/armnn/test/TensorHandleStrategyTest.cpp @@ -339,7 +339,7 @@ BOOST_AUTO_TEST_CASE(TensorHandleSelectionStrategy) graph.TopologicalSort(); std::vector<std::string> errors; - auto result = SelectTensorHandleStrategy(graph, backends, registry, errors); + auto result = SelectTensorHandleStrategy(graph, backends, registry, true, errors); BOOST_TEST(result.m_Error == false); BOOST_TEST(result.m_Warning == false); |