aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/LoadedNetwork.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/LoadedNetwork.cpp')
-rw-r--r--src/armnn/LoadedNetwork.cpp58
1 files changed, 34 insertions, 24 deletions
diff --git a/src/armnn/LoadedNetwork.cpp b/src/armnn/LoadedNetwork.cpp
index 7aa66d9b09..40137779f6 100644
--- a/src/armnn/LoadedNetwork.cpp
+++ b/src/armnn/LoadedNetwork.cpp
@@ -12,6 +12,7 @@
#include "HeapProfiling.hpp"
#include <backends/CpuTensorHandle.hpp>
+#include <backends/BackendRegistry.hpp>
#include <boost/polymorphic_cast.hpp>
#include <boost/assert.hpp>
@@ -70,8 +71,7 @@ std::unique_ptr<LoadedNetwork> LoadedNetwork::MakeLoadedNetwork(std::unique_ptr<
}
LoadedNetwork::LoadedNetwork(std::unique_ptr<OptimizedNetwork> net)
- : m_CpuRef()
- , m_OptimizedNetwork(std::move(net))
+ : m_OptimizedNetwork(std::move(net))
, m_WorkingMemLock(m_WorkingMemMutex, std::defer_lock)
{
// Create a profiler and register it for the current thread.
@@ -79,12 +79,20 @@ LoadedNetwork::LoadedNetwork(std::unique_ptr<OptimizedNetwork> net)
ProfilerManager::GetInstance().RegisterProfiler(m_Profiler.get());
Graph& order = m_OptimizedNetwork->GetGraph().TopologicalSort();
- //First create tensor handlers.
+ //First create tensor handlers, backends and workload factories.
//Handlers are created before workloads are.
//Because workload creation can modify some of the handlers,
//(for example the splitter and merger layers).
for (auto&& layer : order)
{
+ auto const& backend = layer->GetBackendId();
+ if (m_Backends.count(backend) == 0)
+ {
+ auto createBackend = BackendRegistryInstance().GetFactory(backend);
+ auto it = m_Backends.emplace(std::make_pair(backend, createBackend()));
+ m_WorkloadFactories.emplace(std::make_pair(backend,
+ it.first->second->CreateWorkloadFactory()));
+ }
layer->CreateTensorHandles(m_OptimizedNetwork->GetGraph(), GetWorkloadFactory(*layer));
}
@@ -126,9 +134,10 @@ LoadedNetwork::LoadedNetwork(std::unique_ptr<OptimizedNetwork> net)
m_OptimizedNetwork->GetGraph().AllocateDynamicBuffers();
// Finalize the workload factories before execution.
- m_CpuRef.Finalize();
- m_CpuAcc.Finalize();
- m_GpuAcc.Finalize();
+ for (auto&& workloadFactory : m_WorkloadFactories)
+ {
+ workloadFactory.second->Finalize();
+ }
}
TensorInfo LoadedNetwork::GetInputTensorInfo(LayerBindingId layerId) const
@@ -164,26 +173,25 @@ const IWorkloadFactory& LoadedNetwork::GetWorkloadFactory(const Layer& layer) co
{
const IWorkloadFactory* workloadFactory = nullptr;
- if (layer.GetBackendId() == Compute::CpuAcc)
- {
- workloadFactory = &m_CpuAcc;
- }
- else if (layer.GetBackendId() == Compute::GpuAcc)
- {
- workloadFactory = &m_GpuAcc;
- }
- else if (layer.GetBackendId() == Compute::CpuRef)
+ auto it = m_WorkloadFactories.find(layer.GetBackendId());
+ if (it == m_WorkloadFactories.end())
{
- workloadFactory = &m_CpuRef;
+ throw RuntimeException(
+ boost::str(
+ boost::format("No workload factory for %1% to be used for layer: %2%")
+ % layer.GetBackendId().Get()
+ % layer.GetNameStr()),
+ CHECK_LOCATION());
}
+ workloadFactory = it->second.get();
+
BOOST_ASSERT_MSG(workloadFactory, "No workload factory");
std::string reasonIfUnsupported;
BOOST_ASSERT_MSG(IWorkloadFactory::IsLayerSupported(layer, {}, reasonIfUnsupported),
- "Factory does not support layer");
+ "Factory does not support layer");
boost::ignore_unused(reasonIfUnsupported);
-
return *workloadFactory;
}
@@ -408,9 +416,10 @@ void LoadedNetwork::AllocateWorkingMemory()
{
return;
}
- m_CpuRef.Acquire();
- m_CpuAcc.Acquire();
- m_GpuAcc.Acquire();
+ for (auto&& workloadFactory : m_WorkloadFactories)
+ {
+ workloadFactory.second->Acquire();
+ }
m_IsWorkingMemAllocated = true;
}
@@ -422,9 +431,10 @@ void LoadedNetwork::FreeWorkingMemory()
return;
}
// Informs the memory managers to release memory in it's respective memory group
- m_CpuRef.Release();
- m_CpuAcc.Release();
- m_GpuAcc.Release();
+ for (auto&& workloadFactory : m_WorkloadFactories)
+ {
+ workloadFactory.second->Release();
+ }
m_IsWorkingMemAllocated = false;
}