aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorFinn Williams <Finn.Williams@arm.com>2021-06-09 17:07:33 +0100
committerFinn Williams <Finn.Williams@arm.com>2021-06-23 17:14:53 +0100
commitf364d5391b08e9071cd965f5765385ec9156b652 (patch)
tree1ea93ed574a3eb51f5a1f4bb08dc1ad18aa1c6a2 /src
parent7a00eaa6ecf121623823b1951c0e6c9093271adf (diff)
downloadarmnn-f364d5391b08e9071cd965f5765385ec9156b652.tar.gz
IVGCVSW-6062 Rework the async threadpool
!android-nn-driver:5802 * Extract the threadpool from LoadedNetwork/Runtime * Refactor the threadpool to be handle multiple networks * Trim IAsyncExecutionCallback and add an InferenceId to AsyncExecutionCallback * Add AsyncCallbackManager class Signed-off-by: Finn Williams <Finn.Williams@arm.com> Change-Id: I36aa2ad29c16bc10ee0706adfeb6b27f60012afb
Diffstat (limited to 'src')
-rw-r--r--src/armnn/AsyncExecutionCallback.cpp38
-rw-r--r--src/armnn/AsyncExecutionCallback.hpp50
-rw-r--r--src/armnn/LoadedNetwork.cpp159
-rw-r--r--src/armnn/LoadedNetwork.hpp42
-rw-r--r--src/armnn/Runtime.cpp40
-rw-r--r--src/armnn/Runtime.hpp10
-rw-r--r--src/armnn/Threadpool.cpp206
-rw-r--r--src/armnn/WorkingMemHandle.cpp3
-rw-r--r--src/armnn/WorkingMemHandle.hpp7
-rw-r--r--src/backends/backendsCommon/test/StridedSliceAsyncEndToEndTest.hpp31
10 files changed, 299 insertions, 287 deletions
diff --git a/src/armnn/AsyncExecutionCallback.cpp b/src/armnn/AsyncExecutionCallback.cpp
index c44808918d..2973e2d891 100644
--- a/src/armnn/AsyncExecutionCallback.cpp
+++ b/src/armnn/AsyncExecutionCallback.cpp
@@ -15,43 +15,53 @@ void AsyncExecutionCallback::Notify(armnn::Status status, InferenceTimingPair ti
{
{
std::lock_guard<std::mutex> hold(m_Mutex);
- if (m_Notified)
- {
- return;
- }
// store results and mark as notified
m_Status = status;
m_StartTime = timeTaken.first;
m_EndTime = timeTaken.second;
- m_Notified = true;
+ m_NotificationQueue.push(m_InferenceId);
}
m_Condition.notify_all();
}
-void AsyncExecutionCallback::Wait() const
-{
- std::unique_lock<std::mutex> lock(m_Mutex);
- m_Condition.wait(lock, [this] { return m_Notified; });
-}
-
armnn::Status AsyncExecutionCallback::GetStatus() const
{
- Wait();
return m_Status;
}
HighResolutionClock AsyncExecutionCallback::GetStartTime() const
{
- Wait();
return m_StartTime;
}
HighResolutionClock AsyncExecutionCallback::GetEndTime() const
{
- Wait();
return m_EndTime;
}
+std::shared_ptr<AsyncExecutionCallback> AsyncCallbackManager::GetNewCallback()
+{
+ auto cb = std::make_unique<AsyncExecutionCallback>(m_NotificationQueue, m_Mutex, m_Condition);
+ InferenceId id = cb->GetInferenceId();
+ m_Callbacks.insert({id, std::move(cb)});
+
+ return m_Callbacks.at(id);
+}
+
+std::shared_ptr<AsyncExecutionCallback> AsyncCallbackManager::GetNotifiedCallback()
+{
+ std::unique_lock<std::mutex> lock(m_Mutex);
+
+ m_Condition.wait(lock, [this] { return !m_NotificationQueue.empty(); });
+
+ InferenceId id = m_NotificationQueue.front();
+ m_NotificationQueue.pop();
+
+ std::shared_ptr<AsyncExecutionCallback> callback = m_Callbacks.at(id);
+ m_Callbacks.erase(id);
+ return callback;
+}
+
} // namespace experimental
} // namespace armnn \ No newline at end of file
diff --git a/src/armnn/AsyncExecutionCallback.hpp b/src/armnn/AsyncExecutionCallback.hpp
index c17b839748..2ff73b3efb 100644
--- a/src/armnn/AsyncExecutionCallback.hpp
+++ b/src/armnn/AsyncExecutionCallback.hpp
@@ -6,11 +6,14 @@
#pragma once
#include <armnn/IAsyncExecutionCallback.hpp>
+#include <armnn/IWorkingMemHandle.hpp>
#include <armnn/Types.hpp>
-#include <condition_variable>
+#include <condition_variable>
#include <mutex>
#include <thread>
+#include <queue>
+#include <unordered_map>
namespace armnn
{
@@ -18,29 +21,62 @@ namespace armnn
namespace experimental
{
+using InferenceId = uint64_t;
class AsyncExecutionCallback final : public IAsyncExecutionCallback
{
+private:
+ static InferenceId nextID;
+
public:
- AsyncExecutionCallback()
+ AsyncExecutionCallback(std::queue<InferenceId>& notificationQueue,
+ std::mutex& mutex,
+ std::condition_variable& condition)
+ : m_NotificationQueue(notificationQueue)
+ , m_Mutex(mutex)
+ , m_Condition(condition)
+ , m_InferenceId(++nextID)
{}
+
~AsyncExecutionCallback()
{}
void Notify(armnn::Status status, InferenceTimingPair timeTaken);
- void Wait() const;
+
+ InferenceId GetInferenceId()
+ {
+ return m_InferenceId;
+ }
armnn::Status GetStatus() const;
HighResolutionClock GetStartTime() const;
HighResolutionClock GetEndTime() const;
private:
- mutable std::mutex m_Mutex;
- mutable std::condition_variable m_Condition;
+ std::queue<InferenceId>& m_NotificationQueue;
+ std::mutex& m_Mutex;
+ std::condition_variable& m_Condition;
HighResolutionClock m_StartTime;
HighResolutionClock m_EndTime;
- armnn::Status m_Status = Status::Failure;
- bool m_Notified = false;
+ armnn::Status m_Status = Status::Failure;
+ InferenceId m_InferenceId;
+};
+InferenceId AsyncExecutionCallback::nextID = 0u;
+
+// Manager to create and monitor AsyncExecutionCallbacks
+// GetNewCallback will create a callback for use in Threadpool::Schedule
+// GetNotifiedCallback will return the first callback to be notified (finished execution)
+class AsyncCallbackManager
+{
+public:
+ std::shared_ptr<AsyncExecutionCallback> GetNewCallback();
+ std::shared_ptr<AsyncExecutionCallback> GetNotifiedCallback();
+
+private:
+ std::mutex m_Mutex;
+ std::condition_variable m_Condition;
+ std::unordered_map<InferenceId, std::shared_ptr<AsyncExecutionCallback>> m_Callbacks;
+ std::queue<InferenceId> m_NotificationQueue;
};
} // namespace experimental
diff --git a/src/armnn/LoadedNetwork.cpp b/src/armnn/LoadedNetwork.cpp
index 17cc8dcc23..13beb13a07 100644
--- a/src/armnn/LoadedNetwork.cpp
+++ b/src/armnn/LoadedNetwork.cpp
@@ -6,7 +6,6 @@
#include "LoadedNetwork.hpp"
#include "Layer.hpp"
#include "Graph.hpp"
-#include "Network.hpp"
#include <Processes.hpp>
#include "Profiling.hpp"
#include "HeapProfiling.hpp"
@@ -22,7 +21,6 @@
#include <backendsCommon/MemSyncWorkload.hpp>
#include <fmt/format.h>
-#include <armnn/utility/Timer.hpp>
namespace armnn
{
@@ -83,8 +81,7 @@ void AddWorkloadStructure(std::unique_ptr<TimelineUtilityMethods>& timelineUtils
std::unique_ptr<LoadedNetwork> LoadedNetwork::MakeLoadedNetwork(std::unique_ptr<IOptimizedNetwork> net,
std::string& errorMessage,
const INetworkProperties& networkProperties,
- profiling::ProfilingService& profilingService,
- const NetworkId networkIdOut)
+ profiling::ProfilingService& profilingService)
{
std::unique_ptr<LoadedNetwork> loadedNetwork;
@@ -98,7 +95,7 @@ std::unique_ptr<LoadedNetwork> LoadedNetwork::MakeLoadedNetwork(std::unique_ptr<
try
{
- loadedNetwork.reset(new LoadedNetwork(std::move(net), networkProperties, profilingService, networkIdOut));
+ loadedNetwork.reset(new LoadedNetwork(std::move(net), networkProperties, profilingService));
}
catch (const armnn::RuntimeException& error)
{
@@ -118,11 +115,9 @@ std::unique_ptr<LoadedNetwork> LoadedNetwork::MakeLoadedNetwork(std::unique_ptr<
LoadedNetwork::LoadedNetwork(std::unique_ptr<IOptimizedNetwork> net,
const INetworkProperties& networkProperties,
- profiling::ProfilingService& profilingService,
- const NetworkId networkId) :
+ profiling::ProfilingService& profilingService) :
m_OptimizedNetwork(std::move(net)),
m_NetworkProperties(networkProperties),
- m_NetworkId(networkId),
m_TensorHandleFactoryRegistry(),
m_ProfilingService(profilingService)
{
@@ -304,13 +299,6 @@ LoadedNetwork::LoadedNetwork(std::unique_ptr<IOptimizedNetwork> net,
{
AllocateAndExecuteConstantWorkloads();
}
-
- // Create the thread pool which will have working memory handles assigned to each thread
- // Should occur last so all factories and constant layer tensor handles are created
- if (m_NetworkProperties.m_NumThreads > 0 && networkProperties.m_AsyncEnabled)
- {
- CreateThreadPool(m_NetworkProperties.m_NumThreads);
- }
}
void LoadedNetwork::AllocateAndExecuteConstantWorkloads()
@@ -856,147 +844,6 @@ bool LoadedNetwork::Execute(std::unique_ptr<TimelineUtilityMethods>& timelineUti
return success;
}
-void LoadedNetwork::CreateThreadPool(std::size_t numThreads)
-{
-
- for (auto i = 0u; i < numThreads; ++i)
- {
- std::unique_ptr<IWorkingMemHandle> workingMemHandle = CreateWorkingMemHandle(m_NetworkId);
- m_Threads.emplace_back(
- std::make_unique<std::thread>(
- &LoadedNetwork::ProcessExecPriorities,
- this,
- std::move(workingMemHandle)
- )
- );
- }
-}
-
-void LoadedNetwork::TerminateThreadPool() noexcept
-{
- {
- std::unique_lock<std::mutex> threadPoolLock(m_ThreadPoolMutex);
- m_TerminatePool = true;
- }
-
- m_ThreadPoolEvent.notify_all();
-
- for (auto &thread : m_Threads)
- {
- thread->join();
- }
-}
-
-void LoadedNetwork::Schedule(const InputTensors& inputTensors,
- const OutputTensors& outputTensors,
- const QosExecPriority priority,
- std::shared_ptr<IAsyncExecutionCallback> cb)
-{
- // Group execution parameters so that they can be easily added to the queue
- ExecutionTuple groupExecParams = std::make_tuple(inputTensors, outputTensors, cb);
- std::shared_ptr<ExecutionTuple> operation = make_shared<ExecutionTuple>(groupExecParams);
-
- // Add a message to the queue and notify the request thread
- std::unique_lock<std::mutex> lock(m_ThreadPoolMutex);
- switch (priority) {
- case QosExecPriority::High:
- m_HighPriorityQueue.push(operation);
- break;
- case QosExecPriority::Low:
- m_LowPriorityQueue.push(operation);
- break;
- case QosExecPriority::Medium:
- default:
- m_MediumPriorityQueue.push(operation);
- }
- m_ThreadPoolEvent.notify_one();
-}
-
-void LoadedNetwork::ProcessExecPriorities(std::unique_ptr<IWorkingMemHandle> workingMemHandle)
-{
- int expireRate = EXPIRE_RATE;
- int highPriorityCount = 0;
- int mediumPriorityCount = 0;
-
- IWorkingMemHandle& workingMemHandleRef = *workingMemHandle.get();
-
- while (true)
- {
- std::shared_ptr<ExecutionTuple> currentExecInProgress(nullptr);
- {
- // Wait for a message to be added to the queue
- // This is in a separate scope to minimise the lifetime of the lock
- std::unique_lock<std::mutex> lock(m_ThreadPoolMutex);
-
- m_ThreadPoolEvent.wait(lock,
- [=] {
- return m_TerminatePool || !m_HighPriorityQueue.empty() ||
- !m_MediumPriorityQueue.empty() || !m_LowPriorityQueue.empty();
- });
-
- if (m_TerminatePool && m_HighPriorityQueue.empty() && m_MediumPriorityQueue.empty() &&
- m_LowPriorityQueue.empty())
- {
- break;
- }
-
- // Get the message to process from the front of each queue based on priority from high to low
- // Get high priority first if it does not exceed the expire rate
- if (!m_HighPriorityQueue.empty() && highPriorityCount < expireRate)
- {
- currentExecInProgress = m_HighPriorityQueue.front();
- m_HighPriorityQueue.pop();
- highPriorityCount += 1;
- }
- // If high priority queue is empty or the count exceeds the expire rate, get medium priority message
- else if (!m_MediumPriorityQueue.empty() && mediumPriorityCount < expireRate)
- {
- currentExecInProgress = m_MediumPriorityQueue.front();
- m_MediumPriorityQueue.pop();
- mediumPriorityCount += 1;
- // Reset high priority count
- highPriorityCount = 0;
- }
- // If medium priority queue is empty or the count exceeds the expire rate, get low priority message
- else if (!m_LowPriorityQueue.empty())
- {
- currentExecInProgress = m_LowPriorityQueue.front();
- m_LowPriorityQueue.pop();
- // Reset high and medium priority count
- highPriorityCount = 0;
- mediumPriorityCount = 0;
- }
- else
- {
- // Reset high and medium priority count
- highPriorityCount = 0;
- mediumPriorityCount = 0;
- continue;
- }
- }
-
- // invoke the asynchronous execution method
- auto inputTensors = std::get<0>(*currentExecInProgress);
- auto outputTensors = std::get<1>(*currentExecInProgress);
- auto cb = std::get<2>(*currentExecInProgress);
-
- // Get time at start of inference
- HighResolutionClock startTime = armnn::GetTimeNow();
-
- try // executing the inference
- {
- // Execute and populate the time at end of inference in the callback
- Execute(inputTensors, outputTensors, workingMemHandleRef) == Status::Success ?
- cb->Notify(Status::Success, std::make_pair(startTime, armnn::GetTimeNow())) :
- cb->Notify(Status::Failure, std::make_pair(startTime, armnn::GetTimeNow()));
- }
- catch (const RuntimeException& error)
- {
- cb->Notify(Status::Failure, std::make_pair(startTime, armnn::GetTimeNow()));
- }
- }
-}
-
void LoadedNetwork::EnqueueInput(const BindableLayer& layer,
const ConstTensor& inputTensor,
WorkingMemHandle& context)
diff --git a/src/armnn/LoadedNetwork.hpp b/src/armnn/LoadedNetwork.hpp
index c85e82bbdd..360ad91170 100644
--- a/src/armnn/LoadedNetwork.hpp
+++ b/src/armnn/LoadedNetwork.hpp
@@ -37,16 +37,9 @@ class LoadedNetwork
public:
using WorkloadQueue = std::vector<std::unique_ptr<IWorkload>>;
- using ExecutionTuple = std::tuple<InputTensors,
- OutputTensors,
- std::shared_ptr<IAsyncExecutionCallback>>;
-
- using ExecutionQueue = std::queue<std::shared_ptr<ExecutionTuple>>;
-
~LoadedNetwork()
{
FreeWorkingMemory();
- TerminateThreadPool();
}
/// Create a new unique WorkingMemHandle object. Create multiple handles if you wish to have
@@ -64,17 +57,10 @@ public:
const OutputTensors& outputTensors,
IWorkingMemHandle& workingMemHandle);
- /// Schedule an asynchronous execution on the loaded network
- void Schedule(const InputTensors& inputTensors,
- const OutputTensors& outputTensors,
- const QosExecPriority priority,
- std::shared_ptr<IAsyncExecutionCallback> cb);
-
static std::unique_ptr<LoadedNetwork> MakeLoadedNetwork(std::unique_ptr<IOptimizedNetwork> net,
std::string& errorMessage,
const INetworkProperties& networkProperties,
- profiling::ProfilingService& profilingService,
- const NetworkId networkIdOut);
+ profiling::ProfilingService& profilingService);
// NOTE we return by reference as the purpose of this method is only to provide
// access to the private m_Profiler and in theory we should not need to increment
@@ -108,8 +94,7 @@ private:
LoadedNetwork(std::unique_ptr<IOptimizedNetwork> net,
const INetworkProperties& networkProperties,
- profiling::ProfilingService& profilingService,
- const NetworkId networkIdOut);
+ profiling::ProfilingService& profilingService);
void EnqueueInput(const BindableLayer& layer, ITensorHandle* tensorHandle, const TensorInfo& tensorInfo);
@@ -119,15 +104,9 @@ private:
void EnqueueOutput(const BindableLayer& layer, const Tensor& outputTensor, WorkingMemHandle& handle);
- void ProcessExecPriorities(std::unique_ptr<IWorkingMemHandle> workingMemHandle);
-
bool Execute(std::unique_ptr<profiling::TimelineUtilityMethods>& timelineUtils,
profiling::ProfilingGuid inferenceGuid);
- void CreateThreadPool(std::size_t numThreads);
-
- void TerminateThreadPool() noexcept;
-
const IWorkloadFactory& GetWorkloadFactory(const Layer& layer) const;
using BackendPtrMap = std::unordered_map<BackendId, IBackendInternalUniquePtr>;
@@ -146,25 +125,8 @@ private:
bool m_IsWorkingMemAllocated = false;
- std::vector<std::unique_ptr<std::thread>> m_Threads;
- std::stack<IWorkingMemHandle> m_WorkingMemHandles;
-
- ExecutionQueue m_HighPriorityQueue;
- ExecutionQueue m_MediumPriorityQueue;
- ExecutionQueue m_LowPriorityQueue;
-
- // Condition Variables require mutex which will guard the shared state.
- // Has an event happened? Stop signal for example
- std::condition_variable m_ThreadPoolEvent;
- std::mutex m_ThreadPoolMutex;
-
- // The shared state for conditional variable
- bool m_TerminatePool = false;
-
INetworkProperties m_NetworkProperties;
- const NetworkId m_NetworkId;
-
TensorHandleFactoryRegistry m_TensorHandleFactoryRegistry;
profiling::ProfilingService& m_ProfilingService;
diff --git a/src/armnn/Runtime.cpp b/src/armnn/Runtime.cpp
index 374064e408..f16d186191 100644
--- a/src/armnn/Runtime.cpp
+++ b/src/armnn/Runtime.cpp
@@ -89,15 +89,6 @@ Status IRuntime::Execute(IWorkingMemHandle& workingMemHandle,
return pRuntimeImpl->Execute(workingMemHandle, inputTensors, outputTensors);
}
-void IRuntime::Schedule(NetworkId networkId,
- const InputTensors& inputTensors,
- const OutputTensors& outputTensors,
- const QosExecPriority priority,
- std::shared_ptr<IAsyncExecutionCallback> cb)
-{
- pRuntimeImpl->Schedule(networkId, inputTensors, outputTensors, priority, cb);
-}
-
Status IRuntime::UnloadNetwork(NetworkId networkId)
{
return pRuntimeImpl->UnloadNetwork(networkId);
@@ -160,8 +151,7 @@ Status RuntimeImpl::LoadNetwork(NetworkId& networkIdOut,
std::unique_ptr<IOptimizedNetwork>(rawNetwork),
errorMessage,
networkProperties,
- m_ProfilingService,
- networkIdOut);
+ m_ProfilingService);
if (!loadedNetwork)
{
@@ -460,34 +450,6 @@ Status RuntimeImpl::Execute(IWorkingMemHandle& iWorkingMemHandle,
return loadedNetwork->Execute(inputTensors, outputTensors, iWorkingMemHandle);
}
-void RuntimeImpl::Schedule(NetworkId networkId,
- const InputTensors& inputTensors,
- const OutputTensors& outputTensors,
- const QosExecPriority priority,
- std::shared_ptr<IAsyncExecutionCallback> callback)
-{
- LoadedNetwork* loadedNetwork = GetLoadedNetworkPtr(networkId);
-
- if (!loadedNetwork)
- {
- throw armnn::Exception(
- "Network with ID of " + std::to_string(networkId) + " does not exist \n"
- );
- }
- if (!loadedNetwork->IsAsyncEnabled())
- {
- throw armnn::Exception(
- "Attempting to schedule Network " + std::to_string(networkId) + " when it is not async enabled \n"
- );
- }
-
- ProfilerManager::GetInstance().RegisterProfiler(loadedNetwork->GetProfiler().get());
-
- ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "Schedule");
-
- loadedNetwork->Schedule(inputTensors, outputTensors, priority, callback);
-}
-
/// Create a new unique WorkingMemHandle object. Create multiple handles if you wish to have
/// overlapped Execution by calling this function from different threads.
std::unique_ptr<IWorkingMemHandle> RuntimeImpl::CreateWorkingMemHandle(NetworkId networkId)
diff --git a/src/armnn/Runtime.hpp b/src/armnn/Runtime.hpp
index 55a4accf67..7a80acd73e 100644
--- a/src/armnn/Runtime.hpp
+++ b/src/armnn/Runtime.hpp
@@ -61,16 +61,6 @@ public:
const OutputTensors& outputTensors);
/// This is an experimental function.
- /// Schedule a thread safe execution by taking the input tensors and an execution priority for Quality of Service.
- /// The output tensors will then be filled and the callback object will notify that the execution has either
- /// succeeded or failed.
- void Schedule(NetworkId networkId,
- const InputTensors& inputTensors,
- const OutputTensors& outputTensors,
- const QosExecPriority priority,
- std::shared_ptr<IAsyncExecutionCallback> callback);
-
- /// This is an experimental function.
/// Evaluates a network using input in inputTensors and outputs filled into outputTensors.
/// This function performs a thread safe execution of the network. Returns once execution is complete.
/// Will block until this and any other thread using the same workingMem object completes.
diff --git a/src/armnn/Threadpool.cpp b/src/armnn/Threadpool.cpp
new file mode 100644
index 0000000000..a23c1e2339
--- /dev/null
+++ b/src/armnn/Threadpool.cpp
@@ -0,0 +1,206 @@
+//
+// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+
+#include <armnn/Threadpool.hpp>
+
+#include <armnn/utility/Timer.hpp>
+
+namespace armnn
+{
+namespace experimental
+{
+
+Threadpool::Threadpool(std::size_t numThreads,
+ IRuntime* runtimePtr,
+ std::vector<std::shared_ptr<IWorkingMemHandle>> memHandles)
+ : m_RuntimePtr(runtimePtr)
+{
+ for (auto i = 0u; i < numThreads; ++i)
+ {
+ m_Threads.emplace_back(std::make_unique<std::thread>(&Threadpool::ProcessExecPriorities, this, i));
+ }
+
+ LoadMemHandles(memHandles);
+}
+
+void Threadpool::LoadMemHandles(std::vector<std::shared_ptr<IWorkingMemHandle>> memHandles)
+{
+ if (memHandles.size() == 0)
+ {
+ throw armnn::RuntimeException("Threadpool::UnloadMemHandles: Size of memHandles vector must be greater than 0");
+ }
+
+ if (memHandles.size() != m_Threads.size())
+ {
+ throw armnn::RuntimeException(
+ "Threadpool::UnloadMemHandles: Size of memHandles vector must match the number of threads");
+ }
+
+ NetworkId networkId = memHandles[0]->GetNetworkId();
+ for (uint32_t i = 1; i < memHandles.size(); ++i)
+ {
+ if (networkId != memHandles[i]->GetNetworkId())
+ {
+ throw armnn::RuntimeException(
+ "Threadpool::UnloadMemHandles: All network ids must be identical in memHandles");
+ }
+ }
+
+ std::pair<NetworkId, std::vector<std::shared_ptr<IWorkingMemHandle>>> pair {networkId, memHandles};
+
+ m_WorkingMemHandleMap.insert(pair);
+}
+
+void Threadpool::UnloadMemHandles(NetworkId networkId)
+{
+ if (m_WorkingMemHandleMap.find(networkId) != m_WorkingMemHandleMap.end())
+ {
+ m_WorkingMemHandleMap.erase(networkId);
+ }
+ else
+ {
+ throw armnn::RuntimeException("Threadpool::UnloadMemHandles: Unknown NetworkId");
+ }
+}
+
+void Threadpool::Schedule(NetworkId networkId,
+ const InputTensors& inputTensors,
+ const OutputTensors& outputTensors,
+ const QosExecPriority priority,
+ std::shared_ptr<IAsyncExecutionCallback> cb)
+{
+ if (m_WorkingMemHandleMap.find(networkId) == m_WorkingMemHandleMap.end())
+ {
+ throw armnn::RuntimeException("Threadpool::UnloadMemHandles: Unknown NetworkId");
+ }
+
+ // Group execution parameters so that they can be easily added to the queue
+ ExecutionTuple groupExecParams = std::make_tuple(networkId, inputTensors, outputTensors, cb);
+
+ std::shared_ptr<ExecutionTuple> operation = std::make_shared<ExecutionTuple>(groupExecParams);
+
+ // Add a message to the queue and notify the request thread
+ std::unique_lock<std::mutex> lock(m_ThreadPoolMutex);
+ switch (priority)
+ {
+ case QosExecPriority::High:
+ m_HighPriorityQueue.push(operation);
+ break;
+ case QosExecPriority::Low:
+ m_LowPriorityQueue.push(operation);
+ break;
+ case QosExecPriority::Medium:
+ default:
+ m_MediumPriorityQueue.push(operation);
+ }
+ m_ThreadPoolEvent.notify_one();
+}
+
+void Threadpool::TerminateThreadPool() noexcept
+{
+ {
+ std::unique_lock<std::mutex> threadPoolLock(m_ThreadPoolMutex);
+ m_TerminatePool = true;
+ }
+
+ m_ThreadPoolEvent.notify_all();
+
+ for (auto &thread : m_Threads)
+ {
+ thread->join();
+ }
+}
+
+void Threadpool::ProcessExecPriorities(uint32_t index)
+{
+ int expireRate = EXPIRE_RATE;
+ int highPriorityCount = 0;
+ int mediumPriorityCount = 0;
+
+ while (true)
+ {
+ std::shared_ptr<ExecutionTuple> currentExecInProgress(nullptr);
+ {
+ // Wait for a message to be added to the queue
+ // This is in a separate scope to minimise the lifetime of the lock
+ std::unique_lock<std::mutex> lock(m_ThreadPoolMutex);
+
+ m_ThreadPoolEvent.wait(lock,
+ [=]
+ {
+ return m_TerminatePool || !m_HighPriorityQueue.empty() ||
+ !m_MediumPriorityQueue.empty() || !m_LowPriorityQueue.empty();
+ });
+
+ if (m_TerminatePool && m_HighPriorityQueue.empty() && m_MediumPriorityQueue.empty() &&
+ m_LowPriorityQueue.empty())
+ {
+ break;
+ }
+
+ // Get the message to process from the front of each queue based on priority from high to low
+ // Get high priority first if it does not exceed the expire rate
+ if (!m_HighPriorityQueue.empty() && highPriorityCount < expireRate)
+ {
+ currentExecInProgress = m_HighPriorityQueue.front();
+ m_HighPriorityQueue.pop();
+ highPriorityCount += 1;
+ }
+ // If high priority queue is empty or the count exceeds the expire rate, get medium priority message
+ else if (!m_MediumPriorityQueue.empty() && mediumPriorityCount < expireRate)
+ {
+ currentExecInProgress = m_MediumPriorityQueue.front();
+ m_MediumPriorityQueue.pop();
+ mediumPriorityCount += 1;
+ // Reset high priority count
+ highPriorityCount = 0;
+ }
+ // If medium priority queue is empty or the count exceeds the expire rate, get low priority message
+ else if (!m_LowPriorityQueue.empty())
+ {
+ currentExecInProgress = m_LowPriorityQueue.front();
+ m_LowPriorityQueue.pop();
+ // Reset high and medium priority count
+ highPriorityCount = 0;
+ mediumPriorityCount = 0;
+ }
+ else
+ {
+ // Reset high and medium priority count
+ highPriorityCount = 0;
+ mediumPriorityCount = 0;
+ continue;
+ }
+ }
+
+ // invoke the asynchronous execution method
+ auto networkId = std::get<0>(*currentExecInProgress);
+ auto inputTensors = std::get<1>(*currentExecInProgress);
+ auto outputTensors = std::get<2>(*currentExecInProgress);
+ auto cb = std::get<3>(*currentExecInProgress);
+
+ // Get time at start of inference
+ HighResolutionClock startTime = armnn::GetTimeNow();
+
+ try // executing the inference
+ {
+ IWorkingMemHandle& memHandle = *(m_WorkingMemHandleMap.at(networkId))[index];
+
+ // Execute and populate the time at end of inference in the callback
+ m_RuntimePtr->Execute(memHandle, inputTensors, outputTensors) == Status::Success ?
+ cb->Notify(Status::Success, std::make_pair(startTime, armnn::GetTimeNow())) :
+ cb->Notify(Status::Failure, std::make_pair(startTime, armnn::GetTimeNow()));
+ }
+ catch (const RuntimeException &error)
+ {
+ cb->Notify(Status::Failure, std::make_pair(startTime, armnn::GetTimeNow()));
+ }
+ }
+}
+
+} // namespace experimental
+
+} // namespace armnn
diff --git a/src/armnn/WorkingMemHandle.cpp b/src/armnn/WorkingMemHandle.cpp
index 94d796eced..1dcaa3853a 100644
--- a/src/armnn/WorkingMemHandle.cpp
+++ b/src/armnn/WorkingMemHandle.cpp
@@ -26,8 +26,7 @@ WorkingMemHandle::WorkingMemHandle(
m_MemoryManagers(memoryManagers),
m_OwnedTensorHandles(std::move(ownedTensorHandles)),
m_IsAllocated(false),
- m_Mutex(),
- m_InferenceId(profiling::ProfilingService::GetNextGuid())
+ m_Mutex()
{
}
diff --git a/src/armnn/WorkingMemHandle.hpp b/src/armnn/WorkingMemHandle.hpp
index 5ccb2b2342..5e3fd66299 100644
--- a/src/armnn/WorkingMemHandle.hpp
+++ b/src/armnn/WorkingMemHandle.hpp
@@ -13,6 +13,7 @@
#include <armnn/Tensor.hpp>
#include <unordered_map>
+#include <mutex>
namespace armnn
{
@@ -38,10 +39,7 @@ public:
return m_NetworkId;
}
- profiling::ProfilingGuid GetInferenceId() override
- {
- return m_InferenceId;
- }
+
/// Allocate the backing memory required for execution. If this is not called, then allocation will be
/// deferred to execution time. The mutex must be locked.
@@ -92,7 +90,6 @@ private:
bool m_IsAllocated;
std::mutex m_Mutex;
- profiling::ProfilingGuid m_InferenceId;
};
} // end experimental namespace
diff --git a/src/backends/backendsCommon/test/StridedSliceAsyncEndToEndTest.hpp b/src/backends/backendsCommon/test/StridedSliceAsyncEndToEndTest.hpp
index a552a6a7da..764983f3b9 100644
--- a/src/backends/backendsCommon/test/StridedSliceAsyncEndToEndTest.hpp
+++ b/src/backends/backendsCommon/test/StridedSliceAsyncEndToEndTest.hpp
@@ -9,6 +9,7 @@
#include <armnn/IWorkingMemHandle.hpp>
#include <armnn/INetwork.hpp>
+#include <armnn/Threadpool.hpp>
#include <armnn/IAsyncExecutionCallback.hpp>
#include <AsyncExecutionCallback.hpp>
@@ -137,7 +138,7 @@ void AsyncEndToEndTestImpl(INetworkPtr network,
std::string errorMessage;
- const INetworkProperties networkProperties(true, MemorySource::Undefined, MemorySource::Undefined, numThreads);
+ const INetworkProperties networkProperties(true, MemorySource::Undefined, MemorySource::Undefined);
runtime->LoadNetwork(networkId, std::move(optNet), errorMessage, networkProperties);
@@ -172,30 +173,32 @@ void AsyncEndToEndTestImpl(INetworkPtr network,
}
else
{
- std::vector<IAsyncExecutionCallbackPtr> callbacks;
+ std::vector<std::shared_ptr<IWorkingMemHandle>> memHandles;
- // Create 1000 callbacks that will be checked post scheduling
- for (size_t i = 0; i < 1000; ++i)
+ for (size_t i = 0; i < numThreads; ++i)
{
- callbacks.emplace_back(std::make_shared<AsyncExecutionCallback>());
+ memHandles.emplace_back(runtime->CreateWorkingMemHandle(networkId));
}
+ Threadpool threadpool(numThreads, runtime.get(), memHandles);
+ AsyncCallbackManager callbackManager;
+
// For the asyncronous execution, we are adding a pool of working memory handles (1 per thread) in the
// LoadedNetwork with a each scheduled inference having a spefic priority
- for (IAsyncExecutionCallbackPtr cb : callbacks)
+ for (size_t i = 0; i < 1000; ++i)
{
- runtime->Schedule(networkId,
- inputTensors,
- outputTensors,
- static_cast<QosExecPriority>(rand()%3),
- cb);
+ threadpool.Schedule(networkId,
+ inputTensors,
+ outputTensors,
+ static_cast<QosExecPriority>(rand()%3),
+ callbackManager.GetNewCallback());
}
// Wait until the execution signals a notify
- for (IAsyncExecutionCallbackPtr cb : callbacks)
+ for (size_t i = 0; i < 1000; ++i)
{
- cb->Wait();
-
+ auto cb = callbackManager.GetNotifiedCallback();
+
// Checks the results.
CHECK(cb->GetStatus() == Status::Success);
}