aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/LoadedNetwork.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/LoadedNetwork.cpp')
-rw-r--r--src/armnn/LoadedNetwork.cpp160
1 files changed, 157 insertions, 3 deletions
diff --git a/src/armnn/LoadedNetwork.cpp b/src/armnn/LoadedNetwork.cpp
index 46eb9883fb..67de00f0f3 100644
--- a/src/armnn/LoadedNetwork.cpp
+++ b/src/armnn/LoadedNetwork.cpp
@@ -24,6 +24,7 @@
#include <LabelsAndEventClasses.hpp>
#include <fmt/format.h>
+#include <armnn/utility/Timer.hpp>
namespace armnn
{
@@ -84,7 +85,8 @@ void AddWorkloadStructure(std::unique_ptr<TimelineUtilityMethods>& timelineUtils
std::unique_ptr<LoadedNetwork> LoadedNetwork::MakeLoadedNetwork(std::unique_ptr<IOptimizedNetwork> net,
std::string& errorMessage,
const INetworkProperties& networkProperties,
- profiling::ProfilingService& profilingService)
+ profiling::ProfilingService& profilingService,
+ const NetworkId networkIdOut)
{
std::unique_ptr<LoadedNetwork> loadedNetwork;
@@ -98,7 +100,7 @@ std::unique_ptr<LoadedNetwork> LoadedNetwork::MakeLoadedNetwork(std::unique_ptr<
try
{
- loadedNetwork.reset(new LoadedNetwork(std::move(net), networkProperties, profilingService));
+ loadedNetwork.reset(new LoadedNetwork(std::move(net), networkProperties, profilingService, networkIdOut));
}
catch (const armnn::RuntimeException& error)
{
@@ -118,9 +120,11 @@ std::unique_ptr<LoadedNetwork> LoadedNetwork::MakeLoadedNetwork(std::unique_ptr<
LoadedNetwork::LoadedNetwork(std::unique_ptr<IOptimizedNetwork> net,
const INetworkProperties& networkProperties,
- profiling::ProfilingService& profilingService) :
+ profiling::ProfilingService& profilingService,
+ const NetworkId networkId) :
m_OptimizedNetwork(std::move(net)),
m_NetworkProperties(networkProperties),
+ m_NetworkId(networkId),
m_TensorHandleFactoryRegistry(),
m_ProfilingService(profilingService)
{
@@ -161,6 +165,14 @@ LoadedNetwork::LoadedNetwork(std::unique_ptr<IOptimizedNetwork> net,
}
}
}
+
+ // Create the thread pool which will have working memory handles assigned to each thread
+ // Should occur after factories are registered so thet the WorkingMemHandles can be created
+ if (m_NetworkProperties.m_NumThreads > 0 && networkProperties.m_AsyncEnabled)
+ {
+ CreateThreadPool(m_NetworkProperties.m_NumThreads);
+ }
+
if (!networkProperties.m_AsyncEnabled)
{
for (auto &&layer : order)
@@ -846,6 +858,147 @@ bool LoadedNetwork::Execute(std::unique_ptr<TimelineUtilityMethods>& timelineUti
return success;
}
+void LoadedNetwork::CreateThreadPool(std::size_t numThreads)
+{
+
+ for (auto i = 0u; i < numThreads; ++i)
+ {
+ std::unique_ptr<IWorkingMemHandle> workingMemHandle = CreateWorkingMemHandle(m_NetworkId);
+ m_Threads.emplace_back(
+ std::make_unique<std::thread>(
+ &LoadedNetwork::ProcessExecPriorities,
+ this,
+ std::move(workingMemHandle)
+ )
+ );
+ }
+}
+
+void LoadedNetwork::TerminateThreadPool() noexcept
+{
+ {
+ std::unique_lock<std::mutex> threadPoolLock(m_ThreadPoolMutex);
+ m_TerminatePool = true;
+ }
+
+ m_ThreadPoolEvent.notify_all();
+
+ for (auto &thread : m_Threads)
+ {
+ thread->join();
+ }
+}
+
+void LoadedNetwork::Schedule(const InputTensors& inputTensors,
+ const OutputTensors& outputTensors,
+ const QosExecPriority priority,
+ std::shared_ptr<IAsyncExecutionCallback> cb)
+{
+ // Group execution parameters so that they can be easily added to the queue
+ ExecutionTuple groupExecParams = std::make_tuple(inputTensors, outputTensors, cb);
+ std::shared_ptr<ExecutionTuple> operation = make_shared<ExecutionTuple>(groupExecParams);
+
+ // Add a message to the queue and notify the request thread
+ std::unique_lock<std::mutex> lock(m_ThreadPoolMutex);
+ switch (priority) {
+ case QosExecPriority::High:
+ m_HighPriorityQueue.push(operation);
+ break;
+ case QosExecPriority::Low:
+ m_LowPriorityQueue.push(operation);
+ break;
+ case QosExecPriority::Medium:
+ default:
+ m_MediumPriorityQueue.push(operation);
+ }
+ m_ThreadPoolEvent.notify_one();
+}
+
+void LoadedNetwork::ProcessExecPriorities(std::unique_ptr<IWorkingMemHandle> workingMemHandle)
+{
+ int expireRate = EXPIRE_RATE;
+ int highPriorityCount = 0;
+ int mediumPriorityCount = 0;
+
+ IWorkingMemHandle& workingMemHandleRef = *workingMemHandle.get();
+
+ while (true)
+ {
+ std::shared_ptr<ExecutionTuple> currentExecInProgress(nullptr);
+ {
+ // Wait for a message to be added to the queue
+ // This is in a separate scope to minimise the lifetime of the lock
+ std::unique_lock<std::mutex> lock(m_ThreadPoolMutex);
+
+ m_ThreadPoolEvent.wait(lock,
+ [=] {
+ return m_TerminatePool || !m_HighPriorityQueue.empty() ||
+ !m_MediumPriorityQueue.empty() || !m_LowPriorityQueue.empty();
+ });
+
+ if (m_TerminatePool && m_HighPriorityQueue.empty() && m_MediumPriorityQueue.empty() &&
+ m_LowPriorityQueue.empty())
+ {
+ break;
+ }
+
+ // Get the message to process from the front of each queue based on priority from high to low
+ // Get high priority first if it does not exceed the expire rate
+ if (!m_HighPriorityQueue.empty() && highPriorityCount < expireRate)
+ {
+ currentExecInProgress = m_HighPriorityQueue.front();
+ m_HighPriorityQueue.pop();
+ highPriorityCount += 1;
+ }
+ // If high priority queue is empty or the count exceeds the expire rate, get medium priority message
+ else if (!m_MediumPriorityQueue.empty() && mediumPriorityCount < expireRate)
+ {
+ currentExecInProgress = m_MediumPriorityQueue.front();
+ m_MediumPriorityQueue.pop();
+ mediumPriorityCount += 1;
+ // Reset high priority count
+ highPriorityCount = 0;
+ }
+ // If medium priority queue is empty or the count exceeds the expire rate, get low priority message
+ else if (!m_LowPriorityQueue.empty())
+ {
+ currentExecInProgress = m_LowPriorityQueue.front();
+ m_LowPriorityQueue.pop();
+ // Reset high and medium priority count
+ highPriorityCount = 0;
+ mediumPriorityCount = 0;
+ }
+ else
+ {
+ // Reset high and medium priority count
+ highPriorityCount = 0;
+ mediumPriorityCount = 0;
+ continue;
+ }
+ }
+
+ // invoke the asynchronous execution method
+ auto inputTensors = std::get<0>(*currentExecInProgress);
+ auto outputTensors = std::get<1>(*currentExecInProgress);
+ auto cb = std::get<2>(*currentExecInProgress);
+
+ // Get time at start of inference
+ HighResolutionClock startTime = armnn::GetTimeNow();
+
+ try // executing the inference
+ {
+ // Execute and populate the time at end of inference in the callback
+ Execute(inputTensors, outputTensors, workingMemHandleRef) == Status::Success ?
+ cb->Notify(Status::Success, std::make_pair(startTime, armnn::GetTimeNow())) :
+ cb->Notify(Status::Failure, std::make_pair(startTime, armnn::GetTimeNow()));
+ }
+ catch (const RuntimeException& error)
+ {
+ cb->Notify(Status::Failure, std::make_pair(startTime, armnn::GetTimeNow()));
+ }
+ }
+}
+
void LoadedNetwork::EnqueueInput(const BindableLayer& layer,
const ConstTensor& inputTensor,
WorkingMemHandle& context)
@@ -1096,6 +1249,7 @@ Status LoadedNetwork::Execute(const InputTensors& inputTensors,
EnqueueOutput(*outputLayer, GetOutputTensor(outputLayer->GetBindingId(), outputTensors), workingMemHandle);
}
}
+
return executionSucceeded ? Status::Success : Status::Failure;
}