aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFinn Williams <Finn.Williams@arm.com>2021-06-09 17:07:33 +0100
committerFinn Williams <Finn.Williams@arm.com>2021-06-23 17:14:53 +0100
commitf364d5391b08e9071cd965f5765385ec9156b652 (patch)
tree1ea93ed574a3eb51f5a1f4bb08dc1ad18aa1c6a2
parent7a00eaa6ecf121623823b1951c0e6c9093271adf (diff)
downloadarmnn-f364d5391b08e9071cd965f5765385ec9156b652.tar.gz
IVGCVSW-6062 Rework the async threadpool
!android-nn-driver:5802 * Extract the threadpool from LoadedNetwork/Runtime * Refactor the threadpool to be handle multiple networks * Trim IAsyncExecutionCallback and add an InferenceId to AsyncExecutionCallback * Add AsyncCallbackManager class Signed-off-by: Finn Williams <Finn.Williams@arm.com> Change-Id: I36aa2ad29c16bc10ee0706adfeb6b27f60012afb
-rw-r--r--Android.mk1
-rw-r--r--CMakeLists.txt2
-rw-r--r--include/armnn/IAsyncExecutionCallback.hpp13
-rw-r--r--include/armnn/IRuntime.hpp31
-rw-r--r--include/armnn/IWorkingMemHandle.hpp3
-rw-r--r--include/armnn/Threadpool.hpp78
-rw-r--r--src/armnn/AsyncExecutionCallback.cpp38
-rw-r--r--src/armnn/AsyncExecutionCallback.hpp50
-rw-r--r--src/armnn/LoadedNetwork.cpp159
-rw-r--r--src/armnn/LoadedNetwork.hpp42
-rw-r--r--src/armnn/Runtime.cpp40
-rw-r--r--src/armnn/Runtime.hpp10
-rw-r--r--src/armnn/Threadpool.cpp206
-rw-r--r--src/armnn/WorkingMemHandle.cpp3
-rw-r--r--src/armnn/WorkingMemHandle.hpp7
-rw-r--r--src/backends/backendsCommon/test/StridedSliceAsyncEndToEndTest.hpp31
-rw-r--r--tests/ExecuteNetwork/ExecuteNetwork.cpp27
-rw-r--r--tests/InferenceModel.hpp38
18 files changed, 434 insertions, 345 deletions
diff --git a/Android.mk b/Android.mk
index aec32699b0..f3e4f40534 100644
--- a/Android.mk
+++ b/Android.mk
@@ -132,6 +132,7 @@ LOCAL_SRC_FILES := \
src/armnn/SubgraphView.cpp \
src/armnn/SubgraphViewSelector.cpp \
src/armnn/Tensor.cpp \
+ src/armnn/Threadpool.cpp \
src/armnn/TypesUtils.cpp \
src/armnn/Utils.cpp \
src/armnn/WallClockTimer.cpp \
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 426c089bd6..78c2f17e6d 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -180,6 +180,7 @@ list(APPEND armnn_sources
include/armnn/QuantizedLstmParams.hpp
include/armnn/Tensor.hpp
include/armnn/TensorFwd.hpp
+ include/armnn/Threadpool.hpp
include/armnn/Types.hpp
include/armnn/TypesUtils.hpp
include/armnn/Utils.hpp
@@ -387,6 +388,7 @@ list(APPEND armnn_sources
src/armnn/SubgraphViewSelector.cpp
src/armnn/SubgraphViewSelector.hpp
src/armnn/Tensor.cpp
+ src/armnn/Threadpool.cpp
src/armnn/TypesUtils.cpp
src/armnn/Utils.cpp
src/armnn/WallClockTimer.cpp
diff --git a/include/armnn/IAsyncExecutionCallback.hpp b/include/armnn/IAsyncExecutionCallback.hpp
index 045ec4581f..3e0cacccee 100644
--- a/include/armnn/IAsyncExecutionCallback.hpp
+++ b/include/armnn/IAsyncExecutionCallback.hpp
@@ -23,19 +23,6 @@ public:
// Notify the AsyncExecutionCallback object of the armnn execution status
virtual void Notify(armnn::Status status, InferenceTimingPair timeTaken) = 0;
-
- // Block the calling thread until the AsyncExecutionCallback object allows it to proceed
- virtual void Wait() const = 0;
-
- // Retrieve the ArmNN Status from the AsyncExecutionCallback that has been notified
- virtual armnn::Status GetStatus() const = 0;
-
- // Retrieve the start time before executing the inference
- virtual HighResolutionClock GetStartTime() const = 0;
-
- // Retrieve the time after executing the inference
- virtual HighResolutionClock GetEndTime() const = 0;
-
};
} // experimental
diff --git a/include/armnn/IRuntime.hpp b/include/armnn/IRuntime.hpp
index bfc13c9c01..f88b6b664f 100644
--- a/include/armnn/IRuntime.hpp
+++ b/include/armnn/IRuntime.hpp
@@ -32,24 +32,34 @@ struct INetworkProperties
ARMNN_DEPRECATED_MSG("Please use INetworkProperties constructor with MemorySource argument")
INetworkProperties(bool importEnabled = false,
bool exportEnabled = false,
- bool asyncEnabled = false,
- size_t numThreads = 1)
+ bool asyncEnabled = false)
: m_ImportEnabled(importEnabled)
, m_ExportEnabled(exportEnabled)
, m_AsyncEnabled(asyncEnabled)
- , m_NumThreads(numThreads)
, m_InputSource(m_ImportEnabled ? MemorySource::Malloc : MemorySource::Undefined)
, m_OutputSource(m_ExportEnabled ? MemorySource::Malloc : MemorySource::Undefined)
{}
+ ARMNN_DEPRECATED_MSG("Please use INetworkProperties constructor without numThreads argument")
INetworkProperties(bool asyncEnabled,
MemorySource m_InputSource,
MemorySource m_OutputSource,
- size_t numThreads = 1)
+ size_t numThreads)
+ : m_ImportEnabled(m_InputSource != MemorySource::Undefined)
+ , m_ExportEnabled(m_OutputSource != MemorySource::Undefined)
+ , m_AsyncEnabled(asyncEnabled)
+ , m_InputSource(m_InputSource)
+ , m_OutputSource(m_OutputSource)
+ {
+ armnn::IgnoreUnused(numThreads);
+ }
+
+ INetworkProperties(bool asyncEnabled,
+ MemorySource m_InputSource,
+ MemorySource m_OutputSource)
: m_ImportEnabled(m_InputSource != MemorySource::Undefined)
, m_ExportEnabled(m_OutputSource != MemorySource::Undefined)
, m_AsyncEnabled(asyncEnabled)
- , m_NumThreads(numThreads)
, m_InputSource(m_InputSource)
, m_OutputSource(m_OutputSource)
{}
@@ -60,7 +70,6 @@ struct INetworkProperties
const bool m_ExportEnabled;
const bool m_AsyncEnabled;
- const size_t m_NumThreads;
const MemorySource m_InputSource;
const MemorySource m_OutputSource;
@@ -191,16 +200,6 @@ public:
const InputTensors& inputTensors,
const OutputTensors& outputTensors);
- /// This is an experimental function
- /// Schedule a thread safe execution by taking the input tensors and an execution priority for Quality of Service.
- /// The output tensors will then be filled and the callback object will notify that the execution has either
- /// succeeded or failed.
- void Schedule(NetworkId networkId,
- const InputTensors& inputTensors,
- const OutputTensors& outputTensors,
- const QosExecPriority priority,
- std::shared_ptr<IAsyncExecutionCallback> callback);
-
/// Unloads a network from the IRuntime.
/// At the moment this only removes the network from the m_Impl->m_Network.
/// This might need more work in the future to be AndroidNN compliant.
diff --git a/include/armnn/IWorkingMemHandle.hpp b/include/armnn/IWorkingMemHandle.hpp
index 6fb2f9fe5f..171fa3d81c 100644
--- a/include/armnn/IWorkingMemHandle.hpp
+++ b/include/armnn/IWorkingMemHandle.hpp
@@ -25,9 +25,6 @@ public:
/// Returns the NetworkId of the Network that this IWorkingMemHandle works with.
virtual NetworkId GetNetworkId() = 0;
- /// Returns the InferenceId of the Inference that this IWorkingMemHandle works with.
- virtual profiling::ProfilingGuid GetInferenceId() = 0;
-
/// Allocate the backing memory required for execution. If this is not called, then allocation will be
/// deferred to execution time. The mutex must be locked.
virtual void Allocate() = 0;
diff --git a/include/armnn/Threadpool.hpp b/include/armnn/Threadpool.hpp
new file mode 100644
index 0000000000..e2458dbb65
--- /dev/null
+++ b/include/armnn/Threadpool.hpp
@@ -0,0 +1,78 @@
+//
+// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <armnn/Tensor.hpp>
+#include <armnn/Types.hpp>
+
+#include "INetwork.hpp"
+#include "IRuntime.hpp"
+
+#include <thread>
+#include <mutex>
+#include <condition_variable>
+#include <unordered_map>
+#include <queue>
+
+namespace armnn
+{
+namespace experimental
+{
+class Threadpool
+{
+public:
+ Threadpool(std::size_t numThreads,
+ IRuntime* runtimePtr,
+ std::vector<std::shared_ptr<IWorkingMemHandle>> memHandles);
+
+ ~Threadpool()
+ {
+ TerminateThreadPool();
+ }
+
+ void LoadMemHandles(std::vector<std::shared_ptr<IWorkingMemHandle>> memHandles);
+ void UnloadMemHandles(NetworkId networkId);
+
+ /// Schedule an asynchronous execution on the loaded network
+ void Schedule(NetworkId networkId,
+ const InputTensors &inputTensors,
+ const OutputTensors &outputTensors,
+ const QosExecPriority priority,
+ std::shared_ptr<IAsyncExecutionCallback> cb);
+
+ void TerminateThreadPool() noexcept;
+
+private:
+ using ExecutionTuple = std::tuple<NetworkId,
+ InputTensors,
+ OutputTensors,
+ std::shared_ptr<IAsyncExecutionCallback>>;
+
+ using ExecutionQueue = std::queue<std::shared_ptr<ExecutionTuple>>;
+
+ void ProcessExecPriorities(uint32_t index);
+
+ IRuntime* m_RuntimePtr;
+
+ ExecutionQueue m_HighPriorityQueue;
+ ExecutionQueue m_MediumPriorityQueue;
+ ExecutionQueue m_LowPriorityQueue;
+
+ // Condition Variables require mutex which will guard the shared state.
+ // Has an event happened? Stop signal for example
+ std::condition_variable m_ThreadPoolEvent;
+ std::mutex m_ThreadPoolMutex;
+
+ // The shared state for conditional variable
+ bool m_TerminatePool = false;
+
+ std::unordered_map<NetworkId, std::vector<std::shared_ptr<IWorkingMemHandle>>> m_WorkingMemHandleMap;
+ std::vector<std::unique_ptr<std::thread>> m_Threads;
+};
+
+} // namespace experimental
+
+} // namespace armnn
diff --git a/src/armnn/AsyncExecutionCallback.cpp b/src/armnn/AsyncExecutionCallback.cpp
index c44808918d..2973e2d891 100644
--- a/src/armnn/AsyncExecutionCallback.cpp
+++ b/src/armnn/AsyncExecutionCallback.cpp
@@ -15,43 +15,53 @@ void AsyncExecutionCallback::Notify(armnn::Status status, InferenceTimingPair ti
{
{
std::lock_guard<std::mutex> hold(m_Mutex);
- if (m_Notified)
- {
- return;
- }
// store results and mark as notified
m_Status = status;
m_StartTime = timeTaken.first;
m_EndTime = timeTaken.second;
- m_Notified = true;
+ m_NotificationQueue.push(m_InferenceId);
}
m_Condition.notify_all();
}
-void AsyncExecutionCallback::Wait() const
-{
- std::unique_lock<std::mutex> lock(m_Mutex);
- m_Condition.wait(lock, [this] { return m_Notified; });
-}
-
armnn::Status AsyncExecutionCallback::GetStatus() const
{
- Wait();
return m_Status;
}
HighResolutionClock AsyncExecutionCallback::GetStartTime() const
{
- Wait();
return m_StartTime;
}
HighResolutionClock AsyncExecutionCallback::GetEndTime() const
{
- Wait();
return m_EndTime;
}
+std::shared_ptr<AsyncExecutionCallback> AsyncCallbackManager::GetNewCallback()
+{
+ auto cb = std::make_unique<AsyncExecutionCallback>(m_NotificationQueue, m_Mutex, m_Condition);
+ InferenceId id = cb->GetInferenceId();
+ m_Callbacks.insert({id, std::move(cb)});
+
+ return m_Callbacks.at(id);
+}
+
+std::shared_ptr<AsyncExecutionCallback> AsyncCallbackManager::GetNotifiedCallback()
+{
+ std::unique_lock<std::mutex> lock(m_Mutex);
+
+ m_Condition.wait(lock, [this] { return !m_NotificationQueue.empty(); });
+
+ InferenceId id = m_NotificationQueue.front();
+ m_NotificationQueue.pop();
+
+ std::shared_ptr<AsyncExecutionCallback> callback = m_Callbacks.at(id);
+ m_Callbacks.erase(id);
+ return callback;
+}
+
} // namespace experimental
} // namespace armnn \ No newline at end of file
diff --git a/src/armnn/AsyncExecutionCallback.hpp b/src/armnn/AsyncExecutionCallback.hpp
index c17b839748..2ff73b3efb 100644
--- a/src/armnn/AsyncExecutionCallback.hpp
+++ b/src/armnn/AsyncExecutionCallback.hpp
@@ -6,11 +6,14 @@
#pragma once
#include <armnn/IAsyncExecutionCallback.hpp>
+#include <armnn/IWorkingMemHandle.hpp>
#include <armnn/Types.hpp>
-#include <condition_variable>
+#include <condition_variable>
#include <mutex>
#include <thread>
+#include <queue>
+#include <unordered_map>
namespace armnn
{
@@ -18,29 +21,62 @@ namespace armnn
namespace experimental
{
+using InferenceId = uint64_t;
class AsyncExecutionCallback final : public IAsyncExecutionCallback
{
+private:
+ static InferenceId nextID;
+
public:
- AsyncExecutionCallback()
+ AsyncExecutionCallback(std::queue<InferenceId>& notificationQueue,
+ std::mutex& mutex,
+ std::condition_variable& condition)
+ : m_NotificationQueue(notificationQueue)
+ , m_Mutex(mutex)
+ , m_Condition(condition)
+ , m_InferenceId(++nextID)
{}
+
~AsyncExecutionCallback()
{}
void Notify(armnn::Status status, InferenceTimingPair timeTaken);
- void Wait() const;
+
+ InferenceId GetInferenceId()
+ {
+ return m_InferenceId;
+ }
armnn::Status GetStatus() const;
HighResolutionClock GetStartTime() const;
HighResolutionClock GetEndTime() const;
private:
- mutable std::mutex m_Mutex;
- mutable std::condition_variable m_Condition;
+ std::queue<InferenceId>& m_NotificationQueue;
+ std::mutex& m_Mutex;
+ std::condition_variable& m_Condition;
HighResolutionClock m_StartTime;
HighResolutionClock m_EndTime;
- armnn::Status m_Status = Status::Failure;
- bool m_Notified = false;
+ armnn::Status m_Status = Status::Failure;
+ InferenceId m_InferenceId;
+};
+InferenceId AsyncExecutionCallback::nextID = 0u;
+
+// Manager to create and monitor AsyncExecutionCallbacks
+// GetNewCallback will create a callback for use in Threadpool::Schedule
+// GetNotifiedCallback will return the first callback to be notified (finished execution)
+class AsyncCallbackManager
+{
+public:
+ std::shared_ptr<AsyncExecutionCallback> GetNewCallback();
+ std::shared_ptr<AsyncExecutionCallback> GetNotifiedCallback();
+
+private:
+ std::mutex m_Mutex;
+ std::condition_variable m_Condition;
+ std::unordered_map<InferenceId, std::shared_ptr<AsyncExecutionCallback>> m_Callbacks;
+ std::queue<InferenceId> m_NotificationQueue;
};
} // namespace experimental
diff --git a/src/armnn/LoadedNetwork.cpp b/src/armnn/LoadedNetwork.cpp
index 17cc8dcc23..13beb13a07 100644
--- a/src/armnn/LoadedNetwork.cpp
+++ b/src/armnn/LoadedNetwork.cpp
@@ -6,7 +6,6 @@
#include "LoadedNetwork.hpp"
#include "Layer.hpp"
#include "Graph.hpp"
-#include "Network.hpp"
#include <Processes.hpp>
#include "Profiling.hpp"
#include "HeapProfiling.hpp"
@@ -22,7 +21,6 @@
#include <backendsCommon/MemSyncWorkload.hpp>
#include <fmt/format.h>
-#include <armnn/utility/Timer.hpp>
namespace armnn
{
@@ -83,8 +81,7 @@ void AddWorkloadStructure(std::unique_ptr<TimelineUtilityMethods>& timelineUtils
std::unique_ptr<LoadedNetwork> LoadedNetwork::MakeLoadedNetwork(std::unique_ptr<IOptimizedNetwork> net,
std::string& errorMessage,
const INetworkProperties& networkProperties,
- profiling::ProfilingService& profilingService,
- const NetworkId networkIdOut)
+ profiling::ProfilingService& profilingService)
{
std::unique_ptr<LoadedNetwork> loadedNetwork;
@@ -98,7 +95,7 @@ std::unique_ptr<LoadedNetwork> LoadedNetwork::MakeLoadedNetwork(std::unique_ptr<
try
{
- loadedNetwork.reset(new LoadedNetwork(std::move(net), networkProperties, profilingService, networkIdOut));
+ loadedNetwork.reset(new LoadedNetwork(std::move(net), networkProperties, profilingService));
}
catch (const armnn::RuntimeException& error)
{
@@ -118,11 +115,9 @@ std::unique_ptr<LoadedNetwork> LoadedNetwork::MakeLoadedNetwork(std::unique_ptr<
LoadedNetwork::LoadedNetwork(std::unique_ptr<IOptimizedNetwork> net,
const INetworkProperties& networkProperties,
- profiling::ProfilingService& profilingService,
- const NetworkId networkId) :
+ profiling::ProfilingService& profilingService) :
m_OptimizedNetwork(std::move(net)),
m_NetworkProperties(networkProperties),
- m_NetworkId(networkId),
m_TensorHandleFactoryRegistry(),
m_ProfilingService(profilingService)
{
@@ -304,13 +299,6 @@ LoadedNetwork::LoadedNetwork(std::unique_ptr<IOptimizedNetwork> net,
{
AllocateAndExecuteConstantWorkloads();
}
-
- // Create the thread pool which will have working memory handles assigned to each thread
- // Should occur last so all factories and constant layer tensor handles are created
- if (m_NetworkProperties.m_NumThreads > 0 && networkProperties.m_AsyncEnabled)
- {
- CreateThreadPool(m_NetworkProperties.m_NumThreads);
- }
}
void LoadedNetwork::AllocateAndExecuteConstantWorkloads()
@@ -856,147 +844,6 @@ bool LoadedNetwork::Execute(std::unique_ptr<TimelineUtilityMethods>& timelineUti
return success;
}
-void LoadedNetwork::CreateThreadPool(std::size_t numThreads)
-{
-
- for (auto i = 0u; i < numThreads; ++i)
- {
- std::unique_ptr<IWorkingMemHandle> workingMemHandle = CreateWorkingMemHandle(m_NetworkId);
- m_Threads.emplace_back(
- std::make_unique<std::thread>(
- &LoadedNetwork::ProcessExecPriorities,
- this,
- std::move(workingMemHandle)
- )
- );
- }
-}
-
-void LoadedNetwork::TerminateThreadPool() noexcept
-{
- {
- std::unique_lock<std::mutex> threadPoolLock(m_ThreadPoolMutex);
- m_TerminatePool = true;
- }
-
- m_ThreadPoolEvent.notify_all();
-
- for (auto &thread : m_Threads)
- {
- thread->join();
- }
-}
-
-void LoadedNetwork::Schedule(const InputTensors& inputTensors,
- const OutputTensors& outputTensors,
- const QosExecPriority priority,
- std::shared_ptr<IAsyncExecutionCallback> cb)
-{
- // Group execution parameters so that they can be easily added to the queue
- ExecutionTuple groupExecParams = std::make_tuple(inputTensors, outputTensors, cb);
- std::shared_ptr<ExecutionTuple> operation = make_shared<ExecutionTuple>(groupExecParams);
-
- // Add a message to the queue and notify the request thread
- std::unique_lock<std::mutex> lock(m_ThreadPoolMutex);
- switch (priority) {
- case QosExecPriority::High:
- m_HighPriorityQueue.push(operation);
- break;
- case QosExecPriority::Low:
- m_LowPriorityQueue.push(operation);
- break;
- case QosExecPriority::Medium:
- default:
- m_MediumPriorityQueue.push(operation);
- }
- m_ThreadPoolEvent.notify_one();
-}
-
-void LoadedNetwork::ProcessExecPriorities(std::unique_ptr<IWorkingMemHandle> workingMemHandle)
-{
- int expireRate = EXPIRE_RATE;
- int highPriorityCount = 0;
- int mediumPriorityCount = 0;
-
- IWorkingMemHandle& workingMemHandleRef = *workingMemHandle.get();
-
- while (true)
- {
- std::shared_ptr<ExecutionTuple> currentExecInProgress(nullptr);
- {
- // Wait for a message to be added to the queue
- // This is in a separate scope to minimise the lifetime of the lock
- std::unique_lock<std::mutex> lock(m_ThreadPoolMutex);
-
- m_ThreadPoolEvent.wait(lock,
- [=] {
- return m_TerminatePool || !m_HighPriorityQueue.empty() ||
- !m_MediumPriorityQueue.empty() || !m_LowPriorityQueue.empty();
- });
-
- if (m_TerminatePool && m_HighPriorityQueue.empty() && m_MediumPriorityQueue.empty() &&
- m_LowPriorityQueue.empty())
- {
- break;
- }
-
- // Get the message to process from the front of each queue based on priority from high to low
- // Get high priority first if it does not exceed the expire rate
- if (!m_HighPriorityQueue.empty() && highPriorityCount < expireRate)
- {
- currentExecInProgress = m_HighPriorityQueue.front();
- m_HighPriorityQueue.pop();
- highPriorityCount += 1;
- }
- // If high priority queue is empty or the count exceeds the expire rate, get medium priority message
- else if (!m_MediumPriorityQueue.empty() && mediumPriorityCount < expireRate)
- {
- currentExecInProgress = m_MediumPriorityQueue.front();
- m_MediumPriorityQueue.pop();
- mediumPriorityCount += 1;
- // Reset high priority count
- highPriorityCount = 0;
- }
- // If medium priority queue is empty or the count exceeds the expire rate, get low priority message
- else if (!m_LowPriorityQueue.empty())
- {
- currentExecInProgress = m_LowPriorityQueue.front();
- m_LowPriorityQueue.pop();
- // Reset high and medium priority count
- highPriorityCount = 0;
- mediumPriorityCount = 0;
- }
- else
- {
- // Reset high and medium priority count
- highPriorityCount = 0;
- mediumPriorityCount = 0;
- continue;
- }
- }
-
- // invoke the asynchronous execution method
- auto inputTensors = std::get<0>(*currentExecInProgress);
- auto outputTensors = std::get<1>(*currentExecInProgress);
- auto cb = std::get<2>(*currentExecInProgress);
-
- // Get time at start of inference
- HighResolutionClock startTime = armnn::GetTimeNow();
-
- try // executing the inference
- {
- // Execute and populate the time at end of inference in the callback
- Execute(inputTensors, outputTensors, workingMemHandleRef) == Status::Success ?
- cb->Notify(Status::Success, std::make_pair(startTime, armnn::GetTimeNow())) :
- cb->Notify(Status::Failure, std::make_pair(startTime, armnn::GetTimeNow()));
- }
- catch (const RuntimeException& error)
- {
- cb->Notify(Status::Failure, std::make_pair(startTime, armnn::GetTimeNow()));
- }
- }
-}
-
void LoadedNetwork::EnqueueInput(const BindableLayer& layer,
const ConstTensor& inputTensor,
WorkingMemHandle& context)
diff --git a/src/armnn/LoadedNetwork.hpp b/src/armnn/LoadedNetwork.hpp
index c85e82bbdd..360ad91170 100644
--- a/src/armnn/LoadedNetwork.hpp
+++ b/src/armnn/LoadedNetwork.hpp
@@ -37,16 +37,9 @@ class LoadedNetwork
public:
using WorkloadQueue = std::vector<std::unique_ptr<IWorkload>>;
- using ExecutionTuple = std::tuple<InputTensors,
- OutputTensors,
- std::shared_ptr<IAsyncExecutionCallback>>;
-
- using ExecutionQueue = std::queue<std::shared_ptr<ExecutionTuple>>;
-
~LoadedNetwork()
{
FreeWorkingMemory();
- TerminateThreadPool();
}
/// Create a new unique WorkingMemHandle object. Create multiple handles if you wish to have
@@ -64,17 +57,10 @@ public:
const OutputTensors& outputTensors,
IWorkingMemHandle& workingMemHandle);
- /// Schedule an asynchronous execution on the loaded network
- void Schedule(const InputTensors& inputTensors,
- const OutputTensors& outputTensors,
- const QosExecPriority priority,
- std::shared_ptr<IAsyncExecutionCallback> cb);
-
static std::unique_ptr<LoadedNetwork> MakeLoadedNetwork(std::unique_ptr<IOptimizedNetwork> net,
std::string& errorMessage,
const INetworkProperties& networkProperties,
- profiling::ProfilingService& profilingService,
- const NetworkId networkIdOut);
+ profiling::ProfilingService& profilingService);
// NOTE we return by reference as the purpose of this method is only to provide
// access to the private m_Profiler and in theory we should not need to increment
@@ -108,8 +94,7 @@ private:
LoadedNetwork(std::unique_ptr<IOptimizedNetwork> net,
const INetworkProperties& networkProperties,
- profiling::ProfilingService& profilingService,
- const NetworkId networkIdOut);
+ profiling::ProfilingService& profilingService);
void EnqueueInput(const BindableLayer& layer, ITensorHandle* tensorHandle, const TensorInfo& tensorInfo);
@@ -119,15 +104,9 @@ private:
void EnqueueOutput(const BindableLayer& layer, const Tensor& outputTensor, WorkingMemHandle& handle);
- void ProcessExecPriorities(std::unique_ptr<IWorkingMemHandle> workingMemHandle);
-
bool Execute(std::unique_ptr<profiling::TimelineUtilityMethods>& timelineUtils,
profiling::ProfilingGuid inferenceGuid);
- void CreateThreadPool(std::size_t numThreads);
-
- void TerminateThreadPool() noexcept;
-
const IWorkloadFactory& GetWorkloadFactory(const Layer& layer) const;
using BackendPtrMap = std::unordered_map<BackendId, IBackendInternalUniquePtr>;
@@ -146,25 +125,8 @@ private:
bool m_IsWorkingMemAllocated = false;
- std::vector<std::unique_ptr<std::thread>> m_Threads;
- std::stack<IWorkingMemHandle> m_WorkingMemHandles;
-
- ExecutionQueue m_HighPriorityQueue;
- ExecutionQueue m_MediumPriorityQueue;
- ExecutionQueue m_LowPriorityQueue;
-
- // Condition Variables require mutex which will guard the shared state.
- // Has an event happened? Stop signal for example
- std::condition_variable m_ThreadPoolEvent;
- std::mutex m_ThreadPoolMutex;
-
- // The shared state for conditional variable
- bool m_TerminatePool = false;
-
INetworkProperties m_NetworkProperties;
- const NetworkId m_NetworkId;
-
TensorHandleFactoryRegistry m_TensorHandleFactoryRegistry;
profiling::ProfilingService& m_ProfilingService;
diff --git a/src/armnn/Runtime.cpp b/src/armnn/Runtime.cpp
index 374064e408..f16d186191 100644
--- a/src/armnn/Runtime.cpp
+++ b/src/armnn/Runtime.cpp
@@ -89,15 +89,6 @@ Status IRuntime::Execute(IWorkingMemHandle& workingMemHandle,
return pRuntimeImpl->Execute(workingMemHandle, inputTensors, outputTensors);
}
-void IRuntime::Schedule(NetworkId networkId,
- const InputTensors& inputTensors,
- const OutputTensors& outputTensors,
- const QosExecPriority priority,
- std::shared_ptr<IAsyncExecutionCallback> cb)
-{
- pRuntimeImpl->Schedule(networkId, inputTensors, outputTensors, priority, cb);
-}
-
Status IRuntime::UnloadNetwork(NetworkId networkId)
{
return pRuntimeImpl->UnloadNetwork(networkId);
@@ -160,8 +151,7 @@ Status RuntimeImpl::LoadNetwork(NetworkId& networkIdOut,
std::unique_ptr<IOptimizedNetwork>(rawNetwork),
errorMessage,
networkProperties,
- m_ProfilingService,
- networkIdOut);
+ m_ProfilingService);
if (!loadedNetwork)
{
@@ -460,34 +450,6 @@ Status RuntimeImpl::Execute(IWorkingMemHandle& iWorkingMemHandle,
return loadedNetwork->Execute(inputTensors, outputTensors, iWorkingMemHandle);
}
-void RuntimeImpl::Schedule(NetworkId networkId,
- const InputTensors& inputTensors,
- const OutputTensors& outputTensors,
- const QosExecPriority priority,
- std::shared_ptr<IAsyncExecutionCallback> callback)
-{
- LoadedNetwork* loadedNetwork = GetLoadedNetworkPtr(networkId);
-
- if (!loadedNetwork)
- {
- throw armnn::Exception(
- "Network with ID of " + std::to_string(networkId) + " does not exist \n"
- );
- }
- if (!loadedNetwork->IsAsyncEnabled())
- {
- throw armnn::Exception(
- "Attempting to schedule Network " + std::to_string(networkId) + " when it is not async enabled \n"
- );
- }
-
- ProfilerManager::GetInstance().RegisterProfiler(loadedNetwork->GetProfiler().get());
-
- ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "Schedule");
-
- loadedNetwork->Schedule(inputTensors, outputTensors, priority, callback);
-}
-
/// Create a new unique WorkingMemHandle object. Create multiple handles if you wish to have
/// overlapped Execution by calling this function from different threads.
std::unique_ptr<IWorkingMemHandle> RuntimeImpl::CreateWorkingMemHandle(NetworkId networkId)
diff --git a/src/armnn/Runtime.hpp b/src/armnn/Runtime.hpp
index 55a4accf67..7a80acd73e 100644
--- a/src/armnn/Runtime.hpp
+++ b/src/armnn/Runtime.hpp
@@ -61,16 +61,6 @@ public:
const OutputTensors& outputTensors);
/// This is an experimental function.
- /// Schedule a thread safe execution by taking the input tensors and an execution priority for Quality of Service.
- /// The output tensors will then be filled and the callback object will notify that the execution has either
- /// succeeded or failed.
- void Schedule(NetworkId networkId,
- const InputTensors& inputTensors,
- const OutputTensors& outputTensors,
- const QosExecPriority priority,
- std::shared_ptr<IAsyncExecutionCallback> callback);
-
- /// This is an experimental function.
/// Evaluates a network using input in inputTensors and outputs filled into outputTensors.
/// This function performs a thread safe execution of the network. Returns once execution is complete.
/// Will block until this and any other thread using the same workingMem object completes.
diff --git a/src/armnn/Threadpool.cpp b/src/armnn/Threadpool.cpp
new file mode 100644
index 0000000000..a23c1e2339
--- /dev/null
+++ b/src/armnn/Threadpool.cpp
@@ -0,0 +1,206 @@
+//
+// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+
+#include <armnn/Threadpool.hpp>
+
+#include <armnn/utility/Timer.hpp>
+
+namespace armnn
+{
+namespace experimental
+{
+
+Threadpool::Threadpool(std::size_t numThreads,
+ IRuntime* runtimePtr,
+ std::vector<std::shared_ptr<IWorkingMemHandle>> memHandles)
+ : m_RuntimePtr(runtimePtr)
+{
+ for (auto i = 0u; i < numThreads; ++i)
+ {
+ m_Threads.emplace_back(std::make_unique<std::thread>(&Threadpool::ProcessExecPriorities, this, i));
+ }
+
+ LoadMemHandles(memHandles);
+}
+
+void Threadpool::LoadMemHandles(std::vector<std::shared_ptr<IWorkingMemHandle>> memHandles)
+{
+ if (memHandles.size() == 0)
+ {
+ throw armnn::RuntimeException("Threadpool::UnloadMemHandles: Size of memHandles vector must be greater than 0");
+ }
+
+ if (memHandles.size() != m_Threads.size())
+ {
+ throw armnn::RuntimeException(
+ "Threadpool::UnloadMemHandles: Size of memHandles vector must match the number of threads");
+ }
+
+ NetworkId networkId = memHandles[0]->GetNetworkId();
+ for (uint32_t i = 1; i < memHandles.size(); ++i)
+ {
+ if (networkId != memHandles[i]->GetNetworkId())
+ {
+ throw armnn::RuntimeException(
+ "Threadpool::UnloadMemHandles: All network ids must be identical in memHandles");
+ }
+ }
+
+ std::pair<NetworkId, std::vector<std::shared_ptr<IWorkingMemHandle>>> pair {networkId, memHandles};
+
+ m_WorkingMemHandleMap.insert(pair);
+}
+
+void Threadpool::UnloadMemHandles(NetworkId networkId)
+{
+ if (m_WorkingMemHandleMap.find(networkId) != m_WorkingMemHandleMap.end())
+ {
+ m_WorkingMemHandleMap.erase(networkId);
+ }
+ else
+ {
+ throw armnn::RuntimeException("Threadpool::UnloadMemHandles: Unknown NetworkId");
+ }
+}
+
+void Threadpool::Schedule(NetworkId networkId,
+ const InputTensors& inputTensors,
+ const OutputTensors& outputTensors,
+ const QosExecPriority priority,
+ std::shared_ptr<IAsyncExecutionCallback> cb)
+{
+ if (m_WorkingMemHandleMap.find(networkId) == m_WorkingMemHandleMap.end())
+ {
+ throw armnn::RuntimeException("Threadpool::UnloadMemHandles: Unknown NetworkId");
+ }
+
+ // Group execution parameters so that they can be easily added to the queue
+ ExecutionTuple groupExecParams = std::make_tuple(networkId, inputTensors, outputTensors, cb);
+
+ std::shared_ptr<ExecutionTuple> operation = std::make_shared<ExecutionTuple>(groupExecParams);
+
+ // Add a message to the queue and notify the request thread
+ std::unique_lock<std::mutex> lock(m_ThreadPoolMutex);
+ switch (priority)
+ {
+ case QosExecPriority::High:
+ m_HighPriorityQueue.push(operation);
+ break;
+ case QosExecPriority::Low:
+ m_LowPriorityQueue.push(operation);
+ break;
+ case QosExecPriority::Medium:
+ default:
+ m_MediumPriorityQueue.push(operation);
+ }
+ m_ThreadPoolEvent.notify_one();
+}
+
+void Threadpool::TerminateThreadPool() noexcept
+{
+ {
+ std::unique_lock<std::mutex> threadPoolLock(m_ThreadPoolMutex);
+ m_TerminatePool = true;
+ }
+
+ m_ThreadPoolEvent.notify_all();
+
+ for (auto &thread : m_Threads)
+ {
+ thread->join();
+ }
+}
+
+void Threadpool::ProcessExecPriorities(uint32_t index)
+{
+ int expireRate = EXPIRE_RATE;
+ int highPriorityCount = 0;
+ int mediumPriorityCount = 0;
+
+ while (true)
+ {
+ std::shared_ptr<ExecutionTuple> currentExecInProgress(nullptr);
+ {
+ // Wait for a message to be added to the queue
+ // This is in a separate scope to minimise the lifetime of the lock
+ std::unique_lock<std::mutex> lock(m_ThreadPoolMutex);
+
+ m_ThreadPoolEvent.wait(lock,
+ [=]
+ {
+ return m_TerminatePool || !m_HighPriorityQueue.empty() ||
+ !m_MediumPriorityQueue.empty() || !m_LowPriorityQueue.empty();
+ });
+
+ if (m_TerminatePool && m_HighPriorityQueue.empty() && m_MediumPriorityQueue.empty() &&
+ m_LowPriorityQueue.empty())
+ {
+ break;
+ }
+
+ // Get the message to process from the front of each queue based on priority from high to low
+ // Get high priority first if it does not exceed the expire rate
+ if (!m_HighPriorityQueue.empty() && highPriorityCount < expireRate)
+ {
+ currentExecInProgress = m_HighPriorityQueue.front();
+ m_HighPriorityQueue.pop();
+ highPriorityCount += 1;
+ }
+ // If high priority queue is empty or the count exceeds the expire rate, get medium priority message
+ else if (!m_MediumPriorityQueue.empty() && mediumPriorityCount < expireRate)
+ {
+ currentExecInProgress = m_MediumPriorityQueue.front();
+ m_MediumPriorityQueue.pop();
+ mediumPriorityCount += 1;
+ // Reset high priority count
+ highPriorityCount = 0;
+ }
+ // If medium priority queue is empty or the count exceeds the expire rate, get low priority message
+ else if (!m_LowPriorityQueue.empty())
+ {
+ currentExecInProgress = m_LowPriorityQueue.front();
+ m_LowPriorityQueue.pop();
+ // Reset high and medium priority count
+ highPriorityCount = 0;
+ mediumPriorityCount = 0;
+ }
+ else
+ {
+ // Reset high and medium priority count
+ highPriorityCount = 0;
+ mediumPriorityCount = 0;
+ continue;
+ }
+ }
+
+ // invoke the asynchronous execution method
+ auto networkId = std::get<0>(*currentExecInProgress);
+ auto inputTensors = std::get<1>(*currentExecInProgress);
+ auto outputTensors = std::get<2>(*currentExecInProgress);
+ auto cb = std::get<3>(*currentExecInProgress);
+
+ // Get time at start of inference
+ HighResolutionClock startTime = armnn::GetTimeNow();
+
+ try // executing the inference
+ {
+ IWorkingMemHandle& memHandle = *(m_WorkingMemHandleMap.at(networkId))[index];
+
+ // Execute and populate the time at end of inference in the callback
+ m_RuntimePtr->Execute(memHandle, inputTensors, outputTensors) == Status::Success ?
+ cb->Notify(Status::Success, std::make_pair(startTime, armnn::GetTimeNow())) :
+ cb->Notify(Status::Failure, std::make_pair(startTime, armnn::GetTimeNow()));
+ }
+ catch (const RuntimeException &error)
+ {
+ cb->Notify(Status::Failure, std::make_pair(startTime, armnn::GetTimeNow()));
+ }
+ }
+}
+
+} // namespace experimental
+
+} // namespace armnn
diff --git a/src/armnn/WorkingMemHandle.cpp b/src/armnn/WorkingMemHandle.cpp
index 94d796eced..1dcaa3853a 100644
--- a/src/armnn/WorkingMemHandle.cpp
+++ b/src/armnn/WorkingMemHandle.cpp
@@ -26,8 +26,7 @@ WorkingMemHandle::WorkingMemHandle(
m_MemoryManagers(memoryManagers),
m_OwnedTensorHandles(std::move(ownedTensorHandles)),
m_IsAllocated(false),
- m_Mutex(),
- m_InferenceId(profiling::ProfilingService::GetNextGuid())
+ m_Mutex()
{
}
diff --git a/src/armnn/WorkingMemHandle.hpp b/src/armnn/WorkingMemHandle.hpp
index 5ccb2b2342..5e3fd66299 100644
--- a/src/armnn/WorkingMemHandle.hpp
+++ b/src/armnn/WorkingMemHandle.hpp
@@ -13,6 +13,7 @@
#include <armnn/Tensor.hpp>
#include <unordered_map>
+#include <mutex>
namespace armnn
{
@@ -38,10 +39,7 @@ public:
return m_NetworkId;
}
- profiling::ProfilingGuid GetInferenceId() override
- {
- return m_InferenceId;
- }
+
/// Allocate the backing memory required for execution. If this is not called, then allocation will be
/// deferred to execution time. The mutex must be locked.
@@ -92,7 +90,6 @@ private:
bool m_IsAllocated;
std::mutex m_Mutex;
- profiling::ProfilingGuid m_InferenceId;
};
} // end experimental namespace
diff --git a/src/backends/backendsCommon/test/StridedSliceAsyncEndToEndTest.hpp b/src/backends/backendsCommon/test/StridedSliceAsyncEndToEndTest.hpp
index a552a6a7da..764983f3b9 100644
--- a/src/backends/backendsCommon/test/StridedSliceAsyncEndToEndTest.hpp
+++ b/src/backends/backendsCommon/test/StridedSliceAsyncEndToEndTest.hpp
@@ -9,6 +9,7 @@
#include <armnn/IWorkingMemHandle.hpp>
#include <armnn/INetwork.hpp>
+#include <armnn/Threadpool.hpp>
#include <armnn/IAsyncExecutionCallback.hpp>
#include <AsyncExecutionCallback.hpp>
@@ -137,7 +138,7 @@ void AsyncEndToEndTestImpl(INetworkPtr network,
std::string errorMessage;
- const INetworkProperties networkProperties(true, MemorySource::Undefined, MemorySource::Undefined, numThreads);
+ const INetworkProperties networkProperties(true, MemorySource::Undefined, MemorySource::Undefined);
runtime->LoadNetwork(networkId, std::move(optNet), errorMessage, networkProperties);
@@ -172,30 +173,32 @@ void AsyncEndToEndTestImpl(INetworkPtr network,
}
else
{
- std::vector<IAsyncExecutionCallbackPtr> callbacks;
+ std::vector<std::shared_ptr<IWorkingMemHandle>> memHandles;
- // Create 1000 callbacks that will be checked post scheduling
- for (size_t i = 0; i < 1000; ++i)
+ for (size_t i = 0; i < numThreads; ++i)
{
- callbacks.emplace_back(std::make_shared<AsyncExecutionCallback>());
+ memHandles.emplace_back(runtime->CreateWorkingMemHandle(networkId));
}
+ Threadpool threadpool(numThreads, runtime.get(), memHandles);
+ AsyncCallbackManager callbackManager;
+
// For the asyncronous execution, we are adding a pool of working memory handles (1 per thread) in the
// LoadedNetwork with a each scheduled inference having a spefic priority
- for (IAsyncExecutionCallbackPtr cb : callbacks)
+ for (size_t i = 0; i < 1000; ++i)
{
- runtime->Schedule(networkId,
- inputTensors,
- outputTensors,
- static_cast<QosExecPriority>(rand()%3),
- cb);
+ threadpool.Schedule(networkId,
+ inputTensors,
+ outputTensors,
+ static_cast<QosExecPriority>(rand()%3),
+ callbackManager.GetNewCallback());
}
// Wait until the execution signals a notify
- for (IAsyncExecutionCallbackPtr cb : callbacks)
+ for (size_t i = 0; i < 1000; ++i)
{
- cb->Wait();
-
+ auto cb = callbackManager.GetNotifiedCallback();
+
// Checks the results.
CHECK(cb->GetStatus() == Status::Success);
}
diff --git a/tests/ExecuteNetwork/ExecuteNetwork.cpp b/tests/ExecuteNetwork/ExecuteNetwork.cpp
index e8d5b1860c..48577c9990 100644
--- a/tests/ExecuteNetwork/ExecuteNetwork.cpp
+++ b/tests/ExecuteNetwork/ExecuteNetwork.cpp
@@ -445,14 +445,8 @@ int MainImpl(const ExecuteNetworkParams& params,
try
{
ARMNN_LOG(info) << "Asynchronous execution with Arm NN thread pool... \n";
- std::vector<armnn::experimental::IAsyncExecutionCallbackPtr> callbacks;
-
- // Create callbacks that will be checked post scheduling
- for (size_t i = 0; i < params.m_SimultaneousIterations; ++i)
- {
- // Point to ArmNN example implementation of AsyncExecutionCallback
- callbacks.emplace_back(std::make_shared<armnn::experimental::AsyncExecutionCallback>());
- }
+ armnn::AsyncCallbackManager callbackManager;
+ std::unordered_map<armnn::InferenceId, std::vector<TContainer>&> inferenceOutputMap;
// Declare the latest and earliest inference times here to be used when calculating overall time
std::chrono::high_resolution_clock::time_point earliestStartTime;
@@ -461,15 +455,19 @@ int MainImpl(const ExecuteNetworkParams& params,
// For the asynchronous execution, we are adding a pool of working memory handles (1 per thread) in the
// LoadedNetwork with each scheduled inference having a specific priority
- for (size_t i = 0; i < callbacks.size(); ++i)
+ for (size_t i = 0; i < params.m_SimultaneousIterations; ++i)
{
- model.RunAsync(inputs[i], outputs[i], callbacks[i]);
+ std::shared_ptr<armnn::AsyncExecutionCallback> cb = callbackManager.GetNewCallback();
+ inferenceOutputMap.insert({cb->GetInferenceId(), outputs[i]});
+ model.RunAsync(inputs[i], outputs[i], cb);
}
// Check the results
unsigned int j = 0;
- for (armnn::experimental::IAsyncExecutionCallbackPtr cb : callbacks)
+ for (size_t iteration = 0; iteration < params.m_SimultaneousIterations; ++iteration)
{
+ auto cb = callbackManager.GetNotifiedCallback();
+
// Get the results
auto endTime = time_point_cast<std::chrono::milliseconds>(cb->GetEndTime());
auto startTime = time_point_cast<std::chrono::milliseconds>(cb->GetStartTime());
@@ -507,7 +505,7 @@ int MainImpl(const ExecuteNetworkParams& params,
infoOut,
outputTensorFile,
params.m_DequantizeOutput);
- mapbox::util::apply_visitor(printer, outputs[j][i]);
+ mapbox::util::apply_visitor(printer, inferenceOutputMap.at(cb->GetInferenceId())[i]);
}
ARMNN_LOG(info) << "\nInference time: " << std::setprecision(2)
@@ -549,7 +547,7 @@ int MainImpl(const ExecuteNetworkParams& params,
try
{
ARMNN_LOG(info) << "Asynchronous Execution with std::launch:async... \n";
- std::vector<std::future<std::tuple<armnn::profiling::ProfilingGuid,
+ std::vector<std::future<std::tuple<unsigned int,
std::chrono::duration<double, std::milli>>>> inferenceResults;
inferenceResults.reserve(params.m_SimultaneousIterations);
@@ -567,9 +565,10 @@ int MainImpl(const ExecuteNetworkParams& params,
for (unsigned int i = 0; i < params.m_SimultaneousIterations; ++i)
{
armnn::experimental::IWorkingMemHandle& workingMemHandleRef = *workingMemHandles[i].get();
+
inferenceResults.push_back(std::async(
std::launch::async, [&model, &workingMemHandleRef, &inputs, &outputs, i]() {
- return model.RunAsync(workingMemHandleRef, inputs[i], outputs[i]);
+ return model.RunAsync(workingMemHandleRef, inputs[i], outputs[i], i);
}
));
}
diff --git a/tests/InferenceModel.hpp b/tests/InferenceModel.hpp
index 9d6096a3eb..3eb1e6a9e7 100644
--- a/tests/InferenceModel.hpp
+++ b/tests/InferenceModel.hpp
@@ -6,6 +6,7 @@
#pragma once
#include <armnn/ArmNN.hpp>
+#include <armnn/Threadpool.hpp>
#include <armnn/Logging.hpp>
#include <armnn/utility/Timer.hpp>
#include <armnn/BackendRegistry.hpp>
@@ -415,7 +416,7 @@ public:
armnn::IRuntime::CreationOptions options;
options.m_EnableGpuProfiling = m_EnableProfiling;
options.m_DynamicBackendsPath = m_DynamicBackendsPath;
- m_Runtime = std::move(armnn::IRuntime::Create(options));
+ m_Runtime = armnn::IRuntime::Create(options);
}
std::string invalidBackends;
@@ -484,13 +485,25 @@ public:
const auto loading_start_time = armnn::GetTimeNow();
armnn::INetworkProperties networkProperties(params.m_AsyncEnabled,
armnn::MemorySource::Undefined,
- armnn::MemorySource::Undefined,
- params.m_ThreadPoolSize);
+ armnn::MemorySource::Undefined);
std::string errorMessage;
ret = m_Runtime->LoadNetwork(m_NetworkIdentifier, std::move(optNet), errorMessage, networkProperties);
ARMNN_LOG(info) << "Network loading time: " << std::setprecision(2)
<< std::fixed << armnn::GetTimeDuration(loading_start_time).count() << " ms\n";
+
+ if (params.m_AsyncEnabled && params.m_ThreadPoolSize > 0)
+ {
+ std::vector<std::shared_ptr<armnn::IWorkingMemHandle>> memHandles;
+ for (size_t i = 0; i < params.m_ThreadPoolSize; ++i)
+ {
+ memHandles.emplace_back(m_Runtime->CreateWorkingMemHandle(m_NetworkIdentifier));
+ }
+
+ m_Threadpool = std::make_unique<armnn::Threadpool>(params.m_ThreadPoolSize,
+ m_Runtime.get(),
+ memHandles);
+ }
}
if (ret == armnn::Status::Failure)
@@ -579,10 +592,11 @@ public:
}
}
- std::tuple<armnn::profiling::ProfilingGuid, std::chrono::duration<double, std::milli>> RunAsync(
+ std::tuple<unsigned int, std::chrono::duration<double, std::milli>> RunAsync(
armnn::experimental::IWorkingMemHandle& workingMemHandleRef,
const std::vector<TContainer>& inputContainers,
- std::vector<TContainer>& outputContainers)
+ std::vector<TContainer>& outputContainers,
+ unsigned int inferenceID)
{
for (unsigned int i = 0; i < outputContainers.size(); ++i)
{
@@ -614,7 +628,6 @@ public:
armnn::Status ret = m_Runtime->Execute(workingMemHandleRef,
MakeInputTensors(inputContainers),
MakeOutputTensors(outputContainers));
- auto inferenceID = workingMemHandleRef.GetInferenceId();
const auto duration = armnn::GetTimeDuration(start_time);
@@ -638,7 +651,7 @@ public:
void RunAsync(const std::vector<TContainer>& inputContainers,
std::vector<TContainer>& outputContainers,
- armnn::experimental::IAsyncExecutionCallbackPtr cb)
+ std::shared_ptr<armnn::IAsyncExecutionCallback> cb)
{
for (unsigned int i = 0; i < outputContainers.size(); ++i)
{
@@ -664,11 +677,11 @@ public:
profiler->EnableProfiling(m_EnableProfiling);
}
- m_Runtime->Schedule(m_NetworkIdentifier,
- MakeInputTensors(inputContainers),
- MakeOutputTensors(outputContainers),
- armnn::QosExecPriority::Medium,
- cb);
+ m_Threadpool->Schedule(m_NetworkIdentifier,
+ MakeInputTensors(inputContainers),
+ MakeOutputTensors(outputContainers),
+ armnn::QosExecPriority::Medium,
+ cb);
// if profiling is enabled print out the results
if (profiler && profiler->IsProfilingEnabled())
@@ -731,6 +744,7 @@ public:
private:
armnn::NetworkId m_NetworkIdentifier;
std::shared_ptr<armnn::IRuntime> m_Runtime;
+ std::unique_ptr<armnn::Threadpool> m_Threadpool;
std::vector<armnn::BindingPointInfo> m_InputBindings;
std::vector<armnn::BindingPointInfo> m_OutputBindings;