aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/Network.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/Network.cpp')
-rw-r--r--src/armnn/Network.cpp12
1 files changed, 8 insertions, 4 deletions
diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp
index 94a9961a81..dec9468d7b 100644
--- a/src/armnn/Network.cpp
+++ b/src/armnn/Network.cpp
@@ -861,7 +861,8 @@ EdgeStrategy CalculateEdgeStrategy(BackendsMap& backends,
ITensorHandleFactory::FactoryId srcFactoryId,
const Layer& layer,
const Layer& connectedLayer,
- TensorHandleFactoryRegistry& registry)
+ TensorHandleFactoryRegistry& registry,
+ bool importEnabled)
{
auto toBackend = backends.find(connectedLayer.GetBackendId());
ARMNN_ASSERT_MSG(toBackend != backends.end(), "Backend id not found for the connected layer");
@@ -899,7 +900,7 @@ EdgeStrategy CalculateEdgeStrategy(BackendsMap& backends,
// Search for export/import options
ITensorHandleFactory* srcFactory = registry.GetFactory(srcFactoryId);
- if (srcFactory->GetExportFlags() != 0)
+ if (srcFactory->GetExportFlags() != 0 && importEnabled)
{
for (auto&& pref : dstPrefs)
{
@@ -945,11 +946,12 @@ EdgeStrategy CalculateEdgeStrategy(BackendsMap& backends,
OptimizationResult SelectTensorHandleStrategy(Graph& optGraph,
BackendsMap& backends,
TensorHandleFactoryRegistry& registry,
+ bool importEnabled,
Optional<std::vector<std::string>&> errMessages)
{
OptimizationResult result;
- optGraph.ForEachLayer([&backends, &registry, &result, &errMessages](Layer* layer)
+ optGraph.ForEachLayer([&backends, &registry, &result, &errMessages, importEnabled](Layer* layer)
{
ARMNN_ASSERT(layer);
@@ -985,7 +987,8 @@ OptimizationResult SelectTensorHandleStrategy(Graph& optGraph,
{
const Layer& connectedLayer = connection->GetOwningLayer();
- EdgeStrategy strategy = CalculateEdgeStrategy(backends, slotOption, *layer, connectedLayer, registry);
+ EdgeStrategy strategy = CalculateEdgeStrategy(backends, slotOption, *layer, connectedLayer,
+ registry, importEnabled);
if (strategy == EdgeStrategy::Undefined)
{
@@ -1122,6 +1125,7 @@ IOptimizedNetworkPtr Optimize(const INetwork& inNetwork,
OptimizationResult strategyResult = SelectTensorHandleStrategy(optGraph,
backends,
tensorHandleFactoryRegistry,
+ options.m_ImportEnabled,
messages);
if (strategyResult.m_Error)
{