aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/Network.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/Network.cpp')
-rw-r--r--src/armnn/Network.cpp75
1 files changed, 61 insertions, 14 deletions
diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp
index b79576c87e..f097e677d7 100644
--- a/src/armnn/Network.cpp
+++ b/src/armnn/Network.cpp
@@ -1165,7 +1165,8 @@ bool RequiresCopy(ITensorHandleFactory::FactoryId src,
// Find the handle factory for the input layer which results in fewest required copies.
ITensorHandleFactory::FactoryId CalculateSlotOptionForInput(BackendsMap& backends,
OutputSlot& slot,
- TensorHandleFactoryRegistry& registry)
+ TensorHandleFactoryRegistry& registry,
+ bool importEnabled)
{
Layer& layer = slot.GetOwningLayer();
ARMNN_ASSERT(layer.GetType() == LayerType::Input);
@@ -1191,6 +1192,7 @@ ITensorHandleFactory::FactoryId CalculateSlotOptionForInput(BackendsMap& backend
for (auto&& connection : slot.GetConnections())
{
+
const Layer& connectedLayer = connection->GetOwningLayer();
auto toBackend = backends.find(connectedLayer.GetBackendId());
@@ -1208,11 +1210,12 @@ ITensorHandleFactory::FactoryId CalculateSlotOptionForInput(BackendsMap& backend
// Input layers use the mem copy workload or import, so the selected factory must
// support either the map/unmap API or Import API
ITensorHandleFactory* factory = registry.GetFactory(dst);
- if (!factory->SupportsMapUnmap() &&
- !CheckFlag(factory->GetImportFlags(), MemorySource::Malloc)) // Just support cpu mem imports for now
+ if (importEnabled && factory->GetImportFlags() == 0)
+ {
+ continue;
+ }
+ else if (!importEnabled && !factory->SupportsMapUnmap())
{
- // The current tensor handle factory does not support the map/unmap or import
- // strategy, move to the next one
continue;
}
@@ -1257,7 +1260,8 @@ ITensorHandleFactory::FactoryId CalculateSlotOptionForOutput(BackendsMap& backen
// when considering all connections.
ITensorHandleFactory::FactoryId CalculateSlotOption(BackendsMap& backends,
OutputSlot& outputSlot,
- TensorHandleFactoryRegistry& registry)
+ TensorHandleFactoryRegistry& registry,
+ bool importEnabled)
{
// First ensure the from backends can support the TensorHandeAPI
Layer& layer = outputSlot.GetOwningLayer();
@@ -1268,14 +1272,13 @@ ITensorHandleFactory::FactoryId CalculateSlotOption(BackendsMap& backends,
return ITensorHandleFactory::LegacyFactoryId;
}
- // Connections to Output Layers requires support for map/unmap on the TensorHandle.
- bool requiresMapUnmap = false;
+ bool outputConnection = false;
for (auto&& connection : outputSlot.GetConnections())
{
const Layer& connectedLayer = connection->GetOwningLayer();
if (connectedLayer.GetType() == LayerType::Output)
{
- requiresMapUnmap = true;
+ outputConnection = true;
}
}
@@ -1286,8 +1289,48 @@ ITensorHandleFactory::FactoryId CalculateSlotOption(BackendsMap& backends,
std::map<ITensorHandleFactory::FactoryId, int> factoryScores;
for (auto&& pref : srcPrefs)
{
- if (requiresMapUnmap) // Only consider factories that support map/unmap if required
+ if (importEnabled)
+ {
+ ITensorHandleFactory* factory = registry.GetFactory(pref);
+ if (outputConnection)
+ {
+ // Check if this is fallback case
+ bool fallbackConnection = false;
+ for (auto&& inputSlot : layer.GetInputSlots())
+ {
+ if (inputSlot.GetConnectedOutputSlot()->GetOwningLayer().GetBackendId() != layer.GetBackendId())
+ {
+ fallbackConnection = true;
+ }
+ }
+ if (fallbackConnection)
+ {
+ auto factoryCap = factory->GetCapabilities(&layer, &layer, CapabilityClass::FallbackImportDisabled);
+ // Cannot use factory import if fallback import is not supported.
+ if (!factoryCap.empty())
+ {
+ continue;
+ }
+ }
+ else if (factory->GetExportFlags() == 0)
+ {
+ continue;
+ }
+ }
+ if (!outputConnection)
+ {
+ auto factoryCap = factory->GetCapabilities(&layer, &layer, CapabilityClass::FallbackImportDisabled);
+ // Cannot use factory import if fallback import is not supported.
+ if (!factoryCap.empty())
+ {
+ continue;
+ }
+ }
+
+ }
+ else
{
+ // Only consider factories that support map/unmap
ITensorHandleFactory* factory = registry.GetFactory(pref);
if (!factory->SupportsMapUnmap())
{
@@ -1296,6 +1339,7 @@ ITensorHandleFactory::FactoryId CalculateSlotOption(BackendsMap& backends,
}
}
+
auto it = factoryScores.find(pref);
if (it == factoryScores.end())
{
@@ -1417,15 +1461,18 @@ EdgeStrategy CalculateEdgeStrategy(BackendsMap& backends,
if (!dstFactory) {
continue;
}
-
if ((dstFactory->GetImportFlags() & srcFactory->GetExportFlags()) != 0)
{
auto srcCapability = srcFactory->GetCapabilities(&layer, &layer, CapabilityClass::PaddingRequired);
auto dstCapability = dstFactory->GetCapabilities(&connectedLayer,
&connectedLayer,
CapabilityClass::PaddingRequired);
+ auto srcFallback = srcFactory->GetCapabilities(&layer, &layer, CapabilityClass::FallbackImportDisabled);
+ auto dstFallback = dstFactory->GetCapabilities(&connectedLayer,
+ &connectedLayer,
+ CapabilityClass::FallbackImportDisabled);
// Do not require memory copy if the source and destination do not require padding.
- if (srcCapability.empty() && dstCapability.empty())
+ if (srcCapability.empty() && dstCapability.empty() && srcFallback.empty() && dstFallback.empty())
{
return EdgeStrategy::ExportToTarget;
}
@@ -1477,13 +1524,13 @@ OptimizationResult SelectTensorHandleStrategy(Graph& optGraph,
switch(layer->GetType())
{
case LayerType::Input:
- slotOption = CalculateSlotOptionForInput(backends, outputSlot, registry);
+ slotOption = CalculateSlotOptionForInput(backends, outputSlot, registry, importEnabled);
break;
case LayerType::Output:
slotOption = CalculateSlotOptionForOutput(backends, outputSlot, registry);
break;
default:
- slotOption = CalculateSlotOption(backends, outputSlot, registry);
+ slotOption = CalculateSlotOption(backends, outputSlot, registry, importEnabled);
break;
}
outputSlot.SetTensorHandleFactory(slotOption);