diff options
author | Narumol Prangnawarat <narumol.prangnawarat@arm.com> | 2021-05-07 17:52:36 +0100 |
---|---|---|
committer | Narumol Prangnawarat <narumol.prangnawarat@arm.com> | 2021-05-08 20:15:32 +0100 |
commit | e5f0b2409c2e557a5a78e2f4659d203154289b23 (patch) | |
tree | 0e32680ed15ed5157c78d5deeabda2c0ceeeb4a3 /src/armnn/Network.cpp | |
parent | ae12306486efc55293a40048618abe5e8b19151b (diff) | |
download | armnn-e5f0b2409c2e557a5a78e2f4659d203154289b23.tar.gz |
IVGCVSW-5818 Enable import on GPU
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com>
Change-Id: I4e4eb107aa2bfa09625840d738001f33152e6792
Diffstat (limited to 'src/armnn/Network.cpp')
-rw-r--r-- | src/armnn/Network.cpp | 75 |
1 files changed, 61 insertions, 14 deletions
diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp index b79576c87e..f097e677d7 100644 --- a/src/armnn/Network.cpp +++ b/src/armnn/Network.cpp @@ -1165,7 +1165,8 @@ bool RequiresCopy(ITensorHandleFactory::FactoryId src, // Find the handle factory for the input layer which results in fewest required copies. ITensorHandleFactory::FactoryId CalculateSlotOptionForInput(BackendsMap& backends, OutputSlot& slot, - TensorHandleFactoryRegistry& registry) + TensorHandleFactoryRegistry& registry, + bool importEnabled) { Layer& layer = slot.GetOwningLayer(); ARMNN_ASSERT(layer.GetType() == LayerType::Input); @@ -1191,6 +1192,7 @@ ITensorHandleFactory::FactoryId CalculateSlotOptionForInput(BackendsMap& backend for (auto&& connection : slot.GetConnections()) { + const Layer& connectedLayer = connection->GetOwningLayer(); auto toBackend = backends.find(connectedLayer.GetBackendId()); @@ -1208,11 +1210,12 @@ ITensorHandleFactory::FactoryId CalculateSlotOptionForInput(BackendsMap& backend // Input layers use the mem copy workload or import, so the selected factory must // support either the map/unmap API or Import API ITensorHandleFactory* factory = registry.GetFactory(dst); - if (!factory->SupportsMapUnmap() && - !CheckFlag(factory->GetImportFlags(), MemorySource::Malloc)) // Just support cpu mem imports for now + if (importEnabled && factory->GetImportFlags() == 0) + { + continue; + } + else if (!importEnabled && !factory->SupportsMapUnmap()) { - // The current tensor handle factory does not support the map/unmap or import - // strategy, move to the next one continue; } @@ -1257,7 +1260,8 @@ ITensorHandleFactory::FactoryId CalculateSlotOptionForOutput(BackendsMap& backen // when considering all connections. ITensorHandleFactory::FactoryId CalculateSlotOption(BackendsMap& backends, OutputSlot& outputSlot, - TensorHandleFactoryRegistry& registry) + TensorHandleFactoryRegistry& registry, + bool importEnabled) { // First ensure the from backends can support the TensorHandeAPI Layer& layer = outputSlot.GetOwningLayer(); @@ -1268,14 +1272,13 @@ ITensorHandleFactory::FactoryId CalculateSlotOption(BackendsMap& backends, return ITensorHandleFactory::LegacyFactoryId; } - // Connections to Output Layers requires support for map/unmap on the TensorHandle. - bool requiresMapUnmap = false; + bool outputConnection = false; for (auto&& connection : outputSlot.GetConnections()) { const Layer& connectedLayer = connection->GetOwningLayer(); if (connectedLayer.GetType() == LayerType::Output) { - requiresMapUnmap = true; + outputConnection = true; } } @@ -1286,8 +1289,48 @@ ITensorHandleFactory::FactoryId CalculateSlotOption(BackendsMap& backends, std::map<ITensorHandleFactory::FactoryId, int> factoryScores; for (auto&& pref : srcPrefs) { - if (requiresMapUnmap) // Only consider factories that support map/unmap if required + if (importEnabled) + { + ITensorHandleFactory* factory = registry.GetFactory(pref); + if (outputConnection) + { + // Check if this is fallback case + bool fallbackConnection = false; + for (auto&& inputSlot : layer.GetInputSlots()) + { + if (inputSlot.GetConnectedOutputSlot()->GetOwningLayer().GetBackendId() != layer.GetBackendId()) + { + fallbackConnection = true; + } + } + if (fallbackConnection) + { + auto factoryCap = factory->GetCapabilities(&layer, &layer, CapabilityClass::FallbackImportDisabled); + // Cannot use factory import if fallback import is not supported. + if (!factoryCap.empty()) + { + continue; + } + } + else if (factory->GetExportFlags() == 0) + { + continue; + } + } + if (!outputConnection) + { + auto factoryCap = factory->GetCapabilities(&layer, &layer, CapabilityClass::FallbackImportDisabled); + // Cannot use factory import if fallback import is not supported. + if (!factoryCap.empty()) + { + continue; + } + } + + } + else { + // Only consider factories that support map/unmap ITensorHandleFactory* factory = registry.GetFactory(pref); if (!factory->SupportsMapUnmap()) { @@ -1296,6 +1339,7 @@ ITensorHandleFactory::FactoryId CalculateSlotOption(BackendsMap& backends, } } + auto it = factoryScores.find(pref); if (it == factoryScores.end()) { @@ -1417,15 +1461,18 @@ EdgeStrategy CalculateEdgeStrategy(BackendsMap& backends, if (!dstFactory) { continue; } - if ((dstFactory->GetImportFlags() & srcFactory->GetExportFlags()) != 0) { auto srcCapability = srcFactory->GetCapabilities(&layer, &layer, CapabilityClass::PaddingRequired); auto dstCapability = dstFactory->GetCapabilities(&connectedLayer, &connectedLayer, CapabilityClass::PaddingRequired); + auto srcFallback = srcFactory->GetCapabilities(&layer, &layer, CapabilityClass::FallbackImportDisabled); + auto dstFallback = dstFactory->GetCapabilities(&connectedLayer, + &connectedLayer, + CapabilityClass::FallbackImportDisabled); // Do not require memory copy if the source and destination do not require padding. - if (srcCapability.empty() && dstCapability.empty()) + if (srcCapability.empty() && dstCapability.empty() && srcFallback.empty() && dstFallback.empty()) { return EdgeStrategy::ExportToTarget; } @@ -1477,13 +1524,13 @@ OptimizationResult SelectTensorHandleStrategy(Graph& optGraph, switch(layer->GetType()) { case LayerType::Input: - slotOption = CalculateSlotOptionForInput(backends, outputSlot, registry); + slotOption = CalculateSlotOptionForInput(backends, outputSlot, registry, importEnabled); break; case LayerType::Output: slotOption = CalculateSlotOptionForOutput(backends, outputSlot, registry); break; default: - slotOption = CalculateSlotOption(backends, outputSlot, registry); + slotOption = CalculateSlotOption(backends, outputSlot, registry, importEnabled); break; } outputSlot.SetTensorHandleFactory(slotOption); |