aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatteo Martincigh <matteo.martincigh@arm.com>2019-10-04 14:40:04 +0100
committerMatteo Martincigh <matteo.martincigh@arm.com>2019-10-07 10:08:27 +0000
commit88813936232bc47fc7768800c6895191585570e8 (patch)
treedfe9a615abd585ad9489afacc017ae40270486cc
parenta84edee4702c112a6e004b1987acc11144e2d6dd (diff)
downloadarmnn-88813936232bc47fc7768800c6895191585570e8.tar.gz
IVGCVSW-3937 Refactor the command thread
* Integrated the Join method into Stop * Updated the unit tests accordingly * General code refactoring Signed-off-by: Matteo Martincigh <matteo.martincigh@arm.com> Change-Id: If8537e77b3d3ff2b780f58a07df01191a91d83d2
-rw-r--r--src/profiling/CommandThread.cpp108
-rw-r--r--src/profiling/CommandThread.hpp26
-rw-r--r--src/profiling/test/ProfilingTests.cpp13
3 files changed, 75 insertions, 72 deletions
diff --git a/src/profiling/CommandThread.cpp b/src/profiling/CommandThread.cpp
index bd4aa96c7c..320e4bcf5c 100644
--- a/src/profiling/CommandThread.cpp
+++ b/src/profiling/CommandThread.cpp
@@ -12,86 +12,76 @@ namespace armnn
namespace profiling
{
-CommandThread::CommandThread(uint32_t timeout,
- bool stopAfterTimeout,
- CommandHandlerRegistry& commandHandlerRegistry,
- PacketVersionResolver& packetVersionResolver,
- IProfilingConnection& socketProfilingConnection)
- : m_Timeout(timeout)
- , m_StopAfterTimeout(stopAfterTimeout)
- , m_IsRunning(false)
- , m_CommandHandlerRegistry(commandHandlerRegistry)
- , m_PacketVersionResolver(packetVersionResolver)
- , m_SocketProfilingConnection(socketProfilingConnection)
-{};
-
-void CommandThread::WaitForPacket()
+void CommandThread::Start()
{
- do {
- try
- {
- Packet packet = m_SocketProfilingConnection.ReadPacket(m_Timeout);
- Version version = m_PacketVersionResolver.ResolvePacketVersion(packet.GetPacketId());
-
- CommandHandlerFunctor* commandHandlerFunctor =
- m_CommandHandlerRegistry.GetFunctor(packet.GetPacketId(), version.GetEncodedValue());
- commandHandlerFunctor->operator()(packet);
- }
- catch(const armnn::TimeoutException&)
- {
- if(m_StopAfterTimeout)
- {
- m_IsRunning.store(false, std::memory_order_relaxed);
- return;
- }
- }
- catch(...)
- {
- //might want to differentiate the errors more
- m_IsRunning.store(false, std::memory_order_relaxed);
- return;
- }
-
- } while(m_KeepRunning.load(std::memory_order_relaxed));
+ if (IsRunning())
+ {
+ return;
+ }
- m_IsRunning.store(false, std::memory_order_relaxed);
+ m_IsRunning.store(true, std::memory_order_relaxed);
+ m_KeepRunning.store(true, std::memory_order_relaxed);
+ m_CommandThread = std::thread(&CommandThread::WaitForPacket, this);
}
-void CommandThread::Start()
+void CommandThread::Stop()
{
- if (!m_CommandThread.joinable() && !IsRunning())
+ m_KeepRunning.store(false, std::memory_order_relaxed);
+
+ if (m_CommandThread.joinable())
{
- m_IsRunning.store(true, std::memory_order_relaxed);
- m_KeepRunning.store(true, std::memory_order_relaxed);
- m_CommandThread = std::thread(&CommandThread::WaitForPacket, this);
+ m_CommandThread.join();
}
}
-void CommandThread::Stop()
+bool CommandThread::IsRunning() const
{
- m_KeepRunning.store(false, std::memory_order_relaxed);
+ return m_IsRunning.load(std::memory_order_relaxed);
}
-void CommandThread::Join()
+void CommandThread::SetTimeout(uint32_t timeout)
{
- m_CommandThread.join();
+ m_Timeout.store(timeout, std::memory_order_relaxed);
}
-bool CommandThread::IsRunning() const
+void CommandThread::SetStopAfterTimeout(bool stopAfterTimeout)
{
- return m_IsRunning.load(std::memory_order_relaxed);
+ m_StopAfterTimeout.store(stopAfterTimeout, std::memory_order_relaxed);
}
-bool CommandThread::StopAfterTimeout(bool stopAfterTimeout)
+void CommandThread::WaitForPacket()
{
- if (!IsRunning())
+ do
{
- m_StopAfterTimeout = stopAfterTimeout;
- return true;
+ try
+ {
+ Packet packet = m_SocketProfilingConnection.ReadPacket(m_Timeout);
+ Version version = m_PacketVersionResolver.ResolvePacketVersion(packet.GetPacketId());
+
+ CommandHandlerFunctor* commandHandlerFunctor =
+ m_CommandHandlerRegistry.GetFunctor(packet.GetPacketId(), version.GetEncodedValue());
+ BOOST_ASSERT(commandHandlerFunctor);
+ commandHandlerFunctor->operator()(packet);
+ }
+ catch (const armnn::TimeoutException&)
+ {
+ if (m_StopAfterTimeout)
+ {
+ m_KeepRunning.store(false, std::memory_order_relaxed);
+ }
+ }
+ catch (...)
+ {
+ // Might want to differentiate the errors more
+ m_KeepRunning.store(false, std::memory_order_relaxed);
+ }
+
}
- return false;
+ while (m_KeepRunning.load(std::memory_order_relaxed));
+
+ m_IsRunning.store(false, std::memory_order_relaxed);
}
-}//namespace profiling
+} // namespace profiling
-}//namespace armnn
+} // namespace armnn
diff --git a/src/profiling/CommandThread.hpp b/src/profiling/CommandThread.hpp
index 6237cd2914..0456ba4372 100644
--- a/src/profiling/CommandThread.hpp
+++ b/src/profiling/CommandThread.hpp
@@ -26,19 +26,31 @@ public:
bool stopAfterTimeout,
CommandHandlerRegistry& commandHandlerRegistry,
PacketVersionResolver& packetVersionResolver,
- IProfilingConnection& socketProfilingConnection);
+ IProfilingConnection& socketProfilingConnection)
+ : m_Timeout(timeout)
+ , m_StopAfterTimeout(stopAfterTimeout)
+ , m_IsRunning(false)
+ , m_KeepRunning(false)
+ , m_CommandThread()
+ , m_CommandHandlerRegistry(commandHandlerRegistry)
+ , m_PacketVersionResolver(packetVersionResolver)
+ , m_SocketProfilingConnection(socketProfilingConnection)
+ {}
+ ~CommandThread() { Stop(); }
void Start();
void Stop();
- void Join();
+
bool IsRunning() const;
- bool StopAfterTimeout(bool StopAfterTimeout);
+
+ void SetTimeout(uint32_t timeout);
+ void SetStopAfterTimeout(bool stopAfterTimeout);
private:
void WaitForPacket();
- uint32_t m_Timeout;
- bool m_StopAfterTimeout;
+ std::atomic<uint32_t> m_Timeout;
+ std::atomic<bool> m_StopAfterTimeout;
std::atomic<bool> m_IsRunning;
std::atomic<bool> m_KeepRunning;
std::thread m_CommandThread;
@@ -48,6 +60,6 @@ private:
IProfilingConnection& m_SocketProfilingConnection;
};
-}//namespace profiling
+} // namespace profiling
-}//namespace armnn \ No newline at end of file
+} // namespace armnn
diff --git a/src/profiling/test/ProfilingTests.cpp b/src/profiling/test/ProfilingTests.cpp
index d14791c43d..9dd7cd3d64 100644
--- a/src/profiling/test/ProfilingTests.cpp
+++ b/src/profiling/test/ProfilingTests.cpp
@@ -174,7 +174,6 @@ BOOST_AUTO_TEST_CASE(CheckCommandThread)
commandThread0.Start();
commandThread0.Stop();
- commandThread0.Join();
BOOST_CHECK(profilingStateMachine.GetCurrentState() == ProfilingState::Active);
@@ -188,11 +187,15 @@ BOOST_AUTO_TEST_CASE(CheckCommandThread)
testProfilingConnectionTimeOutError);
commandThread1.Start();
- commandThread1.Join();
+
+ std::this_thread::sleep_for(std::chrono::milliseconds(100));
+
+ BOOST_CHECK(!commandThread1.IsRunning());
+ commandThread1.Stop();
BOOST_CHECK(profilingStateMachine.GetCurrentState() == ProfilingState::WaitingForAck);
//now commandThread1 should persist after a timeout
- commandThread1.StopAfterTimeout(false);
+ commandThread1.SetStopAfterTimeout(false);
commandThread1.Start();
for (int i = 0; i < 100; i++)
@@ -208,11 +211,9 @@ BOOST_AUTO_TEST_CASE(CheckCommandThread)
}
commandThread1.Stop();
- commandThread1.Join();
BOOST_CHECK(profilingStateMachine.GetCurrentState() == ProfilingState::Active);
-
CommandThread commandThread2(1,
false,
commandHandlerRegistry,
@@ -226,13 +227,13 @@ BOOST_AUTO_TEST_CASE(CheckCommandThread)
if (!commandThread2.IsRunning())
{
//commandThread2 should stop once it encounters a non timing error
- commandThread2.Join();
return;
}
std::this_thread::sleep_for(std::chrono::milliseconds(5));
}
BOOST_ERROR("commandThread2 has failed to stop");
+ commandThread2.Stop();
}
BOOST_AUTO_TEST_CASE(CheckEncodeVersion)