aboutsummaryrefslogtreecommitdiff
path: root/include
diff options
context:
space:
mode:
authorFinn Williams <Finn.Williams@arm.com>2021-06-09 17:07:33 +0100
committerFinn Williams <Finn.Williams@arm.com>2021-06-23 17:14:53 +0100
commitf364d5391b08e9071cd965f5765385ec9156b652 (patch)
tree1ea93ed574a3eb51f5a1f4bb08dc1ad18aa1c6a2 /include
parent7a00eaa6ecf121623823b1951c0e6c9093271adf (diff)
downloadarmnn-f364d5391b08e9071cd965f5765385ec9156b652.tar.gz
IVGCVSW-6062 Rework the async threadpool
!android-nn-driver:5802 * Extract the threadpool from LoadedNetwork/Runtime * Refactor the threadpool to be handle multiple networks * Trim IAsyncExecutionCallback and add an InferenceId to AsyncExecutionCallback * Add AsyncCallbackManager class Signed-off-by: Finn Williams <Finn.Williams@arm.com> Change-Id: I36aa2ad29c16bc10ee0706adfeb6b27f60012afb
Diffstat (limited to 'include')
-rw-r--r--include/armnn/IAsyncExecutionCallback.hpp13
-rw-r--r--include/armnn/IRuntime.hpp31
-rw-r--r--include/armnn/IWorkingMemHandle.hpp3
-rw-r--r--include/armnn/Threadpool.hpp78
4 files changed, 93 insertions, 32 deletions
diff --git a/include/armnn/IAsyncExecutionCallback.hpp b/include/armnn/IAsyncExecutionCallback.hpp
index 045ec4581f..3e0cacccee 100644
--- a/include/armnn/IAsyncExecutionCallback.hpp
+++ b/include/armnn/IAsyncExecutionCallback.hpp
@@ -23,19 +23,6 @@ public:
// Notify the AsyncExecutionCallback object of the armnn execution status
virtual void Notify(armnn::Status status, InferenceTimingPair timeTaken) = 0;
-
- // Block the calling thread until the AsyncExecutionCallback object allows it to proceed
- virtual void Wait() const = 0;
-
- // Retrieve the ArmNN Status from the AsyncExecutionCallback that has been notified
- virtual armnn::Status GetStatus() const = 0;
-
- // Retrieve the start time before executing the inference
- virtual HighResolutionClock GetStartTime() const = 0;
-
- // Retrieve the time after executing the inference
- virtual HighResolutionClock GetEndTime() const = 0;
-
};
} // experimental
diff --git a/include/armnn/IRuntime.hpp b/include/armnn/IRuntime.hpp
index bfc13c9c01..f88b6b664f 100644
--- a/include/armnn/IRuntime.hpp
+++ b/include/armnn/IRuntime.hpp
@@ -32,24 +32,34 @@ struct INetworkProperties
ARMNN_DEPRECATED_MSG("Please use INetworkProperties constructor with MemorySource argument")
INetworkProperties(bool importEnabled = false,
bool exportEnabled = false,
- bool asyncEnabled = false,
- size_t numThreads = 1)
+ bool asyncEnabled = false)
: m_ImportEnabled(importEnabled)
, m_ExportEnabled(exportEnabled)
, m_AsyncEnabled(asyncEnabled)
- , m_NumThreads(numThreads)
, m_InputSource(m_ImportEnabled ? MemorySource::Malloc : MemorySource::Undefined)
, m_OutputSource(m_ExportEnabled ? MemorySource::Malloc : MemorySource::Undefined)
{}
+ ARMNN_DEPRECATED_MSG("Please use INetworkProperties constructor without numThreads argument")
INetworkProperties(bool asyncEnabled,
MemorySource m_InputSource,
MemorySource m_OutputSource,
- size_t numThreads = 1)
+ size_t numThreads)
+ : m_ImportEnabled(m_InputSource != MemorySource::Undefined)
+ , m_ExportEnabled(m_OutputSource != MemorySource::Undefined)
+ , m_AsyncEnabled(asyncEnabled)
+ , m_InputSource(m_InputSource)
+ , m_OutputSource(m_OutputSource)
+ {
+ armnn::IgnoreUnused(numThreads);
+ }
+
+ INetworkProperties(bool asyncEnabled,
+ MemorySource m_InputSource,
+ MemorySource m_OutputSource)
: m_ImportEnabled(m_InputSource != MemorySource::Undefined)
, m_ExportEnabled(m_OutputSource != MemorySource::Undefined)
, m_AsyncEnabled(asyncEnabled)
- , m_NumThreads(numThreads)
, m_InputSource(m_InputSource)
, m_OutputSource(m_OutputSource)
{}
@@ -60,7 +70,6 @@ struct INetworkProperties
const bool m_ExportEnabled;
const bool m_AsyncEnabled;
- const size_t m_NumThreads;
const MemorySource m_InputSource;
const MemorySource m_OutputSource;
@@ -191,16 +200,6 @@ public:
const InputTensors& inputTensors,
const OutputTensors& outputTensors);
- /// This is an experimental function
- /// Schedule a thread safe execution by taking the input tensors and an execution priority for Quality of Service.
- /// The output tensors will then be filled and the callback object will notify that the execution has either
- /// succeeded or failed.
- void Schedule(NetworkId networkId,
- const InputTensors& inputTensors,
- const OutputTensors& outputTensors,
- const QosExecPriority priority,
- std::shared_ptr<IAsyncExecutionCallback> callback);
-
/// Unloads a network from the IRuntime.
/// At the moment this only removes the network from the m_Impl->m_Network.
/// This might need more work in the future to be AndroidNN compliant.
diff --git a/include/armnn/IWorkingMemHandle.hpp b/include/armnn/IWorkingMemHandle.hpp
index 6fb2f9fe5f..171fa3d81c 100644
--- a/include/armnn/IWorkingMemHandle.hpp
+++ b/include/armnn/IWorkingMemHandle.hpp
@@ -25,9 +25,6 @@ public:
/// Returns the NetworkId of the Network that this IWorkingMemHandle works with.
virtual NetworkId GetNetworkId() = 0;
- /// Returns the InferenceId of the Inference that this IWorkingMemHandle works with.
- virtual profiling::ProfilingGuid GetInferenceId() = 0;
-
/// Allocate the backing memory required for execution. If this is not called, then allocation will be
/// deferred to execution time. The mutex must be locked.
virtual void Allocate() = 0;
diff --git a/include/armnn/Threadpool.hpp b/include/armnn/Threadpool.hpp
new file mode 100644
index 0000000000..e2458dbb65
--- /dev/null
+++ b/include/armnn/Threadpool.hpp
@@ -0,0 +1,78 @@
+//
+// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <armnn/Tensor.hpp>
+#include <armnn/Types.hpp>
+
+#include "INetwork.hpp"
+#include "IRuntime.hpp"
+
+#include <thread>
+#include <mutex>
+#include <condition_variable>
+#include <unordered_map>
+#include <queue>
+
+namespace armnn
+{
+namespace experimental
+{
+class Threadpool
+{
+public:
+ Threadpool(std::size_t numThreads,
+ IRuntime* runtimePtr,
+ std::vector<std::shared_ptr<IWorkingMemHandle>> memHandles);
+
+ ~Threadpool()
+ {
+ TerminateThreadPool();
+ }
+
+ void LoadMemHandles(std::vector<std::shared_ptr<IWorkingMemHandle>> memHandles);
+ void UnloadMemHandles(NetworkId networkId);
+
+ /// Schedule an asynchronous execution on the loaded network
+ void Schedule(NetworkId networkId,
+ const InputTensors &inputTensors,
+ const OutputTensors &outputTensors,
+ const QosExecPriority priority,
+ std::shared_ptr<IAsyncExecutionCallback> cb);
+
+ void TerminateThreadPool() noexcept;
+
+private:
+ using ExecutionTuple = std::tuple<NetworkId,
+ InputTensors,
+ OutputTensors,
+ std::shared_ptr<IAsyncExecutionCallback>>;
+
+ using ExecutionQueue = std::queue<std::shared_ptr<ExecutionTuple>>;
+
+ void ProcessExecPriorities(uint32_t index);
+
+ IRuntime* m_RuntimePtr;
+
+ ExecutionQueue m_HighPriorityQueue;
+ ExecutionQueue m_MediumPriorityQueue;
+ ExecutionQueue m_LowPriorityQueue;
+
+ // Condition Variables require mutex which will guard the shared state.
+ // Has an event happened? Stop signal for example
+ std::condition_variable m_ThreadPoolEvent;
+ std::mutex m_ThreadPoolMutex;
+
+ // The shared state for conditional variable
+ bool m_TerminatePool = false;
+
+ std::unordered_map<NetworkId, std::vector<std::shared_ptr<IWorkingMemHandle>>> m_WorkingMemHandleMap;
+ std::vector<std::unique_ptr<std::thread>> m_Threads;
+};
+
+} // namespace experimental
+
+} // namespace armnn