aboutsummaryrefslogtreecommitdiff
path: root/include/armnn/Threadpool.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'include/armnn/Threadpool.hpp')
-rw-r--r--include/armnn/Threadpool.hpp78
1 files changed, 78 insertions, 0 deletions
diff --git a/include/armnn/Threadpool.hpp b/include/armnn/Threadpool.hpp
new file mode 100644
index 0000000000..e2458dbb65
--- /dev/null
+++ b/include/armnn/Threadpool.hpp
@@ -0,0 +1,78 @@
+//
+// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <armnn/Tensor.hpp>
+#include <armnn/Types.hpp>
+
+#include "INetwork.hpp"
+#include "IRuntime.hpp"
+
+#include <thread>
+#include <mutex>
+#include <condition_variable>
+#include <unordered_map>
+#include <queue>
+
+namespace armnn
+{
+namespace experimental
+{
+class Threadpool
+{
+public:
+ Threadpool(std::size_t numThreads,
+ IRuntime* runtimePtr,
+ std::vector<std::shared_ptr<IWorkingMemHandle>> memHandles);
+
+ ~Threadpool()
+ {
+ TerminateThreadPool();
+ }
+
+ void LoadMemHandles(std::vector<std::shared_ptr<IWorkingMemHandle>> memHandles);
+ void UnloadMemHandles(NetworkId networkId);
+
+ /// Schedule an asynchronous execution on the loaded network
+ void Schedule(NetworkId networkId,
+ const InputTensors &inputTensors,
+ const OutputTensors &outputTensors,
+ const QosExecPriority priority,
+ std::shared_ptr<IAsyncExecutionCallback> cb);
+
+ void TerminateThreadPool() noexcept;
+
+private:
+ using ExecutionTuple = std::tuple<NetworkId,
+ InputTensors,
+ OutputTensors,
+ std::shared_ptr<IAsyncExecutionCallback>>;
+
+ using ExecutionQueue = std::queue<std::shared_ptr<ExecutionTuple>>;
+
+ void ProcessExecPriorities(uint32_t index);
+
+ IRuntime* m_RuntimePtr;
+
+ ExecutionQueue m_HighPriorityQueue;
+ ExecutionQueue m_MediumPriorityQueue;
+ ExecutionQueue m_LowPriorityQueue;
+
+ // Condition Variables require mutex which will guard the shared state.
+ // Has an event happened? Stop signal for example
+ std::condition_variable m_ThreadPoolEvent;
+ std::mutex m_ThreadPoolMutex;
+
+ // The shared state for conditional variable
+ bool m_TerminatePool = false;
+
+ std::unordered_map<NetworkId, std::vector<std::shared_ptr<IWorkingMemHandle>>> m_WorkingMemHandleMap;
+ std::vector<std::unique_ptr<std::thread>> m_Threads;
+};
+
+} // namespace experimental
+
+} // namespace armnn