diff options
author | Keith Davis <keith.davis@arm.com> | 2021-04-22 10:10:34 +0100 |
---|---|---|
committer | finn.williams <finn.williams@arm.com> | 2021-05-06 19:39:39 +0000 |
commit | e813d67f86df41a238ff79b5c554ef5027f56576 (patch) | |
tree | 54c2145a78297d61e6e94676729cb2468490ade3 /src/armnn/LoadedNetwork.cpp | |
parent | d905decd256558bbee165e636ce4242ac3b9c917 (diff) | |
download | armnn-e813d67f86df41a238ff79b5c554ef5027f56576.tar.gz |
IVGCVSW-5813 Add Async Queue to IRuntime
Signed-off-by: Keith Davis <keith.davis@arm.com>
Change-Id: Icc0d131c8ee2e9748e2f14762a75962b39c10f9d
Diffstat (limited to 'src/armnn/LoadedNetwork.cpp')
-rw-r--r-- | src/armnn/LoadedNetwork.cpp | 160 |
1 files changed, 157 insertions, 3 deletions
diff --git a/src/armnn/LoadedNetwork.cpp b/src/armnn/LoadedNetwork.cpp index 46eb9883fb..67de00f0f3 100644 --- a/src/armnn/LoadedNetwork.cpp +++ b/src/armnn/LoadedNetwork.cpp @@ -24,6 +24,7 @@ #include <LabelsAndEventClasses.hpp> #include <fmt/format.h> +#include <armnn/utility/Timer.hpp> namespace armnn { @@ -84,7 +85,8 @@ void AddWorkloadStructure(std::unique_ptr<TimelineUtilityMethods>& timelineUtils std::unique_ptr<LoadedNetwork> LoadedNetwork::MakeLoadedNetwork(std::unique_ptr<IOptimizedNetwork> net, std::string& errorMessage, const INetworkProperties& networkProperties, - profiling::ProfilingService& profilingService) + profiling::ProfilingService& profilingService, + const NetworkId networkIdOut) { std::unique_ptr<LoadedNetwork> loadedNetwork; @@ -98,7 +100,7 @@ std::unique_ptr<LoadedNetwork> LoadedNetwork::MakeLoadedNetwork(std::unique_ptr< try { - loadedNetwork.reset(new LoadedNetwork(std::move(net), networkProperties, profilingService)); + loadedNetwork.reset(new LoadedNetwork(std::move(net), networkProperties, profilingService, networkIdOut)); } catch (const armnn::RuntimeException& error) { @@ -118,9 +120,11 @@ std::unique_ptr<LoadedNetwork> LoadedNetwork::MakeLoadedNetwork(std::unique_ptr< LoadedNetwork::LoadedNetwork(std::unique_ptr<IOptimizedNetwork> net, const INetworkProperties& networkProperties, - profiling::ProfilingService& profilingService) : + profiling::ProfilingService& profilingService, + const NetworkId networkId) : m_OptimizedNetwork(std::move(net)), m_NetworkProperties(networkProperties), + m_NetworkId(networkId), m_TensorHandleFactoryRegistry(), m_ProfilingService(profilingService) { @@ -161,6 +165,14 @@ LoadedNetwork::LoadedNetwork(std::unique_ptr<IOptimizedNetwork> net, } } } + + // Create the thread pool which will have working memory handles assigned to each thread + // Should occur after factories are registered so thet the WorkingMemHandles can be created + if (m_NetworkProperties.m_NumThreads > 0 && networkProperties.m_AsyncEnabled) + { + CreateThreadPool(m_NetworkProperties.m_NumThreads); + } + if (!networkProperties.m_AsyncEnabled) { for (auto &&layer : order) @@ -846,6 +858,147 @@ bool LoadedNetwork::Execute(std::unique_ptr<TimelineUtilityMethods>& timelineUti return success; } +void LoadedNetwork::CreateThreadPool(std::size_t numThreads) +{ + + for (auto i = 0u; i < numThreads; ++i) + { + std::unique_ptr<IWorkingMemHandle> workingMemHandle = CreateWorkingMemHandle(m_NetworkId); + m_Threads.emplace_back( + std::make_unique<std::thread>( + &LoadedNetwork::ProcessExecPriorities, + this, + std::move(workingMemHandle) + ) + ); + } +} + +void LoadedNetwork::TerminateThreadPool() noexcept +{ + { + std::unique_lock<std::mutex> threadPoolLock(m_ThreadPoolMutex); + m_TerminatePool = true; + } + + m_ThreadPoolEvent.notify_all(); + + for (auto &thread : m_Threads) + { + thread->join(); + } +} + +void LoadedNetwork::Schedule(const InputTensors& inputTensors, + const OutputTensors& outputTensors, + const QosExecPriority priority, + std::shared_ptr<IAsyncExecutionCallback> cb) +{ + // Group execution parameters so that they can be easily added to the queue + ExecutionTuple groupExecParams = std::make_tuple(inputTensors, outputTensors, cb); + std::shared_ptr<ExecutionTuple> operation = make_shared<ExecutionTuple>(groupExecParams); + + // Add a message to the queue and notify the request thread + std::unique_lock<std::mutex> lock(m_ThreadPoolMutex); + switch (priority) { + case QosExecPriority::High: + m_HighPriorityQueue.push(operation); + break; + case QosExecPriority::Low: + m_LowPriorityQueue.push(operation); + break; + case QosExecPriority::Medium: + default: + m_MediumPriorityQueue.push(operation); + } + m_ThreadPoolEvent.notify_one(); +} + +void LoadedNetwork::ProcessExecPriorities(std::unique_ptr<IWorkingMemHandle> workingMemHandle) +{ + int expireRate = EXPIRE_RATE; + int highPriorityCount = 0; + int mediumPriorityCount = 0; + + IWorkingMemHandle& workingMemHandleRef = *workingMemHandle.get(); + + while (true) + { + std::shared_ptr<ExecutionTuple> currentExecInProgress(nullptr); + { + // Wait for a message to be added to the queue + // This is in a separate scope to minimise the lifetime of the lock + std::unique_lock<std::mutex> lock(m_ThreadPoolMutex); + + m_ThreadPoolEvent.wait(lock, + [=] { + return m_TerminatePool || !m_HighPriorityQueue.empty() || + !m_MediumPriorityQueue.empty() || !m_LowPriorityQueue.empty(); + }); + + if (m_TerminatePool && m_HighPriorityQueue.empty() && m_MediumPriorityQueue.empty() && + m_LowPriorityQueue.empty()) + { + break; + } + + // Get the message to process from the front of each queue based on priority from high to low + // Get high priority first if it does not exceed the expire rate + if (!m_HighPriorityQueue.empty() && highPriorityCount < expireRate) + { + currentExecInProgress = m_HighPriorityQueue.front(); + m_HighPriorityQueue.pop(); + highPriorityCount += 1; + } + // If high priority queue is empty or the count exceeds the expire rate, get medium priority message + else if (!m_MediumPriorityQueue.empty() && mediumPriorityCount < expireRate) + { + currentExecInProgress = m_MediumPriorityQueue.front(); + m_MediumPriorityQueue.pop(); + mediumPriorityCount += 1; + // Reset high priority count + highPriorityCount = 0; + } + // If medium priority queue is empty or the count exceeds the expire rate, get low priority message + else if (!m_LowPriorityQueue.empty()) + { + currentExecInProgress = m_LowPriorityQueue.front(); + m_LowPriorityQueue.pop(); + // Reset high and medium priority count + highPriorityCount = 0; + mediumPriorityCount = 0; + } + else + { + // Reset high and medium priority count + highPriorityCount = 0; + mediumPriorityCount = 0; + continue; + } + } + + // invoke the asynchronous execution method + auto inputTensors = std::get<0>(*currentExecInProgress); + auto outputTensors = std::get<1>(*currentExecInProgress); + auto cb = std::get<2>(*currentExecInProgress); + + // Get time at start of inference + HighResolutionClock startTime = armnn::GetTimeNow(); + + try // executing the inference + { + // Execute and populate the time at end of inference in the callback + Execute(inputTensors, outputTensors, workingMemHandleRef) == Status::Success ? + cb->Notify(Status::Success, std::make_pair(startTime, armnn::GetTimeNow())) : + cb->Notify(Status::Failure, std::make_pair(startTime, armnn::GetTimeNow())); + } + catch (const RuntimeException& error) + { + cb->Notify(Status::Failure, std::make_pair(startTime, armnn::GetTimeNow())); + } + } +} + void LoadedNetwork::EnqueueInput(const BindableLayer& layer, const ConstTensor& inputTensor, WorkingMemHandle& context) @@ -1096,6 +1249,7 @@ Status LoadedNetwork::Execute(const InputTensors& inputTensors, EnqueueOutput(*outputLayer, GetOutputTensor(outputLayer->GetBindingId(), outputTensors), workingMemHandle); } } + return executionSucceeded ? Status::Success : Status::Failure; } |