aboutsummaryrefslogtreecommitdiff
path: root/src/armnn
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn')
-rw-r--r--src/armnn/Network.cpp12
-rw-r--r--src/armnn/Network.hpp1
-rw-r--r--src/armnn/test/TensorHandleStrategyTest.cpp2
3 files changed, 10 insertions, 5 deletions
diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp
index 94a9961a81..dec9468d7b 100644
--- a/src/armnn/Network.cpp
+++ b/src/armnn/Network.cpp
@@ -861,7 +861,8 @@ EdgeStrategy CalculateEdgeStrategy(BackendsMap& backends,
ITensorHandleFactory::FactoryId srcFactoryId,
const Layer& layer,
const Layer& connectedLayer,
- TensorHandleFactoryRegistry& registry)
+ TensorHandleFactoryRegistry& registry,
+ bool importEnabled)
{
auto toBackend = backends.find(connectedLayer.GetBackendId());
ARMNN_ASSERT_MSG(toBackend != backends.end(), "Backend id not found for the connected layer");
@@ -899,7 +900,7 @@ EdgeStrategy CalculateEdgeStrategy(BackendsMap& backends,
// Search for export/import options
ITensorHandleFactory* srcFactory = registry.GetFactory(srcFactoryId);
- if (srcFactory->GetExportFlags() != 0)
+ if (srcFactory->GetExportFlags() != 0 && importEnabled)
{
for (auto&& pref : dstPrefs)
{
@@ -945,11 +946,12 @@ EdgeStrategy CalculateEdgeStrategy(BackendsMap& backends,
OptimizationResult SelectTensorHandleStrategy(Graph& optGraph,
BackendsMap& backends,
TensorHandleFactoryRegistry& registry,
+ bool importEnabled,
Optional<std::vector<std::string>&> errMessages)
{
OptimizationResult result;
- optGraph.ForEachLayer([&backends, &registry, &result, &errMessages](Layer* layer)
+ optGraph.ForEachLayer([&backends, &registry, &result, &errMessages, importEnabled](Layer* layer)
{
ARMNN_ASSERT(layer);
@@ -985,7 +987,8 @@ OptimizationResult SelectTensorHandleStrategy(Graph& optGraph,
{
const Layer& connectedLayer = connection->GetOwningLayer();
- EdgeStrategy strategy = CalculateEdgeStrategy(backends, slotOption, *layer, connectedLayer, registry);
+ EdgeStrategy strategy = CalculateEdgeStrategy(backends, slotOption, *layer, connectedLayer,
+ registry, importEnabled);
if (strategy == EdgeStrategy::Undefined)
{
@@ -1122,6 +1125,7 @@ IOptimizedNetworkPtr Optimize(const INetwork& inNetwork,
OptimizationResult strategyResult = SelectTensorHandleStrategy(optGraph,
backends,
tensorHandleFactoryRegistry,
+ options.m_ImportEnabled,
messages);
if (strategyResult.m_Error)
{
diff --git a/src/armnn/Network.hpp b/src/armnn/Network.hpp
index 77d6b04919..7136ee4d32 100644
--- a/src/armnn/Network.hpp
+++ b/src/armnn/Network.hpp
@@ -323,6 +323,7 @@ BackendsMap CreateSupportedBackends(TensorHandleFactoryRegistry& handleFactoryRe
OptimizationResult SelectTensorHandleStrategy(Graph& optGraph,
BackendsMap& backends,
TensorHandleFactoryRegistry& registry,
+ bool importEnabled,
Optional<std::vector<std::string>&> errMessages);
OptimizationResult AssignBackends(OptimizedNetwork* optNetObjPtr,
diff --git a/src/armnn/test/TensorHandleStrategyTest.cpp b/src/armnn/test/TensorHandleStrategyTest.cpp
index 976e58eb50..c7aa30f701 100644
--- a/src/armnn/test/TensorHandleStrategyTest.cpp
+++ b/src/armnn/test/TensorHandleStrategyTest.cpp
@@ -339,7 +339,7 @@ BOOST_AUTO_TEST_CASE(TensorHandleSelectionStrategy)
graph.TopologicalSort();
std::vector<std::string> errors;
- auto result = SelectTensorHandleStrategy(graph, backends, registry, errors);
+ auto result = SelectTensorHandleStrategy(graph, backends, registry, true, errors);
BOOST_TEST(result.m_Error == false);
BOOST_TEST(result.m_Warning == false);