diff options
Diffstat (limited to 'src/armnn/Threadpool.cpp')
-rw-r--r-- | src/armnn/Threadpool.cpp | 206 |
1 files changed, 206 insertions, 0 deletions
diff --git a/src/armnn/Threadpool.cpp b/src/armnn/Threadpool.cpp new file mode 100644 index 0000000000..a23c1e2339 --- /dev/null +++ b/src/armnn/Threadpool.cpp @@ -0,0 +1,206 @@ +// +// Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + + +#include <armnn/Threadpool.hpp> + +#include <armnn/utility/Timer.hpp> + +namespace armnn +{ +namespace experimental +{ + +Threadpool::Threadpool(std::size_t numThreads, + IRuntime* runtimePtr, + std::vector<std::shared_ptr<IWorkingMemHandle>> memHandles) + : m_RuntimePtr(runtimePtr) +{ + for (auto i = 0u; i < numThreads; ++i) + { + m_Threads.emplace_back(std::make_unique<std::thread>(&Threadpool::ProcessExecPriorities, this, i)); + } + + LoadMemHandles(memHandles); +} + +void Threadpool::LoadMemHandles(std::vector<std::shared_ptr<IWorkingMemHandle>> memHandles) +{ + if (memHandles.size() == 0) + { + throw armnn::RuntimeException("Threadpool::UnloadMemHandles: Size of memHandles vector must be greater than 0"); + } + + if (memHandles.size() != m_Threads.size()) + { + throw armnn::RuntimeException( + "Threadpool::UnloadMemHandles: Size of memHandles vector must match the number of threads"); + } + + NetworkId networkId = memHandles[0]->GetNetworkId(); + for (uint32_t i = 1; i < memHandles.size(); ++i) + { + if (networkId != memHandles[i]->GetNetworkId()) + { + throw armnn::RuntimeException( + "Threadpool::UnloadMemHandles: All network ids must be identical in memHandles"); + } + } + + std::pair<NetworkId, std::vector<std::shared_ptr<IWorkingMemHandle>>> pair {networkId, memHandles}; + + m_WorkingMemHandleMap.insert(pair); +} + +void Threadpool::UnloadMemHandles(NetworkId networkId) +{ + if (m_WorkingMemHandleMap.find(networkId) != m_WorkingMemHandleMap.end()) + { + m_WorkingMemHandleMap.erase(networkId); + } + else + { + throw armnn::RuntimeException("Threadpool::UnloadMemHandles: Unknown NetworkId"); + } +} + +void Threadpool::Schedule(NetworkId networkId, + const InputTensors& inputTensors, + const OutputTensors& outputTensors, + const QosExecPriority priority, + std::shared_ptr<IAsyncExecutionCallback> cb) +{ + if (m_WorkingMemHandleMap.find(networkId) == m_WorkingMemHandleMap.end()) + { + throw armnn::RuntimeException("Threadpool::UnloadMemHandles: Unknown NetworkId"); + } + + // Group execution parameters so that they can be easily added to the queue + ExecutionTuple groupExecParams = std::make_tuple(networkId, inputTensors, outputTensors, cb); + + std::shared_ptr<ExecutionTuple> operation = std::make_shared<ExecutionTuple>(groupExecParams); + + // Add a message to the queue and notify the request thread + std::unique_lock<std::mutex> lock(m_ThreadPoolMutex); + switch (priority) + { + case QosExecPriority::High: + m_HighPriorityQueue.push(operation); + break; + case QosExecPriority::Low: + m_LowPriorityQueue.push(operation); + break; + case QosExecPriority::Medium: + default: + m_MediumPriorityQueue.push(operation); + } + m_ThreadPoolEvent.notify_one(); +} + +void Threadpool::TerminateThreadPool() noexcept +{ + { + std::unique_lock<std::mutex> threadPoolLock(m_ThreadPoolMutex); + m_TerminatePool = true; + } + + m_ThreadPoolEvent.notify_all(); + + for (auto &thread : m_Threads) + { + thread->join(); + } +} + +void Threadpool::ProcessExecPriorities(uint32_t index) +{ + int expireRate = EXPIRE_RATE; + int highPriorityCount = 0; + int mediumPriorityCount = 0; + + while (true) + { + std::shared_ptr<ExecutionTuple> currentExecInProgress(nullptr); + { + // Wait for a message to be added to the queue + // This is in a separate scope to minimise the lifetime of the lock + std::unique_lock<std::mutex> lock(m_ThreadPoolMutex); + + m_ThreadPoolEvent.wait(lock, + [=] + { + return m_TerminatePool || !m_HighPriorityQueue.empty() || + !m_MediumPriorityQueue.empty() || !m_LowPriorityQueue.empty(); + }); + + if (m_TerminatePool && m_HighPriorityQueue.empty() && m_MediumPriorityQueue.empty() && + m_LowPriorityQueue.empty()) + { + break; + } + + // Get the message to process from the front of each queue based on priority from high to low + // Get high priority first if it does not exceed the expire rate + if (!m_HighPriorityQueue.empty() && highPriorityCount < expireRate) + { + currentExecInProgress = m_HighPriorityQueue.front(); + m_HighPriorityQueue.pop(); + highPriorityCount += 1; + } + // If high priority queue is empty or the count exceeds the expire rate, get medium priority message + else if (!m_MediumPriorityQueue.empty() && mediumPriorityCount < expireRate) + { + currentExecInProgress = m_MediumPriorityQueue.front(); + m_MediumPriorityQueue.pop(); + mediumPriorityCount += 1; + // Reset high priority count + highPriorityCount = 0; + } + // If medium priority queue is empty or the count exceeds the expire rate, get low priority message + else if (!m_LowPriorityQueue.empty()) + { + currentExecInProgress = m_LowPriorityQueue.front(); + m_LowPriorityQueue.pop(); + // Reset high and medium priority count + highPriorityCount = 0; + mediumPriorityCount = 0; + } + else + { + // Reset high and medium priority count + highPriorityCount = 0; + mediumPriorityCount = 0; + continue; + } + } + + // invoke the asynchronous execution method + auto networkId = std::get<0>(*currentExecInProgress); + auto inputTensors = std::get<1>(*currentExecInProgress); + auto outputTensors = std::get<2>(*currentExecInProgress); + auto cb = std::get<3>(*currentExecInProgress); + + // Get time at start of inference + HighResolutionClock startTime = armnn::GetTimeNow(); + + try // executing the inference + { + IWorkingMemHandle& memHandle = *(m_WorkingMemHandleMap.at(networkId))[index]; + + // Execute and populate the time at end of inference in the callback + m_RuntimePtr->Execute(memHandle, inputTensors, outputTensors) == Status::Success ? + cb->Notify(Status::Success, std::make_pair(startTime, armnn::GetTimeNow())) : + cb->Notify(Status::Failure, std::make_pair(startTime, armnn::GetTimeNow())); + } + catch (const RuntimeException &error) + { + cb->Notify(Status::Failure, std::make_pair(startTime, armnn::GetTimeNow())); + } + } +} + +} // namespace experimental + +} // namespace armnn |