aboutsummaryrefslogtreecommitdiff
path: root/src/profiling/CommandThread.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/profiling/CommandThread.cpp')
-rw-r--r--src/profiling/CommandThread.cpp108
1 files changed, 49 insertions, 59 deletions
diff --git a/src/profiling/CommandThread.cpp b/src/profiling/CommandThread.cpp
index bd4aa96c7c..320e4bcf5c 100644
--- a/src/profiling/CommandThread.cpp
+++ b/src/profiling/CommandThread.cpp
@@ -12,86 +12,76 @@ namespace armnn
namespace profiling
{
-CommandThread::CommandThread(uint32_t timeout,
- bool stopAfterTimeout,
- CommandHandlerRegistry& commandHandlerRegistry,
- PacketVersionResolver& packetVersionResolver,
- IProfilingConnection& socketProfilingConnection)
- : m_Timeout(timeout)
- , m_StopAfterTimeout(stopAfterTimeout)
- , m_IsRunning(false)
- , m_CommandHandlerRegistry(commandHandlerRegistry)
- , m_PacketVersionResolver(packetVersionResolver)
- , m_SocketProfilingConnection(socketProfilingConnection)
-{};
-
-void CommandThread::WaitForPacket()
+void CommandThread::Start()
{
- do {
- try
- {
- Packet packet = m_SocketProfilingConnection.ReadPacket(m_Timeout);
- Version version = m_PacketVersionResolver.ResolvePacketVersion(packet.GetPacketId());
-
- CommandHandlerFunctor* commandHandlerFunctor =
- m_CommandHandlerRegistry.GetFunctor(packet.GetPacketId(), version.GetEncodedValue());
- commandHandlerFunctor->operator()(packet);
- }
- catch(const armnn::TimeoutException&)
- {
- if(m_StopAfterTimeout)
- {
- m_IsRunning.store(false, std::memory_order_relaxed);
- return;
- }
- }
- catch(...)
- {
- //might want to differentiate the errors more
- m_IsRunning.store(false, std::memory_order_relaxed);
- return;
- }
-
- } while(m_KeepRunning.load(std::memory_order_relaxed));
+ if (IsRunning())
+ {
+ return;
+ }
- m_IsRunning.store(false, std::memory_order_relaxed);
+ m_IsRunning.store(true, std::memory_order_relaxed);
+ m_KeepRunning.store(true, std::memory_order_relaxed);
+ m_CommandThread = std::thread(&CommandThread::WaitForPacket, this);
}
-void CommandThread::Start()
+void CommandThread::Stop()
{
- if (!m_CommandThread.joinable() && !IsRunning())
+ m_KeepRunning.store(false, std::memory_order_relaxed);
+
+ if (m_CommandThread.joinable())
{
- m_IsRunning.store(true, std::memory_order_relaxed);
- m_KeepRunning.store(true, std::memory_order_relaxed);
- m_CommandThread = std::thread(&CommandThread::WaitForPacket, this);
+ m_CommandThread.join();
}
}
-void CommandThread::Stop()
+bool CommandThread::IsRunning() const
{
- m_KeepRunning.store(false, std::memory_order_relaxed);
+ return m_IsRunning.load(std::memory_order_relaxed);
}
-void CommandThread::Join()
+void CommandThread::SetTimeout(uint32_t timeout)
{
- m_CommandThread.join();
+ m_Timeout.store(timeout, std::memory_order_relaxed);
}
-bool CommandThread::IsRunning() const
+void CommandThread::SetStopAfterTimeout(bool stopAfterTimeout)
{
- return m_IsRunning.load(std::memory_order_relaxed);
+ m_StopAfterTimeout.store(stopAfterTimeout, std::memory_order_relaxed);
}
-bool CommandThread::StopAfterTimeout(bool stopAfterTimeout)
+void CommandThread::WaitForPacket()
{
- if (!IsRunning())
+ do
{
- m_StopAfterTimeout = stopAfterTimeout;
- return true;
+ try
+ {
+ Packet packet = m_SocketProfilingConnection.ReadPacket(m_Timeout);
+ Version version = m_PacketVersionResolver.ResolvePacketVersion(packet.GetPacketId());
+
+ CommandHandlerFunctor* commandHandlerFunctor =
+ m_CommandHandlerRegistry.GetFunctor(packet.GetPacketId(), version.GetEncodedValue());
+ BOOST_ASSERT(commandHandlerFunctor);
+ commandHandlerFunctor->operator()(packet);
+ }
+ catch (const armnn::TimeoutException&)
+ {
+ if (m_StopAfterTimeout)
+ {
+ m_KeepRunning.store(false, std::memory_order_relaxed);
+ }
+ }
+ catch (...)
+ {
+ // Might want to differentiate the errors more
+ m_KeepRunning.store(false, std::memory_order_relaxed);
+ }
+
}
- return false;
+ while (m_KeepRunning.load(std::memory_order_relaxed));
+
+ m_IsRunning.store(false, std::memory_order_relaxed);
}
-}//namespace profiling
+} // namespace profiling
-}//namespace armnn
+} // namespace armnn