aboutsummaryrefslogtreecommitdiff
path: root/src/armnn
diff options
context:
space:
mode:
authorKeith Davis <keith.davis@arm.com>2021-04-22 10:10:34 +0100
committerfinn.williams <finn.williams@arm.com>2021-05-06 19:39:39 +0000
commite813d67f86df41a238ff79b5c554ef5027f56576 (patch)
tree54c2145a78297d61e6e94676729cb2468490ade3 /src/armnn
parentd905decd256558bbee165e636ce4242ac3b9c917 (diff)
downloadarmnn-e813d67f86df41a238ff79b5c554ef5027f56576.tar.gz
IVGCVSW-5813 Add Async Queue to IRuntime
Signed-off-by: Keith Davis <keith.davis@arm.com> Change-Id: Icc0d131c8ee2e9748e2f14762a75962b39c10f9d
Diffstat (limited to 'src/armnn')
-rw-r--r--src/armnn/AsyncExecutionCallback.cpp57
-rw-r--r--src/armnn/AsyncExecutionCallback.hpp48
-rw-r--r--src/armnn/LoadedNetwork.cpp160
-rw-r--r--src/armnn/LoadedNetwork.hpp75
-rw-r--r--src/armnn/Runtime.cpp48
-rw-r--r--src/armnn/Runtime.hpp14
6 files changed, 375 insertions, 27 deletions
diff --git a/src/armnn/AsyncExecutionCallback.cpp b/src/armnn/AsyncExecutionCallback.cpp
new file mode 100644
index 0000000000..c44808918d
--- /dev/null
+++ b/src/armnn/AsyncExecutionCallback.cpp
@@ -0,0 +1,57 @@
+//
+// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include <AsyncExecutionCallback.hpp>
+
+namespace armnn
+{
+
+namespace experimental
+{
+
+void AsyncExecutionCallback::Notify(armnn::Status status, InferenceTimingPair timeTaken)
+{
+ {
+ std::lock_guard<std::mutex> hold(m_Mutex);
+ if (m_Notified)
+ {
+ return;
+ }
+ // store results and mark as notified
+ m_Status = status;
+ m_StartTime = timeTaken.first;
+ m_EndTime = timeTaken.second;
+ m_Notified = true;
+ }
+ m_Condition.notify_all();
+}
+
+void AsyncExecutionCallback::Wait() const
+{
+ std::unique_lock<std::mutex> lock(m_Mutex);
+ m_Condition.wait(lock, [this] { return m_Notified; });
+}
+
+armnn::Status AsyncExecutionCallback::GetStatus() const
+{
+ Wait();
+ return m_Status;
+}
+
+HighResolutionClock AsyncExecutionCallback::GetStartTime() const
+{
+ Wait();
+ return m_StartTime;
+}
+
+HighResolutionClock AsyncExecutionCallback::GetEndTime() const
+{
+ Wait();
+ return m_EndTime;
+}
+
+} // namespace experimental
+
+} // namespace armnn \ No newline at end of file
diff --git a/src/armnn/AsyncExecutionCallback.hpp b/src/armnn/AsyncExecutionCallback.hpp
new file mode 100644
index 0000000000..c17b839748
--- /dev/null
+++ b/src/armnn/AsyncExecutionCallback.hpp
@@ -0,0 +1,48 @@
+//
+// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <armnn/IAsyncExecutionCallback.hpp>
+#include <armnn/Types.hpp>
+#include <condition_variable>
+
+#include <mutex>
+#include <thread>
+
+namespace armnn
+{
+
+namespace experimental
+{
+
+class AsyncExecutionCallback final : public IAsyncExecutionCallback
+{
+public:
+ AsyncExecutionCallback()
+ {}
+ ~AsyncExecutionCallback()
+ {}
+
+ void Notify(armnn::Status status, InferenceTimingPair timeTaken);
+ void Wait() const;
+
+ armnn::Status GetStatus() const;
+ HighResolutionClock GetStartTime() const;
+ HighResolutionClock GetEndTime() const;
+
+private:
+ mutable std::mutex m_Mutex;
+ mutable std::condition_variable m_Condition;
+
+ HighResolutionClock m_StartTime;
+ HighResolutionClock m_EndTime;
+ armnn::Status m_Status = Status::Failure;
+ bool m_Notified = false;
+};
+
+} // namespace experimental
+
+} // namespace armnn \ No newline at end of file
diff --git a/src/armnn/LoadedNetwork.cpp b/src/armnn/LoadedNetwork.cpp
index 46eb9883fb..67de00f0f3 100644
--- a/src/armnn/LoadedNetwork.cpp
+++ b/src/armnn/LoadedNetwork.cpp
@@ -24,6 +24,7 @@
#include <LabelsAndEventClasses.hpp>
#include <fmt/format.h>
+#include <armnn/utility/Timer.hpp>
namespace armnn
{
@@ -84,7 +85,8 @@ void AddWorkloadStructure(std::unique_ptr<TimelineUtilityMethods>& timelineUtils
std::unique_ptr<LoadedNetwork> LoadedNetwork::MakeLoadedNetwork(std::unique_ptr<IOptimizedNetwork> net,
std::string& errorMessage,
const INetworkProperties& networkProperties,
- profiling::ProfilingService& profilingService)
+ profiling::ProfilingService& profilingService,
+ const NetworkId networkIdOut)
{
std::unique_ptr<LoadedNetwork> loadedNetwork;
@@ -98,7 +100,7 @@ std::unique_ptr<LoadedNetwork> LoadedNetwork::MakeLoadedNetwork(std::unique_ptr<
try
{
- loadedNetwork.reset(new LoadedNetwork(std::move(net), networkProperties, profilingService));
+ loadedNetwork.reset(new LoadedNetwork(std::move(net), networkProperties, profilingService, networkIdOut));
}
catch (const armnn::RuntimeException& error)
{
@@ -118,9 +120,11 @@ std::unique_ptr<LoadedNetwork> LoadedNetwork::MakeLoadedNetwork(std::unique_ptr<
LoadedNetwork::LoadedNetwork(std::unique_ptr<IOptimizedNetwork> net,
const INetworkProperties& networkProperties,
- profiling::ProfilingService& profilingService) :
+ profiling::ProfilingService& profilingService,
+ const NetworkId networkId) :
m_OptimizedNetwork(std::move(net)),
m_NetworkProperties(networkProperties),
+ m_NetworkId(networkId),
m_TensorHandleFactoryRegistry(),
m_ProfilingService(profilingService)
{
@@ -161,6 +165,14 @@ LoadedNetwork::LoadedNetwork(std::unique_ptr<IOptimizedNetwork> net,
}
}
}
+
+ // Create the thread pool which will have working memory handles assigned to each thread
+ // Should occur after factories are registered so thet the WorkingMemHandles can be created
+ if (m_NetworkProperties.m_NumThreads > 0 && networkProperties.m_AsyncEnabled)
+ {
+ CreateThreadPool(m_NetworkProperties.m_NumThreads);
+ }
+
if (!networkProperties.m_AsyncEnabled)
{
for (auto &&layer : order)
@@ -846,6 +858,147 @@ bool LoadedNetwork::Execute(std::unique_ptr<TimelineUtilityMethods>& timelineUti
return success;
}
+void LoadedNetwork::CreateThreadPool(std::size_t numThreads)
+{
+
+ for (auto i = 0u; i < numThreads; ++i)
+ {
+ std::unique_ptr<IWorkingMemHandle> workingMemHandle = CreateWorkingMemHandle(m_NetworkId);
+ m_Threads.emplace_back(
+ std::make_unique<std::thread>(
+ &LoadedNetwork::ProcessExecPriorities,
+ this,
+ std::move(workingMemHandle)
+ )
+ );
+ }
+}
+
+void LoadedNetwork::TerminateThreadPool() noexcept
+{
+ {
+ std::unique_lock<std::mutex> threadPoolLock(m_ThreadPoolMutex);
+ m_TerminatePool = true;
+ }
+
+ m_ThreadPoolEvent.notify_all();
+
+ for (auto &thread : m_Threads)
+ {
+ thread->join();
+ }
+}
+
+void LoadedNetwork::Schedule(const InputTensors& inputTensors,
+ const OutputTensors& outputTensors,
+ const QosExecPriority priority,
+ std::shared_ptr<IAsyncExecutionCallback> cb)
+{
+ // Group execution parameters so that they can be easily added to the queue
+ ExecutionTuple groupExecParams = std::make_tuple(inputTensors, outputTensors, cb);
+ std::shared_ptr<ExecutionTuple> operation = make_shared<ExecutionTuple>(groupExecParams);
+
+ // Add a message to the queue and notify the request thread
+ std::unique_lock<std::mutex> lock(m_ThreadPoolMutex);
+ switch (priority) {
+ case QosExecPriority::High:
+ m_HighPriorityQueue.push(operation);
+ break;
+ case QosExecPriority::Low:
+ m_LowPriorityQueue.push(operation);
+ break;
+ case QosExecPriority::Medium:
+ default:
+ m_MediumPriorityQueue.push(operation);
+ }
+ m_ThreadPoolEvent.notify_one();
+}
+
+void LoadedNetwork::ProcessExecPriorities(std::unique_ptr<IWorkingMemHandle> workingMemHandle)
+{
+ int expireRate = EXPIRE_RATE;
+ int highPriorityCount = 0;
+ int mediumPriorityCount = 0;
+
+ IWorkingMemHandle& workingMemHandleRef = *workingMemHandle.get();
+
+ while (true)
+ {
+ std::shared_ptr<ExecutionTuple> currentExecInProgress(nullptr);
+ {
+ // Wait for a message to be added to the queue
+ // This is in a separate scope to minimise the lifetime of the lock
+ std::unique_lock<std::mutex> lock(m_ThreadPoolMutex);
+
+ m_ThreadPoolEvent.wait(lock,
+ [=] {
+ return m_TerminatePool || !m_HighPriorityQueue.empty() ||
+ !m_MediumPriorityQueue.empty() || !m_LowPriorityQueue.empty();
+ });
+
+ if (m_TerminatePool && m_HighPriorityQueue.empty() && m_MediumPriorityQueue.empty() &&
+ m_LowPriorityQueue.empty())
+ {
+ break;
+ }
+
+ // Get the message to process from the front of each queue based on priority from high to low
+ // Get high priority first if it does not exceed the expire rate
+ if (!m_HighPriorityQueue.empty() && highPriorityCount < expireRate)
+ {
+ currentExecInProgress = m_HighPriorityQueue.front();
+ m_HighPriorityQueue.pop();
+ highPriorityCount += 1;
+ }
+ // If high priority queue is empty or the count exceeds the expire rate, get medium priority message
+ else if (!m_MediumPriorityQueue.empty() && mediumPriorityCount < expireRate)
+ {
+ currentExecInProgress = m_MediumPriorityQueue.front();
+ m_MediumPriorityQueue.pop();
+ mediumPriorityCount += 1;
+ // Reset high priority count
+ highPriorityCount = 0;
+ }
+ // If medium priority queue is empty or the count exceeds the expire rate, get low priority message
+ else if (!m_LowPriorityQueue.empty())
+ {
+ currentExecInProgress = m_LowPriorityQueue.front();
+ m_LowPriorityQueue.pop();
+ // Reset high and medium priority count
+ highPriorityCount = 0;
+ mediumPriorityCount = 0;
+ }
+ else
+ {
+ // Reset high and medium priority count
+ highPriorityCount = 0;
+ mediumPriorityCount = 0;
+ continue;
+ }
+ }
+
+ // invoke the asynchronous execution method
+ auto inputTensors = std::get<0>(*currentExecInProgress);
+ auto outputTensors = std::get<1>(*currentExecInProgress);
+ auto cb = std::get<2>(*currentExecInProgress);
+
+ // Get time at start of inference
+ HighResolutionClock startTime = armnn::GetTimeNow();
+
+ try // executing the inference
+ {
+ // Execute and populate the time at end of inference in the callback
+ Execute(inputTensors, outputTensors, workingMemHandleRef) == Status::Success ?
+ cb->Notify(Status::Success, std::make_pair(startTime, armnn::GetTimeNow())) :
+ cb->Notify(Status::Failure, std::make_pair(startTime, armnn::GetTimeNow()));
+ }
+ catch (const RuntimeException& error)
+ {
+ cb->Notify(Status::Failure, std::make_pair(startTime, armnn::GetTimeNow()));
+ }
+ }
+}
+
void LoadedNetwork::EnqueueInput(const BindableLayer& layer,
const ConstTensor& inputTensor,
WorkingMemHandle& context)
@@ -1096,6 +1249,7 @@ Status LoadedNetwork::Execute(const InputTensors& inputTensors,
EnqueueOutput(*outputLayer, GetOutputTensor(outputLayer->GetBindingId(), outputTensors), workingMemHandle);
}
}
+
return executionSucceeded ? Status::Success : Status::Failure;
}
diff --git a/src/armnn/LoadedNetwork.hpp b/src/armnn/LoadedNetwork.hpp
index 51092c744e..b5474db294 100644
--- a/src/armnn/LoadedNetwork.hpp
+++ b/src/armnn/LoadedNetwork.hpp
@@ -19,13 +19,14 @@
#include <TimelineUtilityMethods.hpp>
#include <mutex>
+#include <condition_variable>
#include <unordered_map>
namespace cl
{
- class Context;
- class CommandQueue;
- class Device;
+class Context;
+class CommandQueue;
+class Device;
}
namespace armnn
@@ -34,8 +35,19 @@ namespace armnn
class LoadedNetwork
{
public:
- using WorkloadQueue = std::vector< std::unique_ptr<IWorkload> >;
- ~LoadedNetwork(){ FreeWorkingMemory(); }
+ using WorkloadQueue = std::vector<std::unique_ptr<IWorkload>>;
+
+ using ExecutionTuple = std::tuple<InputTensors,
+ OutputTensors,
+ std::shared_ptr<IAsyncExecutionCallback>>;
+
+ using ExecutionQueue = std::queue<std::shared_ptr<ExecutionTuple>>;
+
+ ~LoadedNetwork()
+ {
+ FreeWorkingMemory();
+ TerminateThreadPool();
+ }
/// Create a new unique WorkingMemHandle object. Create multiple handles if you wish to have
/// overlapped Execution by calling this function from different threads.
@@ -44,16 +56,25 @@ public:
TensorInfo GetInputTensorInfo(LayerBindingId layerId) const;
TensorInfo GetOutputTensorInfo(LayerBindingId layerId) const;
+ /// Single thread execution of the loaded network
Status EnqueueWorkload(const InputTensors& inputTensors, const OutputTensors& outputTensors);
+ /// Thread safe execution of the loaded network
Status Execute(const InputTensors& inputTensors,
const OutputTensors& outputTensors,
IWorkingMemHandle& workingMemHandle);
+ /// Schedule an asynchronous execution on the loaded network
+ void Schedule(const InputTensors& inputTensors,
+ const OutputTensors& outputTensors,
+ const QosExecPriority priority,
+ std::shared_ptr<IAsyncExecutionCallback> cb);
+
static std::unique_ptr<LoadedNetwork> MakeLoadedNetwork(std::unique_ptr<IOptimizedNetwork> net,
- std::string & errorMessage,
+ std::string& errorMessage,
const INetworkProperties& networkProperties,
- profiling::ProfilingService& profilingService);
+ profiling::ProfilingService& profilingService,
+ const NetworkId networkIdOut);
// NOTE we return by reference as the purpose of this method is only to provide
// access to the private m_Profiler and in theory we should not need to increment
@@ -87,7 +108,8 @@ private:
LoadedNetwork(std::unique_ptr<IOptimizedNetwork> net,
const INetworkProperties& networkProperties,
- profiling::ProfilingService& profilingService);
+ profiling::ProfilingService& profilingService,
+ const NetworkId networkIdOut);
void EnqueueInput(const BindableLayer& layer, ITensorHandle* tensorHandle, const TensorInfo& tensorInfo);
@@ -97,9 +119,15 @@ private:
void EnqueueOutput(const BindableLayer& layer, const Tensor& outputTensor, WorkingMemHandle& handle);
+ void ProcessExecPriorities(std::unique_ptr<IWorkingMemHandle> workingMemHandle);
+
bool Execute(std::unique_ptr<profiling::TimelineUtilityMethods>& timelineUtils,
profiling::ProfilingGuid inferenceGuid);
+ void CreateThreadPool(std::size_t numThreads);
+
+ void TerminateThreadPool() noexcept;
+
const IWorkloadFactory& GetWorkloadFactory(const Layer& layer) const;
using BackendPtrMap = std::unordered_map<BackendId, IBackendInternalUniquePtr>;
@@ -108,19 +136,38 @@ private:
WorkloadFactoryMap m_WorkloadFactories;
std::unique_ptr<IOptimizedNetwork> m_OptimizedNetwork;
- WorkloadQueue m_InputQueue;
- WorkloadQueue m_WorkloadQueue;
- WorkloadQueue m_OutputQueue;
- std::shared_ptr<IProfiler> m_Profiler;
+ std::shared_ptr<IProfiler> m_Profiler;
+
+ WorkloadQueue m_InputQueue;
+ WorkloadQueue m_WorkloadQueue;
+ WorkloadQueue m_OutputQueue;
mutable std::mutex m_WorkingMemMutex;
- bool m_IsWorkingMemAllocated=false;
+ bool m_IsWorkingMemAllocated = false;
+
+ std::vector<std::unique_ptr<std::thread>> m_Threads;
+ std::stack<IWorkingMemHandle> m_WorkingMemHandles;
+
+ ExecutionQueue m_HighPriorityQueue;
+ ExecutionQueue m_MediumPriorityQueue;
+ ExecutionQueue m_LowPriorityQueue;
+
+ // Condition Variables require mutex which will guard the shared state.
+ // Has an event happened? Stop signal for example
+ std::condition_variable m_ThreadPoolEvent;
+ std::mutex m_ThreadPoolMutex;
+
+ // The shared state for conditional variable
+ bool m_TerminatePool = false;
+
INetworkProperties m_NetworkProperties;
+ const NetworkId m_NetworkId;
+
TensorHandleFactoryRegistry m_TensorHandleFactoryRegistry;
- profiling::ProfilingService& m_ProfilingService;
+ profiling::ProfilingService& m_ProfilingService;
};
}
diff --git a/src/armnn/Runtime.cpp b/src/armnn/Runtime.cpp
index 1dd86a61ce..e04cf5ddaf 100644
--- a/src/armnn/Runtime.cpp
+++ b/src/armnn/Runtime.cpp
@@ -88,6 +88,15 @@ Status IRuntime::Execute(IWorkingMemHandle& workingMemHandle,
return pRuntimeImpl->Execute(workingMemHandle, inputTensors, outputTensors);
}
+void IRuntime::Schedule(NetworkId networkId,
+ const InputTensors& inputTensors,
+ const OutputTensors& outputTensors,
+ const QosExecPriority priority,
+ std::shared_ptr<IAsyncExecutionCallback> cb)
+{
+ pRuntimeImpl->Schedule(networkId, inputTensors, outputTensors, priority, cb);
+}
+
Status IRuntime::UnloadNetwork(NetworkId networkId)
{
return pRuntimeImpl->UnloadNetwork(networkId);
@@ -150,7 +159,8 @@ Status RuntimeImpl::LoadNetwork(NetworkId& networkIdOut,
std::unique_ptr<IOptimizedNetwork>(rawNetwork),
errorMessage,
networkProperties,
- m_ProfilingService);
+ m_ProfilingService,
+ networkIdOut);
if (!loadedNetwork)
{
@@ -439,24 +449,42 @@ Status RuntimeImpl::Execute(IWorkingMemHandle& iWorkingMemHandle,
}
if (!loadedNetwork->IsAsyncEnabled())
{
- ARMNN_LOG(error) << "Network " << networkId << " is not async enabled.\n";
+ ARMNN_LOG(error) << "Attempting execute " << networkId << " when it is not async enabled.\n";
return Status::Failure;
}
ProfilerManager::GetInstance().RegisterProfiler(loadedNetwork->GetProfiler().get());
ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "Execute");
- static thread_local NetworkId lastId = networkId;
- if (lastId != networkId)
+ return loadedNetwork->Execute(inputTensors, outputTensors, iWorkingMemHandle);
+}
+
+void RuntimeImpl::Schedule(NetworkId networkId,
+ const InputTensors& inputTensors,
+ const OutputTensors& outputTensors,
+ const QosExecPriority priority,
+ std::shared_ptr<IAsyncExecutionCallback> callback)
+{
+ LoadedNetwork* loadedNetwork = GetLoadedNetworkPtr(networkId);
+
+ if (!loadedNetwork)
{
- LoadedNetworkFuncSafe(lastId, [](LoadedNetwork* network)
- {
- network->FreeWorkingMemory();
- });
+ throw armnn::Exception(
+ "Network with ID of " + std::to_string(networkId) + " does not exist \n"
+ );
+ }
+ if (!loadedNetwork->IsAsyncEnabled())
+ {
+ throw armnn::Exception(
+ "Attempting to schedule Network " + std::to_string(networkId) + " when it is not async enabled \n"
+ );
}
- lastId=networkId;
- return loadedNetwork->Execute(inputTensors, outputTensors, iWorkingMemHandle);
+ ProfilerManager::GetInstance().RegisterProfiler(loadedNetwork->GetProfiler().get());
+
+ ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "Schedule");
+
+ loadedNetwork->Schedule(inputTensors, outputTensors, priority, callback);
}
/// Create a new unique WorkingMemHandle object. Create multiple handles if you wish to have
diff --git a/src/armnn/Runtime.hpp b/src/armnn/Runtime.hpp
index da5445383f..55a4accf67 100644
--- a/src/armnn/Runtime.hpp
+++ b/src/armnn/Runtime.hpp
@@ -60,6 +60,20 @@ public:
const InputTensors& inputTensors,
const OutputTensors& outputTensors);
+ /// This is an experimental function.
+ /// Schedule a thread safe execution by taking the input tensors and an execution priority for Quality of Service.
+ /// The output tensors will then be filled and the callback object will notify that the execution has either
+ /// succeeded or failed.
+ void Schedule(NetworkId networkId,
+ const InputTensors& inputTensors,
+ const OutputTensors& outputTensors,
+ const QosExecPriority priority,
+ std::shared_ptr<IAsyncExecutionCallback> callback);
+
+ /// This is an experimental function.
+ /// Evaluates a network using input in inputTensors and outputs filled into outputTensors.
+ /// This function performs a thread safe execution of the network. Returns once execution is complete.
+ /// Will block until this and any other thread using the same workingMem object completes.
Status Execute(IWorkingMemHandle& workingMemHandle,
const InputTensors& inputTensors,
const OutputTensors& outputTensors);