diff options
Diffstat (limited to 'src/armnn')
-rw-r--r-- | src/armnn/Network.cpp | 12 | ||||
-rw-r--r-- | src/armnn/Network.hpp | 1 | ||||
-rw-r--r-- | src/armnn/test/TensorHandleStrategyTest.cpp | 2 |
3 files changed, 10 insertions, 5 deletions
diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp index 94a9961a81..dec9468d7b 100644 --- a/src/armnn/Network.cpp +++ b/src/armnn/Network.cpp @@ -861,7 +861,8 @@ EdgeStrategy CalculateEdgeStrategy(BackendsMap& backends, ITensorHandleFactory::FactoryId srcFactoryId, const Layer& layer, const Layer& connectedLayer, - TensorHandleFactoryRegistry& registry) + TensorHandleFactoryRegistry& registry, + bool importEnabled) { auto toBackend = backends.find(connectedLayer.GetBackendId()); ARMNN_ASSERT_MSG(toBackend != backends.end(), "Backend id not found for the connected layer"); @@ -899,7 +900,7 @@ EdgeStrategy CalculateEdgeStrategy(BackendsMap& backends, // Search for export/import options ITensorHandleFactory* srcFactory = registry.GetFactory(srcFactoryId); - if (srcFactory->GetExportFlags() != 0) + if (srcFactory->GetExportFlags() != 0 && importEnabled) { for (auto&& pref : dstPrefs) { @@ -945,11 +946,12 @@ EdgeStrategy CalculateEdgeStrategy(BackendsMap& backends, OptimizationResult SelectTensorHandleStrategy(Graph& optGraph, BackendsMap& backends, TensorHandleFactoryRegistry& registry, + bool importEnabled, Optional<std::vector<std::string>&> errMessages) { OptimizationResult result; - optGraph.ForEachLayer([&backends, ®istry, &result, &errMessages](Layer* layer) + optGraph.ForEachLayer([&backends, ®istry, &result, &errMessages, importEnabled](Layer* layer) { ARMNN_ASSERT(layer); @@ -985,7 +987,8 @@ OptimizationResult SelectTensorHandleStrategy(Graph& optGraph, { const Layer& connectedLayer = connection->GetOwningLayer(); - EdgeStrategy strategy = CalculateEdgeStrategy(backends, slotOption, *layer, connectedLayer, registry); + EdgeStrategy strategy = CalculateEdgeStrategy(backends, slotOption, *layer, connectedLayer, + registry, importEnabled); if (strategy == EdgeStrategy::Undefined) { @@ -1122,6 +1125,7 @@ IOptimizedNetworkPtr Optimize(const INetwork& inNetwork, OptimizationResult strategyResult = SelectTensorHandleStrategy(optGraph, backends, tensorHandleFactoryRegistry, + options.m_ImportEnabled, messages); if (strategyResult.m_Error) { diff --git a/src/armnn/Network.hpp b/src/armnn/Network.hpp index 77d6b04919..7136ee4d32 100644 --- a/src/armnn/Network.hpp +++ b/src/armnn/Network.hpp @@ -323,6 +323,7 @@ BackendsMap CreateSupportedBackends(TensorHandleFactoryRegistry& handleFactoryRe OptimizationResult SelectTensorHandleStrategy(Graph& optGraph, BackendsMap& backends, TensorHandleFactoryRegistry& registry, + bool importEnabled, Optional<std::vector<std::string>&> errMessages); OptimizationResult AssignBackends(OptimizedNetwork* optNetObjPtr, diff --git a/src/armnn/test/TensorHandleStrategyTest.cpp b/src/armnn/test/TensorHandleStrategyTest.cpp index 976e58eb50..c7aa30f701 100644 --- a/src/armnn/test/TensorHandleStrategyTest.cpp +++ b/src/armnn/test/TensorHandleStrategyTest.cpp @@ -339,7 +339,7 @@ BOOST_AUTO_TEST_CASE(TensorHandleSelectionStrategy) graph.TopologicalSort(); std::vector<std::string> errors; - auto result = SelectTensorHandleStrategy(graph, backends, registry, errors); + auto result = SelectTensorHandleStrategy(graph, backends, registry, true, errors); BOOST_TEST(result.m_Error == false); BOOST_TEST(result.m_Warning == false); |