aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatteo Martincigh <matteo.martincigh@arm.com>2019-10-09 16:47:04 +0100
committerMatteo Martincigh <matteo.martincigh@arm.com>2019-10-09 17:44:11 +0100
commitd0613b56cea7eba0604e0548bddffd773a4eb554 (patch)
tree18e5a28c346018340910c456eedd56717ab01c9c
parent09ca49cdcfbe377da979a19df9bcdb7cbffc7b50 (diff)
downloadarmnn-d0613b56cea7eba0604e0548bddffd773a4eb554.tar.gz
IVGCVSW-3937 Improve the Connection Acknowledged Handler
* The Connection Acknowledged Handler should report an error is it's called while in a wrong state * Stopping the threads in the ProfilingService before having to start them again * Updated the unit tests to check the changes * Removed unnecessary Packet.cpp file * Fixed memory leak Signed-off-by: Matteo Martincigh <matteo.martincigh@arm.com> Change-Id: I8c4d33b4d97994df86fe6c9f8c659f880ec64c16
-rw-r--r--Android.mk1
-rw-r--r--CMakeLists.txt3
-rw-r--r--src/profiling/CommandHandler.cpp23
-rw-r--r--src/profiling/ConnectionAcknowledgedCommandHandler.cpp35
-rw-r--r--src/profiling/Packet.cpp51
-rw-r--r--src/profiling/Packet.hpp39
-rw-r--r--src/profiling/ProfilingService.cpp28
-rw-r--r--src/profiling/ProfilingService.hpp6
-rw-r--r--src/profiling/RequestCounterDirectoryCommandHandler.cpp4
-rw-r--r--src/profiling/SendCounterPacket.cpp2
-rw-r--r--src/profiling/test/ProfilingTests.cpp239
-rw-r--r--src/profiling/test/ProfilingTests.hpp200
-rw-r--r--src/profiling/test/SendCounterPacketTests.cpp6
-rw-r--r--src/profiling/test/SendCounterPacketTests.hpp49
14 files changed, 381 insertions, 305 deletions
diff --git a/Android.mk b/Android.mk
index fcbab689ce..108e01107a 100644
--- a/Android.mk
+++ b/Android.mk
@@ -179,7 +179,6 @@ LOCAL_SRC_FILES := \
src/profiling/CounterDirectory.cpp \
src/profiling/Holder.cpp \
src/profiling/PacketBuffer.cpp \
- src/profiling/Packet.cpp \
src/profiling/PacketVersionResolver.cpp \
src/profiling/PeriodicCounterCapture.cpp \
src/profiling/PeriodicCounterSelectionCommandHandler.cpp \
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 3b27d05be5..a4c8fc980f 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -451,7 +451,6 @@ list(APPEND armnn_sources
src/profiling/IPeriodicCounterCapture.hpp
src/profiling/IProfilingConnection.hpp
src/profiling/IProfilingConnectionFactory.hpp
- src/profiling/Packet.cpp
src/profiling/Packet.hpp
src/profiling/PacketBuffer.cpp
src/profiling/PacketBuffer.hpp
@@ -599,7 +598,9 @@ if(BUILD_UNIT_TESTS)
src/profiling/test/BufferTests.cpp
src/profiling/test/ProfilingConnectionDumpToFileDecoratorTests.cpp
src/profiling/test/ProfilingTests.cpp
+ src/profiling/test/ProfilingTests.hpp
src/profiling/test/SendCounterPacketTests.cpp
+ src/profiling/test/SendCounterPacketTests.hpp
src/profiling/test/TimelinePacketTests.cpp
)
diff --git a/src/profiling/CommandHandler.cpp b/src/profiling/CommandHandler.cpp
index 86fa2571df..cc68dcf74d 100644
--- a/src/profiling/CommandHandler.cpp
+++ b/src/profiling/CommandHandler.cpp
@@ -5,6 +5,8 @@
#include "CommandHandler.hpp"
+#include <boost/log/trivial.hpp>
+
namespace armnn
{
@@ -39,7 +41,14 @@ void CommandHandler::HandleCommands(IProfilingConnection& profilingConnection)
{
try
{
- Packet packet = profilingConnection.ReadPacket(m_Timeout);
+ Packet packet = profilingConnection.ReadPacket(m_Timeout.load());
+
+ if (packet.IsEmpty())
+ {
+ // Nothing to do, continue
+ continue;
+ }
+
Version version = m_PacketVersionResolver.ResolvePacketVersion(packet.GetPacketId());
CommandHandlerFunctor* commandHandlerFunctor =
@@ -49,19 +58,15 @@ void CommandHandler::HandleCommands(IProfilingConnection& profilingConnection)
}
catch (const armnn::TimeoutException&)
{
- if (m_StopAfterTimeout)
+ if (m_StopAfterTimeout.load())
{
- m_KeepRunning.store(false, std::memory_order_relaxed);
+ m_KeepRunning.store(false);
}
}
catch (const Exception& e)
{
- // Log the error
- BOOST_LOG_TRIVIAL(warning) << "An error has occurred when handling a command: "
- << e.what();
-
- // Might want to differentiate the errors more
- m_KeepRunning.store(false);
+ // Log the error and continue
+ BOOST_LOG_TRIVIAL(warning) << "An error has occurred when handling a command: " << e.what() << std::endl;
}
}
while (m_KeepRunning.load());
diff --git a/src/profiling/ConnectionAcknowledgedCommandHandler.cpp b/src/profiling/ConnectionAcknowledgedCommandHandler.cpp
index f90b601b7e..9d2d1a2bd2 100644
--- a/src/profiling/ConnectionAcknowledgedCommandHandler.cpp
+++ b/src/profiling/ConnectionAcknowledgedCommandHandler.cpp
@@ -7,6 +7,8 @@
#include <armnn/Exceptions.hpp>
+#include <boost/format.hpp>
+
namespace armnn
{
@@ -15,15 +17,34 @@ namespace profiling
void ConnectionAcknowledgedCommandHandler::operator()(const Packet& packet)
{
- if (!(packet.GetPacketFamily() == 0u && packet.GetPacketId() == 1u))
+ ProfilingState currentState = m_StateMachine.GetCurrentState();
+ switch (currentState)
{
- throw armnn::InvalidArgumentException(std::string("Expected Packet family = 0, id = 1 but received family = ")
- + std::to_string(packet.GetPacketFamily())
- + " id = " + std::to_string(packet.GetPacketId()));
+ case ProfilingState::Uninitialised:
+ case ProfilingState::NotConnected:
+ throw RuntimeException(boost::str(boost::format("Connection Acknowledged Handler invoked while in an "
+ "wrong state: %1%")
+ % GetProfilingStateName(currentState)));
+ case ProfilingState::WaitingForAck:
+ // Process the packet
+ if (!(packet.GetPacketFamily() == 0u && packet.GetPacketId() == 1u))
+ {
+ throw armnn::InvalidArgumentException(boost::str(boost::format("Expected Packet family = 0, id = 1 but "
+ "received family = %1%, id = %2%")
+ % packet.GetPacketFamily()
+ % packet.GetPacketId()));
+ }
+
+ // Once a Connection Acknowledged packet has been received, move to the Active state immediately
+ m_StateMachine.TransitionToState(ProfilingState::Active);
+
+ break;
+ case ProfilingState::Active:
+ return; // NOP
+ default:
+ throw RuntimeException(boost::str(boost::format("Unknown profiling service state: %1%")
+ % static_cast<int>(currentState)));
}
-
- // Once a Connection Acknowledged packet has been received, move to the Active state immediately
- m_StateMachine.TransitionToState(ProfilingState::Active);
}
} // namespace profiling
diff --git a/src/profiling/Packet.cpp b/src/profiling/Packet.cpp
deleted file mode 100644
index 4cfa42bbc9..0000000000
--- a/src/profiling/Packet.cpp
+++ /dev/null
@@ -1,51 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#include "Packet.hpp"
-
-namespace armnn
-{
-
-namespace profiling
-{
-
-std::uint32_t Packet::GetHeader() const
-{
- return m_Header;
-}
-
-std::uint32_t Packet::GetPacketFamily() const
-{
- return m_PacketFamily;
-}
-
-std::uint32_t Packet::GetPacketId() const
-{
- return m_PacketId;
-}
-
-std::uint32_t Packet::GetLength() const
-{
- return m_Length;
-}
-
-const char* const Packet::GetData() const
-{
- return m_Data.get();
-}
-
-std::uint32_t Packet::GetPacketClass() const
-{
- return (m_PacketId >> 3);
-}
-
-std::uint32_t Packet::GetPacketType() const
-{
- return (m_PacketId & 7);
-}
-
-} // namespace profiling
-
-} // namespace armnn
diff --git a/src/profiling/Packet.hpp b/src/profiling/Packet.hpp
index 2aae14b741..fae368b64e 100644
--- a/src/profiling/Packet.hpp
+++ b/src/profiling/Packet.hpp
@@ -2,11 +2,12 @@
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
+
#pragma once
#include <armnn/Exceptions.hpp>
-#include <boost/log/trivial.hpp>
+#include <memory>
namespace armnn
{
@@ -46,26 +47,32 @@ public:
}
}
- Packet(Packet&& other) :
- m_Header(other.m_Header),
- m_PacketFamily(other.m_PacketFamily),
- m_PacketId(other.m_PacketId),
- m_Length(other.m_Length),
- m_Data(std::move(other.m_Data))
- {}
+ Packet(Packet&& other)
+ : m_Header(other.m_Header)
+ , m_PacketFamily(other.m_PacketFamily)
+ , m_PacketId(other.m_PacketId)
+ , m_Length(other.m_Length)
+ , m_Data(std::move(other.m_Data))
+ {
+ other.m_Header = 0;
+ other.m_PacketFamily = 0;
+ other.m_PacketId = 0;
+ other.m_Length = 0;
+ }
+
+ ~Packet() = default;
Packet(const Packet& other) = delete;
Packet& operator=(const Packet&) = delete;
Packet& operator=(Packet&&) = default;
- uint32_t GetHeader() const;
- uint32_t GetPacketFamily() const;
- uint32_t GetPacketId() const;
- uint32_t GetLength() const;
- const char* const GetData() const;
-
- uint32_t GetPacketClass() const;
- uint32_t GetPacketType() const;
+ uint32_t GetHeader() const { return m_Header; }
+ uint32_t GetPacketFamily() const { return m_PacketFamily; }
+ uint32_t GetPacketId() const { return m_PacketId; }
+ uint32_t GetPacketClass() const { return m_PacketId >> 3; }
+ uint32_t GetPacketType() const { return m_PacketId & 7; }
+ uint32_t GetLength() const { return m_Length; }
+ const char* const GetData() const { return m_Data.get(); }
bool IsEmpty() { return m_Header == 0 && m_Length == 0; }
diff --git a/src/profiling/ProfilingService.cpp b/src/profiling/ProfilingService.cpp
index 19cf9cb58e..693f8337db 100644
--- a/src/profiling/ProfilingService.cpp
+++ b/src/profiling/ProfilingService.cpp
@@ -47,7 +47,11 @@ void ProfilingService::Update()
m_StateMachine.TransitionToState(ProfilingState::NotConnected);
break;
case ProfilingState::NotConnected:
- BOOST_ASSERT(m_ProfilingConnectionFactory);
+ // Stop the command thread (if running)
+ m_CommandHandler.Stop();
+
+ // Stop the send thread (if running)
+ m_SendCounterPacket.Stop(false);
// Reset any existing profiling connection
m_ProfilingConnection.reset();
@@ -55,13 +59,13 @@ void ProfilingService::Update()
try
{
// Setup the profiling connection
- //m_ProfilingConnection = m_ProfilingConnectionFactory.GetProfilingConnection(m_Options);
+ BOOST_ASSERT(m_ProfilingConnectionFactory);
m_ProfilingConnection = m_ProfilingConnectionFactory->GetProfilingConnection(m_Options);
}
catch (const Exception& e)
{
BOOST_LOG_TRIVIAL(warning) << "An error has occurred when creating the profiling connection: "
- << e.what();
+ << e.what() << std::endl;
}
// Move to the next state
@@ -229,13 +233,23 @@ void ProfilingService::InitializeCounterValue(uint16_t counterUid)
void ProfilingService::Reset()
{
// Reset the profiling service
- m_CounterDirectory.Clear();
+
+ // The order in which we reset/stop the components is not trivial!
+
+ // First stop the threads (Command Handler first)...
+ m_CommandHandler.Stop();
+ m_SendCounterPacket.Stop(false);
+
+ // ...then destroy the profiling connection...
m_ProfilingConnection.reset();
- m_StateMachine.Reset();
+
+ // ...then delete all the counter data and configuration...
m_CounterIndex.clear();
m_CounterValues.clear();
- m_CommandHandler.Stop();
- m_SendCounterPacket.Stop(false);
+ m_CounterDirectory.Clear();
+
+ // ...finally reset the profiling state machine
+ m_StateMachine.Reset();
}
} // namespace profiling
diff --git a/src/profiling/ProfilingService.hpp b/src/profiling/ProfilingService.hpp
index 50a938e33d..edeb6bde90 100644
--- a/src/profiling/ProfilingService.hpp
+++ b/src/profiling/ProfilingService.hpp
@@ -109,7 +109,7 @@ protected:
}
~ProfilingService() = default;
- // Protected method for testing
+ // Protected methods for testing
void SwapProfilingConnectionFactory(ProfilingService& instance,
IProfilingConnectionFactory* other,
IProfilingConnectionFactory*& backup)
@@ -120,6 +120,10 @@ protected:
backup = instance.m_ProfilingConnectionFactory.release();
instance.m_ProfilingConnectionFactory.reset(other);
}
+ IProfilingConnection* GetProfilingConnection(ProfilingService& instance)
+ {
+ return instance.m_ProfilingConnection.get();
+ }
};
} // namespace profiling
diff --git a/src/profiling/RequestCounterDirectoryCommandHandler.cpp b/src/profiling/RequestCounterDirectoryCommandHandler.cpp
index f186add357..0fdcf10de4 100644
--- a/src/profiling/RequestCounterDirectoryCommandHandler.cpp
+++ b/src/profiling/RequestCounterDirectoryCommandHandler.cpp
@@ -5,6 +5,8 @@
#include "RequestCounterDirectoryCommandHandler.hpp"
+#include <boost/assert.hpp>
+
namespace armnn
{
@@ -21,4 +23,4 @@ void RequestCounterDirectoryCommandHandler::operator()(const Packet& packet)
} // namespace profiling
-} // namespace armnn \ No newline at end of file
+} // namespace armnn
diff --git a/src/profiling/SendCounterPacket.cpp b/src/profiling/SendCounterPacket.cpp
index b9f2b187b7..e48da3ed7c 100644
--- a/src/profiling/SendCounterPacket.cpp
+++ b/src/profiling/SendCounterPacket.cpp
@@ -945,7 +945,7 @@ void SendCounterPacket::Stop(bool rethrowSendThreadExceptions)
// Exception handling lock scope - Begin
{
// Lock the mutex to handle any exception coming from the send thread
- std::unique_lock<std::mutex> lock(m_WaitMutex);
+ std::lock_guard<std::mutex> lock(m_WaitMutex);
// Check if there's an exception to rethrow
if (m_SendThreadException)
diff --git a/src/profiling/test/ProfilingTests.cpp b/src/profiling/test/ProfilingTests.cpp
index de92fb9eb0..80d99dd7ab 100644
--- a/src/profiling/test/ProfilingTests.cpp
+++ b/src/profiling/test/ProfilingTests.cpp
@@ -3,11 +3,10 @@
// SPDX-License-Identifier: MIT
//
-#include "SendCounterPacketTests.hpp"
+#include "ProfilingTests.hpp"
#include <CommandHandler.hpp>
#include <CommandHandlerKey.hpp>
-#include <CommandHandlerFunctor.hpp>
#include <CommandHandlerRegistry.hpp>
#include <ConnectionAcknowledgedCommandHandler.hpp>
#include <CounterDirectory.hpp>
@@ -19,7 +18,6 @@
#include <PeriodicCounterCapture.hpp>
#include <PeriodicCounterSelectionCommandHandler.hpp>
#include <ProfilingStateMachine.hpp>
-#include <ProfilingService.hpp>
#include <ProfilingUtils.hpp>
#include <RequestCounterDirectoryCommandHandler.hpp>
#include <Runtime.hpp>
@@ -27,21 +25,16 @@
#include <armnn/Conversion.hpp>
-#include <Logging.hpp>
#include <armnn/Utils.hpp>
#include <boost/algorithm/string.hpp>
#include <boost/numeric/conversion/cast.hpp>
-#include <boost/test/unit_test.hpp>
#include <cstdint>
#include <cstring>
-#include <iostream>
#include <limits>
#include <map>
#include <random>
-#include <thread>
-#include <chrono>
using namespace armnn::profiling;
@@ -94,59 +87,6 @@ BOOST_AUTO_TEST_CASE(CheckCommandHandlerKeyComparisons)
BOOST_CHECK(vect == expectedVect);
}
-class TestProfilingConnectionBase :public IProfilingConnection
-{
-public:
- TestProfilingConnectionBase() = default;
- ~TestProfilingConnectionBase() = default;
-
- bool IsOpen() const override { return true; }
-
- void Close() override {}
-
- bool WritePacket(const unsigned char* buffer, uint32_t length) override { return false; }
-
- Packet ReadPacket(uint32_t timeout) override
- {
- std::this_thread::sleep_for(std::chrono::milliseconds(timeout));
- std::unique_ptr<char[]> packetData;
-
- // Return connection acknowledged packet
- return { 65536, 0, packetData };
- }
-};
-
-class TestProfilingConnectionTimeoutError : public TestProfilingConnectionBase
-{
-public:
- Packet ReadPacket(uint32_t timeout) {
- if (readRequests < 3)
- {
- readRequests++;
- throw armnn::TimeoutException("Simulate a timeout");
- }
- std::this_thread::sleep_for(std::chrono::milliseconds(timeout));
- std::unique_ptr<char[]> packetData;
-
- // Return connection acknowledged packet after three timeouts
- return { 65536, 0, packetData };
- }
-
-private:
- int readRequests = 0;
-};
-
-class TestProfilingConnectionArmnnError :public TestProfilingConnectionBase
-{
-public:
-
- Packet ReadPacket(uint32_t timeout)
- {
- std::this_thread::sleep_for(std::chrono::milliseconds(timeout));
- throw armnn::Exception(": Simulate a non timeout error");
- }
-};
-
BOOST_AUTO_TEST_CASE(CheckCommandHandler)
{
PacketVersionResolver packetVersionResolver;
@@ -180,7 +120,7 @@ BOOST_AUTO_TEST_CASE(CheckCommandHandler)
profilingStateMachine.TransitionToState(ProfilingState::NotConnected);
profilingStateMachine.TransitionToState(ProfilingState::WaitingForAck);
// commandHandler1 should give up after one timeout
- CommandHandler commandHandler1(1,
+ CommandHandler commandHandler1(10,
true,
commandHandlerRegistry,
packetVersionResolver);
@@ -204,32 +144,24 @@ BOOST_AUTO_TEST_CASE(CheckCommandHandler)
break;
}
- std::this_thread::sleep_for(std::chrono::milliseconds(5));
+ std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
commandHandler1.Stop();
BOOST_CHECK(profilingStateMachine.GetCurrentState() == ProfilingState::Active);
- CommandHandler commandHandler2(1,
+ CommandHandler commandHandler2(100,
false,
commandHandlerRegistry,
packetVersionResolver);
commandHandler2.Start(testProfilingConnectionArmnnError);
- for (int i = 0; i < 100; i++)
- {
- if (!commandHandler2.IsRunning())
- {
- // commandHandler2 should stop once it encounters a non timing error
- return;
- }
-
- std::this_thread::sleep_for(std::chrono::milliseconds(5));
- }
+ // commandHandler2 should not stop once it encounters a non timing error
+ std::this_thread::sleep_for(std::chrono::milliseconds(500));
- BOOST_ERROR("commandHandler2 has failed to stop");
+ BOOST_CHECK(commandHandler2.IsRunning());
commandHandler2.Stop();
}
@@ -300,33 +232,6 @@ BOOST_AUTO_TEST_CASE(CheckPacketClass)
BOOST_CHECK(packetTest4.GetPacketClass() == 5);
}
-// Create Derived Classes
-class TestFunctorA : public CommandHandlerFunctor
-{
-public:
- using CommandHandlerFunctor::CommandHandlerFunctor;
-
- int GetCount() { return m_Count; }
-
- void operator()(const Packet& packet) override
- {
- m_Count++;
- }
-
-private:
- int m_Count = 0;
-};
-
-class TestFunctorB : public TestFunctorA
-{
- using TestFunctorA::TestFunctorA;
-};
-
-class TestFunctorC : public TestFunctorA
-{
- using TestFunctorA::TestFunctorA;
-};
-
BOOST_AUTO_TEST_CASE(CheckCommandHandlerFunctor)
{
// Hard code the version as it will be the same during a single profiling session
@@ -455,6 +360,7 @@ BOOST_AUTO_TEST_CASE(CheckPacketVersionResolver)
BOOST_TEST(resolvedVersion == expectedVersion);
}
}
+
void ProfilingCurrentStateThreadImpl(ProfilingStateMachine& states)
{
ProfilingState newState = ProfilingState::NotConnected;
@@ -664,32 +570,6 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceDisabled)
BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
}
-struct LogLevelSwapper
-{
-public:
- LogLevelSwapper(armnn::LogSeverity severity)
- {
- // Set the new log level
- armnn::ConfigureLogging(true, true, severity);
- }
- ~LogLevelSwapper()
- {
- // The default log level for unit tests is "Fatal"
- armnn::ConfigureLogging(true, true, armnn::LogSeverity::Fatal);
- }
-};
-
-struct CoutRedirect
-{
-public:
- CoutRedirect(std::streambuf* newStreamBuffer)
- : old(std::cout.rdbuf(newStreamBuffer)) {}
- ~CoutRedirect() { std::cout.rdbuf(old); }
-
-private:
- std::streambuf* old;
-};
-
BOOST_AUTO_TEST_CASE(CheckProfilingServiceEnabled)
{
// Locally reduce log level to "Warning", as this test needs to parse a warning message from the standard output
@@ -705,7 +585,7 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceEnabled)
// Redirect the output to a local stream so that we can parse the warning message
std::stringstream ss;
- CoutRedirect coutRedirect(ss.rdbuf());
+ StreamRedirector streamRedirector(std::cout, ss.rdbuf());
profilingService.Update();
BOOST_CHECK(boost::contains(ss.str(), "Cannot connect to stream socket: Connection refused"));
}
@@ -729,7 +609,7 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceEnabledRuntime)
// Redirect the output to a local stream so that we can parse the warning message
std::stringstream ss;
- CoutRedirect coutRedirect(ss.rdbuf());
+ StreamRedirector streamRedirector(std::cout, ss.rdbuf());
profilingService.Update();
BOOST_CHECK(boost::contains(ss.str(), "Cannot connect to stream socket: Connection refused"));
}
@@ -1949,16 +1829,18 @@ BOOST_AUTO_TEST_CASE(CheckConnectionAcknowledged)
profilingState.TransitionToState(ProfilingState::WaitingForAck);
BOOST_CHECK(profilingState.GetCurrentState() == ProfilingState::WaitingForAck);
// command handler received packet on ProfilingState::WaitingForAck
- commandHandler(packetA);
+ BOOST_CHECK_NO_THROW(commandHandler(packetA));
BOOST_CHECK(profilingState.GetCurrentState() == ProfilingState::Active);
// command handler received packet on ProfilingState::Active
- commandHandler(packetA);
+ BOOST_CHECK_NO_THROW(commandHandler(packetA));
BOOST_CHECK(profilingState.GetCurrentState() == ProfilingState::Active);
// command handler received different packet
const uint32_t differentPacketId = 0x40000;
Packet packetB(differentPacketId, dataLength1, uniqueData1);
+ profilingState.TransitionToState(ProfilingState::NotConnected);
+ profilingState.TransitionToState(ProfilingState::WaitingForAck);
ConnectionAcknowledgedCommandHandler differentCommandHandler(differentPacketId, version, profilingState);
BOOST_CHECK_THROW(differentCommandHandler(packetB), armnn::Exception);
}
@@ -2333,62 +2215,17 @@ BOOST_AUTO_TEST_CASE(RequestCounterDirectoryCommandHandlerTest1)
BOOST_TEST(categoryRecordOffset == 44);
}
-class MockProfilingConnectionFactory : public IProfilingConnectionFactory
-{
-public:
- MockProfilingConnectionFactory()
- : m_MockProfilingConnection(new MockProfilingConnection())
- {}
-
- IProfilingConnectionPtr GetProfilingConnection(const ExternalProfilingOptions& options) const override
- {
- return std::unique_ptr<MockProfilingConnection>(m_MockProfilingConnection);
- }
-
- MockProfilingConnection* GetMockProfilingConnection() { return m_MockProfilingConnection; }
-
-private:
- MockProfilingConnection* m_MockProfilingConnection;
-};
-
-class SwapProfilingConnectionFactoryHelper : public ProfilingService
-{
-public:
- SwapProfilingConnectionFactoryHelper()
- : ProfilingService()
- , m_MockProfilingConnectionFactory(new MockProfilingConnectionFactory())
- , m_BackupProfilingConnectionFactory(nullptr)
- {
- SwapProfilingConnectionFactory(ProfilingService::Instance(),
- m_MockProfilingConnectionFactory.get(),
- m_BackupProfilingConnectionFactory);
- }
- ~SwapProfilingConnectionFactoryHelper()
- {
- IProfilingConnectionFactory* temp = nullptr;
- SwapProfilingConnectionFactory(ProfilingService::Instance(),
- m_BackupProfilingConnectionFactory,
- temp);
- }
-
- IProfilingConnectionFactory* GetMockProfilingConnectionFactory() { return m_MockProfilingConnectionFactory.get(); }
-
-private:
- IProfilingConnectionFactoryPtr m_MockProfilingConnectionFactory;
- IProfilingConnectionFactory* m_BackupProfilingConnectionFactory;
-};
-
BOOST_AUTO_TEST_CASE(CheckProfilingServiceBadConnectionAcknowledgedPacket)
{
// Locally reduce log level to "Warning", as this test needs to parse a warning message from the standard output
LogLevelSwapper logLevelSwapper(armnn::LogSeverity::Warning);
+ // Swap the profiling connection factory in the profiling service instance with our mock one
SwapProfilingConnectionFactoryHelper helper;
- MockProfilingConnectionFactory* mockProfilingConnectionFactory =
- boost::polymorphic_downcast<MockProfilingConnectionFactory*>(helper.GetMockProfilingConnectionFactory());
- BOOST_CHECK(mockProfilingConnectionFactory);
- MockProfilingConnection* mockProfilingConnection = mockProfilingConnectionFactory->GetMockProfilingConnection();
- BOOST_CHECK(mockProfilingConnection);
+
+ // Redirect the standard output to a local stream so that we can parse the warning message
+ std::stringstream ss;
+ StreamRedirector streamRedirector(std::cout, ss.rdbuf());
// Calculate the size of a Stream Metadata packet
std::string processName = GetProcessName().substr(0, 60);
@@ -2408,15 +2245,15 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceBadConnectionAcknowledgedPacket)
BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
profilingService.Update();
- // Redirect the output to a local stream so that we can parse the warning message
- std::stringstream ss;
- CoutRedirect coutRedirect(ss.rdbuf());
-
// Wait for a bit to make sure that we get the packet
std::this_thread::sleep_for(std::chrono::milliseconds(100));
+ // Get the mock profiling connection
+ MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
+ BOOST_CHECK(mockProfilingConnection);
+
// Check that the mock profiling connection contains one Stream Metadata packet
- const std::vector<uint32_t>& writtenData = mockProfilingConnection->GetWrittenData();
+ const std::vector<uint32_t> writtenData = mockProfilingConnection->GetWrittenData();
BOOST_TEST(writtenData.size() == 1);
BOOST_TEST(writtenData[0] == streamMetadataPacketsize);
@@ -2433,7 +2270,7 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceBadConnectionAcknowledgedPacket)
uint32_t header = ((packetFamily & 0x0000003F) << 26) |
((packetId & 0x000003FF) << 16);
- // Connection Acknowledged Packet
+ // Create the Connection Acknowledged Packet
Packet connectionAcknowledgedPacket(header);
// Write the packet to the mock profiling connection
@@ -2441,23 +2278,23 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceBadConnectionAcknowledgedPacket)
// Wait for a bit (must at least be the delay value of the mock profiling connection) to make sure that
// the Connection Acknowledged packet gets processed by the profiling service
- std::this_thread::sleep_for(std::chrono::seconds(1));
+ std::this_thread::sleep_for(std::chrono::seconds(2));
// Check that the expected error has occurred and logged to the standard output
BOOST_CHECK(boost::contains(ss.str(), "Functor with requested PacketId=37 and Version=4194304 does not exist"));
// The Connection Acknowledged Command Handler should not have updated the profiling state
BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
+
+ // Reset the profiling service to stop any running thread
+ options.m_EnableProfiling = false;
+ profilingService.ResetExternalProfilingOptions(options, true);
}
BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodConnectionAcknowledgedPacket)
{
+ // Swap the profiling connection factory in the profiling service instance with our mock one
SwapProfilingConnectionFactoryHelper helper;
- MockProfilingConnectionFactory* mockProfilingConnectionFactory =
- boost::polymorphic_downcast<MockProfilingConnectionFactory*>(helper.GetMockProfilingConnectionFactory());
- BOOST_CHECK(mockProfilingConnectionFactory);
- MockProfilingConnection* mockProfilingConnection = mockProfilingConnectionFactory->GetMockProfilingConnection();
- BOOST_CHECK(mockProfilingConnection);
// Calculate the size of a Stream Metadata packet
std::string processName = GetProcessName().substr(0, 60);
@@ -2480,8 +2317,12 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodConnectionAcknowledgedPacket)
// Wait for a bit to make sure that we get the packet
std::this_thread::sleep_for(std::chrono::milliseconds(100));
+ // Get the mock profiling connection
+ MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
+ BOOST_CHECK(mockProfilingConnection);
+
// Check that the mock profiling connection contains one Stream Metadata packet
- const std::vector<uint32_t>& writtenData = mockProfilingConnection->GetWrittenData();
+ const std::vector<uint32_t> writtenData = mockProfilingConnection->GetWrittenData();
BOOST_TEST(writtenData.size() == 1);
BOOST_TEST(writtenData[0] == streamMetadataPacketsize);
@@ -2498,7 +2339,7 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodConnectionAcknowledgedPacket)
uint32_t header = ((packetFamily & 0x0000003F) << 26) |
((packetId & 0x000003FF) << 16);
- // Connection Acknowledged Packet
+ // Create the Connection Acknowledged Packet
Packet connectionAcknowledgedPacket(header);
// Write the packet to the mock profiling connection
@@ -2506,10 +2347,14 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodConnectionAcknowledgedPacket)
// Wait for a bit (must at least be the delay value of the mock profiling connection) to make sure that
// the Connection Acknowledged packet gets processed by the profiling service
- std::this_thread::sleep_for(std::chrono::seconds(1));
+ std::this_thread::sleep_for(std::chrono::seconds(2));
// The Connection Acknowledged Command Handler should have updated the profiling state accordingly
BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
+
+ // Reset the profiling service to stop any running thread
+ options.m_EnableProfiling = false;
+ profilingService.ResetExternalProfilingOptions(options, true);
}
BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/profiling/test/ProfilingTests.hpp b/src/profiling/test/ProfilingTests.hpp
new file mode 100644
index 0000000000..3e6cf63efe
--- /dev/null
+++ b/src/profiling/test/ProfilingTests.hpp
@@ -0,0 +1,200 @@
+//
+// Copyright © 2019 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "SendCounterPacketTests.hpp"
+
+#include <CommandHandlerFunctor.hpp>
+#include <IProfilingConnection.hpp>
+#include <IProfilingConnectionFactory.hpp>
+#include <Logging.hpp>
+#include <ProfilingService.hpp>
+
+#include <boost/test/unit_test.hpp>
+
+#include <chrono>
+#include <iostream>
+#include <thread>
+
+namespace armnn
+{
+
+namespace profiling
+{
+
+struct LogLevelSwapper
+{
+public:
+ LogLevelSwapper(armnn::LogSeverity severity)
+ {
+ // Set the new log level
+ armnn::ConfigureLogging(true, true, severity);
+ }
+ ~LogLevelSwapper()
+ {
+ // The default log level for unit tests is "Fatal"
+ armnn::ConfigureLogging(true, true, armnn::LogSeverity::Fatal);
+ }
+};
+
+struct CoutRedirect
+{
+public:
+ CoutRedirect(std::streambuf* newStreamBuffer)
+ : m_Old(std::cout.rdbuf(newStreamBuffer)) {}
+ ~CoutRedirect() { std::cout.rdbuf(m_Old); }
+
+private:
+ std::streambuf* m_Old;
+};
+
+struct StreamRedirector
+{
+public:
+ StreamRedirector(std::ostream& stream, std::streambuf* newStreamBuffer)
+ : m_Stream(stream)
+ , m_BackupBuffer(m_Stream.rdbuf(newStreamBuffer))
+ {}
+ ~StreamRedirector() { m_Stream.rdbuf(m_BackupBuffer); }
+
+private:
+ std::ostream& m_Stream;
+ std::streambuf* m_BackupBuffer;
+};
+
+class TestProfilingConnectionBase : public IProfilingConnection
+{
+public:
+ TestProfilingConnectionBase() = default;
+ ~TestProfilingConnectionBase() = default;
+
+ bool IsOpen() const override { return true; }
+
+ void Close() override {}
+
+ bool WritePacket(const unsigned char* buffer, uint32_t length) override { return false; }
+
+ Packet ReadPacket(uint32_t timeout) override
+ {
+ std::this_thread::sleep_for(std::chrono::milliseconds(timeout));
+
+ // Return connection acknowledged packet
+ std::unique_ptr<char[]> packetData;
+ return Packet(65536, 0, packetData);
+ }
+};
+
+class TestProfilingConnectionTimeoutError : public TestProfilingConnectionBase
+{
+public:
+ TestProfilingConnectionTimeoutError()
+ : m_ReadRequests(0)
+ {}
+
+ Packet ReadPacket(uint32_t timeout) override
+ {
+ std::this_thread::sleep_for(std::chrono::milliseconds(timeout));
+
+ if (m_ReadRequests < 3)
+ {
+ m_ReadRequests++;
+ throw armnn::TimeoutException("Simulate a timeout error\n");
+ }
+
+ // Return connection acknowledged packet after three timeouts
+ std::unique_ptr<char[]> packetData;
+ return Packet(65536, 0, packetData);
+ }
+
+private:
+ int m_ReadRequests;
+};
+
+class TestProfilingConnectionArmnnError : public TestProfilingConnectionBase
+{
+public:
+ Packet ReadPacket(uint32_t timeout) override
+ {
+ std::this_thread::sleep_for(std::chrono::milliseconds(timeout));
+
+ throw armnn::Exception("Simulate a non-timeout error");
+ }
+};
+
+class TestFunctorA : public CommandHandlerFunctor
+{
+public:
+ using CommandHandlerFunctor::CommandHandlerFunctor;
+
+ int GetCount() { return m_Count; }
+
+ void operator()(const Packet& packet) override
+ {
+ m_Count++;
+ }
+
+private:
+ int m_Count = 0;
+};
+
+class TestFunctorB : public TestFunctorA
+{
+ using TestFunctorA::TestFunctorA;
+};
+
+class TestFunctorC : public TestFunctorA
+{
+ using TestFunctorA::TestFunctorA;
+};
+
+class MockProfilingConnectionFactory : public IProfilingConnectionFactory
+{
+public:
+ IProfilingConnectionPtr GetProfilingConnection(const ExternalProfilingOptions& options) const override
+ {
+ return std::make_unique<MockProfilingConnection>();
+ }
+};
+
+class SwapProfilingConnectionFactoryHelper : public ProfilingService
+{
+public:
+ using MockProfilingConnectionFactoryPtr = std::unique_ptr<MockProfilingConnectionFactory>;
+
+ SwapProfilingConnectionFactoryHelper()
+ : ProfilingService()
+ , m_MockProfilingConnectionFactory(new MockProfilingConnectionFactory())
+ , m_BackupProfilingConnectionFactory(nullptr)
+ {
+ BOOST_CHECK(m_MockProfilingConnectionFactory);
+ SwapProfilingConnectionFactory(ProfilingService::Instance(),
+ m_MockProfilingConnectionFactory.get(),
+ m_BackupProfilingConnectionFactory);
+ BOOST_CHECK(m_BackupProfilingConnectionFactory);
+ }
+ ~SwapProfilingConnectionFactoryHelper()
+ {
+ BOOST_CHECK(m_BackupProfilingConnectionFactory);
+ IProfilingConnectionFactory* temp = nullptr;
+ SwapProfilingConnectionFactory(ProfilingService::Instance(),
+ m_BackupProfilingConnectionFactory,
+ temp);
+ }
+
+ MockProfilingConnection* GetMockProfilingConnection()
+ {
+ IProfilingConnection* profilingConnection = GetProfilingConnection(ProfilingService::Instance());
+ return boost::polymorphic_downcast<MockProfilingConnection*>(profilingConnection);
+ }
+
+private:
+ MockProfilingConnectionFactoryPtr m_MockProfilingConnectionFactory;
+ IProfilingConnectionFactory* m_BackupProfilingConnectionFactory;
+};
+
+} // namespace profiling
+
+} // namespace armnn
diff --git a/src/profiling/test/SendCounterPacketTests.cpp b/src/profiling/test/SendCounterPacketTests.cpp
index 1216420383..00dad38078 100644
--- a/src/profiling/test/SendCounterPacketTests.cpp
+++ b/src/profiling/test/SendCounterPacketTests.cpp
@@ -2322,7 +2322,7 @@ BOOST_AUTO_TEST_CASE(SendThreadBufferTest1)
BOOST_TEST(reservedBuffer.get());
// Check that data was actually written to the profiling connection in any order
- const std::vector<uint32_t>& writtenData = mockProfilingConnection.GetWrittenData();
+ const std::vector<uint32_t> writtenData = mockProfilingConnection.GetWrittenData();
BOOST_TEST(writtenData.size() == 3);
bool foundStreamMetaDataPacket =
std::find(writtenData.begin(), writtenData.end(), streamMetadataPacketsize) != writtenData.end();
@@ -2391,7 +2391,7 @@ BOOST_AUTO_TEST_CASE(SendThreadSendStreamMetadataPacket3)
BOOST_CHECK_NO_THROW(sendCounterPacket.Stop());
// Check that the buffer contains one Stream Metadata packet
- const std::vector<uint32_t>& writtenData = mockProfilingConnection.GetWrittenData();
+ const std::vector<uint32_t> writtenData = mockProfilingConnection.GetWrittenData();
BOOST_TEST(writtenData.size() == 1);
BOOST_TEST(writtenData[0] == streamMetadataPacketsize);
}
@@ -2420,7 +2420,7 @@ BOOST_AUTO_TEST_CASE(SendThreadSendStreamMetadataPacket4)
BOOST_TEST((profilingStateMachine.GetCurrentState() == ProfilingState::WaitingForAck));
// Check that the buffer contains one Stream Metadata packet
- const std::vector<uint32_t>& writtenData = mockProfilingConnection.GetWrittenData();
+ const std::vector<uint32_t> writtenData = mockProfilingConnection.GetWrittenData();
BOOST_TEST(writtenData.size() == 1);
BOOST_TEST(writtenData[0] == streamMetadataPacketsize);
diff --git a/src/profiling/test/SendCounterPacketTests.hpp b/src/profiling/test/SendCounterPacketTests.hpp
index 48bab025dd..871ca74124 100644
--- a/src/profiling/test/SendCounterPacketTests.hpp
+++ b/src/profiling/test/SendCounterPacketTests.hpp
@@ -12,6 +12,7 @@
#include <armnn/Optional.hpp>
#include <armnn/Conversion.hpp>
+#include <boost/assert.hpp>
#include <boost/numeric/conversion/cast.hpp>
namespace armnn
@@ -19,6 +20,7 @@ namespace armnn
namespace profiling
{
+
class MockProfilingConnection : public IProfilingConnection
{
public:
@@ -28,9 +30,19 @@ public:
, m_Packet()
{}
- bool IsOpen() const override { return m_IsOpen; }
+ bool IsOpen() const override
+ {
+ std::lock_guard<std::mutex> lock(m_Mutex);
+
+ return m_IsOpen;
+ }
+
+ void Close() override
+ {
+ std::lock_guard<std::mutex> lock(m_Mutex);
- void Close() override { m_IsOpen = false; }
+ m_IsOpen = false;
+ }
bool WritePacket(const unsigned char* buffer, uint32_t length) override
{
@@ -39,11 +51,15 @@ public:
return false;
}
+ std::lock_guard<std::mutex> lock(m_Mutex);
+
m_WrittenData.push_back(length);
return true;
}
bool WritePacket(Packet&& packet)
{
+ std::lock_guard<std::mutex> lock(m_Mutex);
+
m_Packet = std::move(packet);
return true;
}
@@ -51,19 +67,32 @@ public:
Packet ReadPacket(uint32_t timeout) override
{
// Simulate a delay in the reading process
- std::this_thread::sleep_for(std::chrono::milliseconds(500));
+ std::this_thread::sleep_for(std::chrono::milliseconds(timeout));
+
+ std::lock_guard<std::mutex> lock(m_Mutex);
return std::move(m_Packet);
}
- const std::vector<uint32_t>& GetWrittenData() const { return m_WrittenData; }
+ const std::vector<uint32_t> GetWrittenData() const
+ {
+ std::lock_guard<std::mutex> lock(m_Mutex);
+
+ return m_WrittenData;
+ }
+
+ void Clear()
+ {
+ std::lock_guard<std::mutex> lock(m_Mutex);
- void Clear() { m_WrittenData.clear(); }
+ m_WrittenData.clear();
+ }
private:
bool m_IsOpen;
std::vector<uint32_t> m_WrittenData;
Packet m_Packet;
+ mutable std::mutex m_Mutex;
};
class MockPacketBuffer : public IPacketBuffer
@@ -162,7 +191,7 @@ public:
IPacketBufferPtr Reserve(unsigned int requestedSize, unsigned int& reservedSize) override
{
- std::unique_lock<std::mutex> lock(m_Mutex);
+ std::lock_guard<std::mutex> lock(m_Mutex);
reservedSize = 0;
if (requestedSize > m_MaxBufferSize)
@@ -176,7 +205,7 @@ public:
void Commit(IPacketBufferPtr& packetBuffer, unsigned int size) override
{
- std::unique_lock<std::mutex> lock(m_Mutex);
+ std::lock_guard<std::mutex> lock(m_Mutex);
packetBuffer->Commit(size);
m_BufferList.push_back(std::move(packetBuffer));
@@ -185,14 +214,14 @@ public:
void Release(IPacketBufferPtr& packetBuffer) override
{
- std::unique_lock<std::mutex> lock(m_Mutex);
+ std::lock_guard<std::mutex> lock(m_Mutex);
packetBuffer->Release();
}
IPacketBufferPtr GetReadableBuffer() override
{
- std::unique_lock<std::mutex> lock(m_Mutex);
+ std::lock_guard<std::mutex> lock(m_Mutex);
if (m_BufferList.empty())
{
@@ -206,7 +235,7 @@ public:
void MarkRead(IPacketBufferPtr& packetBuffer) override
{
- std::unique_lock<std::mutex> lock(m_Mutex);
+ std::lock_guard<std::mutex> lock(m_Mutex);
m_ReadSize += packetBuffer->GetSize();
packetBuffer->MarkRead();