diff options
author | Keith Davis <keith.davis@arm.com> | 2021-04-22 10:10:34 +0100 |
---|---|---|
committer | finn.williams <finn.williams@arm.com> | 2021-05-06 19:39:39 +0000 |
commit | e813d67f86df41a238ff79b5c554ef5027f56576 (patch) | |
tree | 54c2145a78297d61e6e94676729cb2468490ade3 /src/armnn | |
parent | d905decd256558bbee165e636ce4242ac3b9c917 (diff) | |
download | armnn-e813d67f86df41a238ff79b5c554ef5027f56576.tar.gz |
IVGCVSW-5813 Add Async Queue to IRuntime
Signed-off-by: Keith Davis <keith.davis@arm.com>
Change-Id: Icc0d131c8ee2e9748e2f14762a75962b39c10f9d
Diffstat (limited to 'src/armnn')
-rw-r--r-- | src/armnn/AsyncExecutionCallback.cpp | 57 | ||||
-rw-r--r-- | src/armnn/AsyncExecutionCallback.hpp | 48 | ||||
-rw-r--r-- | src/armnn/LoadedNetwork.cpp | 160 | ||||
-rw-r--r-- | src/armnn/LoadedNetwork.hpp | 75 | ||||
-rw-r--r-- | src/armnn/Runtime.cpp | 48 | ||||
-rw-r--r-- | src/armnn/Runtime.hpp | 14 |
6 files changed, 375 insertions, 27 deletions
diff --git a/src/armnn/AsyncExecutionCallback.cpp b/src/armnn/AsyncExecutionCallback.cpp new file mode 100644 index 0000000000..c44808918d --- /dev/null +++ b/src/armnn/AsyncExecutionCallback.cpp @@ -0,0 +1,57 @@ +// +// Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include <AsyncExecutionCallback.hpp> + +namespace armnn +{ + +namespace experimental +{ + +void AsyncExecutionCallback::Notify(armnn::Status status, InferenceTimingPair timeTaken) +{ + { + std::lock_guard<std::mutex> hold(m_Mutex); + if (m_Notified) + { + return; + } + // store results and mark as notified + m_Status = status; + m_StartTime = timeTaken.first; + m_EndTime = timeTaken.second; + m_Notified = true; + } + m_Condition.notify_all(); +} + +void AsyncExecutionCallback::Wait() const +{ + std::unique_lock<std::mutex> lock(m_Mutex); + m_Condition.wait(lock, [this] { return m_Notified; }); +} + +armnn::Status AsyncExecutionCallback::GetStatus() const +{ + Wait(); + return m_Status; +} + +HighResolutionClock AsyncExecutionCallback::GetStartTime() const +{ + Wait(); + return m_StartTime; +} + +HighResolutionClock AsyncExecutionCallback::GetEndTime() const +{ + Wait(); + return m_EndTime; +} + +} // namespace experimental + +} // namespace armnn
\ No newline at end of file diff --git a/src/armnn/AsyncExecutionCallback.hpp b/src/armnn/AsyncExecutionCallback.hpp new file mode 100644 index 0000000000..c17b839748 --- /dev/null +++ b/src/armnn/AsyncExecutionCallback.hpp @@ -0,0 +1,48 @@ +// +// Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include <armnn/IAsyncExecutionCallback.hpp> +#include <armnn/Types.hpp> +#include <condition_variable> + +#include <mutex> +#include <thread> + +namespace armnn +{ + +namespace experimental +{ + +class AsyncExecutionCallback final : public IAsyncExecutionCallback +{ +public: + AsyncExecutionCallback() + {} + ~AsyncExecutionCallback() + {} + + void Notify(armnn::Status status, InferenceTimingPair timeTaken); + void Wait() const; + + armnn::Status GetStatus() const; + HighResolutionClock GetStartTime() const; + HighResolutionClock GetEndTime() const; + +private: + mutable std::mutex m_Mutex; + mutable std::condition_variable m_Condition; + + HighResolutionClock m_StartTime; + HighResolutionClock m_EndTime; + armnn::Status m_Status = Status::Failure; + bool m_Notified = false; +}; + +} // namespace experimental + +} // namespace armnn
\ No newline at end of file diff --git a/src/armnn/LoadedNetwork.cpp b/src/armnn/LoadedNetwork.cpp index 46eb9883fb..67de00f0f3 100644 --- a/src/armnn/LoadedNetwork.cpp +++ b/src/armnn/LoadedNetwork.cpp @@ -24,6 +24,7 @@ #include <LabelsAndEventClasses.hpp> #include <fmt/format.h> +#include <armnn/utility/Timer.hpp> namespace armnn { @@ -84,7 +85,8 @@ void AddWorkloadStructure(std::unique_ptr<TimelineUtilityMethods>& timelineUtils std::unique_ptr<LoadedNetwork> LoadedNetwork::MakeLoadedNetwork(std::unique_ptr<IOptimizedNetwork> net, std::string& errorMessage, const INetworkProperties& networkProperties, - profiling::ProfilingService& profilingService) + profiling::ProfilingService& profilingService, + const NetworkId networkIdOut) { std::unique_ptr<LoadedNetwork> loadedNetwork; @@ -98,7 +100,7 @@ std::unique_ptr<LoadedNetwork> LoadedNetwork::MakeLoadedNetwork(std::unique_ptr< try { - loadedNetwork.reset(new LoadedNetwork(std::move(net), networkProperties, profilingService)); + loadedNetwork.reset(new LoadedNetwork(std::move(net), networkProperties, profilingService, networkIdOut)); } catch (const armnn::RuntimeException& error) { @@ -118,9 +120,11 @@ std::unique_ptr<LoadedNetwork> LoadedNetwork::MakeLoadedNetwork(std::unique_ptr< LoadedNetwork::LoadedNetwork(std::unique_ptr<IOptimizedNetwork> net, const INetworkProperties& networkProperties, - profiling::ProfilingService& profilingService) : + profiling::ProfilingService& profilingService, + const NetworkId networkId) : m_OptimizedNetwork(std::move(net)), m_NetworkProperties(networkProperties), + m_NetworkId(networkId), m_TensorHandleFactoryRegistry(), m_ProfilingService(profilingService) { @@ -161,6 +165,14 @@ LoadedNetwork::LoadedNetwork(std::unique_ptr<IOptimizedNetwork> net, } } } + + // Create the thread pool which will have working memory handles assigned to each thread + // Should occur after factories are registered so thet the WorkingMemHandles can be created + if (m_NetworkProperties.m_NumThreads > 0 && networkProperties.m_AsyncEnabled) + { + CreateThreadPool(m_NetworkProperties.m_NumThreads); + } + if (!networkProperties.m_AsyncEnabled) { for (auto &&layer : order) @@ -846,6 +858,147 @@ bool LoadedNetwork::Execute(std::unique_ptr<TimelineUtilityMethods>& timelineUti return success; } +void LoadedNetwork::CreateThreadPool(std::size_t numThreads) +{ + + for (auto i = 0u; i < numThreads; ++i) + { + std::unique_ptr<IWorkingMemHandle> workingMemHandle = CreateWorkingMemHandle(m_NetworkId); + m_Threads.emplace_back( + std::make_unique<std::thread>( + &LoadedNetwork::ProcessExecPriorities, + this, + std::move(workingMemHandle) + ) + ); + } +} + +void LoadedNetwork::TerminateThreadPool() noexcept +{ + { + std::unique_lock<std::mutex> threadPoolLock(m_ThreadPoolMutex); + m_TerminatePool = true; + } + + m_ThreadPoolEvent.notify_all(); + + for (auto &thread : m_Threads) + { + thread->join(); + } +} + +void LoadedNetwork::Schedule(const InputTensors& inputTensors, + const OutputTensors& outputTensors, + const QosExecPriority priority, + std::shared_ptr<IAsyncExecutionCallback> cb) +{ + // Group execution parameters so that they can be easily added to the queue + ExecutionTuple groupExecParams = std::make_tuple(inputTensors, outputTensors, cb); + std::shared_ptr<ExecutionTuple> operation = make_shared<ExecutionTuple>(groupExecParams); + + // Add a message to the queue and notify the request thread + std::unique_lock<std::mutex> lock(m_ThreadPoolMutex); + switch (priority) { + case QosExecPriority::High: + m_HighPriorityQueue.push(operation); + break; + case QosExecPriority::Low: + m_LowPriorityQueue.push(operation); + break; + case QosExecPriority::Medium: + default: + m_MediumPriorityQueue.push(operation); + } + m_ThreadPoolEvent.notify_one(); +} + +void LoadedNetwork::ProcessExecPriorities(std::unique_ptr<IWorkingMemHandle> workingMemHandle) +{ + int expireRate = EXPIRE_RATE; + int highPriorityCount = 0; + int mediumPriorityCount = 0; + + IWorkingMemHandle& workingMemHandleRef = *workingMemHandle.get(); + + while (true) + { + std::shared_ptr<ExecutionTuple> currentExecInProgress(nullptr); + { + // Wait for a message to be added to the queue + // This is in a separate scope to minimise the lifetime of the lock + std::unique_lock<std::mutex> lock(m_ThreadPoolMutex); + + m_ThreadPoolEvent.wait(lock, + [=] { + return m_TerminatePool || !m_HighPriorityQueue.empty() || + !m_MediumPriorityQueue.empty() || !m_LowPriorityQueue.empty(); + }); + + if (m_TerminatePool && m_HighPriorityQueue.empty() && m_MediumPriorityQueue.empty() && + m_LowPriorityQueue.empty()) + { + break; + } + + // Get the message to process from the front of each queue based on priority from high to low + // Get high priority first if it does not exceed the expire rate + if (!m_HighPriorityQueue.empty() && highPriorityCount < expireRate) + { + currentExecInProgress = m_HighPriorityQueue.front(); + m_HighPriorityQueue.pop(); + highPriorityCount += 1; + } + // If high priority queue is empty or the count exceeds the expire rate, get medium priority message + else if (!m_MediumPriorityQueue.empty() && mediumPriorityCount < expireRate) + { + currentExecInProgress = m_MediumPriorityQueue.front(); + m_MediumPriorityQueue.pop(); + mediumPriorityCount += 1; + // Reset high priority count + highPriorityCount = 0; + } + // If medium priority queue is empty or the count exceeds the expire rate, get low priority message + else if (!m_LowPriorityQueue.empty()) + { + currentExecInProgress = m_LowPriorityQueue.front(); + m_LowPriorityQueue.pop(); + // Reset high and medium priority count + highPriorityCount = 0; + mediumPriorityCount = 0; + } + else + { + // Reset high and medium priority count + highPriorityCount = 0; + mediumPriorityCount = 0; + continue; + } + } + + // invoke the asynchronous execution method + auto inputTensors = std::get<0>(*currentExecInProgress); + auto outputTensors = std::get<1>(*currentExecInProgress); + auto cb = std::get<2>(*currentExecInProgress); + + // Get time at start of inference + HighResolutionClock startTime = armnn::GetTimeNow(); + + try // executing the inference + { + // Execute and populate the time at end of inference in the callback + Execute(inputTensors, outputTensors, workingMemHandleRef) == Status::Success ? + cb->Notify(Status::Success, std::make_pair(startTime, armnn::GetTimeNow())) : + cb->Notify(Status::Failure, std::make_pair(startTime, armnn::GetTimeNow())); + } + catch (const RuntimeException& error) + { + cb->Notify(Status::Failure, std::make_pair(startTime, armnn::GetTimeNow())); + } + } +} + void LoadedNetwork::EnqueueInput(const BindableLayer& layer, const ConstTensor& inputTensor, WorkingMemHandle& context) @@ -1096,6 +1249,7 @@ Status LoadedNetwork::Execute(const InputTensors& inputTensors, EnqueueOutput(*outputLayer, GetOutputTensor(outputLayer->GetBindingId(), outputTensors), workingMemHandle); } } + return executionSucceeded ? Status::Success : Status::Failure; } diff --git a/src/armnn/LoadedNetwork.hpp b/src/armnn/LoadedNetwork.hpp index 51092c744e..b5474db294 100644 --- a/src/armnn/LoadedNetwork.hpp +++ b/src/armnn/LoadedNetwork.hpp @@ -19,13 +19,14 @@ #include <TimelineUtilityMethods.hpp> #include <mutex> +#include <condition_variable> #include <unordered_map> namespace cl { - class Context; - class CommandQueue; - class Device; +class Context; +class CommandQueue; +class Device; } namespace armnn @@ -34,8 +35,19 @@ namespace armnn class LoadedNetwork { public: - using WorkloadQueue = std::vector< std::unique_ptr<IWorkload> >; - ~LoadedNetwork(){ FreeWorkingMemory(); } + using WorkloadQueue = std::vector<std::unique_ptr<IWorkload>>; + + using ExecutionTuple = std::tuple<InputTensors, + OutputTensors, + std::shared_ptr<IAsyncExecutionCallback>>; + + using ExecutionQueue = std::queue<std::shared_ptr<ExecutionTuple>>; + + ~LoadedNetwork() + { + FreeWorkingMemory(); + TerminateThreadPool(); + } /// Create a new unique WorkingMemHandle object. Create multiple handles if you wish to have /// overlapped Execution by calling this function from different threads. @@ -44,16 +56,25 @@ public: TensorInfo GetInputTensorInfo(LayerBindingId layerId) const; TensorInfo GetOutputTensorInfo(LayerBindingId layerId) const; + /// Single thread execution of the loaded network Status EnqueueWorkload(const InputTensors& inputTensors, const OutputTensors& outputTensors); + /// Thread safe execution of the loaded network Status Execute(const InputTensors& inputTensors, const OutputTensors& outputTensors, IWorkingMemHandle& workingMemHandle); + /// Schedule an asynchronous execution on the loaded network + void Schedule(const InputTensors& inputTensors, + const OutputTensors& outputTensors, + const QosExecPriority priority, + std::shared_ptr<IAsyncExecutionCallback> cb); + static std::unique_ptr<LoadedNetwork> MakeLoadedNetwork(std::unique_ptr<IOptimizedNetwork> net, - std::string & errorMessage, + std::string& errorMessage, const INetworkProperties& networkProperties, - profiling::ProfilingService& profilingService); + profiling::ProfilingService& profilingService, + const NetworkId networkIdOut); // NOTE we return by reference as the purpose of this method is only to provide // access to the private m_Profiler and in theory we should not need to increment @@ -87,7 +108,8 @@ private: LoadedNetwork(std::unique_ptr<IOptimizedNetwork> net, const INetworkProperties& networkProperties, - profiling::ProfilingService& profilingService); + profiling::ProfilingService& profilingService, + const NetworkId networkIdOut); void EnqueueInput(const BindableLayer& layer, ITensorHandle* tensorHandle, const TensorInfo& tensorInfo); @@ -97,9 +119,15 @@ private: void EnqueueOutput(const BindableLayer& layer, const Tensor& outputTensor, WorkingMemHandle& handle); + void ProcessExecPriorities(std::unique_ptr<IWorkingMemHandle> workingMemHandle); + bool Execute(std::unique_ptr<profiling::TimelineUtilityMethods>& timelineUtils, profiling::ProfilingGuid inferenceGuid); + void CreateThreadPool(std::size_t numThreads); + + void TerminateThreadPool() noexcept; + const IWorkloadFactory& GetWorkloadFactory(const Layer& layer) const; using BackendPtrMap = std::unordered_map<BackendId, IBackendInternalUniquePtr>; @@ -108,19 +136,38 @@ private: WorkloadFactoryMap m_WorkloadFactories; std::unique_ptr<IOptimizedNetwork> m_OptimizedNetwork; - WorkloadQueue m_InputQueue; - WorkloadQueue m_WorkloadQueue; - WorkloadQueue m_OutputQueue; - std::shared_ptr<IProfiler> m_Profiler; + std::shared_ptr<IProfiler> m_Profiler; + + WorkloadQueue m_InputQueue; + WorkloadQueue m_WorkloadQueue; + WorkloadQueue m_OutputQueue; mutable std::mutex m_WorkingMemMutex; - bool m_IsWorkingMemAllocated=false; + bool m_IsWorkingMemAllocated = false; + + std::vector<std::unique_ptr<std::thread>> m_Threads; + std::stack<IWorkingMemHandle> m_WorkingMemHandles; + + ExecutionQueue m_HighPriorityQueue; + ExecutionQueue m_MediumPriorityQueue; + ExecutionQueue m_LowPriorityQueue; + + // Condition Variables require mutex which will guard the shared state. + // Has an event happened? Stop signal for example + std::condition_variable m_ThreadPoolEvent; + std::mutex m_ThreadPoolMutex; + + // The shared state for conditional variable + bool m_TerminatePool = false; + INetworkProperties m_NetworkProperties; + const NetworkId m_NetworkId; + TensorHandleFactoryRegistry m_TensorHandleFactoryRegistry; - profiling::ProfilingService& m_ProfilingService; + profiling::ProfilingService& m_ProfilingService; }; } diff --git a/src/armnn/Runtime.cpp b/src/armnn/Runtime.cpp index 1dd86a61ce..e04cf5ddaf 100644 --- a/src/armnn/Runtime.cpp +++ b/src/armnn/Runtime.cpp @@ -88,6 +88,15 @@ Status IRuntime::Execute(IWorkingMemHandle& workingMemHandle, return pRuntimeImpl->Execute(workingMemHandle, inputTensors, outputTensors); } +void IRuntime::Schedule(NetworkId networkId, + const InputTensors& inputTensors, + const OutputTensors& outputTensors, + const QosExecPriority priority, + std::shared_ptr<IAsyncExecutionCallback> cb) +{ + pRuntimeImpl->Schedule(networkId, inputTensors, outputTensors, priority, cb); +} + Status IRuntime::UnloadNetwork(NetworkId networkId) { return pRuntimeImpl->UnloadNetwork(networkId); @@ -150,7 +159,8 @@ Status RuntimeImpl::LoadNetwork(NetworkId& networkIdOut, std::unique_ptr<IOptimizedNetwork>(rawNetwork), errorMessage, networkProperties, - m_ProfilingService); + m_ProfilingService, + networkIdOut); if (!loadedNetwork) { @@ -439,24 +449,42 @@ Status RuntimeImpl::Execute(IWorkingMemHandle& iWorkingMemHandle, } if (!loadedNetwork->IsAsyncEnabled()) { - ARMNN_LOG(error) << "Network " << networkId << " is not async enabled.\n"; + ARMNN_LOG(error) << "Attempting execute " << networkId << " when it is not async enabled.\n"; return Status::Failure; } ProfilerManager::GetInstance().RegisterProfiler(loadedNetwork->GetProfiler().get()); ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "Execute"); - static thread_local NetworkId lastId = networkId; - if (lastId != networkId) + return loadedNetwork->Execute(inputTensors, outputTensors, iWorkingMemHandle); +} + +void RuntimeImpl::Schedule(NetworkId networkId, + const InputTensors& inputTensors, + const OutputTensors& outputTensors, + const QosExecPriority priority, + std::shared_ptr<IAsyncExecutionCallback> callback) +{ + LoadedNetwork* loadedNetwork = GetLoadedNetworkPtr(networkId); + + if (!loadedNetwork) { - LoadedNetworkFuncSafe(lastId, [](LoadedNetwork* network) - { - network->FreeWorkingMemory(); - }); + throw armnn::Exception( + "Network with ID of " + std::to_string(networkId) + " does not exist \n" + ); + } + if (!loadedNetwork->IsAsyncEnabled()) + { + throw armnn::Exception( + "Attempting to schedule Network " + std::to_string(networkId) + " when it is not async enabled \n" + ); } - lastId=networkId; - return loadedNetwork->Execute(inputTensors, outputTensors, iWorkingMemHandle); + ProfilerManager::GetInstance().RegisterProfiler(loadedNetwork->GetProfiler().get()); + + ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "Schedule"); + + loadedNetwork->Schedule(inputTensors, outputTensors, priority, callback); } /// Create a new unique WorkingMemHandle object. Create multiple handles if you wish to have diff --git a/src/armnn/Runtime.hpp b/src/armnn/Runtime.hpp index da5445383f..55a4accf67 100644 --- a/src/armnn/Runtime.hpp +++ b/src/armnn/Runtime.hpp @@ -60,6 +60,20 @@ public: const InputTensors& inputTensors, const OutputTensors& outputTensors); + /// This is an experimental function. + /// Schedule a thread safe execution by taking the input tensors and an execution priority for Quality of Service. + /// The output tensors will then be filled and the callback object will notify that the execution has either + /// succeeded or failed. + void Schedule(NetworkId networkId, + const InputTensors& inputTensors, + const OutputTensors& outputTensors, + const QosExecPriority priority, + std::shared_ptr<IAsyncExecutionCallback> callback); + + /// This is an experimental function. + /// Evaluates a network using input in inputTensors and outputs filled into outputTensors. + /// This function performs a thread safe execution of the network. Returns once execution is complete. + /// Will block until this and any other thread using the same workingMem object completes. Status Execute(IWorkingMemHandle& workingMemHandle, const InputTensors& inputTensors, const OutputTensors& outputTensors); |