aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/Runtime.cpp
diff options
context:
space:
mode:
authorsurmeh01 <surabhi.mehta@arm.com>2018-05-18 16:31:43 +0100
committertelsoa01 <telmo.soares@arm.com>2018-05-23 13:09:07 +0100
commit3537c2ca7ebf31c1673b9ec2bb0c17b0406bbae0 (patch)
tree5950603ad78ec3fe56fb31ddc7f4d52a19f5bc60 /src/armnn/Runtime.cpp
parentbceff2fb3fc68bb0aa88b886900c34b77340c826 (diff)
downloadarmnn-3537c2ca7ebf31c1673b9ec2bb0c17b0406bbae0.tar.gz
Release 18.05
Diffstat (limited to 'src/armnn/Runtime.cpp')
-rw-r--r--src/armnn/Runtime.cpp93
1 files changed, 68 insertions, 25 deletions
diff --git a/src/armnn/Runtime.cpp b/src/armnn/Runtime.cpp
index e0d6a9add0..0ca3446e1b 100644
--- a/src/armnn/Runtime.cpp
+++ b/src/armnn/Runtime.cpp
@@ -6,6 +6,8 @@
#include "armnn/Version.hpp"
+#include <iostream>
+
#ifdef ARMCOMPUTECL_ENABLED
#include <arm_compute/core/CL/OpenCL.h>
#include <arm_compute/core/CL/CLKernelLibrary.h>
@@ -46,13 +48,15 @@ Status Runtime::LoadNetwork(NetworkId& networkIdOut, IOptimizedNetworkPtr inNetw
IOptimizedNetwork* rawNetwork = inNetwork.release();
unique_ptr<LoadedNetwork> loadedNetwork = LoadedNetwork::MakeLoadedNetwork(
std::unique_ptr<OptimizedNetwork>(boost::polymorphic_downcast<OptimizedNetwork*>(rawNetwork)),
- m_WorkloadFactories);
+ m_UseCpuRefAsFallback);
if (!loadedNetwork)
{
return Status::Failure;
}
+ std::lock_guard<std::mutex> lockGuard(m_Mutex);
+
networkIdOut = GenerateNetworkId();
// store the network
@@ -66,9 +70,22 @@ Status Runtime::UnloadNetwork(NetworkId networkId)
#ifdef ARMCOMPUTECL_ENABLED
if (arm_compute::CLScheduler::get().context()() != NULL)
{
- arm_compute::CLScheduler::get().sync();
+ // wait for all queued CL requests to finish before unloading the network they may be using
+ try
+ {
+ // Coverity fix: arm_compute::CLScheduler::sync() may throw an exception of type cl::Error.
+ arm_compute::CLScheduler::get().sync();
+ }
+ catch (const cl::Error&)
+ {
+ BOOST_LOG_TRIVIAL(warning) << "WARNING: Runtime::UnloadNetwork(): an error occurred while waiting for "
+ "the queued CL requests to finish";
+ return Status::Failure;
+ }
}
#endif
+ std::lock_guard<std::mutex> lockGuard(m_Mutex);
+
if (m_LoadedNetworks.erase(networkId) == 0)
{
BOOST_LOG_TRIVIAL(warning) << "WARNING: Runtime::UnloadNetwork(): " << networkId << " not found!";
@@ -77,7 +94,8 @@ Status Runtime::UnloadNetwork(NetworkId networkId)
#ifdef ARMCOMPUTECL_ENABLED
if (arm_compute::CLScheduler::get().context()() != NULL && m_LoadedNetworks.empty())
{
- m_WorkloadFactories.m_GpuAcc.get()->LoadOpenClRuntime();
+ // There are no loaded networks left, so clear the CL cache to free up memory
+ m_ClContextControl.ClearClCache();
}
#endif
BOOST_LOG_TRIVIAL(debug) << "Runtime::UnloadNetwork(): Unloaded network with ID: " << networkId;
@@ -85,56 +103,81 @@ Status Runtime::UnloadNetwork(NetworkId networkId)
}
Runtime::Runtime(const CreationOptions& options)
-: m_NetworkIdCounter(0)
+ : m_ClContextControl(options.m_ClTunedParameters)
+ , m_NetworkIdCounter(0)
{
BOOST_LOG_TRIVIAL(info) << "ArmNN v" << ARMNN_VERSION << "\n";
BOOST_LOG_TRIVIAL(info) << "Using compute device: " << options.m_DefaultComputeDevice << "\n";
m_DeviceSpec.DefaultComputeDevice = options.m_DefaultComputeDevice;
- // If useCpuRefAsFallback is false, the reference workload factory will be prevented from creating
- // operation workloads, unless the default compute device is precisely the reference backend.
- m_WorkloadFactories.m_CpuRef = make_shared<RefWorkloadFactory>(
- options.m_DefaultComputeDevice == Compute::CpuRef ? true : options.m_UseCpuRefAsFallback);
- m_WorkloadFactories.m_CpuAcc = make_shared<NeonWorkloadFactory>();
- m_WorkloadFactories.m_GpuAcc = make_shared<ClWorkloadFactory>(options.m_ClTunedParameters);
-
- if (options.m_DefaultComputeDevice == Compute::GpuAcc)
- {
- m_WorkloadFactories.m_GpuAcc.get()->LoadOpenClRuntime();
- }
+ // If useCpuRefAsFallback is false, the reference workload factory will be prevented from creating
+ // operation workloads, unless the default compute device is precisely the reference backend.
+ // This option is passed to the LoadedNetwork, which owns the workload factories.
+ m_UseCpuRefAsFallback = options.m_DefaultComputeDevice == Compute::CpuRef || options.m_UseCpuRefAsFallback;
}
Runtime::~Runtime()
{
std::vector<int> networkIDs;
- std::transform(m_LoadedNetworks.begin(), m_LoadedNetworks.end(),
- std::back_inserter(networkIDs),
- [](const auto &pair) { return pair.first; });
+ try
+ {
+ // Coverity fix: The following code may throw an exception of type std::length_error.
+ std::transform(m_LoadedNetworks.begin(), m_LoadedNetworks.end(),
+ std::back_inserter(networkIDs),
+ [](const auto &pair) { return pair.first; });
+ }
+ catch (const std::exception& e)
+ {
+ // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
+ // exception of type std::length_error.
+ // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
+ std::cerr << "WARNING: An error has occurred when getting the IDs of the networks to unload: " << e.what()
+ << "\nSome of the loaded networks may not be unloaded" << std::endl;
+ }
+ // We then proceed to unload all the networks which IDs have been appended to the list
+ // up to the point the exception was thrown (if any).
for (auto networkID : networkIDs)
{
- UnloadNetwork(networkID);
+ try
+ {
+ // Coverity fix: UnloadNetwork() may throw an exception of type std::length_error,
+ // boost::log::v2s_mt_posix::odr_violation or boost::log::v2s_mt_posix::system_error
+ UnloadNetwork(networkID);
+ }
+ catch (const std::exception& e)
+ {
+ // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
+ // exception of type std::length_error.
+ // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
+ std::cerr << "WARNING: An error has occurred when unloading network " << networkID << ": " << e.what()
+ << std::endl;
+ }
}
}
+LoadedNetwork* Runtime::GetLoadedNetworkPtr(NetworkId networkId) const
+{
+ std::lock_guard<std::mutex> lockGuard(m_Mutex);
+ return m_LoadedNetworks.at(networkId).get();
+}
+
TensorInfo Runtime::GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const
{
- LoadedNetwork* net = m_LoadedNetworks.at(networkId).get();
- return net->GetInputTensorInfo(layerId);
+ return GetLoadedNetworkPtr(networkId)->GetInputTensorInfo(layerId);
}
TensorInfo Runtime::GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const
{
- const LoadedNetwork* net = m_LoadedNetworks.at(networkId).get();
- return net->GetOutputTensorInfo(layerId);
+ return GetLoadedNetworkPtr(networkId)->GetOutputTensorInfo(layerId);
}
Status Runtime::EnqueueWorkload(NetworkId networkId,
const InputTensors& inputTensors,
const OutputTensors& outputTensors)
{
- LoadedNetwork* loadedNetwork = m_LoadedNetworks.at(networkId).get();
- return loadedNetwork->EnqueueWorkload(inputTensors, outputTensors, m_WorkloadFactories);
+ LoadedNetwork* loadedNetwork = GetLoadedNetworkPtr(networkId);
+ return loadedNetwork->EnqueueWorkload(inputTensors, outputTensors);
}
}