From f364d5391b08e9071cd965f5765385ec9156b652 Mon Sep 17 00:00:00 2001 From: Finn Williams Date: Wed, 9 Jun 2021 17:07:33 +0100 Subject: IVGCVSW-6062 Rework the async threadpool !android-nn-driver:5802 * Extract the threadpool from LoadedNetwork/Runtime * Refactor the threadpool to be handle multiple networks * Trim IAsyncExecutionCallback and add an InferenceId to AsyncExecutionCallback * Add AsyncCallbackManager class Signed-off-by: Finn Williams Change-Id: I36aa2ad29c16bc10ee0706adfeb6b27f60012afb --- Android.mk | 1 + CMakeLists.txt | 2 + include/armnn/IAsyncExecutionCallback.hpp | 13 -- include/armnn/IRuntime.hpp | 31 ++-- include/armnn/IWorkingMemHandle.hpp | 3 - include/armnn/Threadpool.hpp | 78 ++++++++ src/armnn/AsyncExecutionCallback.cpp | 38 ++-- src/armnn/AsyncExecutionCallback.hpp | 50 ++++- src/armnn/LoadedNetwork.cpp | 159 +--------------- src/armnn/LoadedNetwork.hpp | 42 +---- src/armnn/Runtime.cpp | 40 +--- src/armnn/Runtime.hpp | 10 - src/armnn/Threadpool.cpp | 206 +++++++++++++++++++++ src/armnn/WorkingMemHandle.cpp | 3 +- src/armnn/WorkingMemHandle.hpp | 7 +- .../test/StridedSliceAsyncEndToEndTest.hpp | 31 ++-- tests/ExecuteNetwork/ExecuteNetwork.cpp | 27 ++- tests/InferenceModel.hpp | 38 ++-- 18 files changed, 434 insertions(+), 345 deletions(-) create mode 100644 include/armnn/Threadpool.hpp create mode 100644 src/armnn/Threadpool.cpp diff --git a/Android.mk b/Android.mk index aec32699b0..f3e4f40534 100644 --- a/Android.mk +++ b/Android.mk @@ -132,6 +132,7 @@ LOCAL_SRC_FILES := \ src/armnn/SubgraphView.cpp \ src/armnn/SubgraphViewSelector.cpp \ src/armnn/Tensor.cpp \ + src/armnn/Threadpool.cpp \ src/armnn/TypesUtils.cpp \ src/armnn/Utils.cpp \ src/armnn/WallClockTimer.cpp \ diff --git a/CMakeLists.txt b/CMakeLists.txt index 426c089bd6..78c2f17e6d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -180,6 +180,7 @@ list(APPEND armnn_sources include/armnn/QuantizedLstmParams.hpp include/armnn/Tensor.hpp include/armnn/TensorFwd.hpp + include/armnn/Threadpool.hpp include/armnn/Types.hpp include/armnn/TypesUtils.hpp include/armnn/Utils.hpp @@ -387,6 +388,7 @@ list(APPEND armnn_sources src/armnn/SubgraphViewSelector.cpp src/armnn/SubgraphViewSelector.hpp src/armnn/Tensor.cpp + src/armnn/Threadpool.cpp src/armnn/TypesUtils.cpp src/armnn/Utils.cpp src/armnn/WallClockTimer.cpp diff --git a/include/armnn/IAsyncExecutionCallback.hpp b/include/armnn/IAsyncExecutionCallback.hpp index 045ec4581f..3e0cacccee 100644 --- a/include/armnn/IAsyncExecutionCallback.hpp +++ b/include/armnn/IAsyncExecutionCallback.hpp @@ -23,19 +23,6 @@ public: // Notify the AsyncExecutionCallback object of the armnn execution status virtual void Notify(armnn::Status status, InferenceTimingPair timeTaken) = 0; - - // Block the calling thread until the AsyncExecutionCallback object allows it to proceed - virtual void Wait() const = 0; - - // Retrieve the ArmNN Status from the AsyncExecutionCallback that has been notified - virtual armnn::Status GetStatus() const = 0; - - // Retrieve the start time before executing the inference - virtual HighResolutionClock GetStartTime() const = 0; - - // Retrieve the time after executing the inference - virtual HighResolutionClock GetEndTime() const = 0; - }; } // experimental diff --git a/include/armnn/IRuntime.hpp b/include/armnn/IRuntime.hpp index bfc13c9c01..f88b6b664f 100644 --- a/include/armnn/IRuntime.hpp +++ b/include/armnn/IRuntime.hpp @@ -32,24 +32,34 @@ struct INetworkProperties ARMNN_DEPRECATED_MSG("Please use INetworkProperties constructor with MemorySource argument") INetworkProperties(bool importEnabled = false, bool exportEnabled = false, - bool asyncEnabled = false, - size_t numThreads = 1) + bool asyncEnabled = false) : m_ImportEnabled(importEnabled) , m_ExportEnabled(exportEnabled) , m_AsyncEnabled(asyncEnabled) - , m_NumThreads(numThreads) , m_InputSource(m_ImportEnabled ? MemorySource::Malloc : MemorySource::Undefined) , m_OutputSource(m_ExportEnabled ? MemorySource::Malloc : MemorySource::Undefined) {} + ARMNN_DEPRECATED_MSG("Please use INetworkProperties constructor without numThreads argument") INetworkProperties(bool asyncEnabled, MemorySource m_InputSource, MemorySource m_OutputSource, - size_t numThreads = 1) + size_t numThreads) + : m_ImportEnabled(m_InputSource != MemorySource::Undefined) + , m_ExportEnabled(m_OutputSource != MemorySource::Undefined) + , m_AsyncEnabled(asyncEnabled) + , m_InputSource(m_InputSource) + , m_OutputSource(m_OutputSource) + { + armnn::IgnoreUnused(numThreads); + } + + INetworkProperties(bool asyncEnabled, + MemorySource m_InputSource, + MemorySource m_OutputSource) : m_ImportEnabled(m_InputSource != MemorySource::Undefined) , m_ExportEnabled(m_OutputSource != MemorySource::Undefined) , m_AsyncEnabled(asyncEnabled) - , m_NumThreads(numThreads) , m_InputSource(m_InputSource) , m_OutputSource(m_OutputSource) {} @@ -60,7 +70,6 @@ struct INetworkProperties const bool m_ExportEnabled; const bool m_AsyncEnabled; - const size_t m_NumThreads; const MemorySource m_InputSource; const MemorySource m_OutputSource; @@ -191,16 +200,6 @@ public: const InputTensors& inputTensors, const OutputTensors& outputTensors); - /// This is an experimental function - /// Schedule a thread safe execution by taking the input tensors and an execution priority for Quality of Service. - /// The output tensors will then be filled and the callback object will notify that the execution has either - /// succeeded or failed. - void Schedule(NetworkId networkId, - const InputTensors& inputTensors, - const OutputTensors& outputTensors, - const QosExecPriority priority, - std::shared_ptr callback); - /// Unloads a network from the IRuntime. /// At the moment this only removes the network from the m_Impl->m_Network. /// This might need more work in the future to be AndroidNN compliant. diff --git a/include/armnn/IWorkingMemHandle.hpp b/include/armnn/IWorkingMemHandle.hpp index 6fb2f9fe5f..171fa3d81c 100644 --- a/include/armnn/IWorkingMemHandle.hpp +++ b/include/armnn/IWorkingMemHandle.hpp @@ -25,9 +25,6 @@ public: /// Returns the NetworkId of the Network that this IWorkingMemHandle works with. virtual NetworkId GetNetworkId() = 0; - /// Returns the InferenceId of the Inference that this IWorkingMemHandle works with. - virtual profiling::ProfilingGuid GetInferenceId() = 0; - /// Allocate the backing memory required for execution. If this is not called, then allocation will be /// deferred to execution time. The mutex must be locked. virtual void Allocate() = 0; diff --git a/include/armnn/Threadpool.hpp b/include/armnn/Threadpool.hpp new file mode 100644 index 0000000000..e2458dbb65 --- /dev/null +++ b/include/armnn/Threadpool.hpp @@ -0,0 +1,78 @@ +// +// Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include +#include + +#include "INetwork.hpp" +#include "IRuntime.hpp" + +#include +#include +#include +#include +#include + +namespace armnn +{ +namespace experimental +{ +class Threadpool +{ +public: + Threadpool(std::size_t numThreads, + IRuntime* runtimePtr, + std::vector> memHandles); + + ~Threadpool() + { + TerminateThreadPool(); + } + + void LoadMemHandles(std::vector> memHandles); + void UnloadMemHandles(NetworkId networkId); + + /// Schedule an asynchronous execution on the loaded network + void Schedule(NetworkId networkId, + const InputTensors &inputTensors, + const OutputTensors &outputTensors, + const QosExecPriority priority, + std::shared_ptr cb); + + void TerminateThreadPool() noexcept; + +private: + using ExecutionTuple = std::tuple>; + + using ExecutionQueue = std::queue>; + + void ProcessExecPriorities(uint32_t index); + + IRuntime* m_RuntimePtr; + + ExecutionQueue m_HighPriorityQueue; + ExecutionQueue m_MediumPriorityQueue; + ExecutionQueue m_LowPriorityQueue; + + // Condition Variables require mutex which will guard the shared state. + // Has an event happened? Stop signal for example + std::condition_variable m_ThreadPoolEvent; + std::mutex m_ThreadPoolMutex; + + // The shared state for conditional variable + bool m_TerminatePool = false; + + std::unordered_map>> m_WorkingMemHandleMap; + std::vector> m_Threads; +}; + +} // namespace experimental + +} // namespace armnn diff --git a/src/armnn/AsyncExecutionCallback.cpp b/src/armnn/AsyncExecutionCallback.cpp index c44808918d..2973e2d891 100644 --- a/src/armnn/AsyncExecutionCallback.cpp +++ b/src/armnn/AsyncExecutionCallback.cpp @@ -15,43 +15,53 @@ void AsyncExecutionCallback::Notify(armnn::Status status, InferenceTimingPair ti { { std::lock_guard hold(m_Mutex); - if (m_Notified) - { - return; - } // store results and mark as notified m_Status = status; m_StartTime = timeTaken.first; m_EndTime = timeTaken.second; - m_Notified = true; + m_NotificationQueue.push(m_InferenceId); } m_Condition.notify_all(); } -void AsyncExecutionCallback::Wait() const -{ - std::unique_lock lock(m_Mutex); - m_Condition.wait(lock, [this] { return m_Notified; }); -} - armnn::Status AsyncExecutionCallback::GetStatus() const { - Wait(); return m_Status; } HighResolutionClock AsyncExecutionCallback::GetStartTime() const { - Wait(); return m_StartTime; } HighResolutionClock AsyncExecutionCallback::GetEndTime() const { - Wait(); return m_EndTime; } +std::shared_ptr AsyncCallbackManager::GetNewCallback() +{ + auto cb = std::make_unique(m_NotificationQueue, m_Mutex, m_Condition); + InferenceId id = cb->GetInferenceId(); + m_Callbacks.insert({id, std::move(cb)}); + + return m_Callbacks.at(id); +} + +std::shared_ptr AsyncCallbackManager::GetNotifiedCallback() +{ + std::unique_lock lock(m_Mutex); + + m_Condition.wait(lock, [this] { return !m_NotificationQueue.empty(); }); + + InferenceId id = m_NotificationQueue.front(); + m_NotificationQueue.pop(); + + std::shared_ptr callback = m_Callbacks.at(id); + m_Callbacks.erase(id); + return callback; +} + } // namespace experimental } // namespace armnn \ No newline at end of file diff --git a/src/armnn/AsyncExecutionCallback.hpp b/src/armnn/AsyncExecutionCallback.hpp index c17b839748..2ff73b3efb 100644 --- a/src/armnn/AsyncExecutionCallback.hpp +++ b/src/armnn/AsyncExecutionCallback.hpp @@ -6,11 +6,14 @@ #pragma once #include +#include #include -#include +#include #include #include +#include +#include namespace armnn { @@ -18,29 +21,62 @@ namespace armnn namespace experimental { +using InferenceId = uint64_t; class AsyncExecutionCallback final : public IAsyncExecutionCallback { +private: + static InferenceId nextID; + public: - AsyncExecutionCallback() + AsyncExecutionCallback(std::queue& notificationQueue, + std::mutex& mutex, + std::condition_variable& condition) + : m_NotificationQueue(notificationQueue) + , m_Mutex(mutex) + , m_Condition(condition) + , m_InferenceId(++nextID) {} + ~AsyncExecutionCallback() {} void Notify(armnn::Status status, InferenceTimingPair timeTaken); - void Wait() const; + + InferenceId GetInferenceId() + { + return m_InferenceId; + } armnn::Status GetStatus() const; HighResolutionClock GetStartTime() const; HighResolutionClock GetEndTime() const; private: - mutable std::mutex m_Mutex; - mutable std::condition_variable m_Condition; + std::queue& m_NotificationQueue; + std::mutex& m_Mutex; + std::condition_variable& m_Condition; HighResolutionClock m_StartTime; HighResolutionClock m_EndTime; - armnn::Status m_Status = Status::Failure; - bool m_Notified = false; + armnn::Status m_Status = Status::Failure; + InferenceId m_InferenceId; +}; +InferenceId AsyncExecutionCallback::nextID = 0u; + +// Manager to create and monitor AsyncExecutionCallbacks +// GetNewCallback will create a callback for use in Threadpool::Schedule +// GetNotifiedCallback will return the first callback to be notified (finished execution) +class AsyncCallbackManager +{ +public: + std::shared_ptr GetNewCallback(); + std::shared_ptr GetNotifiedCallback(); + +private: + std::mutex m_Mutex; + std::condition_variable m_Condition; + std::unordered_map> m_Callbacks; + std::queue m_NotificationQueue; }; } // namespace experimental diff --git a/src/armnn/LoadedNetwork.cpp b/src/armnn/LoadedNetwork.cpp index 17cc8dcc23..13beb13a07 100644 --- a/src/armnn/LoadedNetwork.cpp +++ b/src/armnn/LoadedNetwork.cpp @@ -6,7 +6,6 @@ #include "LoadedNetwork.hpp" #include "Layer.hpp" #include "Graph.hpp" -#include "Network.hpp" #include #include "Profiling.hpp" #include "HeapProfiling.hpp" @@ -22,7 +21,6 @@ #include #include -#include namespace armnn { @@ -83,8 +81,7 @@ void AddWorkloadStructure(std::unique_ptr& timelineUtils std::unique_ptr LoadedNetwork::MakeLoadedNetwork(std::unique_ptr net, std::string& errorMessage, const INetworkProperties& networkProperties, - profiling::ProfilingService& profilingService, - const NetworkId networkIdOut) + profiling::ProfilingService& profilingService) { std::unique_ptr loadedNetwork; @@ -98,7 +95,7 @@ std::unique_ptr LoadedNetwork::MakeLoadedNetwork(std::unique_ptr< try { - loadedNetwork.reset(new LoadedNetwork(std::move(net), networkProperties, profilingService, networkIdOut)); + loadedNetwork.reset(new LoadedNetwork(std::move(net), networkProperties, profilingService)); } catch (const armnn::RuntimeException& error) { @@ -118,11 +115,9 @@ std::unique_ptr LoadedNetwork::MakeLoadedNetwork(std::unique_ptr< LoadedNetwork::LoadedNetwork(std::unique_ptr net, const INetworkProperties& networkProperties, - profiling::ProfilingService& profilingService, - const NetworkId networkId) : + profiling::ProfilingService& profilingService) : m_OptimizedNetwork(std::move(net)), m_NetworkProperties(networkProperties), - m_NetworkId(networkId), m_TensorHandleFactoryRegistry(), m_ProfilingService(profilingService) { @@ -304,13 +299,6 @@ LoadedNetwork::LoadedNetwork(std::unique_ptr net, { AllocateAndExecuteConstantWorkloads(); } - - // Create the thread pool which will have working memory handles assigned to each thread - // Should occur last so all factories and constant layer tensor handles are created - if (m_NetworkProperties.m_NumThreads > 0 && networkProperties.m_AsyncEnabled) - { - CreateThreadPool(m_NetworkProperties.m_NumThreads); - } } void LoadedNetwork::AllocateAndExecuteConstantWorkloads() @@ -856,147 +844,6 @@ bool LoadedNetwork::Execute(std::unique_ptr& timelineUti return success; } -void LoadedNetwork::CreateThreadPool(std::size_t numThreads) -{ - - for (auto i = 0u; i < numThreads; ++i) - { - std::unique_ptr workingMemHandle = CreateWorkingMemHandle(m_NetworkId); - m_Threads.emplace_back( - std::make_unique( - &LoadedNetwork::ProcessExecPriorities, - this, - std::move(workingMemHandle) - ) - ); - } -} - -void LoadedNetwork::TerminateThreadPool() noexcept -{ - { - std::unique_lock threadPoolLock(m_ThreadPoolMutex); - m_TerminatePool = true; - } - - m_ThreadPoolEvent.notify_all(); - - for (auto &thread : m_Threads) - { - thread->join(); - } -} - -void LoadedNetwork::Schedule(const InputTensors& inputTensors, - const OutputTensors& outputTensors, - const QosExecPriority priority, - std::shared_ptr cb) -{ - // Group execution parameters so that they can be easily added to the queue - ExecutionTuple groupExecParams = std::make_tuple(inputTensors, outputTensors, cb); - std::shared_ptr operation = make_shared(groupExecParams); - - // Add a message to the queue and notify the request thread - std::unique_lock lock(m_ThreadPoolMutex); - switch (priority) { - case QosExecPriority::High: - m_HighPriorityQueue.push(operation); - break; - case QosExecPriority::Low: - m_LowPriorityQueue.push(operation); - break; - case QosExecPriority::Medium: - default: - m_MediumPriorityQueue.push(operation); - } - m_ThreadPoolEvent.notify_one(); -} - -void LoadedNetwork::ProcessExecPriorities(std::unique_ptr workingMemHandle) -{ - int expireRate = EXPIRE_RATE; - int highPriorityCount = 0; - int mediumPriorityCount = 0; - - IWorkingMemHandle& workingMemHandleRef = *workingMemHandle.get(); - - while (true) - { - std::shared_ptr currentExecInProgress(nullptr); - { - // Wait for a message to be added to the queue - // This is in a separate scope to minimise the lifetime of the lock - std::unique_lock lock(m_ThreadPoolMutex); - - m_ThreadPoolEvent.wait(lock, - [=] { - return m_TerminatePool || !m_HighPriorityQueue.empty() || - !m_MediumPriorityQueue.empty() || !m_LowPriorityQueue.empty(); - }); - - if (m_TerminatePool && m_HighPriorityQueue.empty() && m_MediumPriorityQueue.empty() && - m_LowPriorityQueue.empty()) - { - break; - } - - // Get the message to process from the front of each queue based on priority from high to low - // Get high priority first if it does not exceed the expire rate - if (!m_HighPriorityQueue.empty() && highPriorityCount < expireRate) - { - currentExecInProgress = m_HighPriorityQueue.front(); - m_HighPriorityQueue.pop(); - highPriorityCount += 1; - } - // If high priority queue is empty or the count exceeds the expire rate, get medium priority message - else if (!m_MediumPriorityQueue.empty() && mediumPriorityCount < expireRate) - { - currentExecInProgress = m_MediumPriorityQueue.front(); - m_MediumPriorityQueue.pop(); - mediumPriorityCount += 1; - // Reset high priority count - highPriorityCount = 0; - } - // If medium priority queue is empty or the count exceeds the expire rate, get low priority message - else if (!m_LowPriorityQueue.empty()) - { - currentExecInProgress = m_LowPriorityQueue.front(); - m_LowPriorityQueue.pop(); - // Reset high and medium priority count - highPriorityCount = 0; - mediumPriorityCount = 0; - } - else - { - // Reset high and medium priority count - highPriorityCount = 0; - mediumPriorityCount = 0; - continue; - } - } - - // invoke the asynchronous execution method - auto inputTensors = std::get<0>(*currentExecInProgress); - auto outputTensors = std::get<1>(*currentExecInProgress); - auto cb = std::get<2>(*currentExecInProgress); - - // Get time at start of inference - HighResolutionClock startTime = armnn::GetTimeNow(); - - try // executing the inference - { - // Execute and populate the time at end of inference in the callback - Execute(inputTensors, outputTensors, workingMemHandleRef) == Status::Success ? - cb->Notify(Status::Success, std::make_pair(startTime, armnn::GetTimeNow())) : - cb->Notify(Status::Failure, std::make_pair(startTime, armnn::GetTimeNow())); - } - catch (const RuntimeException& error) - { - cb->Notify(Status::Failure, std::make_pair(startTime, armnn::GetTimeNow())); - } - } -} - void LoadedNetwork::EnqueueInput(const BindableLayer& layer, const ConstTensor& inputTensor, WorkingMemHandle& context) diff --git a/src/armnn/LoadedNetwork.hpp b/src/armnn/LoadedNetwork.hpp index c85e82bbdd..360ad91170 100644 --- a/src/armnn/LoadedNetwork.hpp +++ b/src/armnn/LoadedNetwork.hpp @@ -37,16 +37,9 @@ class LoadedNetwork public: using WorkloadQueue = std::vector>; - using ExecutionTuple = std::tuple>; - - using ExecutionQueue = std::queue>; - ~LoadedNetwork() { FreeWorkingMemory(); - TerminateThreadPool(); } /// Create a new unique WorkingMemHandle object. Create multiple handles if you wish to have @@ -64,17 +57,10 @@ public: const OutputTensors& outputTensors, IWorkingMemHandle& workingMemHandle); - /// Schedule an asynchronous execution on the loaded network - void Schedule(const InputTensors& inputTensors, - const OutputTensors& outputTensors, - const QosExecPriority priority, - std::shared_ptr cb); - static std::unique_ptr MakeLoadedNetwork(std::unique_ptr net, std::string& errorMessage, const INetworkProperties& networkProperties, - profiling::ProfilingService& profilingService, - const NetworkId networkIdOut); + profiling::ProfilingService& profilingService); // NOTE we return by reference as the purpose of this method is only to provide // access to the private m_Profiler and in theory we should not need to increment @@ -108,8 +94,7 @@ private: LoadedNetwork(std::unique_ptr net, const INetworkProperties& networkProperties, - profiling::ProfilingService& profilingService, - const NetworkId networkIdOut); + profiling::ProfilingService& profilingService); void EnqueueInput(const BindableLayer& layer, ITensorHandle* tensorHandle, const TensorInfo& tensorInfo); @@ -119,15 +104,9 @@ private: void EnqueueOutput(const BindableLayer& layer, const Tensor& outputTensor, WorkingMemHandle& handle); - void ProcessExecPriorities(std::unique_ptr workingMemHandle); - bool Execute(std::unique_ptr& timelineUtils, profiling::ProfilingGuid inferenceGuid); - void CreateThreadPool(std::size_t numThreads); - - void TerminateThreadPool() noexcept; - const IWorkloadFactory& GetWorkloadFactory(const Layer& layer) const; using BackendPtrMap = std::unordered_map; @@ -146,25 +125,8 @@ private: bool m_IsWorkingMemAllocated = false; - std::vector> m_Threads; - std::stack m_WorkingMemHandles; - - ExecutionQueue m_HighPriorityQueue; - ExecutionQueue m_MediumPriorityQueue; - ExecutionQueue m_LowPriorityQueue; - - // Condition Variables require mutex which will guard the shared state. - // Has an event happened? Stop signal for example - std::condition_variable m_ThreadPoolEvent; - std::mutex m_ThreadPoolMutex; - - // The shared state for conditional variable - bool m_TerminatePool = false; - INetworkProperties m_NetworkProperties; - const NetworkId m_NetworkId; - TensorHandleFactoryRegistry m_TensorHandleFactoryRegistry; profiling::ProfilingService& m_ProfilingService; diff --git a/src/armnn/Runtime.cpp b/src/armnn/Runtime.cpp index 374064e408..f16d186191 100644 --- a/src/armnn/Runtime.cpp +++ b/src/armnn/Runtime.cpp @@ -89,15 +89,6 @@ Status IRuntime::Execute(IWorkingMemHandle& workingMemHandle, return pRuntimeImpl->Execute(workingMemHandle, inputTensors, outputTensors); } -void IRuntime::Schedule(NetworkId networkId, - const InputTensors& inputTensors, - const OutputTensors& outputTensors, - const QosExecPriority priority, - std::shared_ptr cb) -{ - pRuntimeImpl->Schedule(networkId, inputTensors, outputTensors, priority, cb); -} - Status IRuntime::UnloadNetwork(NetworkId networkId) { return pRuntimeImpl->UnloadNetwork(networkId); @@ -160,8 +151,7 @@ Status RuntimeImpl::LoadNetwork(NetworkId& networkIdOut, std::unique_ptr(rawNetwork), errorMessage, networkProperties, - m_ProfilingService, - networkIdOut); + m_ProfilingService); if (!loadedNetwork) { @@ -460,34 +450,6 @@ Status RuntimeImpl::Execute(IWorkingMemHandle& iWorkingMemHandle, return loadedNetwork->Execute(inputTensors, outputTensors, iWorkingMemHandle); } -void RuntimeImpl::Schedule(NetworkId networkId, - const InputTensors& inputTensors, - const OutputTensors& outputTensors, - const QosExecPriority priority, - std::shared_ptr callback) -{ - LoadedNetwork* loadedNetwork = GetLoadedNetworkPtr(networkId); - - if (!loadedNetwork) - { - throw armnn::Exception( - "Network with ID of " + std::to_string(networkId) + " does not exist \n" - ); - } - if (!loadedNetwork->IsAsyncEnabled()) - { - throw armnn::Exception( - "Attempting to schedule Network " + std::to_string(networkId) + " when it is not async enabled \n" - ); - } - - ProfilerManager::GetInstance().RegisterProfiler(loadedNetwork->GetProfiler().get()); - - ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "Schedule"); - - loadedNetwork->Schedule(inputTensors, outputTensors, priority, callback); -} - /// Create a new unique WorkingMemHandle object. Create multiple handles if you wish to have /// overlapped Execution by calling this function from different threads. std::unique_ptr RuntimeImpl::CreateWorkingMemHandle(NetworkId networkId) diff --git a/src/armnn/Runtime.hpp b/src/armnn/Runtime.hpp index 55a4accf67..7a80acd73e 100644 --- a/src/armnn/Runtime.hpp +++ b/src/armnn/Runtime.hpp @@ -60,16 +60,6 @@ public: const InputTensors& inputTensors, const OutputTensors& outputTensors); - /// This is an experimental function. - /// Schedule a thread safe execution by taking the input tensors and an execution priority for Quality of Service. - /// The output tensors will then be filled and the callback object will notify that the execution has either - /// succeeded or failed. - void Schedule(NetworkId networkId, - const InputTensors& inputTensors, - const OutputTensors& outputTensors, - const QosExecPriority priority, - std::shared_ptr callback); - /// This is an experimental function. /// Evaluates a network using input in inputTensors and outputs filled into outputTensors. /// This function performs a thread safe execution of the network. Returns once execution is complete. diff --git a/src/armnn/Threadpool.cpp b/src/armnn/Threadpool.cpp new file mode 100644 index 0000000000..a23c1e2339 --- /dev/null +++ b/src/armnn/Threadpool.cpp @@ -0,0 +1,206 @@ +// +// Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + + +#include + +#include + +namespace armnn +{ +namespace experimental +{ + +Threadpool::Threadpool(std::size_t numThreads, + IRuntime* runtimePtr, + std::vector> memHandles) + : m_RuntimePtr(runtimePtr) +{ + for (auto i = 0u; i < numThreads; ++i) + { + m_Threads.emplace_back(std::make_unique(&Threadpool::ProcessExecPriorities, this, i)); + } + + LoadMemHandles(memHandles); +} + +void Threadpool::LoadMemHandles(std::vector> memHandles) +{ + if (memHandles.size() == 0) + { + throw armnn::RuntimeException("Threadpool::UnloadMemHandles: Size of memHandles vector must be greater than 0"); + } + + if (memHandles.size() != m_Threads.size()) + { + throw armnn::RuntimeException( + "Threadpool::UnloadMemHandles: Size of memHandles vector must match the number of threads"); + } + + NetworkId networkId = memHandles[0]->GetNetworkId(); + for (uint32_t i = 1; i < memHandles.size(); ++i) + { + if (networkId != memHandles[i]->GetNetworkId()) + { + throw armnn::RuntimeException( + "Threadpool::UnloadMemHandles: All network ids must be identical in memHandles"); + } + } + + std::pair>> pair {networkId, memHandles}; + + m_WorkingMemHandleMap.insert(pair); +} + +void Threadpool::UnloadMemHandles(NetworkId networkId) +{ + if (m_WorkingMemHandleMap.find(networkId) != m_WorkingMemHandleMap.end()) + { + m_WorkingMemHandleMap.erase(networkId); + } + else + { + throw armnn::RuntimeException("Threadpool::UnloadMemHandles: Unknown NetworkId"); + } +} + +void Threadpool::Schedule(NetworkId networkId, + const InputTensors& inputTensors, + const OutputTensors& outputTensors, + const QosExecPriority priority, + std::shared_ptr cb) +{ + if (m_WorkingMemHandleMap.find(networkId) == m_WorkingMemHandleMap.end()) + { + throw armnn::RuntimeException("Threadpool::UnloadMemHandles: Unknown NetworkId"); + } + + // Group execution parameters so that they can be easily added to the queue + ExecutionTuple groupExecParams = std::make_tuple(networkId, inputTensors, outputTensors, cb); + + std::shared_ptr operation = std::make_shared(groupExecParams); + + // Add a message to the queue and notify the request thread + std::unique_lock lock(m_ThreadPoolMutex); + switch (priority) + { + case QosExecPriority::High: + m_HighPriorityQueue.push(operation); + break; + case QosExecPriority::Low: + m_LowPriorityQueue.push(operation); + break; + case QosExecPriority::Medium: + default: + m_MediumPriorityQueue.push(operation); + } + m_ThreadPoolEvent.notify_one(); +} + +void Threadpool::TerminateThreadPool() noexcept +{ + { + std::unique_lock threadPoolLock(m_ThreadPoolMutex); + m_TerminatePool = true; + } + + m_ThreadPoolEvent.notify_all(); + + for (auto &thread : m_Threads) + { + thread->join(); + } +} + +void Threadpool::ProcessExecPriorities(uint32_t index) +{ + int expireRate = EXPIRE_RATE; + int highPriorityCount = 0; + int mediumPriorityCount = 0; + + while (true) + { + std::shared_ptr currentExecInProgress(nullptr); + { + // Wait for a message to be added to the queue + // This is in a separate scope to minimise the lifetime of the lock + std::unique_lock lock(m_ThreadPoolMutex); + + m_ThreadPoolEvent.wait(lock, + [=] + { + return m_TerminatePool || !m_HighPriorityQueue.empty() || + !m_MediumPriorityQueue.empty() || !m_LowPriorityQueue.empty(); + }); + + if (m_TerminatePool && m_HighPriorityQueue.empty() && m_MediumPriorityQueue.empty() && + m_LowPriorityQueue.empty()) + { + break; + } + + // Get the message to process from the front of each queue based on priority from high to low + // Get high priority first if it does not exceed the expire rate + if (!m_HighPriorityQueue.empty() && highPriorityCount < expireRate) + { + currentExecInProgress = m_HighPriorityQueue.front(); + m_HighPriorityQueue.pop(); + highPriorityCount += 1; + } + // If high priority queue is empty or the count exceeds the expire rate, get medium priority message + else if (!m_MediumPriorityQueue.empty() && mediumPriorityCount < expireRate) + { + currentExecInProgress = m_MediumPriorityQueue.front(); + m_MediumPriorityQueue.pop(); + mediumPriorityCount += 1; + // Reset high priority count + highPriorityCount = 0; + } + // If medium priority queue is empty or the count exceeds the expire rate, get low priority message + else if (!m_LowPriorityQueue.empty()) + { + currentExecInProgress = m_LowPriorityQueue.front(); + m_LowPriorityQueue.pop(); + // Reset high and medium priority count + highPriorityCount = 0; + mediumPriorityCount = 0; + } + else + { + // Reset high and medium priority count + highPriorityCount = 0; + mediumPriorityCount = 0; + continue; + } + } + + // invoke the asynchronous execution method + auto networkId = std::get<0>(*currentExecInProgress); + auto inputTensors = std::get<1>(*currentExecInProgress); + auto outputTensors = std::get<2>(*currentExecInProgress); + auto cb = std::get<3>(*currentExecInProgress); + + // Get time at start of inference + HighResolutionClock startTime = armnn::GetTimeNow(); + + try // executing the inference + { + IWorkingMemHandle& memHandle = *(m_WorkingMemHandleMap.at(networkId))[index]; + + // Execute and populate the time at end of inference in the callback + m_RuntimePtr->Execute(memHandle, inputTensors, outputTensors) == Status::Success ? + cb->Notify(Status::Success, std::make_pair(startTime, armnn::GetTimeNow())) : + cb->Notify(Status::Failure, std::make_pair(startTime, armnn::GetTimeNow())); + } + catch (const RuntimeException &error) + { + cb->Notify(Status::Failure, std::make_pair(startTime, armnn::GetTimeNow())); + } + } +} + +} // namespace experimental + +} // namespace armnn diff --git a/src/armnn/WorkingMemHandle.cpp b/src/armnn/WorkingMemHandle.cpp index 94d796eced..1dcaa3853a 100644 --- a/src/armnn/WorkingMemHandle.cpp +++ b/src/armnn/WorkingMemHandle.cpp @@ -26,8 +26,7 @@ WorkingMemHandle::WorkingMemHandle( m_MemoryManagers(memoryManagers), m_OwnedTensorHandles(std::move(ownedTensorHandles)), m_IsAllocated(false), - m_Mutex(), - m_InferenceId(profiling::ProfilingService::GetNextGuid()) + m_Mutex() { } diff --git a/src/armnn/WorkingMemHandle.hpp b/src/armnn/WorkingMemHandle.hpp index 5ccb2b2342..5e3fd66299 100644 --- a/src/armnn/WorkingMemHandle.hpp +++ b/src/armnn/WorkingMemHandle.hpp @@ -13,6 +13,7 @@ #include #include +#include namespace armnn { @@ -38,10 +39,7 @@ public: return m_NetworkId; } - profiling::ProfilingGuid GetInferenceId() override - { - return m_InferenceId; - } + /// Allocate the backing memory required for execution. If this is not called, then allocation will be /// deferred to execution time. The mutex must be locked. @@ -92,7 +90,6 @@ private: bool m_IsAllocated; std::mutex m_Mutex; - profiling::ProfilingGuid m_InferenceId; }; } // end experimental namespace diff --git a/src/backends/backendsCommon/test/StridedSliceAsyncEndToEndTest.hpp b/src/backends/backendsCommon/test/StridedSliceAsyncEndToEndTest.hpp index a552a6a7da..764983f3b9 100644 --- a/src/backends/backendsCommon/test/StridedSliceAsyncEndToEndTest.hpp +++ b/src/backends/backendsCommon/test/StridedSliceAsyncEndToEndTest.hpp @@ -9,6 +9,7 @@ #include #include +#include #include #include @@ -137,7 +138,7 @@ void AsyncEndToEndTestImpl(INetworkPtr network, std::string errorMessage; - const INetworkProperties networkProperties(true, MemorySource::Undefined, MemorySource::Undefined, numThreads); + const INetworkProperties networkProperties(true, MemorySource::Undefined, MemorySource::Undefined); runtime->LoadNetwork(networkId, std::move(optNet), errorMessage, networkProperties); @@ -172,30 +173,32 @@ void AsyncEndToEndTestImpl(INetworkPtr network, } else { - std::vector callbacks; + std::vector> memHandles; - // Create 1000 callbacks that will be checked post scheduling - for (size_t i = 0; i < 1000; ++i) + for (size_t i = 0; i < numThreads; ++i) { - callbacks.emplace_back(std::make_shared()); + memHandles.emplace_back(runtime->CreateWorkingMemHandle(networkId)); } + Threadpool threadpool(numThreads, runtime.get(), memHandles); + AsyncCallbackManager callbackManager; + // For the asyncronous execution, we are adding a pool of working memory handles (1 per thread) in the // LoadedNetwork with a each scheduled inference having a spefic priority - for (IAsyncExecutionCallbackPtr cb : callbacks) + for (size_t i = 0; i < 1000; ++i) { - runtime->Schedule(networkId, - inputTensors, - outputTensors, - static_cast(rand()%3), - cb); + threadpool.Schedule(networkId, + inputTensors, + outputTensors, + static_cast(rand()%3), + callbackManager.GetNewCallback()); } // Wait until the execution signals a notify - for (IAsyncExecutionCallbackPtr cb : callbacks) + for (size_t i = 0; i < 1000; ++i) { - cb->Wait(); - + auto cb = callbackManager.GetNotifiedCallback(); + // Checks the results. CHECK(cb->GetStatus() == Status::Success); } diff --git a/tests/ExecuteNetwork/ExecuteNetwork.cpp b/tests/ExecuteNetwork/ExecuteNetwork.cpp index e8d5b1860c..48577c9990 100644 --- a/tests/ExecuteNetwork/ExecuteNetwork.cpp +++ b/tests/ExecuteNetwork/ExecuteNetwork.cpp @@ -445,14 +445,8 @@ int MainImpl(const ExecuteNetworkParams& params, try { ARMNN_LOG(info) << "Asynchronous execution with Arm NN thread pool... \n"; - std::vector callbacks; - - // Create callbacks that will be checked post scheduling - for (size_t i = 0; i < params.m_SimultaneousIterations; ++i) - { - // Point to ArmNN example implementation of AsyncExecutionCallback - callbacks.emplace_back(std::make_shared()); - } + armnn::AsyncCallbackManager callbackManager; + std::unordered_map&> inferenceOutputMap; // Declare the latest and earliest inference times here to be used when calculating overall time std::chrono::high_resolution_clock::time_point earliestStartTime; @@ -461,15 +455,19 @@ int MainImpl(const ExecuteNetworkParams& params, // For the asynchronous execution, we are adding a pool of working memory handles (1 per thread) in the // LoadedNetwork with each scheduled inference having a specific priority - for (size_t i = 0; i < callbacks.size(); ++i) + for (size_t i = 0; i < params.m_SimultaneousIterations; ++i) { - model.RunAsync(inputs[i], outputs[i], callbacks[i]); + std::shared_ptr cb = callbackManager.GetNewCallback(); + inferenceOutputMap.insert({cb->GetInferenceId(), outputs[i]}); + model.RunAsync(inputs[i], outputs[i], cb); } // Check the results unsigned int j = 0; - for (armnn::experimental::IAsyncExecutionCallbackPtr cb : callbacks) + for (size_t iteration = 0; iteration < params.m_SimultaneousIterations; ++iteration) { + auto cb = callbackManager.GetNotifiedCallback(); + // Get the results auto endTime = time_point_cast(cb->GetEndTime()); auto startTime = time_point_cast(cb->GetStartTime()); @@ -507,7 +505,7 @@ int MainImpl(const ExecuteNetworkParams& params, infoOut, outputTensorFile, params.m_DequantizeOutput); - mapbox::util::apply_visitor(printer, outputs[j][i]); + mapbox::util::apply_visitor(printer, inferenceOutputMap.at(cb->GetInferenceId())[i]); } ARMNN_LOG(info) << "\nInference time: " << std::setprecision(2) @@ -549,7 +547,7 @@ int MainImpl(const ExecuteNetworkParams& params, try { ARMNN_LOG(info) << "Asynchronous Execution with std::launch:async... \n"; - std::vector>>> inferenceResults; inferenceResults.reserve(params.m_SimultaneousIterations); @@ -567,9 +565,10 @@ int MainImpl(const ExecuteNetworkParams& params, for (unsigned int i = 0; i < params.m_SimultaneousIterations; ++i) { armnn::experimental::IWorkingMemHandle& workingMemHandleRef = *workingMemHandles[i].get(); + inferenceResults.push_back(std::async( std::launch::async, [&model, &workingMemHandleRef, &inputs, &outputs, i]() { - return model.RunAsync(workingMemHandleRef, inputs[i], outputs[i]); + return model.RunAsync(workingMemHandleRef, inputs[i], outputs[i], i); } )); } diff --git a/tests/InferenceModel.hpp b/tests/InferenceModel.hpp index 9d6096a3eb..3eb1e6a9e7 100644 --- a/tests/InferenceModel.hpp +++ b/tests/InferenceModel.hpp @@ -6,6 +6,7 @@ #pragma once #include +#include #include #include #include @@ -415,7 +416,7 @@ public: armnn::IRuntime::CreationOptions options; options.m_EnableGpuProfiling = m_EnableProfiling; options.m_DynamicBackendsPath = m_DynamicBackendsPath; - m_Runtime = std::move(armnn::IRuntime::Create(options)); + m_Runtime = armnn::IRuntime::Create(options); } std::string invalidBackends; @@ -484,13 +485,25 @@ public: const auto loading_start_time = armnn::GetTimeNow(); armnn::INetworkProperties networkProperties(params.m_AsyncEnabled, armnn::MemorySource::Undefined, - armnn::MemorySource::Undefined, - params.m_ThreadPoolSize); + armnn::MemorySource::Undefined); std::string errorMessage; ret = m_Runtime->LoadNetwork(m_NetworkIdentifier, std::move(optNet), errorMessage, networkProperties); ARMNN_LOG(info) << "Network loading time: " << std::setprecision(2) << std::fixed << armnn::GetTimeDuration(loading_start_time).count() << " ms\n"; + + if (params.m_AsyncEnabled && params.m_ThreadPoolSize > 0) + { + std::vector> memHandles; + for (size_t i = 0; i < params.m_ThreadPoolSize; ++i) + { + memHandles.emplace_back(m_Runtime->CreateWorkingMemHandle(m_NetworkIdentifier)); + } + + m_Threadpool = std::make_unique(params.m_ThreadPoolSize, + m_Runtime.get(), + memHandles); + } } if (ret == armnn::Status::Failure) @@ -579,10 +592,11 @@ public: } } - std::tuple> RunAsync( + std::tuple> RunAsync( armnn::experimental::IWorkingMemHandle& workingMemHandleRef, const std::vector& inputContainers, - std::vector& outputContainers) + std::vector& outputContainers, + unsigned int inferenceID) { for (unsigned int i = 0; i < outputContainers.size(); ++i) { @@ -614,7 +628,6 @@ public: armnn::Status ret = m_Runtime->Execute(workingMemHandleRef, MakeInputTensors(inputContainers), MakeOutputTensors(outputContainers)); - auto inferenceID = workingMemHandleRef.GetInferenceId(); const auto duration = armnn::GetTimeDuration(start_time); @@ -638,7 +651,7 @@ public: void RunAsync(const std::vector& inputContainers, std::vector& outputContainers, - armnn::experimental::IAsyncExecutionCallbackPtr cb) + std::shared_ptr cb) { for (unsigned int i = 0; i < outputContainers.size(); ++i) { @@ -664,11 +677,11 @@ public: profiler->EnableProfiling(m_EnableProfiling); } - m_Runtime->Schedule(m_NetworkIdentifier, - MakeInputTensors(inputContainers), - MakeOutputTensors(outputContainers), - armnn::QosExecPriority::Medium, - cb); + m_Threadpool->Schedule(m_NetworkIdentifier, + MakeInputTensors(inputContainers), + MakeOutputTensors(outputContainers), + armnn::QosExecPriority::Medium, + cb); // if profiling is enabled print out the results if (profiler && profiler->IsProfilingEnabled()) @@ -731,6 +744,7 @@ public: private: armnn::NetworkId m_NetworkIdentifier; std::shared_ptr m_Runtime; + std::unique_ptr m_Threadpool; std::vector m_InputBindings; std::vector m_OutputBindings; -- cgit v1.2.1