diff options
author | Finn Williams <Finn.Williams@arm.com> | 2021-06-09 17:07:33 +0100 |
---|---|---|
committer | Finn Williams <Finn.Williams@arm.com> | 2021-06-23 17:14:53 +0100 |
commit | f364d5391b08e9071cd965f5765385ec9156b652 (patch) | |
tree | 1ea93ed574a3eb51f5a1f4bb08dc1ad18aa1c6a2 /include | |
parent | 7a00eaa6ecf121623823b1951c0e6c9093271adf (diff) | |
download | armnn-f364d5391b08e9071cd965f5765385ec9156b652.tar.gz |
IVGCVSW-6062 Rework the async threadpool
!android-nn-driver:5802
* Extract the threadpool from LoadedNetwork/Runtime
* Refactor the threadpool to be handle multiple networks
* Trim IAsyncExecutionCallback and add an InferenceId to AsyncExecutionCallback
* Add AsyncCallbackManager class
Signed-off-by: Finn Williams <Finn.Williams@arm.com>
Change-Id: I36aa2ad29c16bc10ee0706adfeb6b27f60012afb
Diffstat (limited to 'include')
-rw-r--r-- | include/armnn/IAsyncExecutionCallback.hpp | 13 | ||||
-rw-r--r-- | include/armnn/IRuntime.hpp | 31 | ||||
-rw-r--r-- | include/armnn/IWorkingMemHandle.hpp | 3 | ||||
-rw-r--r-- | include/armnn/Threadpool.hpp | 78 |
4 files changed, 93 insertions, 32 deletions
diff --git a/include/armnn/IAsyncExecutionCallback.hpp b/include/armnn/IAsyncExecutionCallback.hpp index 045ec4581f..3e0cacccee 100644 --- a/include/armnn/IAsyncExecutionCallback.hpp +++ b/include/armnn/IAsyncExecutionCallback.hpp @@ -23,19 +23,6 @@ public: // Notify the AsyncExecutionCallback object of the armnn execution status virtual void Notify(armnn::Status status, InferenceTimingPair timeTaken) = 0; - - // Block the calling thread until the AsyncExecutionCallback object allows it to proceed - virtual void Wait() const = 0; - - // Retrieve the ArmNN Status from the AsyncExecutionCallback that has been notified - virtual armnn::Status GetStatus() const = 0; - - // Retrieve the start time before executing the inference - virtual HighResolutionClock GetStartTime() const = 0; - - // Retrieve the time after executing the inference - virtual HighResolutionClock GetEndTime() const = 0; - }; } // experimental diff --git a/include/armnn/IRuntime.hpp b/include/armnn/IRuntime.hpp index bfc13c9c01..f88b6b664f 100644 --- a/include/armnn/IRuntime.hpp +++ b/include/armnn/IRuntime.hpp @@ -32,24 +32,34 @@ struct INetworkProperties ARMNN_DEPRECATED_MSG("Please use INetworkProperties constructor with MemorySource argument") INetworkProperties(bool importEnabled = false, bool exportEnabled = false, - bool asyncEnabled = false, - size_t numThreads = 1) + bool asyncEnabled = false) : m_ImportEnabled(importEnabled) , m_ExportEnabled(exportEnabled) , m_AsyncEnabled(asyncEnabled) - , m_NumThreads(numThreads) , m_InputSource(m_ImportEnabled ? MemorySource::Malloc : MemorySource::Undefined) , m_OutputSource(m_ExportEnabled ? MemorySource::Malloc : MemorySource::Undefined) {} + ARMNN_DEPRECATED_MSG("Please use INetworkProperties constructor without numThreads argument") INetworkProperties(bool asyncEnabled, MemorySource m_InputSource, MemorySource m_OutputSource, - size_t numThreads = 1) + size_t numThreads) + : m_ImportEnabled(m_InputSource != MemorySource::Undefined) + , m_ExportEnabled(m_OutputSource != MemorySource::Undefined) + , m_AsyncEnabled(asyncEnabled) + , m_InputSource(m_InputSource) + , m_OutputSource(m_OutputSource) + { + armnn::IgnoreUnused(numThreads); + } + + INetworkProperties(bool asyncEnabled, + MemorySource m_InputSource, + MemorySource m_OutputSource) : m_ImportEnabled(m_InputSource != MemorySource::Undefined) , m_ExportEnabled(m_OutputSource != MemorySource::Undefined) , m_AsyncEnabled(asyncEnabled) - , m_NumThreads(numThreads) , m_InputSource(m_InputSource) , m_OutputSource(m_OutputSource) {} @@ -60,7 +70,6 @@ struct INetworkProperties const bool m_ExportEnabled; const bool m_AsyncEnabled; - const size_t m_NumThreads; const MemorySource m_InputSource; const MemorySource m_OutputSource; @@ -191,16 +200,6 @@ public: const InputTensors& inputTensors, const OutputTensors& outputTensors); - /// This is an experimental function - /// Schedule a thread safe execution by taking the input tensors and an execution priority for Quality of Service. - /// The output tensors will then be filled and the callback object will notify that the execution has either - /// succeeded or failed. - void Schedule(NetworkId networkId, - const InputTensors& inputTensors, - const OutputTensors& outputTensors, - const QosExecPriority priority, - std::shared_ptr<IAsyncExecutionCallback> callback); - /// Unloads a network from the IRuntime. /// At the moment this only removes the network from the m_Impl->m_Network. /// This might need more work in the future to be AndroidNN compliant. diff --git a/include/armnn/IWorkingMemHandle.hpp b/include/armnn/IWorkingMemHandle.hpp index 6fb2f9fe5f..171fa3d81c 100644 --- a/include/armnn/IWorkingMemHandle.hpp +++ b/include/armnn/IWorkingMemHandle.hpp @@ -25,9 +25,6 @@ public: /// Returns the NetworkId of the Network that this IWorkingMemHandle works with. virtual NetworkId GetNetworkId() = 0; - /// Returns the InferenceId of the Inference that this IWorkingMemHandle works with. - virtual profiling::ProfilingGuid GetInferenceId() = 0; - /// Allocate the backing memory required for execution. If this is not called, then allocation will be /// deferred to execution time. The mutex must be locked. virtual void Allocate() = 0; diff --git a/include/armnn/Threadpool.hpp b/include/armnn/Threadpool.hpp new file mode 100644 index 0000000000..e2458dbb65 --- /dev/null +++ b/include/armnn/Threadpool.hpp @@ -0,0 +1,78 @@ +// +// Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include <armnn/Tensor.hpp> +#include <armnn/Types.hpp> + +#include "INetwork.hpp" +#include "IRuntime.hpp" + +#include <thread> +#include <mutex> +#include <condition_variable> +#include <unordered_map> +#include <queue> + +namespace armnn +{ +namespace experimental +{ +class Threadpool +{ +public: + Threadpool(std::size_t numThreads, + IRuntime* runtimePtr, + std::vector<std::shared_ptr<IWorkingMemHandle>> memHandles); + + ~Threadpool() + { + TerminateThreadPool(); + } + + void LoadMemHandles(std::vector<std::shared_ptr<IWorkingMemHandle>> memHandles); + void UnloadMemHandles(NetworkId networkId); + + /// Schedule an asynchronous execution on the loaded network + void Schedule(NetworkId networkId, + const InputTensors &inputTensors, + const OutputTensors &outputTensors, + const QosExecPriority priority, + std::shared_ptr<IAsyncExecutionCallback> cb); + + void TerminateThreadPool() noexcept; + +private: + using ExecutionTuple = std::tuple<NetworkId, + InputTensors, + OutputTensors, + std::shared_ptr<IAsyncExecutionCallback>>; + + using ExecutionQueue = std::queue<std::shared_ptr<ExecutionTuple>>; + + void ProcessExecPriorities(uint32_t index); + + IRuntime* m_RuntimePtr; + + ExecutionQueue m_HighPriorityQueue; + ExecutionQueue m_MediumPriorityQueue; + ExecutionQueue m_LowPriorityQueue; + + // Condition Variables require mutex which will guard the shared state. + // Has an event happened? Stop signal for example + std::condition_variable m_ThreadPoolEvent; + std::mutex m_ThreadPoolMutex; + + // The shared state for conditional variable + bool m_TerminatePool = false; + + std::unordered_map<NetworkId, std::vector<std::shared_ptr<IWorkingMemHandle>>> m_WorkingMemHandleMap; + std::vector<std::unique_ptr<std::thread>> m_Threads; +}; + +} // namespace experimental + +} // namespace armnn |