aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CMakeLists.txt2
-rw-r--r--include/armnn/Exceptions.hpp5
-rw-r--r--src/profiling/CommandThread.cpp97
-rw-r--r--src/profiling/CommandThread.hpp53
-rw-r--r--src/profiling/SocketProfilingConnection.cpp2
-rw-r--r--src/profiling/test/ProfilingTests.cpp145
6 files changed, 303 insertions, 1 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 90eb0328dd..3da7e8bcfa 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -428,6 +428,8 @@ list(APPEND armnn_sources
src/profiling/CommandHandlerKey.hpp
src/profiling/CommandHandlerRegistry.cpp
src/profiling/CommandHandlerRegistry.hpp
+ src/profiling/CommandThread.cpp
+ src/profiling/CommandThread.hpp
src/profiling/ConnectionAcknowledgedCommandHandler.cpp
src/profiling/ConnectionAcknowledgedCommandHandler.hpp
src/profiling/CounterDirectory.cpp
diff --git a/include/armnn/Exceptions.hpp b/include/armnn/Exceptions.hpp
index f8e0b430a6..e21e974fc7 100644
--- a/include/armnn/Exceptions.hpp
+++ b/include/armnn/Exceptions.hpp
@@ -125,6 +125,11 @@ class MemoryExportException : public Exception
using Exception::Exception;
};
+class TimeoutException : public Exception
+{
+ using Exception::Exception;
+};
+
template <typename ExceptionType>
void ConditionalThrow(bool condition, const std::string& message)
{
diff --git a/src/profiling/CommandThread.cpp b/src/profiling/CommandThread.cpp
new file mode 100644
index 0000000000..4cd622c477
--- /dev/null
+++ b/src/profiling/CommandThread.cpp
@@ -0,0 +1,97 @@
+//
+// Copyright © 2019 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include <atomic>
+#include "CommandThread.hpp"
+
+namespace armnn
+{
+
+namespace profiling
+{
+
+CommandThread::CommandThread(uint32_t timeout,
+ bool stopAfterTimeout,
+ CommandHandlerRegistry& commandHandlerRegistry,
+ PacketVersionResolver& packetVersionResolver,
+ IProfilingConnection& socketProfilingConnection)
+ : m_Timeout(timeout)
+ , m_StopAfterTimeout(stopAfterTimeout)
+ , m_IsRunning(false)
+ , m_CommandHandlerRegistry(commandHandlerRegistry)
+ , m_PacketVersionResolver(packetVersionResolver)
+ , m_SocketProfilingConnection(socketProfilingConnection)
+{};
+
+void CommandThread::WaitForPacket()
+{
+ do {
+ try
+ {
+ Packet packet = m_SocketProfilingConnection.ReadPacket(m_Timeout);
+ Version version = m_PacketVersionResolver.ResolvePacketVersion(packet.GetPacketId());
+
+ CommandHandlerFunctor* commandHandlerFunctor =
+ m_CommandHandlerRegistry.GetFunctor(packet.GetPacketId(), version.GetEncodedValue());
+ commandHandlerFunctor->operator()(packet);
+ }
+ catch(armnn::TimeoutException)
+ {
+ if(m_StopAfterTimeout)
+ {
+ m_IsRunning.store(false, std::memory_order_relaxed);
+ return;
+ }
+ }
+ catch(...)
+ {
+ //might want to differentiate the errors more
+ m_IsRunning.store(false, std::memory_order_relaxed);
+ return;
+ }
+
+ } while(m_KeepRunning.load(std::memory_order_relaxed));
+
+ m_IsRunning.store(false, std::memory_order_relaxed);
+}
+
+void CommandThread::Start()
+{
+ if (!m_CommandThread.joinable() && !IsRunning())
+ {
+ m_IsRunning.store(true, std::memory_order_relaxed);
+ m_KeepRunning.store(true, std::memory_order_relaxed);
+ m_CommandThread = std::thread(&CommandThread::WaitForPacket, this);
+ }
+}
+
+void CommandThread::Stop()
+{
+ m_KeepRunning.store(false, std::memory_order_relaxed);
+}
+
+void CommandThread::Join()
+{
+ m_CommandThread.join();
+}
+
+bool CommandThread::IsRunning() const
+{
+ return m_IsRunning.load(std::memory_order_relaxed);
+}
+
+bool CommandThread::StopAfterTimeout(bool stopAfterTimeout)
+{
+ if (!IsRunning())
+ {
+ m_StopAfterTimeout = stopAfterTimeout;
+ return true;
+ }
+ return false;
+}
+
+}//namespace profiling
+
+}//namespace armnn \ No newline at end of file
diff --git a/src/profiling/CommandThread.hpp b/src/profiling/CommandThread.hpp
new file mode 100644
index 0000000000..6237cd2914
--- /dev/null
+++ b/src/profiling/CommandThread.hpp
@@ -0,0 +1,53 @@
+//
+// Copyright © 2019 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "CommandHandlerRegistry.hpp"
+#include "IProfilingConnection.hpp"
+#include "PacketVersionResolver.hpp"
+#include "ProfilingService.hpp"
+
+#include <atomic>
+#include <thread>
+
+namespace armnn
+{
+
+namespace profiling
+{
+
+class CommandThread
+{
+public:
+ CommandThread(uint32_t timeout,
+ bool stopAfterTimeout,
+ CommandHandlerRegistry& commandHandlerRegistry,
+ PacketVersionResolver& packetVersionResolver,
+ IProfilingConnection& socketProfilingConnection);
+
+ void Start();
+ void Stop();
+ void Join();
+ bool IsRunning() const;
+ bool StopAfterTimeout(bool StopAfterTimeout);
+
+private:
+ void WaitForPacket();
+
+ uint32_t m_Timeout;
+ bool m_StopAfterTimeout;
+ std::atomic<bool> m_IsRunning;
+ std::atomic<bool> m_KeepRunning;
+ std::thread m_CommandThread;
+
+ CommandHandlerRegistry& m_CommandHandlerRegistry;
+ PacketVersionResolver& m_PacketVersionResolver;
+ IProfilingConnection& m_SocketProfilingConnection;
+};
+
+}//namespace profiling
+
+}//namespace armnn \ No newline at end of file
diff --git a/src/profiling/SocketProfilingConnection.cpp b/src/profiling/SocketProfilingConnection.cpp
index 188ca23e12..91d57cc9bd 100644
--- a/src/profiling/SocketProfilingConnection.cpp
+++ b/src/profiling/SocketProfilingConnection.cpp
@@ -135,7 +135,7 @@ Packet SocketProfilingConnection::ReadPacket(uint32_t timeout)
}
else // it's 0 so a timeout.
{
- throw armnn::Exception(": Timeout while reading from socket.");
+ throw armnn::TimeoutException(": Timeout while reading from socket.");
}
}
diff --git a/src/profiling/test/ProfilingTests.cpp b/src/profiling/test/ProfilingTests.cpp
index 32a41f37c2..48723dbc34 100644
--- a/src/profiling/test/ProfilingTests.cpp
+++ b/src/profiling/test/ProfilingTests.cpp
@@ -4,6 +4,7 @@
//
#include "SendCounterPacketTests.hpp"
+#include "../CommandThread.hpp"
#include <CommandHandlerKey.hpp>
#include <CommandHandlerFunctor.hpp>
@@ -87,6 +88,150 @@ BOOST_AUTO_TEST_CASE(CheckCommandHandlerKeyComparisons)
BOOST_CHECK(vect == expectedVect);
}
+class TestProfilingConnectionBase :public IProfilingConnection
+{
+public:
+ TestProfilingConnectionBase() = default;
+ ~TestProfilingConnectionBase() = default;
+
+ bool IsOpen()
+ {
+ return true;
+ }
+
+ void Close(){}
+
+ bool WritePacket(const char* buffer, uint32_t length)
+ {
+ return false;
+ }
+
+ Packet ReadPacket(uint32_t timeout)
+ {
+ std::this_thread::sleep_for(std::chrono::milliseconds(timeout));
+ std::unique_ptr<char[]> packetData;
+ //Return connection acknowledged packet
+ return {65536 ,0 , packetData};
+ }
+};
+
+class TestProfilingConnectionTimeoutError :public TestProfilingConnectionBase
+{
+ int readRequests = 0;
+public:
+ Packet ReadPacket(uint32_t timeout) {
+ if (readRequests < 3)
+ {
+ readRequests++;
+ throw armnn::TimeoutException(": Simulate a timeout");
+ }
+ std::this_thread::sleep_for(std::chrono::milliseconds(timeout));
+ std::unique_ptr<char[]> packetData;
+ //Return connection acknowledged packet after three timeouts
+ return {65536 ,0 , packetData};
+ }
+};
+
+class TestProfilingConnectionArmnnError :public TestProfilingConnectionBase
+{
+public:
+
+ Packet ReadPacket(uint32_t timeout)
+ {
+ std::this_thread::sleep_for(std::chrono::milliseconds(timeout));
+ throw armnn::Exception(": Simulate a non timeout error");
+ }
+};
+
+BOOST_AUTO_TEST_CASE(CheckCommandThread)
+{
+ PacketVersionResolver packetVersionResolver;
+ ProfilingStateMachine profilingStateMachine;
+
+ TestProfilingConnectionBase testProfilingConnectionBase;
+ TestProfilingConnectionTimeoutError testProfilingConnectionTimeOutError;
+ TestProfilingConnectionArmnnError testProfilingConnectionArmnnError;
+
+ ConnectionAcknowledgedCommandHandler connectionAcknowledgedCommandHandler(1, 4194304, profilingStateMachine);
+ CommandHandlerRegistry commandHandlerRegistry;
+
+ commandHandlerRegistry.RegisterFunctor(&connectionAcknowledgedCommandHandler, 1, 4194304);
+
+ profilingStateMachine.TransitionToState(ProfilingState::NotConnected);
+ profilingStateMachine.TransitionToState(ProfilingState::WaitingForAck);
+
+ CommandThread commandThread0(1,
+ true,
+ commandHandlerRegistry,
+ packetVersionResolver,
+ testProfilingConnectionBase);
+
+ commandThread0.Start();
+ commandThread0.Start();
+ commandThread0.Start();
+
+ commandThread0.Stop();
+ commandThread0.Join();
+
+ BOOST_CHECK(profilingStateMachine.GetCurrentState() == ProfilingState::Active);
+
+ profilingStateMachine.TransitionToState(ProfilingState::NotConnected);
+ profilingStateMachine.TransitionToState(ProfilingState::WaitingForAck);
+ //commandThread1 should give up after one timeout
+ CommandThread commandThread1(1,
+ true,
+ commandHandlerRegistry,
+ packetVersionResolver,
+ testProfilingConnectionTimeOutError);
+
+ commandThread1.Start();
+ commandThread1.Join();
+
+ BOOST_CHECK(profilingStateMachine.GetCurrentState() == ProfilingState::WaitingForAck);
+ //now commandThread1 should persist after a timeout
+ commandThread1.StopAfterTimeout(false);
+ commandThread1.Start();
+
+ for (int i = 0; i < 100; i++)
+ {
+ if (profilingStateMachine.GetCurrentState() == ProfilingState::Active)
+ {
+ break;
+ }
+ else
+ {
+ std::this_thread::sleep_for(std::chrono::milliseconds(5));
+ }
+ }
+
+ commandThread1.Stop();
+ commandThread1.Join();
+
+ BOOST_CHECK(profilingStateMachine.GetCurrentState() == ProfilingState::Active);
+
+
+ CommandThread commandThread2(1,
+ false,
+ commandHandlerRegistry,
+ packetVersionResolver,
+ testProfilingConnectionArmnnError);
+
+ commandThread2.Start();
+
+ for (int i = 0; i < 100; i++)
+ {
+ if (!commandThread2.IsRunning())
+ {
+ //commandThread2 should stop once it encounters a non timing error
+ commandThread2.Join();
+ return;
+ }
+ std::this_thread::sleep_for(std::chrono::milliseconds(5));
+ }
+
+ BOOST_ERROR("commandThread2 has failed to stop");
+}
+
BOOST_AUTO_TEST_CASE(CheckEncodeVersion)
{
Version version1(12);