diff options
author | Finn Williams <Finn.Williams@arm.com> | 2021-06-09 17:07:33 +0100 |
---|---|---|
committer | Finn Williams <Finn.Williams@arm.com> | 2021-06-23 17:14:53 +0100 |
commit | f364d5391b08e9071cd965f5765385ec9156b652 (patch) | |
tree | 1ea93ed574a3eb51f5a1f4bb08dc1ad18aa1c6a2 /src/armnn/LoadedNetwork.cpp | |
parent | 7a00eaa6ecf121623823b1951c0e6c9093271adf (diff) | |
download | armnn-f364d5391b08e9071cd965f5765385ec9156b652.tar.gz |
IVGCVSW-6062 Rework the async threadpool
!android-nn-driver:5802
* Extract the threadpool from LoadedNetwork/Runtime
* Refactor the threadpool to be handle multiple networks
* Trim IAsyncExecutionCallback and add an InferenceId to AsyncExecutionCallback
* Add AsyncCallbackManager class
Signed-off-by: Finn Williams <Finn.Williams@arm.com>
Change-Id: I36aa2ad29c16bc10ee0706adfeb6b27f60012afb
Diffstat (limited to 'src/armnn/LoadedNetwork.cpp')
-rw-r--r-- | src/armnn/LoadedNetwork.cpp | 159 |
1 files changed, 3 insertions, 156 deletions
diff --git a/src/armnn/LoadedNetwork.cpp b/src/armnn/LoadedNetwork.cpp index 17cc8dcc23..13beb13a07 100644 --- a/src/armnn/LoadedNetwork.cpp +++ b/src/armnn/LoadedNetwork.cpp @@ -6,7 +6,6 @@ #include "LoadedNetwork.hpp" #include "Layer.hpp" #include "Graph.hpp" -#include "Network.hpp" #include <Processes.hpp> #include "Profiling.hpp" #include "HeapProfiling.hpp" @@ -22,7 +21,6 @@ #include <backendsCommon/MemSyncWorkload.hpp> #include <fmt/format.h> -#include <armnn/utility/Timer.hpp> namespace armnn { @@ -83,8 +81,7 @@ void AddWorkloadStructure(std::unique_ptr<TimelineUtilityMethods>& timelineUtils std::unique_ptr<LoadedNetwork> LoadedNetwork::MakeLoadedNetwork(std::unique_ptr<IOptimizedNetwork> net, std::string& errorMessage, const INetworkProperties& networkProperties, - profiling::ProfilingService& profilingService, - const NetworkId networkIdOut) + profiling::ProfilingService& profilingService) { std::unique_ptr<LoadedNetwork> loadedNetwork; @@ -98,7 +95,7 @@ std::unique_ptr<LoadedNetwork> LoadedNetwork::MakeLoadedNetwork(std::unique_ptr< try { - loadedNetwork.reset(new LoadedNetwork(std::move(net), networkProperties, profilingService, networkIdOut)); + loadedNetwork.reset(new LoadedNetwork(std::move(net), networkProperties, profilingService)); } catch (const armnn::RuntimeException& error) { @@ -118,11 +115,9 @@ std::unique_ptr<LoadedNetwork> LoadedNetwork::MakeLoadedNetwork(std::unique_ptr< LoadedNetwork::LoadedNetwork(std::unique_ptr<IOptimizedNetwork> net, const INetworkProperties& networkProperties, - profiling::ProfilingService& profilingService, - const NetworkId networkId) : + profiling::ProfilingService& profilingService) : m_OptimizedNetwork(std::move(net)), m_NetworkProperties(networkProperties), - m_NetworkId(networkId), m_TensorHandleFactoryRegistry(), m_ProfilingService(profilingService) { @@ -304,13 +299,6 @@ LoadedNetwork::LoadedNetwork(std::unique_ptr<IOptimizedNetwork> net, { AllocateAndExecuteConstantWorkloads(); } - - // Create the thread pool which will have working memory handles assigned to each thread - // Should occur last so all factories and constant layer tensor handles are created - if (m_NetworkProperties.m_NumThreads > 0 && networkProperties.m_AsyncEnabled) - { - CreateThreadPool(m_NetworkProperties.m_NumThreads); - } } void LoadedNetwork::AllocateAndExecuteConstantWorkloads() @@ -856,147 +844,6 @@ bool LoadedNetwork::Execute(std::unique_ptr<TimelineUtilityMethods>& timelineUti return success; } -void LoadedNetwork::CreateThreadPool(std::size_t numThreads) -{ - - for (auto i = 0u; i < numThreads; ++i) - { - std::unique_ptr<IWorkingMemHandle> workingMemHandle = CreateWorkingMemHandle(m_NetworkId); - m_Threads.emplace_back( - std::make_unique<std::thread>( - &LoadedNetwork::ProcessExecPriorities, - this, - std::move(workingMemHandle) - ) - ); - } -} - -void LoadedNetwork::TerminateThreadPool() noexcept -{ - { - std::unique_lock<std::mutex> threadPoolLock(m_ThreadPoolMutex); - m_TerminatePool = true; - } - - m_ThreadPoolEvent.notify_all(); - - for (auto &thread : m_Threads) - { - thread->join(); - } -} - -void LoadedNetwork::Schedule(const InputTensors& inputTensors, - const OutputTensors& outputTensors, - const QosExecPriority priority, - std::shared_ptr<IAsyncExecutionCallback> cb) -{ - // Group execution parameters so that they can be easily added to the queue - ExecutionTuple groupExecParams = std::make_tuple(inputTensors, outputTensors, cb); - std::shared_ptr<ExecutionTuple> operation = make_shared<ExecutionTuple>(groupExecParams); - - // Add a message to the queue and notify the request thread - std::unique_lock<std::mutex> lock(m_ThreadPoolMutex); - switch (priority) { - case QosExecPriority::High: - m_HighPriorityQueue.push(operation); - break; - case QosExecPriority::Low: - m_LowPriorityQueue.push(operation); - break; - case QosExecPriority::Medium: - default: - m_MediumPriorityQueue.push(operation); - } - m_ThreadPoolEvent.notify_one(); -} - -void LoadedNetwork::ProcessExecPriorities(std::unique_ptr<IWorkingMemHandle> workingMemHandle) -{ - int expireRate = EXPIRE_RATE; - int highPriorityCount = 0; - int mediumPriorityCount = 0; - - IWorkingMemHandle& workingMemHandleRef = *workingMemHandle.get(); - - while (true) - { - std::shared_ptr<ExecutionTuple> currentExecInProgress(nullptr); - { - // Wait for a message to be added to the queue - // This is in a separate scope to minimise the lifetime of the lock - std::unique_lock<std::mutex> lock(m_ThreadPoolMutex); - - m_ThreadPoolEvent.wait(lock, - [=] { - return m_TerminatePool || !m_HighPriorityQueue.empty() || - !m_MediumPriorityQueue.empty() || !m_LowPriorityQueue.empty(); - }); - - if (m_TerminatePool && m_HighPriorityQueue.empty() && m_MediumPriorityQueue.empty() && - m_LowPriorityQueue.empty()) - { - break; - } - - // Get the message to process from the front of each queue based on priority from high to low - // Get high priority first if it does not exceed the expire rate - if (!m_HighPriorityQueue.empty() && highPriorityCount < expireRate) - { - currentExecInProgress = m_HighPriorityQueue.front(); - m_HighPriorityQueue.pop(); - highPriorityCount += 1; - } - // If high priority queue is empty or the count exceeds the expire rate, get medium priority message - else if (!m_MediumPriorityQueue.empty() && mediumPriorityCount < expireRate) - { - currentExecInProgress = m_MediumPriorityQueue.front(); - m_MediumPriorityQueue.pop(); - mediumPriorityCount += 1; - // Reset high priority count - highPriorityCount = 0; - } - // If medium priority queue is empty or the count exceeds the expire rate, get low priority message - else if (!m_LowPriorityQueue.empty()) - { - currentExecInProgress = m_LowPriorityQueue.front(); - m_LowPriorityQueue.pop(); - // Reset high and medium priority count - highPriorityCount = 0; - mediumPriorityCount = 0; - } - else - { - // Reset high and medium priority count - highPriorityCount = 0; - mediumPriorityCount = 0; - continue; - } - } - - // invoke the asynchronous execution method - auto inputTensors = std::get<0>(*currentExecInProgress); - auto outputTensors = std::get<1>(*currentExecInProgress); - auto cb = std::get<2>(*currentExecInProgress); - - // Get time at start of inference - HighResolutionClock startTime = armnn::GetTimeNow(); - - try // executing the inference - { - // Execute and populate the time at end of inference in the callback - Execute(inputTensors, outputTensors, workingMemHandleRef) == Status::Success ? - cb->Notify(Status::Success, std::make_pair(startTime, armnn::GetTimeNow())) : - cb->Notify(Status::Failure, std::make_pair(startTime, armnn::GetTimeNow())); - } - catch (const RuntimeException& error) - { - cb->Notify(Status::Failure, std::make_pair(startTime, armnn::GetTimeNow())); - } - } -} - void LoadedNetwork::EnqueueInput(const BindableLayer& layer, const ConstTensor& inputTensor, WorkingMemHandle& context) |