diff options
Diffstat (limited to 'include/armnn/Threadpool.hpp')
-rw-r--r-- | include/armnn/Threadpool.hpp | 78 |
1 files changed, 78 insertions, 0 deletions
diff --git a/include/armnn/Threadpool.hpp b/include/armnn/Threadpool.hpp new file mode 100644 index 0000000000..e2458dbb65 --- /dev/null +++ b/include/armnn/Threadpool.hpp @@ -0,0 +1,78 @@ +// +// Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include <armnn/Tensor.hpp> +#include <armnn/Types.hpp> + +#include "INetwork.hpp" +#include "IRuntime.hpp" + +#include <thread> +#include <mutex> +#include <condition_variable> +#include <unordered_map> +#include <queue> + +namespace armnn +{ +namespace experimental +{ +class Threadpool +{ +public: + Threadpool(std::size_t numThreads, + IRuntime* runtimePtr, + std::vector<std::shared_ptr<IWorkingMemHandle>> memHandles); + + ~Threadpool() + { + TerminateThreadPool(); + } + + void LoadMemHandles(std::vector<std::shared_ptr<IWorkingMemHandle>> memHandles); + void UnloadMemHandles(NetworkId networkId); + + /// Schedule an asynchronous execution on the loaded network + void Schedule(NetworkId networkId, + const InputTensors &inputTensors, + const OutputTensors &outputTensors, + const QosExecPriority priority, + std::shared_ptr<IAsyncExecutionCallback> cb); + + void TerminateThreadPool() noexcept; + +private: + using ExecutionTuple = std::tuple<NetworkId, + InputTensors, + OutputTensors, + std::shared_ptr<IAsyncExecutionCallback>>; + + using ExecutionQueue = std::queue<std::shared_ptr<ExecutionTuple>>; + + void ProcessExecPriorities(uint32_t index); + + IRuntime* m_RuntimePtr; + + ExecutionQueue m_HighPriorityQueue; + ExecutionQueue m_MediumPriorityQueue; + ExecutionQueue m_LowPriorityQueue; + + // Condition Variables require mutex which will guard the shared state. + // Has an event happened? Stop signal for example + std::condition_variable m_ThreadPoolEvent; + std::mutex m_ThreadPoolMutex; + + // The shared state for conditional variable + bool m_TerminatePool = false; + + std::unordered_map<NetworkId, std::vector<std::shared_ptr<IWorkingMemHandle>>> m_WorkingMemHandleMap; + std::vector<std::unique_ptr<std::thread>> m_Threads; +}; + +} // namespace experimental + +} // namespace armnn |