diff options
author | Keith Davis <keith.davis@arm.com> | 2021-04-22 10:10:34 +0100 |
---|---|---|
committer | finn.williams <finn.williams@arm.com> | 2021-05-06 19:39:39 +0000 |
commit | e813d67f86df41a238ff79b5c554ef5027f56576 (patch) | |
tree | 54c2145a78297d61e6e94676729cb2468490ade3 /src/armnn/LoadedNetwork.hpp | |
parent | d905decd256558bbee165e636ce4242ac3b9c917 (diff) | |
download | armnn-e813d67f86df41a238ff79b5c554ef5027f56576.tar.gz |
IVGCVSW-5813 Add Async Queue to IRuntime
Signed-off-by: Keith Davis <keith.davis@arm.com>
Change-Id: Icc0d131c8ee2e9748e2f14762a75962b39c10f9d
Diffstat (limited to 'src/armnn/LoadedNetwork.hpp')
-rw-r--r-- | src/armnn/LoadedNetwork.hpp | 75 |
1 files changed, 61 insertions, 14 deletions
diff --git a/src/armnn/LoadedNetwork.hpp b/src/armnn/LoadedNetwork.hpp index 51092c744e..b5474db294 100644 --- a/src/armnn/LoadedNetwork.hpp +++ b/src/armnn/LoadedNetwork.hpp @@ -19,13 +19,14 @@ #include <TimelineUtilityMethods.hpp> #include <mutex> +#include <condition_variable> #include <unordered_map> namespace cl { - class Context; - class CommandQueue; - class Device; +class Context; +class CommandQueue; +class Device; } namespace armnn @@ -34,8 +35,19 @@ namespace armnn class LoadedNetwork { public: - using WorkloadQueue = std::vector< std::unique_ptr<IWorkload> >; - ~LoadedNetwork(){ FreeWorkingMemory(); } + using WorkloadQueue = std::vector<std::unique_ptr<IWorkload>>; + + using ExecutionTuple = std::tuple<InputTensors, + OutputTensors, + std::shared_ptr<IAsyncExecutionCallback>>; + + using ExecutionQueue = std::queue<std::shared_ptr<ExecutionTuple>>; + + ~LoadedNetwork() + { + FreeWorkingMemory(); + TerminateThreadPool(); + } /// Create a new unique WorkingMemHandle object. Create multiple handles if you wish to have /// overlapped Execution by calling this function from different threads. @@ -44,16 +56,25 @@ public: TensorInfo GetInputTensorInfo(LayerBindingId layerId) const; TensorInfo GetOutputTensorInfo(LayerBindingId layerId) const; + /// Single thread execution of the loaded network Status EnqueueWorkload(const InputTensors& inputTensors, const OutputTensors& outputTensors); + /// Thread safe execution of the loaded network Status Execute(const InputTensors& inputTensors, const OutputTensors& outputTensors, IWorkingMemHandle& workingMemHandle); + /// Schedule an asynchronous execution on the loaded network + void Schedule(const InputTensors& inputTensors, + const OutputTensors& outputTensors, + const QosExecPriority priority, + std::shared_ptr<IAsyncExecutionCallback> cb); + static std::unique_ptr<LoadedNetwork> MakeLoadedNetwork(std::unique_ptr<IOptimizedNetwork> net, - std::string & errorMessage, + std::string& errorMessage, const INetworkProperties& networkProperties, - profiling::ProfilingService& profilingService); + profiling::ProfilingService& profilingService, + const NetworkId networkIdOut); // NOTE we return by reference as the purpose of this method is only to provide // access to the private m_Profiler and in theory we should not need to increment @@ -87,7 +108,8 @@ private: LoadedNetwork(std::unique_ptr<IOptimizedNetwork> net, const INetworkProperties& networkProperties, - profiling::ProfilingService& profilingService); + profiling::ProfilingService& profilingService, + const NetworkId networkIdOut); void EnqueueInput(const BindableLayer& layer, ITensorHandle* tensorHandle, const TensorInfo& tensorInfo); @@ -97,9 +119,15 @@ private: void EnqueueOutput(const BindableLayer& layer, const Tensor& outputTensor, WorkingMemHandle& handle); + void ProcessExecPriorities(std::unique_ptr<IWorkingMemHandle> workingMemHandle); + bool Execute(std::unique_ptr<profiling::TimelineUtilityMethods>& timelineUtils, profiling::ProfilingGuid inferenceGuid); + void CreateThreadPool(std::size_t numThreads); + + void TerminateThreadPool() noexcept; + const IWorkloadFactory& GetWorkloadFactory(const Layer& layer) const; using BackendPtrMap = std::unordered_map<BackendId, IBackendInternalUniquePtr>; @@ -108,19 +136,38 @@ private: WorkloadFactoryMap m_WorkloadFactories; std::unique_ptr<IOptimizedNetwork> m_OptimizedNetwork; - WorkloadQueue m_InputQueue; - WorkloadQueue m_WorkloadQueue; - WorkloadQueue m_OutputQueue; - std::shared_ptr<IProfiler> m_Profiler; + std::shared_ptr<IProfiler> m_Profiler; + + WorkloadQueue m_InputQueue; + WorkloadQueue m_WorkloadQueue; + WorkloadQueue m_OutputQueue; mutable std::mutex m_WorkingMemMutex; - bool m_IsWorkingMemAllocated=false; + bool m_IsWorkingMemAllocated = false; + + std::vector<std::unique_ptr<std::thread>> m_Threads; + std::stack<IWorkingMemHandle> m_WorkingMemHandles; + + ExecutionQueue m_HighPriorityQueue; + ExecutionQueue m_MediumPriorityQueue; + ExecutionQueue m_LowPriorityQueue; + + // Condition Variables require mutex which will guard the shared state. + // Has an event happened? Stop signal for example + std::condition_variable m_ThreadPoolEvent; + std::mutex m_ThreadPoolMutex; + + // The shared state for conditional variable + bool m_TerminatePool = false; + INetworkProperties m_NetworkProperties; + const NetworkId m_NetworkId; + TensorHandleFactoryRegistry m_TensorHandleFactoryRegistry; - profiling::ProfilingService& m_ProfilingService; + profiling::ProfilingService& m_ProfilingService; }; } |