aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/AsyncExecutionCallback.hpp
diff options
context:
space:
mode:
authorFinn Williams <Finn.Williams@arm.com>2021-06-09 17:07:33 +0100
committerFinn Williams <Finn.Williams@arm.com>2021-06-23 17:14:53 +0100
commitf364d5391b08e9071cd965f5765385ec9156b652 (patch)
tree1ea93ed574a3eb51f5a1f4bb08dc1ad18aa1c6a2 /src/armnn/AsyncExecutionCallback.hpp
parent7a00eaa6ecf121623823b1951c0e6c9093271adf (diff)
downloadarmnn-f364d5391b08e9071cd965f5765385ec9156b652.tar.gz
IVGCVSW-6062 Rework the async threadpool
!android-nn-driver:5802 * Extract the threadpool from LoadedNetwork/Runtime * Refactor the threadpool to be handle multiple networks * Trim IAsyncExecutionCallback and add an InferenceId to AsyncExecutionCallback * Add AsyncCallbackManager class Signed-off-by: Finn Williams <Finn.Williams@arm.com> Change-Id: I36aa2ad29c16bc10ee0706adfeb6b27f60012afb
Diffstat (limited to 'src/armnn/AsyncExecutionCallback.hpp')
-rw-r--r--src/armnn/AsyncExecutionCallback.hpp50
1 files changed, 43 insertions, 7 deletions
diff --git a/src/armnn/AsyncExecutionCallback.hpp b/src/armnn/AsyncExecutionCallback.hpp
index c17b839748..2ff73b3efb 100644
--- a/src/armnn/AsyncExecutionCallback.hpp
+++ b/src/armnn/AsyncExecutionCallback.hpp
@@ -6,11 +6,14 @@
#pragma once
#include <armnn/IAsyncExecutionCallback.hpp>
+#include <armnn/IWorkingMemHandle.hpp>
#include <armnn/Types.hpp>
-#include <condition_variable>
+#include <condition_variable>
#include <mutex>
#include <thread>
+#include <queue>
+#include <unordered_map>
namespace armnn
{
@@ -18,29 +21,62 @@ namespace armnn
namespace experimental
{
+using InferenceId = uint64_t;
class AsyncExecutionCallback final : public IAsyncExecutionCallback
{
+private:
+ static InferenceId nextID;
+
public:
- AsyncExecutionCallback()
+ AsyncExecutionCallback(std::queue<InferenceId>& notificationQueue,
+ std::mutex& mutex,
+ std::condition_variable& condition)
+ : m_NotificationQueue(notificationQueue)
+ , m_Mutex(mutex)
+ , m_Condition(condition)
+ , m_InferenceId(++nextID)
{}
+
~AsyncExecutionCallback()
{}
void Notify(armnn::Status status, InferenceTimingPair timeTaken);
- void Wait() const;
+
+ InferenceId GetInferenceId()
+ {
+ return m_InferenceId;
+ }
armnn::Status GetStatus() const;
HighResolutionClock GetStartTime() const;
HighResolutionClock GetEndTime() const;
private:
- mutable std::mutex m_Mutex;
- mutable std::condition_variable m_Condition;
+ std::queue<InferenceId>& m_NotificationQueue;
+ std::mutex& m_Mutex;
+ std::condition_variable& m_Condition;
HighResolutionClock m_StartTime;
HighResolutionClock m_EndTime;
- armnn::Status m_Status = Status::Failure;
- bool m_Notified = false;
+ armnn::Status m_Status = Status::Failure;
+ InferenceId m_InferenceId;
+};
+InferenceId AsyncExecutionCallback::nextID = 0u;
+
+// Manager to create and monitor AsyncExecutionCallbacks
+// GetNewCallback will create a callback for use in Threadpool::Schedule
+// GetNotifiedCallback will return the first callback to be notified (finished execution)
+class AsyncCallbackManager
+{
+public:
+ std::shared_ptr<AsyncExecutionCallback> GetNewCallback();
+ std::shared_ptr<AsyncExecutionCallback> GetNotifiedCallback();
+
+private:
+ std::mutex m_Mutex;
+ std::condition_variable m_Condition;
+ std::unordered_map<InferenceId, std::shared_ptr<AsyncExecutionCallback>> m_Callbacks;
+ std::queue<InferenceId> m_NotificationQueue;
};
} // namespace experimental