aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/LoadedNetwork.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/LoadedNetwork.hpp')
-rw-r--r--src/armnn/LoadedNetwork.hpp75
1 files changed, 61 insertions, 14 deletions
diff --git a/src/armnn/LoadedNetwork.hpp b/src/armnn/LoadedNetwork.hpp
index 51092c744e..b5474db294 100644
--- a/src/armnn/LoadedNetwork.hpp
+++ b/src/armnn/LoadedNetwork.hpp
@@ -19,13 +19,14 @@
#include <TimelineUtilityMethods.hpp>
#include <mutex>
+#include <condition_variable>
#include <unordered_map>
namespace cl
{
- class Context;
- class CommandQueue;
- class Device;
+class Context;
+class CommandQueue;
+class Device;
}
namespace armnn
@@ -34,8 +35,19 @@ namespace armnn
class LoadedNetwork
{
public:
- using WorkloadQueue = std::vector< std::unique_ptr<IWorkload> >;
- ~LoadedNetwork(){ FreeWorkingMemory(); }
+ using WorkloadQueue = std::vector<std::unique_ptr<IWorkload>>;
+
+ using ExecutionTuple = std::tuple<InputTensors,
+ OutputTensors,
+ std::shared_ptr<IAsyncExecutionCallback>>;
+
+ using ExecutionQueue = std::queue<std::shared_ptr<ExecutionTuple>>;
+
+ ~LoadedNetwork()
+ {
+ FreeWorkingMemory();
+ TerminateThreadPool();
+ }
/// Create a new unique WorkingMemHandle object. Create multiple handles if you wish to have
/// overlapped Execution by calling this function from different threads.
@@ -44,16 +56,25 @@ public:
TensorInfo GetInputTensorInfo(LayerBindingId layerId) const;
TensorInfo GetOutputTensorInfo(LayerBindingId layerId) const;
+ /// Single thread execution of the loaded network
Status EnqueueWorkload(const InputTensors& inputTensors, const OutputTensors& outputTensors);
+ /// Thread safe execution of the loaded network
Status Execute(const InputTensors& inputTensors,
const OutputTensors& outputTensors,
IWorkingMemHandle& workingMemHandle);
+ /// Schedule an asynchronous execution on the loaded network
+ void Schedule(const InputTensors& inputTensors,
+ const OutputTensors& outputTensors,
+ const QosExecPriority priority,
+ std::shared_ptr<IAsyncExecutionCallback> cb);
+
static std::unique_ptr<LoadedNetwork> MakeLoadedNetwork(std::unique_ptr<IOptimizedNetwork> net,
- std::string & errorMessage,
+ std::string& errorMessage,
const INetworkProperties& networkProperties,
- profiling::ProfilingService& profilingService);
+ profiling::ProfilingService& profilingService,
+ const NetworkId networkIdOut);
// NOTE we return by reference as the purpose of this method is only to provide
// access to the private m_Profiler and in theory we should not need to increment
@@ -87,7 +108,8 @@ private:
LoadedNetwork(std::unique_ptr<IOptimizedNetwork> net,
const INetworkProperties& networkProperties,
- profiling::ProfilingService& profilingService);
+ profiling::ProfilingService& profilingService,
+ const NetworkId networkIdOut);
void EnqueueInput(const BindableLayer& layer, ITensorHandle* tensorHandle, const TensorInfo& tensorInfo);
@@ -97,9 +119,15 @@ private:
void EnqueueOutput(const BindableLayer& layer, const Tensor& outputTensor, WorkingMemHandle& handle);
+ void ProcessExecPriorities(std::unique_ptr<IWorkingMemHandle> workingMemHandle);
+
bool Execute(std::unique_ptr<profiling::TimelineUtilityMethods>& timelineUtils,
profiling::ProfilingGuid inferenceGuid);
+ void CreateThreadPool(std::size_t numThreads);
+
+ void TerminateThreadPool() noexcept;
+
const IWorkloadFactory& GetWorkloadFactory(const Layer& layer) const;
using BackendPtrMap = std::unordered_map<BackendId, IBackendInternalUniquePtr>;
@@ -108,19 +136,38 @@ private:
WorkloadFactoryMap m_WorkloadFactories;
std::unique_ptr<IOptimizedNetwork> m_OptimizedNetwork;
- WorkloadQueue m_InputQueue;
- WorkloadQueue m_WorkloadQueue;
- WorkloadQueue m_OutputQueue;
- std::shared_ptr<IProfiler> m_Profiler;
+ std::shared_ptr<IProfiler> m_Profiler;
+
+ WorkloadQueue m_InputQueue;
+ WorkloadQueue m_WorkloadQueue;
+ WorkloadQueue m_OutputQueue;
mutable std::mutex m_WorkingMemMutex;
- bool m_IsWorkingMemAllocated=false;
+ bool m_IsWorkingMemAllocated = false;
+
+ std::vector<std::unique_ptr<std::thread>> m_Threads;
+ std::stack<IWorkingMemHandle> m_WorkingMemHandles;
+
+ ExecutionQueue m_HighPriorityQueue;
+ ExecutionQueue m_MediumPriorityQueue;
+ ExecutionQueue m_LowPriorityQueue;
+
+ // Condition Variables require mutex which will guard the shared state.
+ // Has an event happened? Stop signal for example
+ std::condition_variable m_ThreadPoolEvent;
+ std::mutex m_ThreadPoolMutex;
+
+ // The shared state for conditional variable
+ bool m_TerminatePool = false;
+
INetworkProperties m_NetworkProperties;
+ const NetworkId m_NetworkId;
+
TensorHandleFactoryRegistry m_TensorHandleFactoryRegistry;
- profiling::ProfilingService& m_ProfilingService;
+ profiling::ProfilingService& m_ProfilingService;
};
}