From f364d5391b08e9071cd965f5765385ec9156b652 Mon Sep 17 00:00:00 2001 From: Finn Williams Date: Wed, 9 Jun 2021 17:07:33 +0100 Subject: IVGCVSW-6062 Rework the async threadpool !android-nn-driver:5802 * Extract the threadpool from LoadedNetwork/Runtime * Refactor the threadpool to be handle multiple networks * Trim IAsyncExecutionCallback and add an InferenceId to AsyncExecutionCallback * Add AsyncCallbackManager class Signed-off-by: Finn Williams Change-Id: I36aa2ad29c16bc10ee0706adfeb6b27f60012afb --- src/armnn/AsyncExecutionCallback.cpp | 38 ++++--- src/armnn/AsyncExecutionCallback.hpp | 50 +++++++-- src/armnn/LoadedNetwork.cpp | 159 +-------------------------- src/armnn/LoadedNetwork.hpp | 42 +------ src/armnn/Runtime.cpp | 40 +------ src/armnn/Runtime.hpp | 10 -- src/armnn/Threadpool.cpp | 206 +++++++++++++++++++++++++++++++++++ src/armnn/WorkingMemHandle.cpp | 3 +- src/armnn/WorkingMemHandle.hpp | 7 +- 9 files changed, 282 insertions(+), 273 deletions(-) create mode 100644 src/armnn/Threadpool.cpp (limited to 'src/armnn') diff --git a/src/armnn/AsyncExecutionCallback.cpp b/src/armnn/AsyncExecutionCallback.cpp index c44808918d..2973e2d891 100644 --- a/src/armnn/AsyncExecutionCallback.cpp +++ b/src/armnn/AsyncExecutionCallback.cpp @@ -15,43 +15,53 @@ void AsyncExecutionCallback::Notify(armnn::Status status, InferenceTimingPair ti { { std::lock_guard hold(m_Mutex); - if (m_Notified) - { - return; - } // store results and mark as notified m_Status = status; m_StartTime = timeTaken.first; m_EndTime = timeTaken.second; - m_Notified = true; + m_NotificationQueue.push(m_InferenceId); } m_Condition.notify_all(); } -void AsyncExecutionCallback::Wait() const -{ - std::unique_lock lock(m_Mutex); - m_Condition.wait(lock, [this] { return m_Notified; }); -} - armnn::Status AsyncExecutionCallback::GetStatus() const { - Wait(); return m_Status; } HighResolutionClock AsyncExecutionCallback::GetStartTime() const { - Wait(); return m_StartTime; } HighResolutionClock AsyncExecutionCallback::GetEndTime() const { - Wait(); return m_EndTime; } +std::shared_ptr AsyncCallbackManager::GetNewCallback() +{ + auto cb = std::make_unique(m_NotificationQueue, m_Mutex, m_Condition); + InferenceId id = cb->GetInferenceId(); + m_Callbacks.insert({id, std::move(cb)}); + + return m_Callbacks.at(id); +} + +std::shared_ptr AsyncCallbackManager::GetNotifiedCallback() +{ + std::unique_lock lock(m_Mutex); + + m_Condition.wait(lock, [this] { return !m_NotificationQueue.empty(); }); + + InferenceId id = m_NotificationQueue.front(); + m_NotificationQueue.pop(); + + std::shared_ptr callback = m_Callbacks.at(id); + m_Callbacks.erase(id); + return callback; +} + } // namespace experimental } // namespace armnn \ No newline at end of file diff --git a/src/armnn/AsyncExecutionCallback.hpp b/src/armnn/AsyncExecutionCallback.hpp index c17b839748..2ff73b3efb 100644 --- a/src/armnn/AsyncExecutionCallback.hpp +++ b/src/armnn/AsyncExecutionCallback.hpp @@ -6,11 +6,14 @@ #pragma once #include +#include #include -#include +#include #include #include +#include +#include namespace armnn { @@ -18,29 +21,62 @@ namespace armnn namespace experimental { +using InferenceId = uint64_t; class AsyncExecutionCallback final : public IAsyncExecutionCallback { +private: + static InferenceId nextID; + public: - AsyncExecutionCallback() + AsyncExecutionCallback(std::queue& notificationQueue, + std::mutex& mutex, + std::condition_variable& condition) + : m_NotificationQueue(notificationQueue) + , m_Mutex(mutex) + , m_Condition(condition) + , m_InferenceId(++nextID) {} + ~AsyncExecutionCallback() {} void Notify(armnn::Status status, InferenceTimingPair timeTaken); - void Wait() const; + + InferenceId GetInferenceId() + { + return m_InferenceId; + } armnn::Status GetStatus() const; HighResolutionClock GetStartTime() const; HighResolutionClock GetEndTime() const; private: - mutable std::mutex m_Mutex; - mutable std::condition_variable m_Condition; + std::queue& m_NotificationQueue; + std::mutex& m_Mutex; + std::condition_variable& m_Condition; HighResolutionClock m_StartTime; HighResolutionClock m_EndTime; - armnn::Status m_Status = Status::Failure; - bool m_Notified = false; + armnn::Status m_Status = Status::Failure; + InferenceId m_InferenceId; +}; +InferenceId AsyncExecutionCallback::nextID = 0u; + +// Manager to create and monitor AsyncExecutionCallbacks +// GetNewCallback will create a callback for use in Threadpool::Schedule +// GetNotifiedCallback will return the first callback to be notified (finished execution) +class AsyncCallbackManager +{ +public: + std::shared_ptr GetNewCallback(); + std::shared_ptr GetNotifiedCallback(); + +private: + std::mutex m_Mutex; + std::condition_variable m_Condition; + std::unordered_map> m_Callbacks; + std::queue m_NotificationQueue; }; } // namespace experimental diff --git a/src/armnn/LoadedNetwork.cpp b/src/armnn/LoadedNetwork.cpp index 17cc8dcc23..13beb13a07 100644 --- a/src/armnn/LoadedNetwork.cpp +++ b/src/armnn/LoadedNetwork.cpp @@ -6,7 +6,6 @@ #include "LoadedNetwork.hpp" #include "Layer.hpp" #include "Graph.hpp" -#include "Network.hpp" #include #include "Profiling.hpp" #include "HeapProfiling.hpp" @@ -22,7 +21,6 @@ #include #include -#include namespace armnn { @@ -83,8 +81,7 @@ void AddWorkloadStructure(std::unique_ptr& timelineUtils std::unique_ptr LoadedNetwork::MakeLoadedNetwork(std::unique_ptr net, std::string& errorMessage, const INetworkProperties& networkProperties, - profiling::ProfilingService& profilingService, - const NetworkId networkIdOut) + profiling::ProfilingService& profilingService) { std::unique_ptr loadedNetwork; @@ -98,7 +95,7 @@ std::unique_ptr LoadedNetwork::MakeLoadedNetwork(std::unique_ptr< try { - loadedNetwork.reset(new LoadedNetwork(std::move(net), networkProperties, profilingService, networkIdOut)); + loadedNetwork.reset(new LoadedNetwork(std::move(net), networkProperties, profilingService)); } catch (const armnn::RuntimeException& error) { @@ -118,11 +115,9 @@ std::unique_ptr LoadedNetwork::MakeLoadedNetwork(std::unique_ptr< LoadedNetwork::LoadedNetwork(std::unique_ptr net, const INetworkProperties& networkProperties, - profiling::ProfilingService& profilingService, - const NetworkId networkId) : + profiling::ProfilingService& profilingService) : m_OptimizedNetwork(std::move(net)), m_NetworkProperties(networkProperties), - m_NetworkId(networkId), m_TensorHandleFactoryRegistry(), m_ProfilingService(profilingService) { @@ -304,13 +299,6 @@ LoadedNetwork::LoadedNetwork(std::unique_ptr net, { AllocateAndExecuteConstantWorkloads(); } - - // Create the thread pool which will have working memory handles assigned to each thread - // Should occur last so all factories and constant layer tensor handles are created - if (m_NetworkProperties.m_NumThreads > 0 && networkProperties.m_AsyncEnabled) - { - CreateThreadPool(m_NetworkProperties.m_NumThreads); - } } void LoadedNetwork::AllocateAndExecuteConstantWorkloads() @@ -856,147 +844,6 @@ bool LoadedNetwork::Execute(std::unique_ptr& timelineUti return success; } -void LoadedNetwork::CreateThreadPool(std::size_t numThreads) -{ - - for (auto i = 0u; i < numThreads; ++i) - { - std::unique_ptr workingMemHandle = CreateWorkingMemHandle(m_NetworkId); - m_Threads.emplace_back( - std::make_unique( - &LoadedNetwork::ProcessExecPriorities, - this, - std::move(workingMemHandle) - ) - ); - } -} - -void LoadedNetwork::TerminateThreadPool() noexcept -{ - { - std::unique_lock threadPoolLock(m_ThreadPoolMutex); - m_TerminatePool = true; - } - - m_ThreadPoolEvent.notify_all(); - - for (auto &thread : m_Threads) - { - thread->join(); - } -} - -void LoadedNetwork::Schedule(const InputTensors& inputTensors, - const OutputTensors& outputTensors, - const QosExecPriority priority, - std::shared_ptr cb) -{ - // Group execution parameters so that they can be easily added to the queue - ExecutionTuple groupExecParams = std::make_tuple(inputTensors, outputTensors, cb); - std::shared_ptr operation = make_shared(groupExecParams); - - // Add a message to the queue and notify the request thread - std::unique_lock lock(m_ThreadPoolMutex); - switch (priority) { - case QosExecPriority::High: - m_HighPriorityQueue.push(operation); - break; - case QosExecPriority::Low: - m_LowPriorityQueue.push(operation); - break; - case QosExecPriority::Medium: - default: - m_MediumPriorityQueue.push(operation); - } - m_ThreadPoolEvent.notify_one(); -} - -void LoadedNetwork::ProcessExecPriorities(std::unique_ptr workingMemHandle) -{ - int expireRate = EXPIRE_RATE; - int highPriorityCount = 0; - int mediumPriorityCount = 0; - - IWorkingMemHandle& workingMemHandleRef = *workingMemHandle.get(); - - while (true) - { - std::shared_ptr currentExecInProgress(nullptr); - { - // Wait for a message to be added to the queue - // This is in a separate scope to minimise the lifetime of the lock - std::unique_lock lock(m_ThreadPoolMutex); - - m_ThreadPoolEvent.wait(lock, - [=] { - return m_TerminatePool || !m_HighPriorityQueue.empty() || - !m_MediumPriorityQueue.empty() || !m_LowPriorityQueue.empty(); - }); - - if (m_TerminatePool && m_HighPriorityQueue.empty() && m_MediumPriorityQueue.empty() && - m_LowPriorityQueue.empty()) - { - break; - } - - // Get the message to process from the front of each queue based on priority from high to low - // Get high priority first if it does not exceed the expire rate - if (!m_HighPriorityQueue.empty() && highPriorityCount < expireRate) - { - currentExecInProgress = m_HighPriorityQueue.front(); - m_HighPriorityQueue.pop(); - highPriorityCount += 1; - } - // If high priority queue is empty or the count exceeds the expire rate, get medium priority message - else if (!m_MediumPriorityQueue.empty() && mediumPriorityCount < expireRate) - { - currentExecInProgress = m_MediumPriorityQueue.front(); - m_MediumPriorityQueue.pop(); - mediumPriorityCount += 1; - // Reset high priority count - highPriorityCount = 0; - } - // If medium priority queue is empty or the count exceeds the expire rate, get low priority message - else if (!m_LowPriorityQueue.empty()) - { - currentExecInProgress = m_LowPriorityQueue.front(); - m_LowPriorityQueue.pop(); - // Reset high and medium priority count - highPriorityCount = 0; - mediumPriorityCount = 0; - } - else - { - // Reset high and medium priority count - highPriorityCount = 0; - mediumPriorityCount = 0; - continue; - } - } - - // invoke the asynchronous execution method - auto inputTensors = std::get<0>(*currentExecInProgress); - auto outputTensors = std::get<1>(*currentExecInProgress); - auto cb = std::get<2>(*currentExecInProgress); - - // Get time at start of inference - HighResolutionClock startTime = armnn::GetTimeNow(); - - try // executing the inference - { - // Execute and populate the time at end of inference in the callback - Execute(inputTensors, outputTensors, workingMemHandleRef) == Status::Success ? - cb->Notify(Status::Success, std::make_pair(startTime, armnn::GetTimeNow())) : - cb->Notify(Status::Failure, std::make_pair(startTime, armnn::GetTimeNow())); - } - catch (const RuntimeException& error) - { - cb->Notify(Status::Failure, std::make_pair(startTime, armnn::GetTimeNow())); - } - } -} - void LoadedNetwork::EnqueueInput(const BindableLayer& layer, const ConstTensor& inputTensor, WorkingMemHandle& context) diff --git a/src/armnn/LoadedNetwork.hpp b/src/armnn/LoadedNetwork.hpp index c85e82bbdd..360ad91170 100644 --- a/src/armnn/LoadedNetwork.hpp +++ b/src/armnn/LoadedNetwork.hpp @@ -37,16 +37,9 @@ class LoadedNetwork public: using WorkloadQueue = std::vector>; - using ExecutionTuple = std::tuple>; - - using ExecutionQueue = std::queue>; - ~LoadedNetwork() { FreeWorkingMemory(); - TerminateThreadPool(); } /// Create a new unique WorkingMemHandle object. Create multiple handles if you wish to have @@ -64,17 +57,10 @@ public: const OutputTensors& outputTensors, IWorkingMemHandle& workingMemHandle); - /// Schedule an asynchronous execution on the loaded network - void Schedule(const InputTensors& inputTensors, - const OutputTensors& outputTensors, - const QosExecPriority priority, - std::shared_ptr cb); - static std::unique_ptr MakeLoadedNetwork(std::unique_ptr net, std::string& errorMessage, const INetworkProperties& networkProperties, - profiling::ProfilingService& profilingService, - const NetworkId networkIdOut); + profiling::ProfilingService& profilingService); // NOTE we return by reference as the purpose of this method is only to provide // access to the private m_Profiler and in theory we should not need to increment @@ -108,8 +94,7 @@ private: LoadedNetwork(std::unique_ptr net, const INetworkProperties& networkProperties, - profiling::ProfilingService& profilingService, - const NetworkId networkIdOut); + profiling::ProfilingService& profilingService); void EnqueueInput(const BindableLayer& layer, ITensorHandle* tensorHandle, const TensorInfo& tensorInfo); @@ -119,15 +104,9 @@ private: void EnqueueOutput(const BindableLayer& layer, const Tensor& outputTensor, WorkingMemHandle& handle); - void ProcessExecPriorities(std::unique_ptr workingMemHandle); - bool Execute(std::unique_ptr& timelineUtils, profiling::ProfilingGuid inferenceGuid); - void CreateThreadPool(std::size_t numThreads); - - void TerminateThreadPool() noexcept; - const IWorkloadFactory& GetWorkloadFactory(const Layer& layer) const; using BackendPtrMap = std::unordered_map; @@ -146,25 +125,8 @@ private: bool m_IsWorkingMemAllocated = false; - std::vector> m_Threads; - std::stack m_WorkingMemHandles; - - ExecutionQueue m_HighPriorityQueue; - ExecutionQueue m_MediumPriorityQueue; - ExecutionQueue m_LowPriorityQueue; - - // Condition Variables require mutex which will guard the shared state. - // Has an event happened? Stop signal for example - std::condition_variable m_ThreadPoolEvent; - std::mutex m_ThreadPoolMutex; - - // The shared state for conditional variable - bool m_TerminatePool = false; - INetworkProperties m_NetworkProperties; - const NetworkId m_NetworkId; - TensorHandleFactoryRegistry m_TensorHandleFactoryRegistry; profiling::ProfilingService& m_ProfilingService; diff --git a/src/armnn/Runtime.cpp b/src/armnn/Runtime.cpp index 374064e408..f16d186191 100644 --- a/src/armnn/Runtime.cpp +++ b/src/armnn/Runtime.cpp @@ -89,15 +89,6 @@ Status IRuntime::Execute(IWorkingMemHandle& workingMemHandle, return pRuntimeImpl->Execute(workingMemHandle, inputTensors, outputTensors); } -void IRuntime::Schedule(NetworkId networkId, - const InputTensors& inputTensors, - const OutputTensors& outputTensors, - const QosExecPriority priority, - std::shared_ptr cb) -{ - pRuntimeImpl->Schedule(networkId, inputTensors, outputTensors, priority, cb); -} - Status IRuntime::UnloadNetwork(NetworkId networkId) { return pRuntimeImpl->UnloadNetwork(networkId); @@ -160,8 +151,7 @@ Status RuntimeImpl::LoadNetwork(NetworkId& networkIdOut, std::unique_ptr(rawNetwork), errorMessage, networkProperties, - m_ProfilingService, - networkIdOut); + m_ProfilingService); if (!loadedNetwork) { @@ -460,34 +450,6 @@ Status RuntimeImpl::Execute(IWorkingMemHandle& iWorkingMemHandle, return loadedNetwork->Execute(inputTensors, outputTensors, iWorkingMemHandle); } -void RuntimeImpl::Schedule(NetworkId networkId, - const InputTensors& inputTensors, - const OutputTensors& outputTensors, - const QosExecPriority priority, - std::shared_ptr callback) -{ - LoadedNetwork* loadedNetwork = GetLoadedNetworkPtr(networkId); - - if (!loadedNetwork) - { - throw armnn::Exception( - "Network with ID of " + std::to_string(networkId) + " does not exist \n" - ); - } - if (!loadedNetwork->IsAsyncEnabled()) - { - throw armnn::Exception( - "Attempting to schedule Network " + std::to_string(networkId) + " when it is not async enabled \n" - ); - } - - ProfilerManager::GetInstance().RegisterProfiler(loadedNetwork->GetProfiler().get()); - - ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "Schedule"); - - loadedNetwork->Schedule(inputTensors, outputTensors, priority, callback); -} - /// Create a new unique WorkingMemHandle object. Create multiple handles if you wish to have /// overlapped Execution by calling this function from different threads. std::unique_ptr RuntimeImpl::CreateWorkingMemHandle(NetworkId networkId) diff --git a/src/armnn/Runtime.hpp b/src/armnn/Runtime.hpp index 55a4accf67..7a80acd73e 100644 --- a/src/armnn/Runtime.hpp +++ b/src/armnn/Runtime.hpp @@ -60,16 +60,6 @@ public: const InputTensors& inputTensors, const OutputTensors& outputTensors); - /// This is an experimental function. - /// Schedule a thread safe execution by taking the input tensors and an execution priority for Quality of Service. - /// The output tensors will then be filled and the callback object will notify that the execution has either - /// succeeded or failed. - void Schedule(NetworkId networkId, - const InputTensors& inputTensors, - const OutputTensors& outputTensors, - const QosExecPriority priority, - std::shared_ptr callback); - /// This is an experimental function. /// Evaluates a network using input in inputTensors and outputs filled into outputTensors. /// This function performs a thread safe execution of the network. Returns once execution is complete. diff --git a/src/armnn/Threadpool.cpp b/src/armnn/Threadpool.cpp new file mode 100644 index 0000000000..a23c1e2339 --- /dev/null +++ b/src/armnn/Threadpool.cpp @@ -0,0 +1,206 @@ +// +// Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + + +#include + +#include + +namespace armnn +{ +namespace experimental +{ + +Threadpool::Threadpool(std::size_t numThreads, + IRuntime* runtimePtr, + std::vector> memHandles) + : m_RuntimePtr(runtimePtr) +{ + for (auto i = 0u; i < numThreads; ++i) + { + m_Threads.emplace_back(std::make_unique(&Threadpool::ProcessExecPriorities, this, i)); + } + + LoadMemHandles(memHandles); +} + +void Threadpool::LoadMemHandles(std::vector> memHandles) +{ + if (memHandles.size() == 0) + { + throw armnn::RuntimeException("Threadpool::UnloadMemHandles: Size of memHandles vector must be greater than 0"); + } + + if (memHandles.size() != m_Threads.size()) + { + throw armnn::RuntimeException( + "Threadpool::UnloadMemHandles: Size of memHandles vector must match the number of threads"); + } + + NetworkId networkId = memHandles[0]->GetNetworkId(); + for (uint32_t i = 1; i < memHandles.size(); ++i) + { + if (networkId != memHandles[i]->GetNetworkId()) + { + throw armnn::RuntimeException( + "Threadpool::UnloadMemHandles: All network ids must be identical in memHandles"); + } + } + + std::pair>> pair {networkId, memHandles}; + + m_WorkingMemHandleMap.insert(pair); +} + +void Threadpool::UnloadMemHandles(NetworkId networkId) +{ + if (m_WorkingMemHandleMap.find(networkId) != m_WorkingMemHandleMap.end()) + { + m_WorkingMemHandleMap.erase(networkId); + } + else + { + throw armnn::RuntimeException("Threadpool::UnloadMemHandles: Unknown NetworkId"); + } +} + +void Threadpool::Schedule(NetworkId networkId, + const InputTensors& inputTensors, + const OutputTensors& outputTensors, + const QosExecPriority priority, + std::shared_ptr cb) +{ + if (m_WorkingMemHandleMap.find(networkId) == m_WorkingMemHandleMap.end()) + { + throw armnn::RuntimeException("Threadpool::UnloadMemHandles: Unknown NetworkId"); + } + + // Group execution parameters so that they can be easily added to the queue + ExecutionTuple groupExecParams = std::make_tuple(networkId, inputTensors, outputTensors, cb); + + std::shared_ptr operation = std::make_shared(groupExecParams); + + // Add a message to the queue and notify the request thread + std::unique_lock lock(m_ThreadPoolMutex); + switch (priority) + { + case QosExecPriority::High: + m_HighPriorityQueue.push(operation); + break; + case QosExecPriority::Low: + m_LowPriorityQueue.push(operation); + break; + case QosExecPriority::Medium: + default: + m_MediumPriorityQueue.push(operation); + } + m_ThreadPoolEvent.notify_one(); +} + +void Threadpool::TerminateThreadPool() noexcept +{ + { + std::unique_lock threadPoolLock(m_ThreadPoolMutex); + m_TerminatePool = true; + } + + m_ThreadPoolEvent.notify_all(); + + for (auto &thread : m_Threads) + { + thread->join(); + } +} + +void Threadpool::ProcessExecPriorities(uint32_t index) +{ + int expireRate = EXPIRE_RATE; + int highPriorityCount = 0; + int mediumPriorityCount = 0; + + while (true) + { + std::shared_ptr currentExecInProgress(nullptr); + { + // Wait for a message to be added to the queue + // This is in a separate scope to minimise the lifetime of the lock + std::unique_lock lock(m_ThreadPoolMutex); + + m_ThreadPoolEvent.wait(lock, + [=] + { + return m_TerminatePool || !m_HighPriorityQueue.empty() || + !m_MediumPriorityQueue.empty() || !m_LowPriorityQueue.empty(); + }); + + if (m_TerminatePool && m_HighPriorityQueue.empty() && m_MediumPriorityQueue.empty() && + m_LowPriorityQueue.empty()) + { + break; + } + + // Get the message to process from the front of each queue based on priority from high to low + // Get high priority first if it does not exceed the expire rate + if (!m_HighPriorityQueue.empty() && highPriorityCount < expireRate) + { + currentExecInProgress = m_HighPriorityQueue.front(); + m_HighPriorityQueue.pop(); + highPriorityCount += 1; + } + // If high priority queue is empty or the count exceeds the expire rate, get medium priority message + else if (!m_MediumPriorityQueue.empty() && mediumPriorityCount < expireRate) + { + currentExecInProgress = m_MediumPriorityQueue.front(); + m_MediumPriorityQueue.pop(); + mediumPriorityCount += 1; + // Reset high priority count + highPriorityCount = 0; + } + // If medium priority queue is empty or the count exceeds the expire rate, get low priority message + else if (!m_LowPriorityQueue.empty()) + { + currentExecInProgress = m_LowPriorityQueue.front(); + m_LowPriorityQueue.pop(); + // Reset high and medium priority count + highPriorityCount = 0; + mediumPriorityCount = 0; + } + else + { + // Reset high and medium priority count + highPriorityCount = 0; + mediumPriorityCount = 0; + continue; + } + } + + // invoke the asynchronous execution method + auto networkId = std::get<0>(*currentExecInProgress); + auto inputTensors = std::get<1>(*currentExecInProgress); + auto outputTensors = std::get<2>(*currentExecInProgress); + auto cb = std::get<3>(*currentExecInProgress); + + // Get time at start of inference + HighResolutionClock startTime = armnn::GetTimeNow(); + + try // executing the inference + { + IWorkingMemHandle& memHandle = *(m_WorkingMemHandleMap.at(networkId))[index]; + + // Execute and populate the time at end of inference in the callback + m_RuntimePtr->Execute(memHandle, inputTensors, outputTensors) == Status::Success ? + cb->Notify(Status::Success, std::make_pair(startTime, armnn::GetTimeNow())) : + cb->Notify(Status::Failure, std::make_pair(startTime, armnn::GetTimeNow())); + } + catch (const RuntimeException &error) + { + cb->Notify(Status::Failure, std::make_pair(startTime, armnn::GetTimeNow())); + } + } +} + +} // namespace experimental + +} // namespace armnn diff --git a/src/armnn/WorkingMemHandle.cpp b/src/armnn/WorkingMemHandle.cpp index 94d796eced..1dcaa3853a 100644 --- a/src/armnn/WorkingMemHandle.cpp +++ b/src/armnn/WorkingMemHandle.cpp @@ -26,8 +26,7 @@ WorkingMemHandle::WorkingMemHandle( m_MemoryManagers(memoryManagers), m_OwnedTensorHandles(std::move(ownedTensorHandles)), m_IsAllocated(false), - m_Mutex(), - m_InferenceId(profiling::ProfilingService::GetNextGuid()) + m_Mutex() { } diff --git a/src/armnn/WorkingMemHandle.hpp b/src/armnn/WorkingMemHandle.hpp index 5ccb2b2342..5e3fd66299 100644 --- a/src/armnn/WorkingMemHandle.hpp +++ b/src/armnn/WorkingMemHandle.hpp @@ -13,6 +13,7 @@ #include #include +#include namespace armnn { @@ -38,10 +39,7 @@ public: return m_NetworkId; } - profiling::ProfilingGuid GetInferenceId() override - { - return m_InferenceId; - } + /// Allocate the backing memory required for execution. If this is not called, then allocation will be /// deferred to execution time. The mutex must be locked. @@ -92,7 +90,6 @@ private: bool m_IsAllocated; std::mutex m_Mutex; - profiling::ProfilingGuid m_InferenceId; }; } // end experimental namespace -- cgit v1.2.1