aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/Runtime.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/Runtime.cpp')
-rw-r--r--src/armnn/Runtime.cpp82
1 files changed, 49 insertions, 33 deletions
diff --git a/src/armnn/Runtime.cpp b/src/armnn/Runtime.cpp
index 37e25a7fb6..09be92c709 100644
--- a/src/armnn/Runtime.cpp
+++ b/src/armnn/Runtime.cpp
@@ -6,15 +6,10 @@
#include <armnn/Version.hpp>
#include <backendsCommon/BackendRegistry.hpp>
+#include <backendsCommon/IBackendContext.hpp>
#include <iostream>
-#ifdef ARMCOMPUTECL_ENABLED
-#include <arm_compute/core/CL/OpenCL.h>
-#include <arm_compute/core/CL/CLKernelLibrary.h>
-#include <arm_compute/runtime/CL/CLScheduler.h>
-#endif
-
#include <boost/log/trivial.hpp>
#include <boost/polymorphic_cast.hpp>
@@ -55,6 +50,14 @@ Status Runtime::LoadNetwork(NetworkId& networkIdOut,
std::string & errorMessage)
{
IOptimizedNetwork* rawNetwork = inNetwork.release();
+
+ networkIdOut = GenerateNetworkId();
+
+ for (auto&& context : m_BackendContexts)
+ {
+ context.second->BeforeLoadNetwork(networkIdOut);
+ }
+
unique_ptr<LoadedNetwork> loadedNetwork = LoadedNetwork::MakeLoadedNetwork(
std::unique_ptr<OptimizedNetwork>(boost::polymorphic_downcast<OptimizedNetwork*>(rawNetwork)),
errorMessage);
@@ -64,8 +67,6 @@ Status Runtime::LoadNetwork(NetworkId& networkIdOut,
return Status::Failure;
}
- networkIdOut = GenerateNetworkId();
-
{
std::lock_guard<std::mutex> lockGuard(m_Mutex);
@@ -73,28 +74,28 @@ Status Runtime::LoadNetwork(NetworkId& networkIdOut,
m_LoadedNetworks[networkIdOut] = std::move(loadedNetwork);
}
+ for (auto&& context : m_BackendContexts)
+ {
+ context.second->AfterLoadNetwork(networkIdOut);
+ }
+
return Status::Success;
}
Status Runtime::UnloadNetwork(NetworkId networkId)
{
-#ifdef ARMCOMPUTECL_ENABLED
- if (arm_compute::CLScheduler::get().context()() != NULL)
+ bool unloadOk = true;
+ for (auto&& context : m_BackendContexts)
{
- // Waits for all queued CL requests to finish before unloading the network they may be using.
- try
- {
- // Coverity fix: arm_compute::CLScheduler::sync() may throw an exception of type cl::Error.
- arm_compute::CLScheduler::get().sync();
- }
- catch (const cl::Error&)
- {
- BOOST_LOG_TRIVIAL(warning) << "WARNING: Runtime::UnloadNetwork(): an error occurred while waiting for "
- "the queued CL requests to finish";
- return Status::Failure;
- }
+ unloadOk &= context.second->BeforeUnloadNetwork(networkId);
+ }
+
+ if (!unloadOk)
+ {
+ BOOST_LOG_TRIVIAL(warning) << "Runtime::UnloadNetwork(): failed to unload "
+ "network with ID:" << networkId << " because BeforeUnloadNetwork failed";
+ return Status::Failure;
}
-#endif
{
std::lock_guard<std::mutex> lockGuard(m_Mutex);
@@ -104,14 +105,11 @@ Status Runtime::UnloadNetwork(NetworkId networkId)
BOOST_LOG_TRIVIAL(warning) << "WARNING: Runtime::UnloadNetwork(): " << networkId << " not found!";
return Status::Failure;
}
+ }
-#ifdef ARMCOMPUTECL_ENABLED
- if (arm_compute::CLScheduler::get().context()() != NULL && m_LoadedNetworks.empty())
- {
- // There are no loaded networks left, so clear the CL cache to free up memory
- m_ClContextControl.ClearClCache();
- }
-#endif
+ for (auto&& context : m_BackendContexts)
+ {
+ context.second->AfterUnloadNetwork(networkId);
}
BOOST_LOG_TRIVIAL(debug) << "Runtime::UnloadNetwork(): Unloaded network with ID: " << networkId;
@@ -131,12 +129,30 @@ const std::shared_ptr<IProfiler> Runtime::GetProfiler(NetworkId networkId) const
}
Runtime::Runtime(const CreationOptions& options)
- : m_ClContextControl(options.m_GpuAccTunedParameters.get(),
- options.m_EnableGpuProfiling)
- , m_NetworkIdCounter(0)
+ : m_NetworkIdCounter(0)
, m_DeviceSpec{BackendRegistryInstance().GetBackendIds()}
{
BOOST_LOG_TRIVIAL(info) << "ArmNN v" << ARMNN_VERSION << "\n";
+
+ for (const auto& id : BackendRegistryInstance().GetBackendIds())
+ {
+ // Store backend contexts for the supported ones
+ if (m_DeviceSpec.GetSupportedBackends().count(id) > 0)
+ {
+ auto factoryFun = BackendRegistryInstance().GetFactory(id);
+ auto backend = factoryFun();
+ BOOST_ASSERT(backend.get() != nullptr);
+
+ auto context = backend->CreateBackendContext(options);
+
+ // backends are allowed to return nullptrs if they
+ // don't wish to create a backend specific context
+ if (context)
+ {
+ m_BackendContexts.emplace(std::make_pair(id, std::move(context)));
+ }
+ }
+ }
}
Runtime::~Runtime()