diff options
author | Finn Williams <Finn.Williams@arm.com> | 2021-06-09 17:07:33 +0100 |
---|---|---|
committer | Finn Williams <Finn.Williams@arm.com> | 2021-06-23 17:14:53 +0100 |
commit | f364d5391b08e9071cd965f5765385ec9156b652 (patch) | |
tree | 1ea93ed574a3eb51f5a1f4bb08dc1ad18aa1c6a2 /src/armnn/LoadedNetwork.hpp | |
parent | 7a00eaa6ecf121623823b1951c0e6c9093271adf (diff) | |
download | armnn-f364d5391b08e9071cd965f5765385ec9156b652.tar.gz |
IVGCVSW-6062 Rework the async threadpool
!android-nn-driver:5802
* Extract the threadpool from LoadedNetwork/Runtime
* Refactor the threadpool to be handle multiple networks
* Trim IAsyncExecutionCallback and add an InferenceId to AsyncExecutionCallback
* Add AsyncCallbackManager class
Signed-off-by: Finn Williams <Finn.Williams@arm.com>
Change-Id: I36aa2ad29c16bc10ee0706adfeb6b27f60012afb
Diffstat (limited to 'src/armnn/LoadedNetwork.hpp')
-rw-r--r-- | src/armnn/LoadedNetwork.hpp | 42 |
1 files changed, 2 insertions, 40 deletions
diff --git a/src/armnn/LoadedNetwork.hpp b/src/armnn/LoadedNetwork.hpp index c85e82bbdd..360ad91170 100644 --- a/src/armnn/LoadedNetwork.hpp +++ b/src/armnn/LoadedNetwork.hpp @@ -37,16 +37,9 @@ class LoadedNetwork public: using WorkloadQueue = std::vector<std::unique_ptr<IWorkload>>; - using ExecutionTuple = std::tuple<InputTensors, - OutputTensors, - std::shared_ptr<IAsyncExecutionCallback>>; - - using ExecutionQueue = std::queue<std::shared_ptr<ExecutionTuple>>; - ~LoadedNetwork() { FreeWorkingMemory(); - TerminateThreadPool(); } /// Create a new unique WorkingMemHandle object. Create multiple handles if you wish to have @@ -64,17 +57,10 @@ public: const OutputTensors& outputTensors, IWorkingMemHandle& workingMemHandle); - /// Schedule an asynchronous execution on the loaded network - void Schedule(const InputTensors& inputTensors, - const OutputTensors& outputTensors, - const QosExecPriority priority, - std::shared_ptr<IAsyncExecutionCallback> cb); - static std::unique_ptr<LoadedNetwork> MakeLoadedNetwork(std::unique_ptr<IOptimizedNetwork> net, std::string& errorMessage, const INetworkProperties& networkProperties, - profiling::ProfilingService& profilingService, - const NetworkId networkIdOut); + profiling::ProfilingService& profilingService); // NOTE we return by reference as the purpose of this method is only to provide // access to the private m_Profiler and in theory we should not need to increment @@ -108,8 +94,7 @@ private: LoadedNetwork(std::unique_ptr<IOptimizedNetwork> net, const INetworkProperties& networkProperties, - profiling::ProfilingService& profilingService, - const NetworkId networkIdOut); + profiling::ProfilingService& profilingService); void EnqueueInput(const BindableLayer& layer, ITensorHandle* tensorHandle, const TensorInfo& tensorInfo); @@ -119,15 +104,9 @@ private: void EnqueueOutput(const BindableLayer& layer, const Tensor& outputTensor, WorkingMemHandle& handle); - void ProcessExecPriorities(std::unique_ptr<IWorkingMemHandle> workingMemHandle); - bool Execute(std::unique_ptr<profiling::TimelineUtilityMethods>& timelineUtils, profiling::ProfilingGuid inferenceGuid); - void CreateThreadPool(std::size_t numThreads); - - void TerminateThreadPool() noexcept; - const IWorkloadFactory& GetWorkloadFactory(const Layer& layer) const; using BackendPtrMap = std::unordered_map<BackendId, IBackendInternalUniquePtr>; @@ -146,25 +125,8 @@ private: bool m_IsWorkingMemAllocated = false; - std::vector<std::unique_ptr<std::thread>> m_Threads; - std::stack<IWorkingMemHandle> m_WorkingMemHandles; - - ExecutionQueue m_HighPriorityQueue; - ExecutionQueue m_MediumPriorityQueue; - ExecutionQueue m_LowPriorityQueue; - - // Condition Variables require mutex which will guard the shared state. - // Has an event happened? Stop signal for example - std::condition_variable m_ThreadPoolEvent; - std::mutex m_ThreadPoolMutex; - - // The shared state for conditional variable - bool m_TerminatePool = false; - INetworkProperties m_NetworkProperties; - const NetworkId m_NetworkId; - TensorHandleFactoryRegistry m_TensorHandleFactoryRegistry; profiling::ProfilingService& m_ProfilingService; |