aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/Network.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/Network.cpp')
-rw-r--r--src/armnn/Network.cpp51
1 files changed, 27 insertions, 24 deletions
diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp
index 2195c71735..b30cd9f3c2 100644
--- a/src/armnn/Network.cpp
+++ b/src/armnn/Network.cpp
@@ -441,7 +441,7 @@ bool RequiresCopy(ITensorHandleFactory::FactoryId src,
ITensorHandleFactory* srcFactory = registry.GetFactory(src);
ITensorHandleFactory* dstFactory = registry.GetFactory(dst);
- if (srcFactory->SupportsExport() && dstFactory->SupportsImport())
+ if ((srcFactory->GetExportFlags() & dstFactory->GetImportFlags()) != 0)
{
return false;
}
@@ -493,11 +493,14 @@ ITensorHandleFactory::FactoryId CalculateSlotOptionForInput(BackendsMap& backend
auto dstPrefs = toBackend->second.get()->GetHandleFactoryPreferences();
for (auto&& dst : dstPrefs)
{
- // Input layers use the mem copy workload, so the selected factory must support map/unmap API
+ // Input layers use the mem copy workload or import, so the selected factory must
+ // support either the map/unmap API or Import API
ITensorHandleFactory* factory = registry.GetFactory(dst);
- if (!factory->SupportsMapUnmap())
+ if (!factory->SupportsMapUnmap() &&
+ !CheckFlag(factory->GetImportFlags(), MemorySource::Malloc)) // Just support cpu mem imports for now
{
- // The current tensor handle factory does not support the map/unmap strategy, move to the next one
+ // The current tensor handle factory does not support the map/unmap or import
+ // strategy, move to the next one
continue;
}
@@ -648,11 +651,11 @@ ITensorHandleFactory::FactoryId CalculateSlotOption(BackendsMap& backends,
return ITensorHandleFactory::LegacyFactoryId;
}
-MemoryStrategy CalculateStrategy(BackendsMap& backends,
- ITensorHandleFactory::FactoryId srcFactoryId,
- const Layer& layer,
- const Layer& connectedLayer,
- TensorHandleFactoryRegistry& registry)
+EdgeStrategy CalculateEdgeStrategy(BackendsMap& backends,
+ ITensorHandleFactory::FactoryId srcFactoryId,
+ const Layer& layer,
+ const Layer& connectedLayer,
+ TensorHandleFactoryRegistry& registry)
{
auto toBackend = backends.find(connectedLayer.GetBackendId());
BOOST_ASSERT_MSG(toBackend != backends.end(), "Backend id not found for the connected layer");
@@ -664,19 +667,19 @@ MemoryStrategy CalculateStrategy(BackendsMap& backends,
{
if (layer.GetBackendId() != connectedLayer.GetBackendId())
{
- return MemoryStrategy::CopyToTarget;
+ return EdgeStrategy::CopyToTarget;
}
else
{
- return MemoryStrategy::DirectCompatibility;
+ return EdgeStrategy::DirectCompatibility;
}
}
// TensorHandleFactory API present, so perform more sophisticated strategies.
- // Dst Output layers don't require copy because they use map/unmap
+ // Dst Output layers don't require copy because they use import or map/unmap
if (connectedLayer.GetType() == LayerType::Output)
{
- return MemoryStrategy::DirectCompatibility;
+ return EdgeStrategy::DirectCompatibility;
}
// Search for direct match in prefs
@@ -684,20 +687,20 @@ MemoryStrategy CalculateStrategy(BackendsMap& backends,
{
if (pref == srcFactoryId)
{
- return MemoryStrategy::DirectCompatibility;
+ return EdgeStrategy::DirectCompatibility;
}
}
// Search for export/import options
ITensorHandleFactory* srcFactory = registry.GetFactory(srcFactoryId);
- if (srcFactory->SupportsExport())
+ if (srcFactory->GetExportFlags() != 0)
{
for (auto&& pref : dstPrefs)
{
ITensorHandleFactory* dstFactory = registry.GetFactory(pref);
- if (dstFactory->SupportsImport())
+ if ((dstFactory->GetImportFlags() & srcFactory->GetExportFlags()) != 0)
{
- return MemoryStrategy::ExportToTarget;
+ return EdgeStrategy::ExportToTarget;
}
}
}
@@ -710,12 +713,12 @@ MemoryStrategy CalculateStrategy(BackendsMap& backends,
ITensorHandleFactory* dstFactory = registry.GetFactory(pref);
if (dstFactory->SupportsMapUnmap())
{
- return MemoryStrategy::CopyToTarget;
+ return EdgeStrategy::CopyToTarget;
}
}
}
- return MemoryStrategy::Undefined;
+ return EdgeStrategy::Undefined;
}
// Select the TensorHandleFactories and the corresponding memory strategy
@@ -756,15 +759,15 @@ OptimizationResult SelectTensorHandleStrategy(Graph& optGraph,
}
outputSlot.SetTensorHandleFactory(slotOption);
- // Now determine the "best" memory strategy for each connection given the slotOption.
+ // Now determine the "best" edge strategy for each connection given the slotOption.
unsigned int connectionIdx = 0;
for (auto&& connection : outputSlot.GetConnections())
{
const Layer& connectedLayer = connection->GetOwningLayer();
- MemoryStrategy strategy = CalculateStrategy(backends, slotOption, *layer, connectedLayer, registry);
+ EdgeStrategy strategy = CalculateEdgeStrategy(backends, slotOption, *layer, connectedLayer, registry);
- if (strategy == MemoryStrategy::Undefined)
+ if (strategy == EdgeStrategy::Undefined)
{
result.m_Error = true;
if (errMessages)
@@ -775,7 +778,7 @@ OptimizationResult SelectTensorHandleStrategy(Graph& optGraph,
return;
}
- outputSlot.SetMemoryStrategy(connectionIdx, strategy);
+ outputSlot.SetEdgeStrategy(connectionIdx, strategy);
connectionIdx++;
}
@@ -887,7 +890,7 @@ IOptimizedNetworkPtr Optimize(const INetwork& inNetwork,
}
// Based on the tensor handle strategy determined above, insert copy layers where required.
- optGraph.AddCopyLayers(backends, tensorHandleFactoryRegistry);
+ optGraph.AddCompatibilityLayers(backends, tensorHandleFactoryRegistry);
// Convert constants
Optimizer::Pass(optGraph, MakeOptimizations(ConvertConstantsFloatToHalf()));