diff options
Diffstat (limited to 'src/profiling/CommandThread.cpp')
-rw-r--r-- | src/profiling/CommandThread.cpp | 108 |
1 files changed, 49 insertions, 59 deletions
diff --git a/src/profiling/CommandThread.cpp b/src/profiling/CommandThread.cpp index bd4aa96c7c..320e4bcf5c 100644 --- a/src/profiling/CommandThread.cpp +++ b/src/profiling/CommandThread.cpp @@ -12,86 +12,76 @@ namespace armnn namespace profiling { -CommandThread::CommandThread(uint32_t timeout, - bool stopAfterTimeout, - CommandHandlerRegistry& commandHandlerRegistry, - PacketVersionResolver& packetVersionResolver, - IProfilingConnection& socketProfilingConnection) - : m_Timeout(timeout) - , m_StopAfterTimeout(stopAfterTimeout) - , m_IsRunning(false) - , m_CommandHandlerRegistry(commandHandlerRegistry) - , m_PacketVersionResolver(packetVersionResolver) - , m_SocketProfilingConnection(socketProfilingConnection) -{}; - -void CommandThread::WaitForPacket() +void CommandThread::Start() { - do { - try - { - Packet packet = m_SocketProfilingConnection.ReadPacket(m_Timeout); - Version version = m_PacketVersionResolver.ResolvePacketVersion(packet.GetPacketId()); - - CommandHandlerFunctor* commandHandlerFunctor = - m_CommandHandlerRegistry.GetFunctor(packet.GetPacketId(), version.GetEncodedValue()); - commandHandlerFunctor->operator()(packet); - } - catch(const armnn::TimeoutException&) - { - if(m_StopAfterTimeout) - { - m_IsRunning.store(false, std::memory_order_relaxed); - return; - } - } - catch(...) - { - //might want to differentiate the errors more - m_IsRunning.store(false, std::memory_order_relaxed); - return; - } - - } while(m_KeepRunning.load(std::memory_order_relaxed)); + if (IsRunning()) + { + return; + } - m_IsRunning.store(false, std::memory_order_relaxed); + m_IsRunning.store(true, std::memory_order_relaxed); + m_KeepRunning.store(true, std::memory_order_relaxed); + m_CommandThread = std::thread(&CommandThread::WaitForPacket, this); } -void CommandThread::Start() +void CommandThread::Stop() { - if (!m_CommandThread.joinable() && !IsRunning()) + m_KeepRunning.store(false, std::memory_order_relaxed); + + if (m_CommandThread.joinable()) { - m_IsRunning.store(true, std::memory_order_relaxed); - m_KeepRunning.store(true, std::memory_order_relaxed); - m_CommandThread = std::thread(&CommandThread::WaitForPacket, this); + m_CommandThread.join(); } } -void CommandThread::Stop() +bool CommandThread::IsRunning() const { - m_KeepRunning.store(false, std::memory_order_relaxed); + return m_IsRunning.load(std::memory_order_relaxed); } -void CommandThread::Join() +void CommandThread::SetTimeout(uint32_t timeout) { - m_CommandThread.join(); + m_Timeout.store(timeout, std::memory_order_relaxed); } -bool CommandThread::IsRunning() const +void CommandThread::SetStopAfterTimeout(bool stopAfterTimeout) { - return m_IsRunning.load(std::memory_order_relaxed); + m_StopAfterTimeout.store(stopAfterTimeout, std::memory_order_relaxed); } -bool CommandThread::StopAfterTimeout(bool stopAfterTimeout) +void CommandThread::WaitForPacket() { - if (!IsRunning()) + do { - m_StopAfterTimeout = stopAfterTimeout; - return true; + try + { + Packet packet = m_SocketProfilingConnection.ReadPacket(m_Timeout); + Version version = m_PacketVersionResolver.ResolvePacketVersion(packet.GetPacketId()); + + CommandHandlerFunctor* commandHandlerFunctor = + m_CommandHandlerRegistry.GetFunctor(packet.GetPacketId(), version.GetEncodedValue()); + BOOST_ASSERT(commandHandlerFunctor); + commandHandlerFunctor->operator()(packet); + } + catch (const armnn::TimeoutException&) + { + if (m_StopAfterTimeout) + { + m_KeepRunning.store(false, std::memory_order_relaxed); + } + } + catch (...) + { + // Might want to differentiate the errors more + m_KeepRunning.store(false, std::memory_order_relaxed); + } + } - return false; + while (m_KeepRunning.load(std::memory_order_relaxed)); + + m_IsRunning.store(false, std::memory_order_relaxed); } -}//namespace profiling +} // namespace profiling -}//namespace armnn +} // namespace armnn |