diff options
Diffstat (limited to 'src/armnn/AsyncExecutionCallback.cpp')
-rw-r--r-- | src/armnn/AsyncExecutionCallback.cpp | 38 |
1 files changed, 24 insertions, 14 deletions
diff --git a/src/armnn/AsyncExecutionCallback.cpp b/src/armnn/AsyncExecutionCallback.cpp index c44808918d..2973e2d891 100644 --- a/src/armnn/AsyncExecutionCallback.cpp +++ b/src/armnn/AsyncExecutionCallback.cpp @@ -15,43 +15,53 @@ void AsyncExecutionCallback::Notify(armnn::Status status, InferenceTimingPair ti { { std::lock_guard<std::mutex> hold(m_Mutex); - if (m_Notified) - { - return; - } // store results and mark as notified m_Status = status; m_StartTime = timeTaken.first; m_EndTime = timeTaken.second; - m_Notified = true; + m_NotificationQueue.push(m_InferenceId); } m_Condition.notify_all(); } -void AsyncExecutionCallback::Wait() const -{ - std::unique_lock<std::mutex> lock(m_Mutex); - m_Condition.wait(lock, [this] { return m_Notified; }); -} - armnn::Status AsyncExecutionCallback::GetStatus() const { - Wait(); return m_Status; } HighResolutionClock AsyncExecutionCallback::GetStartTime() const { - Wait(); return m_StartTime; } HighResolutionClock AsyncExecutionCallback::GetEndTime() const { - Wait(); return m_EndTime; } +std::shared_ptr<AsyncExecutionCallback> AsyncCallbackManager::GetNewCallback() +{ + auto cb = std::make_unique<AsyncExecutionCallback>(m_NotificationQueue, m_Mutex, m_Condition); + InferenceId id = cb->GetInferenceId(); + m_Callbacks.insert({id, std::move(cb)}); + + return m_Callbacks.at(id); +} + +std::shared_ptr<AsyncExecutionCallback> AsyncCallbackManager::GetNotifiedCallback() +{ + std::unique_lock<std::mutex> lock(m_Mutex); + + m_Condition.wait(lock, [this] { return !m_NotificationQueue.empty(); }); + + InferenceId id = m_NotificationQueue.front(); + m_NotificationQueue.pop(); + + std::shared_ptr<AsyncExecutionCallback> callback = m_Callbacks.at(id); + m_Callbacks.erase(id); + return callback; +} + } // namespace experimental } // namespace armnn
\ No newline at end of file |