diff options
Diffstat (limited to 'src/armnn/LoadedNetwork.cpp')
-rw-r--r-- | src/armnn/LoadedNetwork.cpp | 58 |
1 files changed, 34 insertions, 24 deletions
diff --git a/src/armnn/LoadedNetwork.cpp b/src/armnn/LoadedNetwork.cpp index 7aa66d9b09..40137779f6 100644 --- a/src/armnn/LoadedNetwork.cpp +++ b/src/armnn/LoadedNetwork.cpp @@ -12,6 +12,7 @@ #include "HeapProfiling.hpp" #include <backends/CpuTensorHandle.hpp> +#include <backends/BackendRegistry.hpp> #include <boost/polymorphic_cast.hpp> #include <boost/assert.hpp> @@ -70,8 +71,7 @@ std::unique_ptr<LoadedNetwork> LoadedNetwork::MakeLoadedNetwork(std::unique_ptr< } LoadedNetwork::LoadedNetwork(std::unique_ptr<OptimizedNetwork> net) - : m_CpuRef() - , m_OptimizedNetwork(std::move(net)) + : m_OptimizedNetwork(std::move(net)) , m_WorkingMemLock(m_WorkingMemMutex, std::defer_lock) { // Create a profiler and register it for the current thread. @@ -79,12 +79,20 @@ LoadedNetwork::LoadedNetwork(std::unique_ptr<OptimizedNetwork> net) ProfilerManager::GetInstance().RegisterProfiler(m_Profiler.get()); Graph& order = m_OptimizedNetwork->GetGraph().TopologicalSort(); - //First create tensor handlers. + //First create tensor handlers, backends and workload factories. //Handlers are created before workloads are. //Because workload creation can modify some of the handlers, //(for example the splitter and merger layers). for (auto&& layer : order) { + auto const& backend = layer->GetBackendId(); + if (m_Backends.count(backend) == 0) + { + auto createBackend = BackendRegistryInstance().GetFactory(backend); + auto it = m_Backends.emplace(std::make_pair(backend, createBackend())); + m_WorkloadFactories.emplace(std::make_pair(backend, + it.first->second->CreateWorkloadFactory())); + } layer->CreateTensorHandles(m_OptimizedNetwork->GetGraph(), GetWorkloadFactory(*layer)); } @@ -126,9 +134,10 @@ LoadedNetwork::LoadedNetwork(std::unique_ptr<OptimizedNetwork> net) m_OptimizedNetwork->GetGraph().AllocateDynamicBuffers(); // Finalize the workload factories before execution. - m_CpuRef.Finalize(); - m_CpuAcc.Finalize(); - m_GpuAcc.Finalize(); + for (auto&& workloadFactory : m_WorkloadFactories) + { + workloadFactory.second->Finalize(); + } } TensorInfo LoadedNetwork::GetInputTensorInfo(LayerBindingId layerId) const @@ -164,26 +173,25 @@ const IWorkloadFactory& LoadedNetwork::GetWorkloadFactory(const Layer& layer) co { const IWorkloadFactory* workloadFactory = nullptr; - if (layer.GetBackendId() == Compute::CpuAcc) - { - workloadFactory = &m_CpuAcc; - } - else if (layer.GetBackendId() == Compute::GpuAcc) - { - workloadFactory = &m_GpuAcc; - } - else if (layer.GetBackendId() == Compute::CpuRef) + auto it = m_WorkloadFactories.find(layer.GetBackendId()); + if (it == m_WorkloadFactories.end()) { - workloadFactory = &m_CpuRef; + throw RuntimeException( + boost::str( + boost::format("No workload factory for %1% to be used for layer: %2%") + % layer.GetBackendId().Get() + % layer.GetNameStr()), + CHECK_LOCATION()); } + workloadFactory = it->second.get(); + BOOST_ASSERT_MSG(workloadFactory, "No workload factory"); std::string reasonIfUnsupported; BOOST_ASSERT_MSG(IWorkloadFactory::IsLayerSupported(layer, {}, reasonIfUnsupported), - "Factory does not support layer"); + "Factory does not support layer"); boost::ignore_unused(reasonIfUnsupported); - return *workloadFactory; } @@ -408,9 +416,10 @@ void LoadedNetwork::AllocateWorkingMemory() { return; } - m_CpuRef.Acquire(); - m_CpuAcc.Acquire(); - m_GpuAcc.Acquire(); + for (auto&& workloadFactory : m_WorkloadFactories) + { + workloadFactory.second->Acquire(); + } m_IsWorkingMemAllocated = true; } @@ -422,9 +431,10 @@ void LoadedNetwork::FreeWorkingMemory() return; } // Informs the memory managers to release memory in it's respective memory group - m_CpuRef.Release(); - m_CpuAcc.Release(); - m_GpuAcc.Release(); + for (auto&& workloadFactory : m_WorkloadFactories) + { + workloadFactory.second->Release(); + } m_IsWorkingMemAllocated = false; } |